From 4919d6eeab8b1c1d56e206a5fb38bef76bfff368 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 18:52:48 +0000 Subject: [PATCH 1/5] Initial plan From 65e1d035e5b43a9286532a781d919aecd36c5702 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 19:08:20 +0000 Subject: [PATCH 2/5] Add translation scoring feature: score_translations.py, translation_scorer.py, NMTModel.score_translation() Agent-Logs-Url: https://github.com/sillsdev/silnlp/sessions/c9d942dc-1d4a-4f20-b396-c522dcb080bc Co-authored-by: benjaminking <1214233+benjaminking@users.noreply.github.com> --- silnlp/nmt/config.py | 17 +- silnlp/nmt/hugging_face_config.py | 51 +++- silnlp/nmt/score_translations.py | 143 ++++++++++ silnlp/nmt/translation_scorer.py | 261 +++++++++++++++++++ tests/smoke_tests/test_translation_scorer.py | 230 ++++++++++++++++ 5 files changed, 700 insertions(+), 2 deletions(-) create mode 100644 silnlp/nmt/score_translations.py create mode 100644 silnlp/nmt/translation_scorer.py create mode 100644 tests/smoke_tests/test_translation_scorer.py diff --git a/silnlp/nmt/config.py b/silnlp/nmt/config.py index 751f19ff..19d440ed 100644 --- a/silnlp/nmt/config.py +++ b/silnlp/nmt/config.py @@ -8,7 +8,10 @@ from enum import Enum, auto from pathlib import Path from statistics import mean, median, stdev -from typing import Dict, Generator, Iterable, List, Optional, Set, TextIO, Tuple, Union, cast +from typing import TYPE_CHECKING, Dict, Generator, Iterable, List, Optional, Set, TextIO, Tuple, Union, cast + +if TYPE_CHECKING: + from .translation_scorer import ScoredTranslation import pandas as pd from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, get_books @@ -99,6 +102,18 @@ def clear_cache(self) -> None: ... @abstractmethod def get_num_drafts(self) -> int: ... + @abstractmethod + def score_translation( + self, + source: str, + translation: str, + src_iso: str, + trg_iso: str, + ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, + low_prob_threshold: float = -3.0, + top_k_suggestions: int = 5, + ) -> "ScoredTranslation": ... + class Config(ABC): def __init__(self, exp_dir: Path, config: dict) -> None: diff --git a/silnlp/nmt/hugging_face_config.py b/silnlp/nmt/hugging_face_config.py index 0399de9c..1457adfb 100644 --- a/silnlp/nmt/hugging_face_config.py +++ b/silnlp/nmt/hugging_face_config.py @@ -11,7 +11,10 @@ from itertools import repeat from math import prod from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, TypeVar, Union, cast + +if TYPE_CHECKING: + from .translation_scorer import ScoredTranslation import datasets.utils.logging as datasets_logging import evaluate @@ -1406,6 +1409,52 @@ def clear_cache(self) -> None: self._cached_inference_model = None self._inference_model_params = None + def score_translation( + self, + source: str, + translation: str, + src_iso: str, + trg_iso: str, + ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, + low_prob_threshold: float = -3.0, + top_k_suggestions: int = 5, + ) -> "ScoredTranslation": + """Score a translation using forced decoding. + + Computes the conditional probability P(y_t | y_1, ..., y_{t-1}, x) for each + target token in the translation, groups subword tokens into words, flags + low-probability words, and provides top-k alternative suggestions for them. + + Args: + source: The source sentence. + translation: The translation to score. + src_iso: Source language ISO code. + trg_iso: Target language ISO code. + ckpt: Checkpoint to use for the model. + low_prob_threshold: Log-probability threshold below which a word is flagged. + top_k_suggestions: Number of alternative suggestions per low-probability word. + + Returns: + A ScoredTranslation with per-word scores and suggestions. + """ + from .translation_scorer import ScoredTranslation, TranslationScorer + + src_lang = self._config.data["lang_codes"].get(src_iso, src_iso) + trg_lang = self._config.data["lang_codes"].get(trg_iso, trg_iso) + inference_model_params = InferenceModelParams(ckpt, src_lang, trg_lang) + tokenizer = self._config.get_tokenizer() + if self._inference_model_params == inference_model_params and self._cached_inference_model is not None: + model = self._cached_inference_model + else: + model = self._cached_inference_model = self._create_inference_model(ckpt, tokenizer, src_lang, trg_lang) + self._inference_model_params = inference_model_params + + if isinstance(tokenizer, (NllbTokenizer, NllbTokenizerFast)): + tokenizer = PunctuationNormalizingTokenizer(tokenizer) + + scorer = TranslationScorer(model, tokenizer, low_prob_threshold, top_k_suggestions) + return scorer.score(source, translation) + def _create_training_arguments(self) -> Seq2SeqTrainingArguments: parser = HfArgumentParser(Seq2SeqTrainingArguments) args: dict = {} diff --git a/silnlp/nmt/score_translations.py b/silnlp/nmt/score_translations.py new file mode 100644 index 00000000..3b6d341d --- /dev/null +++ b/silnlp/nmt/score_translations.py @@ -0,0 +1,143 @@ +import argparse +import logging +from typing import Optional + +from .config_utils import load_config +from .translation_scorer import DEFAULT_LOW_PROB_THRESHOLD, DEFAULT_TOP_K_SUGGESTIONS, ScoredTranslation + +LOGGER = logging.getLogger(__name__) + + +def format_scored_translation(scored: ScoredTranslation) -> str: + """Format a ScoredTranslation as a human-readable string.""" + lines = [] + lines.append(f"Source: {scored.source}") + lines.append(f"Translation: {scored.translation}") + lines.append(f"Overall log-probability: {scored.sequence_log_prob:.4f}") + lines.append("") + + # Per-word table + col_word = max(len(w.word) for w in scored.word_scores) if scored.word_scores else 10 + col_word = max(col_word, 4) + header = f" {'Word':<{col_word}} {'Log Prob':>10} {'Prob':>10} Suggestions" + lines.append(header) + lines.append(" " + "-" * (len(header) - 2)) + for ws in scored.word_scores: + flag = "* " if ws.is_low_probability else " " + suggestions_str = ", ".join(ws.suggestions) if ws.suggestions else "" + lines.append( + f"{flag}{ws.word:<{col_word}} {ws.log_prob:>10.4f} {ws.prob:>10.6f} {suggestions_str}" + ) + + lines.append("") + low_prob = scored.low_probability_words + if low_prob: + lines.append("Low-probability words and suggested alternatives:") + for ws in low_prob: + if ws.suggestions: + suggestions_str = ", ".join(f"'{s}'" for s in ws.suggestions) + lines.append(f" '{ws.word}' (log prob {ws.log_prob:.4f}) → {suggestions_str}") + else: + lines.append(f" '{ws.word}' (log prob {ws.log_prob:.4f}) → (no suggestions available)") + else: + lines.append("No low-probability words found.") + + return "\n".join(lines) + + +def score_translation( + experiment: str, + source: str, + translation: str, + src_iso: Optional[str], + trg_iso: Optional[str], + checkpoint: str = "last", + low_prob_threshold: float = DEFAULT_LOW_PROB_THRESHOLD, + top_k_suggestions: int = DEFAULT_TOP_K_SUGGESTIONS, +) -> ScoredTranslation: + """Score a translation against a source sentence using a trained NMT model. + + Loads the experiment's model, runs forced decoding on the translation, and returns + a ScoredTranslation with per-word probabilities and suggestions for flagged words. + + Args: + experiment: Name of the experiment (relative to the MT experiments directory). + source: The source sentence to score against. + translation: The translation to evaluate. + src_iso: Source language ISO code. Defaults to the experiment's test source. + trg_iso: Target language ISO code. Defaults to the experiment's test target. + checkpoint: Checkpoint to load ("last", "best", "avg", or a step number). + low_prob_threshold: Log-probability threshold for flagging low-probability words. + top_k_suggestions: Number of alternative suggestions per flagged word. + + Returns: + A ScoredTranslation with per-word scores and suggestions. + """ + config = load_config(experiment) + model = config.create_model() + + effective_src_iso = src_iso or config.default_test_src_iso + effective_trg_iso = trg_iso or config.default_test_trg_iso + + return model.score_translation( + source=source, + translation=translation, + src_iso=effective_src_iso, + trg_iso=effective_trg_iso, + ckpt=checkpoint, + low_prob_threshold=low_prob_threshold, + top_k_suggestions=top_k_suggestions, + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description=( + "Score a translation by computing the model's token-level conditional probabilities. " + "Low-probability words are flagged and paired with suggested alternatives from the model." + ) + ) + parser.add_argument("experiment", help="Experiment name") + parser.add_argument("--source", type=str, required=True, help="Source sentence to score against") + parser.add_argument("--translation", type=str, required=True, help="Translation to evaluate") + parser.add_argument("--src-iso", type=str, default=None, help="Source language ISO code") + parser.add_argument("--trg-iso", type=str, default=None, help="Target language ISO code") + parser.add_argument( + "--checkpoint", + type=str, + default="last", + help="Checkpoint to use: 'last', 'best', 'avg', or a checkpoint step number", + ) + parser.add_argument( + "--low-prob-threshold", + type=float, + default=DEFAULT_LOW_PROB_THRESHOLD, + help=( + f"Log-probability threshold below which a word is considered low-probability " + f"(default: {DEFAULT_LOW_PROB_THRESHOLD})" + ), + ) + parser.add_argument( + "--top-k-suggestions", + type=int, + default=DEFAULT_TOP_K_SUGGESTIONS, + help=f"Number of alternative suggestions per low-probability word (default: {DEFAULT_TOP_K_SUGGESTIONS})", + ) + + args = parser.parse_args() + + scored = score_translation( + experiment=args.experiment, + source=args.source, + translation=args.translation, + src_iso=args.src_iso, + trg_iso=args.trg_iso, + checkpoint=args.checkpoint, + low_prob_threshold=args.low_prob_threshold, + top_k_suggestions=args.top_k_suggestions, + ) + print(format_scored_translation(scored)) + + +if __name__ == "__main__": + main() diff --git a/silnlp/nmt/translation_scorer.py b/silnlp/nmt/translation_scorer.py new file mode 100644 index 00000000..2a8c712a --- /dev/null +++ b/silnlp/nmt/translation_scorer.py @@ -0,0 +1,261 @@ +import logging +from dataclasses import dataclass, field +from math import exp +from typing import List, Optional, Set + +import torch +import torch.nn.functional as F +from transformers import PreTrainedModel, PreTrainedTokenizer + +LOGGER = logging.getLogger(__name__) + +DEFAULT_LOW_PROB_THRESHOLD = -3.0 +DEFAULT_TOP_K_SUGGESTIONS = 5 + + +@dataclass +class TokenScore: + """Score information for a single subword token.""" + + token: str + log_prob: float + + @property + def prob(self) -> float: + return exp(self.log_prob) + + +@dataclass +class WordScore: + """Score information for a single word (may consist of multiple subword tokens).""" + + word: str + tokens: List[TokenScore] + suggestions: List[str] = field(default_factory=list) + low_prob_threshold: float = field(default=DEFAULT_LOW_PROB_THRESHOLD, compare=False, repr=False) + + @property + def log_prob(self) -> float: + return sum(t.log_prob for t in self.tokens) + + @property + def prob(self) -> float: + return exp(self.log_prob) + + @property + def is_low_probability(self) -> bool: + return self.log_prob < self.low_prob_threshold + + +@dataclass +class ScoredTranslation: + """Result of scoring a translation against a source sentence.""" + + source: str + translation: str + word_scores: List[WordScore] + + @property + def sequence_log_prob(self) -> float: + return sum(w.log_prob for w in self.word_scores) + + @property + def low_probability_words(self) -> List[WordScore]: + return [w for w in self.word_scores if w.is_low_probability] + + +class TranslationScorer: + """Scores a translation using forced decoding and identifies low-probability words. + + For each target token y_t in the translation, this class computes the conditional + probability P(y_t | y_1, ..., y_{t-1}, x) where x is the source sentence. It then + groups subword tokens into words, flags words that fall below a probability threshold, + and provides top-k alternative suggestions from the model for each flagged word. + """ + + def __init__( + self, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + low_prob_threshold: float = DEFAULT_LOW_PROB_THRESHOLD, + top_k_suggestions: int = DEFAULT_TOP_K_SUGGESTIONS, + ): + self._model = model + self._tokenizer = tokenizer + self._low_prob_threshold = low_prob_threshold + self._top_k_suggestions = top_k_suggestions + self._special_token_ids: Set[int] = set(tokenizer.all_special_ids) + + def score(self, source: str, translation: str) -> ScoredTranslation: + """Score each token in the translation using forced decoding. + + For each target token y_t, computes P(y_t | y_1, ..., y_{t-1}, x). Low-probability + words are flagged and paired with top-k alternative suggestions from the model. + + Args: + source: The source sentence. + translation: The translation to score. + + Returns: + A ScoredTranslation with per-word scores and suggestions for low-probability words. + """ + # Tokenize source + source_encoding = self._tokenizer(source, return_tensors="pt", truncation=True) + + # Tokenize target as labels + target_encoding = self._tokenizer(text_target=translation, return_tensors="pt", truncation=True) + labels = target_encoding["input_ids"] + + # If the model forces a BOS token (e.g., a language code for NLLB/M2M100), + # prepend it to the labels so the decoder has the correct context for scoring. + forced_bos_token_id: Optional[int] = None + if hasattr(self._model, "generation_config") and self._model.generation_config is not None: + forced_bos_token_id = self._model.generation_config.forced_bos_token_id + if forced_bos_token_id is None and hasattr(self._model.config, "forced_bos_token_id"): + forced_bos_token_id = self._model.config.forced_bos_token_id + + if forced_bos_token_id is not None and labels[0, 0].item() != forced_bos_token_id: + forced_bos = torch.tensor([[forced_bos_token_id]], dtype=labels.dtype) + labels = torch.cat([forced_bos, labels], dim=1) + + # Move tensors to the model's device + device = next(self._model.parameters()).device + input_ids = source_encoding["input_ids"].to(device) + attention_mask = source_encoding["attention_mask"].to(device) + labels_on_device = labels.to(device) + + # Run forward pass with teacher forcing to get logits at each position + self._model.eval() + with torch.no_grad(): + outputs = self._model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels_on_device, + ) + + # Compute log probabilities from logits (use float32 for numerical precision) + logits = outputs.logits # shape: [1, seq_len, vocab_size] + log_probs = F.log_softmax(logits.float(), dim=-1) + + label_ids = labels[0].tolist() + tokens = [self._tokenizer.convert_ids_to_tokens(id) for id in label_ids] + token_log_probs = [log_probs[0, i, id].item() for i, id in enumerate(label_ids)] + + # Get top-k alternative token IDs at each position for suggestions + k = min(self._top_k_suggestions + 1, log_probs.shape[-1]) + top_k_ids = torch.topk(log_probs[0], k=k, dim=-1).indices.tolist() + + word_scores = self._group_tokens_into_words(tokens, token_log_probs, top_k_ids, label_ids) + + return ScoredTranslation( + source=source, + translation=translation, + word_scores=word_scores, + ) + + def _is_word_initial(self, token: str) -> bool: + """Return True if this subword token begins a new word.""" + return token.startswith("▁") or token.startswith("Ġ") + + def _decode_suggestion(self, token_id: int) -> Optional[str]: + """Decode a token ID to a display string, returning None for special tokens.""" + if token_id in self._special_token_ids: + return None + token = self._tokenizer.convert_ids_to_tokens(token_id) + if not token: + return None + # Strip SentencePiece (▁) or BPE (Ġ) word-initial markers + if token.startswith("▁"): + stripped = token[len("▁") :] + return stripped if stripped else None + if token.startswith("Ġ"): + stripped = token[len("Ġ") :] + return stripped if stripped else None + # Strip BERT-style continuation marker + if token.startswith("##"): + return None + return token + + def _group_tokens_into_words( + self, + tokens: List[str], + token_log_probs: List[float], + top_k_ids: List[List[int]], + label_ids: List[int], + ) -> List[WordScore]: + """Group subword tokens into words and compute per-word scores.""" + word_scores: List[WordScore] = [] + current_token_scores: List[TokenScore] = [] + current_word_chars = "" + current_first_top_k: List[int] = [] + + for token, log_prob, position_top_k, label_id in zip(tokens, token_log_probs, top_k_ids, label_ids): + # Skip special tokens (language codes, EOS, BOS, pad) + if label_id in self._special_token_ids: + if current_token_scores: + word_scores.append( + self._create_word_score(current_word_chars, current_token_scores, current_first_top_k) + ) + current_token_scores = [] + current_word_chars = "" + current_first_top_k = [] + continue + + # A word-initial token (starts with ▁ or Ġ) begins a new word + starts_new_word = self._is_word_initial(token) + if starts_new_word and current_token_scores: + word_scores.append( + self._create_word_score(current_word_chars, current_token_scores, current_first_top_k) + ) + current_token_scores = [] + current_word_chars = "" + current_first_top_k = [] + + # Record top-k alternatives only for the first subword of each word + if not current_token_scores: + current_first_top_k = position_top_k + + current_token_scores.append(TokenScore(token=token, log_prob=log_prob)) + + # Append the token's characters to the current word (strip the word-initial marker) + if token.startswith("▁"): + current_word_chars += token[len("▁") :] + elif token.startswith("Ġ"): + current_word_chars += token[len("Ġ") :] + elif token.startswith("##"): + current_word_chars += token[len("##") :] + else: + current_word_chars += token + + # Finalize the last word + if current_token_scores: + word_scores.append( + self._create_word_score(current_word_chars, current_token_scores, current_first_top_k) + ) + + return word_scores + + def _create_word_score( + self, + word: str, + token_scores: List[TokenScore], + first_token_top_k_ids: List[int], + ) -> WordScore: + """Build a WordScore, adding suggestions when the word is low-probability.""" + word_log_prob = sum(t.log_prob for t in token_scores) + suggestions: List[str] = [] + + if word_log_prob < self._low_prob_threshold: + for top_id in first_token_top_k_ids: + suggestion = self._decode_suggestion(top_id) + if suggestion and suggestion.lower() != word.lower(): + suggestions.append(suggestion) + if len(suggestions) >= self._top_k_suggestions: + break + + return WordScore( + word=word, + tokens=token_scores, + suggestions=suggestions, + low_prob_threshold=self._low_prob_threshold, + ) diff --git a/tests/smoke_tests/test_translation_scorer.py b/tests/smoke_tests/test_translation_scorer.py new file mode 100644 index 00000000..e6d40f8e --- /dev/null +++ b/tests/smoke_tests/test_translation_scorer.py @@ -0,0 +1,230 @@ +"""Tests for the TranslationScorer class and related data structures.""" +from unittest.mock import MagicMock, patch + +import torch + +from silnlp.nmt.translation_scorer import ( + DEFAULT_LOW_PROB_THRESHOLD, + DEFAULT_TOP_K_SUGGESTIONS, + ScoredTranslation, + TokenScore, + TranslationScorer, + WordScore, +) + +_TINY_MODEL_NAME = "hf-internal-testing/tiny-random-nllb" + + +def _make_mock_model_and_tokenizer(): + """Create a mock model and tokenizer for testing.""" + from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + from transformers.modeling_outputs import Seq2SeqLMOutput + + tokenizer = AutoTokenizer.from_pretrained(_TINY_MODEL_NAME) + tokenizer.src_lang = "eng_Latn" + tokenizer.tgt_lang = "fra_Latn" + + # Create a minimal mock model that returns deterministic logits + model = MagicMock() + vocab_size = len(tokenizer) + + # Return logits shaped [1, seq_len, vocab_size] with deterministic values + def mock_forward(input_ids=None, attention_mask=None, labels=None, **kwargs): + seq_len = labels.shape[1] if labels is not None else 1 + logits = torch.zeros(1, seq_len, vocab_size) + # Make the first token of the vocabulary highly probable except at position 0 + logits[:, :, 0] = -10.0 + logits[:, :, 1] = -1.0 # slightly low probability + return Seq2SeqLMOutput(logits=logits) + + model.side_effect = None + model.__call__ = MagicMock(side_effect=mock_forward) + model.parameters = MagicMock(return_value=iter([torch.zeros(1)])) + model.eval = MagicMock() + model.generation_config = MagicMock() + model.generation_config.forced_bos_token_id = tokenizer.convert_tokens_to_ids("fra_Latn") + model.config = MagicMock() + model.config.forced_bos_token_id = None + + return model, tokenizer + + +class TestTokenScore: + def test_prob_property(self): + ts = TokenScore(token="▁hello", log_prob=-1.0) + assert abs(ts.prob - 0.3679) < 1e-3 + + def test_zero_log_prob(self): + ts = TokenScore(token="▁hello", log_prob=0.0) + assert ts.prob == 1.0 + + +class TestWordScore: + def test_log_prob_is_sum_of_token_log_probs(self): + ws = WordScore( + word="hello", + tokens=[ + TokenScore(token="▁hel", log_prob=-0.5), + TokenScore(token="lo", log_prob=-0.3), + ], + ) + assert abs(ws.log_prob - (-0.8)) < 1e-6 + + def test_is_low_probability_default_threshold(self): + low = WordScore(word="low", tokens=[TokenScore(token="▁low", log_prob=-5.0)]) + high = WordScore(word="high", tokens=[TokenScore(token="▁high", log_prob=-1.0)]) + assert low.is_low_probability + assert not high.is_low_probability + + def test_is_low_probability_custom_threshold(self): + ws = WordScore( + word="test", + tokens=[TokenScore(token="▁test", log_prob=-2.0)], + low_prob_threshold=-1.0, + ) + assert ws.is_low_probability + + +class TestScoredTranslation: + def test_sequence_log_prob(self): + st = ScoredTranslation( + source="hello", + translation="bonjour", + word_scores=[ + WordScore(word="bonjour", tokens=[TokenScore(token="▁bonjour", log_prob=-0.5)]), + ], + ) + assert abs(st.sequence_log_prob - (-0.5)) < 1e-6 + + def test_low_probability_words(self): + st = ScoredTranslation( + source="hello world", + translation="bonjour monde", + word_scores=[ + WordScore(word="bonjour", tokens=[TokenScore(token="▁bonjour", log_prob=-1.0)]), + WordScore(word="monde", tokens=[TokenScore(token="▁monde", log_prob=-5.0)]), + ], + ) + low = st.low_probability_words + assert len(low) == 1 + assert low[0].word == "monde" + + +class TestTranslationScorer: + def test_score_returns_scored_translation(self): + model, tokenizer = _make_mock_model_and_tokenizer() + scorer = TranslationScorer(model, tokenizer) + result = scorer.score("hello world", "bonjour monde") + assert isinstance(result, ScoredTranslation) + assert result.source == "hello world" + assert result.translation == "bonjour monde" + assert isinstance(result.word_scores, list) + assert len(result.word_scores) > 0 + + def test_each_word_score_has_tokens(self): + model, tokenizer = _make_mock_model_and_tokenizer() + scorer = TranslationScorer(model, tokenizer) + result = scorer.score("hello world", "bonjour monde") + for ws in result.word_scores: + assert isinstance(ws, WordScore) + assert len(ws.tokens) > 0 + for ts in ws.tokens: + assert isinstance(ts, TokenScore) + + def test_low_prob_words_get_suggestions(self): + model, tokenizer = _make_mock_model_and_tokenizer() + # Use a threshold of 0.0 so all words are flagged as low-probability + scorer = TranslationScorer(model, tokenizer, low_prob_threshold=0.0) + result = scorer.score("hello world", "bonjour monde") + for ws in result.word_scores: + assert ws.is_low_probability + # Every flagged word should have suggestions (up to top_k) + assert len(ws.suggestions) <= DEFAULT_TOP_K_SUGGESTIONS + + def test_custom_top_k(self): + model, tokenizer = _make_mock_model_and_tokenizer() + scorer = TranslationScorer(model, tokenizer, low_prob_threshold=0.0, top_k_suggestions=2) + result = scorer.score("hello", "bonjour") + for ws in result.word_scores: + assert len(ws.suggestions) <= 2 + + def test_forced_bos_prepended_to_labels(self): + """Verify that forced_bos_token_id is prepended to labels when not already present.""" + model, tokenizer = _make_mock_model_and_tokenizer() + scorer = TranslationScorer(model, tokenizer) + + # Capture the labels passed to the model + captured_labels = [] + original_call = model.__call__.side_effect + + def capturing_call(**kwargs): + captured_labels.append(kwargs.get("labels")) + return original_call(**kwargs) + + model.__call__.side_effect = capturing_call + scorer.score("hello", "bonjour") + + assert len(captured_labels) == 1 + labels = captured_labels[0] + # The first token should be the forced BOS (language code) + forced_bos = tokenizer.convert_tokens_to_ids("fra_Latn") + assert labels[0, 0].item() == forced_bos + + +class TestWordGrouping: + def test_sentencepiece_tokens_grouped_correctly(self): + """Test that SentencePiece tokens (▁ prefix) are grouped into words.""" + model, tokenizer = _make_mock_model_and_tokenizer() + scorer = TranslationScorer(model, tokenizer) + + # Mock _group_tokens_into_words directly to test grouping logic + tokens = ["▁hello", "▁world", "!"] + token_log_probs = [-1.0, -2.0, -0.5] + label_ids = [101, 102, 103] # fake IDs, not in special_token_ids + top_k_ids = [[1, 2, 3, 4, 5, 6]] * 3 + + # Temporarily clear special token IDs to prevent any skipping + scorer._special_token_ids = set() + words = scorer._group_tokens_into_words(tokens, token_log_probs, top_k_ids, label_ids) + + # "hello" is one word, "world" + "!" are two words + # "!" does not start with ▁, so it continues "world" → "world!" + assert len(words) == 2 + assert words[0].word == "hello" + assert words[1].word == "world!" + + def test_continuation_tokens_merged(self): + """Test that tokens without ▁ are merged with the previous word.""" + model, tokenizer = _make_mock_model_and_tokenizer() + scorer = TranslationScorer(model, tokenizer) + scorer._special_token_ids = set() + + tokens = ["▁walk", "ing"] + token_log_probs = [-0.5, -0.3] + label_ids = [101, 102] + top_k_ids = [[1, 2, 3, 4, 5, 6]] * 2 + + words = scorer._group_tokens_into_words(tokens, token_log_probs, top_k_ids, label_ids) + assert len(words) == 1 + assert words[0].word == "walking" + assert len(words[0].tokens) == 2 + + def test_special_tokens_skipped(self): + """Special tokens (EOS, BOS, lang codes) should not appear as words.""" + model, tokenizer = _make_mock_model_and_tokenizer() + scorer = TranslationScorer(model, tokenizer) + + # Get some real special token IDs + eos_id = tokenizer.eos_token_id + special_ids = {eos_id} + scorer._special_token_ids = special_ids + + tokens = ["▁hello", tokenizer.eos_token] + token_log_probs = [-1.0, 0.0] + label_ids = [101, eos_id] + top_k_ids = [[1, 2, 3, 4, 5, 6]] * 2 + + words = scorer._group_tokens_into_words(tokens, token_log_probs, top_k_ids, label_ids) + # EOS token should be skipped; only "hello" remains + assert len(words) == 1 + assert words[0].word == "hello" From bc6c9b45e100800c561ba22e7706e82a85bbc9f1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 19:10:07 +0000 Subject: [PATCH 3/5] Apply code review: use str.removeprefix() for cleaner token prefix stripping Agent-Logs-Url: https://github.com/sillsdev/silnlp/sessions/c9d942dc-1d4a-4f20-b396-c522dcb080bc Co-authored-by: benjaminking <1214233+benjaminking@users.noreply.github.com> --- silnlp/nmt/translation_scorer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/silnlp/nmt/translation_scorer.py b/silnlp/nmt/translation_scorer.py index 2a8c712a..801ef818 100644 --- a/silnlp/nmt/translation_scorer.py +++ b/silnlp/nmt/translation_scorer.py @@ -166,10 +166,10 @@ def _decode_suggestion(self, token_id: int) -> Optional[str]: return None # Strip SentencePiece (▁) or BPE (Ġ) word-initial markers if token.startswith("▁"): - stripped = token[len("▁") :] + stripped = token.removeprefix("▁") return stripped if stripped else None if token.startswith("Ġ"): - stripped = token[len("Ġ") :] + stripped = token.removeprefix("Ġ") return stripped if stripped else None # Strip BERT-style continuation marker if token.startswith("##"): @@ -219,11 +219,11 @@ def _group_tokens_into_words( # Append the token's characters to the current word (strip the word-initial marker) if token.startswith("▁"): - current_word_chars += token[len("▁") :] + current_word_chars += token.removeprefix("▁") elif token.startswith("Ġ"): - current_word_chars += token[len("Ġ") :] + current_word_chars += token.removeprefix("Ġ") elif token.startswith("##"): - current_word_chars += token[len("##") :] + current_word_chars += token.removeprefix("##") else: current_word_chars += token From e64e39871adae12bbb72544a42c021c02aca1f03 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 27 May 2026 19:43:58 +0000 Subject: [PATCH 4/5] Add contextual phrase scoring and rescored replacement suggestions --- silnlp/nmt/score_translations.py | 91 ++- silnlp/nmt/translation_scorer.py | 566 ++++++++++++++----- tests/smoke_tests/test_translation_scorer.py | 368 +++++------- 3 files changed, 599 insertions(+), 426 deletions(-) diff --git a/silnlp/nmt/score_translations.py b/silnlp/nmt/score_translations.py index 3b6d341d..419201ac 100644 --- a/silnlp/nmt/score_translations.py +++ b/silnlp/nmt/score_translations.py @@ -3,44 +3,59 @@ from typing import Optional from .config_utils import load_config -from .translation_scorer import DEFAULT_LOW_PROB_THRESHOLD, DEFAULT_TOP_K_SUGGESTIONS, ScoredTranslation +from .translation_scorer import DEFAULT_LOW_PROB_THRESHOLD, DEFAULT_TOP_K_SUGGESTIONS, PhraseScore, ScoredTranslation LOGGER = logging.getLogger(__name__) +def _format_suggestions(phrase_score: PhraseScore) -> str: + if not phrase_score.suggestions: + return "" + return "; ".join( + f"{suggestion.phrase} (ctx/word={suggestion.normalized_log_prob:.4f}, Δ={suggestion.improvement:.4f})" + for suggestion in phrase_score.suggestions + ) + + def format_scored_translation(scored: ScoredTranslation) -> str: """Format a ScoredTranslation as a human-readable string.""" - lines = [] - lines.append(f"Source: {scored.source}") - lines.append(f"Translation: {scored.translation}") - lines.append(f"Overall log-probability: {scored.sequence_log_prob:.4f}") - lines.append("") - - # Per-word table - col_word = max(len(w.word) for w in scored.word_scores) if scored.word_scores else 10 - col_word = max(col_word, 4) - header = f" {'Word':<{col_word}} {'Log Prob':>10} {'Prob':>10} Suggestions" + lines = [ + f"Source: {scored.source}", + f"Translation: {scored.translation}", + f"Sequence log-probability: {scored.sequence_log_prob:.4f}", + "", + "Word-level contextual scores:", + ] + + col_word = max((len(score.word) for score in scored.word_scores), default=4) + header = ( + f" {'Word':<{col_word}} {'Forward':>10} {'Right Ctx':>10} {'Ctx/Word':>10} Suggestions" + ) lines.append(header) lines.append(" " + "-" * (len(header) - 2)) - for ws in scored.word_scores: - flag = "* " if ws.is_low_probability else " " - suggestions_str = ", ".join(ws.suggestions) if ws.suggestions else "" + for score in scored.word_scores: + flag = "* " if score.is_low_probability else " " + suggestion_text = ", ".join(suggestion.phrase for suggestion in score.suggestions) lines.append( - f"{flag}{ws.word:<{col_word}} {ws.log_prob:>10.4f} {ws.prob:>10.6f} {suggestions_str}" + f"{flag}{score.word:<{col_word}} {score.forward_log_prob:>10.4f} " + f"{score.right_context_log_prob:>10.4f} {score.normalized_log_prob:>10.4f} {suggestion_text}" ) lines.append("") - low_prob = scored.low_probability_words - if low_prob: - lines.append("Low-probability words and suggested alternatives:") - for ws in low_prob: - if ws.suggestions: - suggestions_str = ", ".join(f"'{s}'" for s in ws.suggestions) - lines.append(f" '{ws.word}' (log prob {ws.log_prob:.4f}) → {suggestions_str}") - else: - lines.append(f" '{ws.word}' (log prob {ws.log_prob:.4f}) → (no suggestions available)") + lines.append("Low-probability phrases:") + low_phrases = scored.low_probability_phrases + if not low_phrases: + lines.append(" None") else: - lines.append("No low-probability words found.") + for score in low_phrases: + lines.append( + f" '{score.phrase}' [{score.span_start}:{score.span_end}] " + f"forward={score.forward_log_prob:.4f}, right_ctx={score.right_context_log_prob:.4f}, " + f"ctx/word={score.normalized_log_prob:.4f}" + ) + suggestions = _format_suggestions(score) + if suggestions: + lines.append(f" Suggestions: {suggestions}") return "\n".join(lines) @@ -55,24 +70,6 @@ def score_translation( low_prob_threshold: float = DEFAULT_LOW_PROB_THRESHOLD, top_k_suggestions: int = DEFAULT_TOP_K_SUGGESTIONS, ) -> ScoredTranslation: - """Score a translation against a source sentence using a trained NMT model. - - Loads the experiment's model, runs forced decoding on the translation, and returns - a ScoredTranslation with per-word probabilities and suggestions for flagged words. - - Args: - experiment: Name of the experiment (relative to the MT experiments directory). - source: The source sentence to score against. - translation: The translation to evaluate. - src_iso: Source language ISO code. Defaults to the experiment's test source. - trg_iso: Target language ISO code. Defaults to the experiment's test target. - checkpoint: Checkpoint to load ("last", "best", "avg", or a step number). - low_prob_threshold: Log-probability threshold for flagging low-probability words. - top_k_suggestions: Number of alternative suggestions per flagged word. - - Returns: - A ScoredTranslation with per-word scores and suggestions. - """ config = load_config(experiment) model = config.create_model() @@ -93,8 +90,8 @@ def score_translation( def main() -> None: parser = argparse.ArgumentParser( description=( - "Score a translation by computing the model's token-level conditional probabilities. " - "Low-probability words are flagged and paired with suggested alternatives from the model." + "Score a translation by computing contextual phrase probabilities. " + "Low-probability words and phrases are flagged and paired with rescored replacement suggestions." ) ) parser.add_argument("experiment", help="Experiment name") @@ -113,7 +110,7 @@ def main() -> None: type=float, default=DEFAULT_LOW_PROB_THRESHOLD, help=( - f"Log-probability threshold below which a word is considered low-probability " + f"Threshold on contextual log-probability per word for flagging low-probability spans " f"(default: {DEFAULT_LOW_PROB_THRESHOLD})" ), ) @@ -121,7 +118,7 @@ def main() -> None: "--top-k-suggestions", type=int, default=DEFAULT_TOP_K_SUGGESTIONS, - help=f"Number of alternative suggestions per low-probability word (default: {DEFAULT_TOP_K_SUGGESTIONS})", + help=f"Number of replacement suggestions per flagged span (default: {DEFAULT_TOP_K_SUGGESTIONS})", ) args = parser.parse_args() diff --git a/silnlp/nmt/translation_scorer.py b/silnlp/nmt/translation_scorer.py index 801ef818..06aca129 100644 --- a/silnlp/nmt/translation_scorer.py +++ b/silnlp/nmt/translation_scorer.py @@ -1,7 +1,7 @@ import logging from dataclasses import dataclass, field from math import exp -from typing import List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple import torch import torch.nn.functional as F @@ -11,6 +11,10 @@ DEFAULT_LOW_PROB_THRESHOLD = -3.0 DEFAULT_TOP_K_SUGGESTIONS = 5 +DEFAULT_MAX_PHRASE_WORDS = 3 +DEFAULT_MIN_SUGGESTION_IMPROVEMENT = 0.25 +DEFAULT_GENERATION_BEAMS = 8 +DEFAULT_MAX_CANDIDATE_TOKENS = 12 @dataclass @@ -25,26 +29,91 @@ def prob(self) -> float: return exp(self.log_prob) +@dataclass +class PhraseSuggestion: + """A replacement phrase suggestion scored with both left and right context.""" + + phrase: str + forward_log_prob: float + right_context_log_prob: float + contextual_log_prob: float + improvement: float + + @property + def word_count(self) -> int: + return max(len(self.phrase.split()), 1) + + @property + def normalized_log_prob(self) -> float: + return self.contextual_log_prob / self.word_count + + @dataclass class WordScore: - """Score information for a single word (may consist of multiple subword tokens).""" + """Context-aware score information for a single word.""" word: str tokens: List[TokenScore] - suggestions: List[str] = field(default_factory=list) + forward_log_prob: float + right_context_log_prob: float + contextual_log_prob: float + span_start: int + span_end: int + suggestions: List[PhraseSuggestion] = field(default_factory=list) low_prob_threshold: float = field(default=DEFAULT_LOW_PROB_THRESHOLD, compare=False, repr=False) + @property + def word_count(self) -> int: + return 1 + @property def log_prob(self) -> float: - return sum(t.log_prob for t in self.tokens) + return self.contextual_log_prob + + @property + def normalized_log_prob(self) -> float: + return self.contextual_log_prob @property def prob(self) -> float: - return exp(self.log_prob) + return exp(self.contextual_log_prob) + + @property + def is_low_probability(self) -> bool: + return self.normalized_log_prob < self.low_prob_threshold + + +@dataclass +class PhraseScore: + """Context-aware score for a variable-length phrase span.""" + + phrase: str + word_scores: List[WordScore] + forward_log_prob: float + right_context_log_prob: float + contextual_log_prob: float + suggestions: List[PhraseSuggestion] = field(default_factory=list) + low_prob_threshold: float = field(default=DEFAULT_LOW_PROB_THRESHOLD, compare=False, repr=False) + + @property + def span_start(self) -> int: + return self.word_scores[0].span_start + + @property + def span_end(self) -> int: + return self.word_scores[-1].span_end + + @property + def word_count(self) -> int: + return len(self.word_scores) + + @property + def normalized_log_prob(self) -> float: + return self.contextual_log_prob / max(self.word_count, 1) @property def is_low_probability(self) -> bool: - return self.log_prob < self.low_prob_threshold + return self.normalized_log_prob < self.low_prob_threshold @dataclass @@ -54,23 +123,72 @@ class ScoredTranslation: source: str translation: str word_scores: List[WordScore] + phrase_scores: List[PhraseScore] @property def sequence_log_prob(self) -> float: - return sum(w.log_prob for w in self.word_scores) + return sum(w.forward_log_prob for w in self.word_scores) @property def low_probability_words(self) -> List[WordScore]: return [w for w in self.word_scores if w.is_low_probability] + @property + def low_probability_phrases(self) -> List[PhraseScore]: + return [p for p in self.phrase_scores if p.is_low_probability] + + +@dataclass +class WordSpan: + text: str + token_start: int + token_end: int + + +@dataclass +class SequenceScore: + text: str + label_ids: List[int] + tokens: List[str] + token_log_probs: List[float] + words: List[WordSpan] + + def __post_init__(self) -> None: + self._word_forward_log_probs = [ + sum(self.token_log_probs[word.token_start : word.token_end]) for word in self.words + ] + cumulative = [0.0] + total = 0.0 + for value in self._word_forward_log_probs: + total += value + cumulative.append(total) + self._cumulative_word_forward_log_probs = cumulative + + @property + def word_texts(self) -> List[str]: + return [word.text for word in self.words] + + def phrase_text(self, start: int, end: int) -> str: + return " ".join(self.word_texts[start:end]) + + def word_forward_log_prob(self, start: int, end: int) -> float: + return self._cumulative_word_forward_log_probs[end] - self._cumulative_word_forward_log_probs[start] + class TranslationScorer: - """Scores a translation using forced decoding and identifies low-probability words. + """Scores translations with contextual phrase rescoring and replacement suggestions. - For each target token y_t in the translation, this class computes the conditional - probability P(y_t | y_1, ..., y_{t-1}, x) where x is the source sentence. It then - groups subword tokens into words, flags words that fall below a probability threshold, - and provides top-k alternative suggestions from the model for each flagged word. + The left-to-right decoder supplies token log probabilities for the observed phrase. + To incorporate right-context evidence, each phrase is rescored against the observed + suffix using a contextual span objective: + + contextual(span) = log P(span | prefix, x) + + log P(suffix | prefix + span, x) + - log P(suffix | prefix, x) + + This acts like a pointwise contextual compatibility score. The same score, normalized + by phrase length in words, is used for both anomaly detection and for ranking + replacement candidates of different lengths. """ def __init__( @@ -79,183 +197,333 @@ def __init__( tokenizer: PreTrainedTokenizer, low_prob_threshold: float = DEFAULT_LOW_PROB_THRESHOLD, top_k_suggestions: int = DEFAULT_TOP_K_SUGGESTIONS, + max_phrase_words: int = DEFAULT_MAX_PHRASE_WORDS, + min_suggestion_improvement: float = DEFAULT_MIN_SUGGESTION_IMPROVEMENT, ): self._model = model self._tokenizer = tokenizer self._low_prob_threshold = low_prob_threshold self._top_k_suggestions = top_k_suggestions + self._max_phrase_words = max_phrase_words + self._min_suggestion_improvement = min_suggestion_improvement self._special_token_ids: Set[int] = set(tokenizer.all_special_ids) def score(self, source: str, translation: str) -> ScoredTranslation: - """Score each token in the translation using forced decoding. + source_context = self._encode_source(source) + cache: Dict[str, SequenceScore] = {} + base = self._score_translation_text(source_context, translation, cache) - For each target token y_t, computes P(y_t | y_1, ..., y_{t-1}, x). Low-probability - words are flagged and paired with top-k alternative suggestions from the model. + phrase_scores = self._build_phrase_scores(source_context, base, cache) + word_scores = [self._to_word_score(phrase_score) for phrase_score in phrase_scores if phrase_score.word_count == 1] + multi_word_phrase_scores = [phrase_score for phrase_score in phrase_scores if phrase_score.word_count > 1] - Args: - source: The source sentence. - translation: The translation to score. + return ScoredTranslation( + source=source, + translation=translation, + word_scores=word_scores, + phrase_scores=multi_word_phrase_scores, + ) - Returns: - A ScoredTranslation with per-word scores and suggestions for low-probability words. - """ - # Tokenize source - source_encoding = self._tokenizer(source, return_tensors="pt", truncation=True) + def _build_phrase_scores( + self, + source_context: Dict[str, torch.Tensor], + base: SequenceScore, + cache: Dict[str, SequenceScore], + ) -> List[PhraseScore]: + phrase_scores: List[PhraseScore] = [] + + for phrase_len in range(1, min(self._max_phrase_words, len(base.words)) + 1): + for start in range(0, len(base.words) - phrase_len + 1): + end = start + phrase_len + phrase_text = base.phrase_text(start, end) + forward_log_prob = base.word_forward_log_prob(start, end) + right_context_log_prob, ablated = self._score_right_context( + source_context, + base, + start, + end, + cache, + ) + contextual_log_prob = forward_log_prob + right_context_log_prob + span_word_scores = self._create_span_word_scores(base, start, end, right_context_log_prob) + phrase_score = PhraseScore( + phrase=phrase_text, + word_scores=span_word_scores, + forward_log_prob=forward_log_prob, + right_context_log_prob=right_context_log_prob, + contextual_log_prob=contextual_log_prob, + low_prob_threshold=self._low_prob_threshold, + ) + if phrase_score.is_low_probability: + phrase_score.suggestions = self._generate_scored_suggestions( + source_context, + base, + ablated, + phrase_score, + cache, + ) + phrase_scores.append(phrase_score) + # Reuse phrase-level suggestions for the singleton word scores. + for phrase_score in phrase_scores: + if phrase_score.word_count == 1: + continue + for word_score in phrase_score.word_scores: + if not word_score.suggestions and phrase_score.suggestions: + word_score.suggestions = phrase_score.suggestions + + return phrase_scores + + def _create_span_word_scores(self, base: SequenceScore, start: int, end: int, right_context_log_prob: float) -> List[WordScore]: + span = base.words[start:end] + span_forward = base.word_forward_log_prob(start, end) + span_contextual = span_forward + right_context_log_prob + word_count = max(end - start, 1) + span_share = right_context_log_prob / word_count + word_scores: List[WordScore] = [] + for index, word in enumerate(span, start): + tokens = [ + TokenScore(token=base.tokens[token_index], log_prob=base.token_log_probs[token_index]) + for token_index in range(word.token_start, word.token_end) + ] + word_forward = base.word_forward_log_prob(index, index + 1) + word_scores.append( + WordScore( + word=word.text, + tokens=tokens, + forward_log_prob=word_forward, + right_context_log_prob=span_share, + contextual_log_prob=(span_contextual / word_count if word_count > 1 else word_forward + right_context_log_prob), + span_start=index, + span_end=index + 1, + low_prob_threshold=self._low_prob_threshold, + ) + ) + return word_scores - # Tokenize target as labels - target_encoding = self._tokenizer(text_target=translation, return_tensors="pt", truncation=True) - labels = target_encoding["input_ids"] + def _to_word_score(self, phrase_score: PhraseScore) -> WordScore: + word_score = phrase_score.word_scores[0] + word_score.suggestions = phrase_score.suggestions + return word_score - # If the model forces a BOS token (e.g., a language code for NLLB/M2M100), - # prepend it to the labels so the decoder has the correct context for scoring. - forced_bos_token_id: Optional[int] = None - if hasattr(self._model, "generation_config") and self._model.generation_config is not None: - forced_bos_token_id = self._model.generation_config.forced_bos_token_id - if forced_bos_token_id is None and hasattr(self._model.config, "forced_bos_token_id"): - forced_bos_token_id = self._model.config.forced_bos_token_id + def _score_right_context( + self, + source_context: Dict[str, torch.Tensor], + base: SequenceScore, + start: int, + end: int, + cache: Dict[str, SequenceScore], + ) -> Tuple[float, SequenceScore]: + suffix_with_phrase = base.word_forward_log_prob(end, len(base.words)) + ablated_words = base.word_texts[:start] + base.word_texts[end:] + ablated_text = self._join_words(ablated_words) + ablated = self._score_translation_text(source_context, ablated_text, cache) + if end >= len(base.words): + return 0.0, ablated + suffix_without_phrase = ablated.word_forward_log_prob(start, len(ablated.words)) + return suffix_with_phrase - suffix_without_phrase, ablated + + def _generate_scored_suggestions( + self, + source_context: Dict[str, torch.Tensor], + base: SequenceScore, + ablated: SequenceScore, + phrase_score: PhraseScore, + cache: Dict[str, SequenceScore], + ) -> List[PhraseSuggestion]: + prefix_words = base.word_texts[: phrase_score.span_start] + suffix_words = base.word_texts[phrase_score.span_end :] + prefix_text = self._join_words(prefix_words) + candidate_phrases = self._generate_candidate_phrases(source_context, prefix_text) + suggestions: List[PhraseSuggestion] = [] + seen: Set[str] = set() + + for candidate_phrase in candidate_phrases: + normalized_candidate = candidate_phrase.strip() + if normalized_candidate == "" or normalized_candidate == phrase_score.phrase or normalized_candidate in seen: + continue + seen.add(normalized_candidate) + candidate_words = normalized_candidate.split() + variant_words = prefix_words + candidate_words + suffix_words + variant_text = self._join_words(variant_words) + variant = self._score_translation_text(source_context, variant_text, cache) + candidate_end = phrase_score.span_start + len(candidate_words) + forward_log_prob = variant.word_forward_log_prob(phrase_score.span_start, candidate_end) + suffix_with_candidate = variant.word_forward_log_prob(candidate_end, len(variant.words)) + suffix_without_phrase = ablated.word_forward_log_prob(phrase_score.span_start, len(ablated.words)) + right_context_log_prob = suffix_with_candidate - suffix_without_phrase + contextual_log_prob = forward_log_prob + right_context_log_prob + suggestion = PhraseSuggestion( + phrase=normalized_candidate, + forward_log_prob=forward_log_prob, + right_context_log_prob=right_context_log_prob, + contextual_log_prob=contextual_log_prob, + improvement=(contextual_log_prob / max(len(candidate_words), 1)) - phrase_score.normalized_log_prob, + ) + if suggestion.improvement >= self._min_suggestion_improvement: + suggestions.append(suggestion) - if forced_bos_token_id is not None and labels[0, 0].item() != forced_bos_token_id: - forced_bos = torch.tensor([[forced_bos_token_id]], dtype=labels.dtype) - labels = torch.cat([forced_bos, labels], dim=1) + suggestions.sort(key=lambda suggestion: (suggestion.normalized_log_prob, suggestion.improvement), reverse=True) + return suggestions[: self._top_k_suggestions] - # Move tensors to the model's device + def _generate_candidate_phrases(self, source_context: Dict[str, torch.Tensor], prefix_text: str) -> List[str]: + self._model.eval() + prefix_ids = self._build_decoder_prefix_ids(prefix_text).to(source_context["input_ids"].device) + num_beams = max(DEFAULT_GENERATION_BEAMS, self._top_k_suggestions * 2) + with torch.no_grad(): + outputs = self._model.generate( + input_ids=source_context["input_ids"], + attention_mask=source_context["attention_mask"], + decoder_input_ids=prefix_ids, + num_beams=num_beams, + num_return_sequences=num_beams, + max_new_tokens=DEFAULT_MAX_CANDIDATE_TOKENS, + early_stopping=True, + return_dict_in_generate=True, + forced_bos_token_id=None, + ) + + candidates: List[str] = [] + prefix_len = prefix_ids.shape[1] + for sequence in outputs.sequences: + continuation_ids = sequence[prefix_len:] + candidate_text = self._tokenizer.decode(continuation_ids, skip_special_tokens=True).strip() + if candidate_text == "": + continue + words = candidate_text.split() + max_words = min(self._max_phrase_words, len(words)) + for phrase_len in range(1, max_words + 1): + candidates.append(" ".join(words[:phrase_len])) + return candidates + + def _encode_source(self, source: str) -> Dict[str, torch.Tensor]: + source_encoding = self._tokenizer(source, return_tensors="pt", truncation=True) device = next(self._model.parameters()).device - input_ids = source_encoding["input_ids"].to(device) - attention_mask = source_encoding["attention_mask"].to(device) - labels_on_device = labels.to(device) + return { + "input_ids": source_encoding["input_ids"].to(device), + "attention_mask": source_encoding["attention_mask"].to(device), + } + + def _score_translation_text( + self, + source_context: Dict[str, torch.Tensor], + translation: str, + cache: Dict[str, SequenceScore], + ) -> SequenceScore: + cached = cache.get(translation) + if cached is not None: + return cached + + labels = self._prepare_target_labels(translation) + labels_on_device = labels.to(source_context["input_ids"].device) - # Run forward pass with teacher forcing to get logits at each position self._model.eval() with torch.no_grad(): outputs = self._model( - input_ids=input_ids, - attention_mask=attention_mask, + input_ids=source_context["input_ids"], + attention_mask=source_context["attention_mask"], labels=labels_on_device, ) - # Compute log probabilities from logits (use float32 for numerical precision) - logits = outputs.logits # shape: [1, seq_len, vocab_size] + logits = outputs.logits log_probs = F.log_softmax(logits.float(), dim=-1) - label_ids = labels[0].tolist() - tokens = [self._tokenizer.convert_ids_to_tokens(id) for id in label_ids] - token_log_probs = [log_probs[0, i, id].item() for i, id in enumerate(label_ids)] - - # Get top-k alternative token IDs at each position for suggestions - k = min(self._top_k_suggestions + 1, log_probs.shape[-1]) - top_k_ids = torch.topk(log_probs[0], k=k, dim=-1).indices.tolist() - - word_scores = self._group_tokens_into_words(tokens, token_log_probs, top_k_ids, label_ids) - - return ScoredTranslation( - source=source, - translation=translation, - word_scores=word_scores, + tokens = [self._tokenizer.convert_ids_to_tokens(token_id) for token_id in label_ids] + token_log_probs = [log_probs[0, index, token_id].item() for index, token_id in enumerate(label_ids)] + words = self._group_tokens_into_words(tokens, token_log_probs, label_ids) + sequence = SequenceScore( + text=translation, + label_ids=label_ids, + tokens=tokens, + token_log_probs=token_log_probs, + words=words, ) + cache[translation] = sequence + return sequence - def _is_word_initial(self, token: str) -> bool: - """Return True if this subword token begins a new word.""" - return token.startswith("▁") or token.startswith("Ġ") - - def _decode_suggestion(self, token_id: int) -> Optional[str]: - """Decode a token ID to a display string, returning None for special tokens.""" - if token_id in self._special_token_ids: - return None - token = self._tokenizer.convert_ids_to_tokens(token_id) - if not token: - return None - # Strip SentencePiece (▁) or BPE (Ġ) word-initial markers - if token.startswith("▁"): - stripped = token.removeprefix("▁") - return stripped if stripped else None - if token.startswith("Ġ"): - stripped = token.removeprefix("Ġ") - return stripped if stripped else None - # Strip BERT-style continuation marker - if token.startswith("##"): - return None - return token + def _prepare_target_labels(self, translation: str) -> torch.Tensor: + target_encoding = self._tokenizer(text_target=translation, return_tensors="pt", truncation=True) + labels = target_encoding["input_ids"] + forced_bos_token_id = self._get_forced_bos_token_id() + if forced_bos_token_id is not None and (labels.shape[1] == 0 or labels[0, 0].item() != forced_bos_token_id): + forced_bos = torch.tensor([[forced_bos_token_id]], dtype=labels.dtype) + labels = torch.cat([forced_bos, labels], dim=1) + return labels + + def _build_decoder_prefix_ids(self, prefix_text: str) -> torch.Tensor: + if prefix_text.strip() == "": + return torch.tensor([[self._get_initial_decoder_token_id()]], dtype=torch.long) + prefix_labels = self._prepare_target_labels(prefix_text) + if prefix_labels.shape[1] > 0 and prefix_labels[0, -1].item() == self._tokenizer.eos_token_id: + prefix_labels = prefix_labels[:, :-1] + return prefix_labels + + def _get_initial_decoder_token_id(self) -> int: + forced_bos_token_id = self._get_forced_bos_token_id() + if forced_bos_token_id is not None: + return forced_bos_token_id + decoder_start_token_id = getattr(self._model.config, "decoder_start_token_id", None) + if decoder_start_token_id is not None: + return decoder_start_token_id + bos_token_id = getattr(self._tokenizer, "bos_token_id", None) + if bos_token_id is not None: + return bos_token_id + raise RuntimeError("Unable to determine the initial decoder token id for suggestion generation.") + + def _get_forced_bos_token_id(self) -> Optional[int]: + if hasattr(self._model, "generation_config") and self._model.generation_config is not None: + forced_bos_token_id = getattr(self._model.generation_config, "forced_bos_token_id", None) + if forced_bos_token_id is not None: + return forced_bos_token_id + return getattr(self._model.config, "forced_bos_token_id", None) def _group_tokens_into_words( self, tokens: List[str], token_log_probs: List[float], - top_k_ids: List[List[int]], label_ids: List[int], - ) -> List[WordScore]: - """Group subword tokens into words and compute per-word scores.""" - word_scores: List[WordScore] = [] - current_token_scores: List[TokenScore] = [] + ) -> List[WordSpan]: + words: List[WordSpan] = [] current_word_chars = "" - current_first_top_k: List[int] = [] + current_token_start: Optional[int] = None + current_token_end: Optional[int] = None - for token, log_prob, position_top_k, label_id in zip(tokens, token_log_probs, top_k_ids, label_ids): - # Skip special tokens (language codes, EOS, BOS, pad) + for token_index, (token, _log_prob, label_id) in enumerate(zip(tokens, token_log_probs, label_ids)): if label_id in self._special_token_ids: - if current_token_scores: - word_scores.append( - self._create_word_score(current_word_chars, current_token_scores, current_first_top_k) - ) - current_token_scores = [] + if current_token_start is not None and current_token_end is not None: + words.append(WordSpan(current_word_chars, current_token_start, current_token_end)) current_word_chars = "" - current_first_top_k = [] + current_token_start = None + current_token_end = None continue - # A word-initial token (starts with ▁ or Ġ) begins a new word starts_new_word = self._is_word_initial(token) - if starts_new_word and current_token_scores: - word_scores.append( - self._create_word_score(current_word_chars, current_token_scores, current_first_top_k) - ) - current_token_scores = [] + if starts_new_word and current_token_start is not None and current_token_end is not None: + words.append(WordSpan(current_word_chars, current_token_start, current_token_end)) current_word_chars = "" - current_first_top_k = [] - - # Record top-k alternatives only for the first subword of each word - if not current_token_scores: - current_first_top_k = position_top_k - - current_token_scores.append(TokenScore(token=token, log_prob=log_prob)) - - # Append the token's characters to the current word (strip the word-initial marker) - if token.startswith("▁"): - current_word_chars += token.removeprefix("▁") - elif token.startswith("Ġ"): - current_word_chars += token.removeprefix("Ġ") - elif token.startswith("##"): - current_word_chars += token.removeprefix("##") - else: - current_word_chars += token - - # Finalize the last word - if current_token_scores: - word_scores.append( - self._create_word_score(current_word_chars, current_token_scores, current_first_top_k) - ) + current_token_start = None + current_token_end = None - return word_scores + if current_token_start is None: + current_token_start = token_index + current_token_end = token_index + 1 + current_word_chars += self._token_to_word_piece(token) - def _create_word_score( - self, - word: str, - token_scores: List[TokenScore], - first_token_top_k_ids: List[int], - ) -> WordScore: - """Build a WordScore, adding suggestions when the word is low-probability.""" - word_log_prob = sum(t.log_prob for t in token_scores) - suggestions: List[str] = [] - - if word_log_prob < self._low_prob_threshold: - for top_id in first_token_top_k_ids: - suggestion = self._decode_suggestion(top_id) - if suggestion and suggestion.lower() != word.lower(): - suggestions.append(suggestion) - if len(suggestions) >= self._top_k_suggestions: - break - - return WordScore( - word=word, - tokens=token_scores, - suggestions=suggestions, - low_prob_threshold=self._low_prob_threshold, - ) + if current_token_start is not None and current_token_end is not None: + words.append(WordSpan(current_word_chars, current_token_start, current_token_end)) + + return words + + def _is_word_initial(self, token: str) -> bool: + return token.startswith("▁") or token.startswith("Ġ") + + def _token_to_word_piece(self, token: str) -> str: + if token.startswith("▁"): + return token.removeprefix("▁") + if token.startswith("Ġ"): + return token.removeprefix("Ġ") + if token.startswith("##"): + return token.removeprefix("##") + return token + + def _join_words(self, words: List[str]) -> str: + return " ".join(word for word in words if word != "") diff --git a/tests/smoke_tests/test_translation_scorer.py b/tests/smoke_tests/test_translation_scorer.py index e6d40f8e..559675dc 100644 --- a/tests/smoke_tests/test_translation_scorer.py +++ b/tests/smoke_tests/test_translation_scorer.py @@ -1,230 +1,138 @@ -"""Tests for the TranslationScorer class and related data structures.""" -from unittest.mock import MagicMock, patch - -import torch - -from silnlp.nmt.translation_scorer import ( - DEFAULT_LOW_PROB_THRESHOLD, - DEFAULT_TOP_K_SUGGESTIONS, - ScoredTranslation, - TokenScore, - TranslationScorer, - WordScore, -) - -_TINY_MODEL_NAME = "hf-internal-testing/tiny-random-nllb" - - -def _make_mock_model_and_tokenizer(): - """Create a mock model and tokenizer for testing.""" - from transformers import AutoModelForSeq2SeqLM, AutoTokenizer - from transformers.modeling_outputs import Seq2SeqLMOutput - - tokenizer = AutoTokenizer.from_pretrained(_TINY_MODEL_NAME) - tokenizer.src_lang = "eng_Latn" - tokenizer.tgt_lang = "fra_Latn" - - # Create a minimal mock model that returns deterministic logits - model = MagicMock() - vocab_size = len(tokenizer) - - # Return logits shaped [1, seq_len, vocab_size] with deterministic values - def mock_forward(input_ids=None, attention_mask=None, labels=None, **kwargs): - seq_len = labels.shape[1] if labels is not None else 1 - logits = torch.zeros(1, seq_len, vocab_size) - # Make the first token of the vocabulary highly probable except at position 0 - logits[:, :, 0] = -10.0 - logits[:, :, 1] = -1.0 # slightly low probability - return Seq2SeqLMOutput(logits=logits) - - model.side_effect = None - model.__call__ = MagicMock(side_effect=mock_forward) - model.parameters = MagicMock(return_value=iter([torch.zeros(1)])) - model.eval = MagicMock() - model.generation_config = MagicMock() - model.generation_config.forced_bos_token_id = tokenizer.convert_tokens_to_ids("fra_Latn") - model.config = MagicMock() - model.config.forced_bos_token_id = None - - return model, tokenizer - - -class TestTokenScore: - def test_prob_property(self): - ts = TokenScore(token="▁hello", log_prob=-1.0) - assert abs(ts.prob - 0.3679) < 1e-3 - - def test_zero_log_prob(self): - ts = TokenScore(token="▁hello", log_prob=0.0) - assert ts.prob == 1.0 - - -class TestWordScore: - def test_log_prob_is_sum_of_token_log_probs(self): - ws = WordScore( - word="hello", - tokens=[ - TokenScore(token="▁hel", log_prob=-0.5), - TokenScore(token="lo", log_prob=-0.3), - ], - ) - assert abs(ws.log_prob - (-0.8)) < 1e-6 - - def test_is_low_probability_default_threshold(self): - low = WordScore(word="low", tokens=[TokenScore(token="▁low", log_prob=-5.0)]) - high = WordScore(word="high", tokens=[TokenScore(token="▁high", log_prob=-1.0)]) - assert low.is_low_probability - assert not high.is_low_probability - - def test_is_low_probability_custom_threshold(self): - ws = WordScore( - word="test", - tokens=[TokenScore(token="▁test", log_prob=-2.0)], - low_prob_threshold=-1.0, - ) - assert ws.is_low_probability - - -class TestScoredTranslation: - def test_sequence_log_prob(self): - st = ScoredTranslation( - source="hello", - translation="bonjour", - word_scores=[ - WordScore(word="bonjour", tokens=[TokenScore(token="▁bonjour", log_prob=-0.5)]), - ], - ) - assert abs(st.sequence_log_prob - (-0.5)) < 1e-6 - - def test_low_probability_words(self): - st = ScoredTranslation( - source="hello world", - translation="bonjour monde", - word_scores=[ - WordScore(word="bonjour", tokens=[TokenScore(token="▁bonjour", log_prob=-1.0)]), - WordScore(word="monde", tokens=[TokenScore(token="▁monde", log_prob=-5.0)]), - ], - ) - low = st.low_probability_words - assert len(low) == 1 - assert low[0].word == "monde" - - -class TestTranslationScorer: - def test_score_returns_scored_translation(self): - model, tokenizer = _make_mock_model_and_tokenizer() - scorer = TranslationScorer(model, tokenizer) - result = scorer.score("hello world", "bonjour monde") - assert isinstance(result, ScoredTranslation) - assert result.source == "hello world" - assert result.translation == "bonjour monde" - assert isinstance(result.word_scores, list) - assert len(result.word_scores) > 0 - - def test_each_word_score_has_tokens(self): - model, tokenizer = _make_mock_model_and_tokenizer() - scorer = TranslationScorer(model, tokenizer) - result = scorer.score("hello world", "bonjour monde") - for ws in result.word_scores: - assert isinstance(ws, WordScore) - assert len(ws.tokens) > 0 - for ts in ws.tokens: - assert isinstance(ts, TokenScore) - - def test_low_prob_words_get_suggestions(self): - model, tokenizer = _make_mock_model_and_tokenizer() - # Use a threshold of 0.0 so all words are flagged as low-probability - scorer = TranslationScorer(model, tokenizer, low_prob_threshold=0.0) - result = scorer.score("hello world", "bonjour monde") - for ws in result.word_scores: - assert ws.is_low_probability - # Every flagged word should have suggestions (up to top_k) - assert len(ws.suggestions) <= DEFAULT_TOP_K_SUGGESTIONS - - def test_custom_top_k(self): - model, tokenizer = _make_mock_model_and_tokenizer() - scorer = TranslationScorer(model, tokenizer, low_prob_threshold=0.0, top_k_suggestions=2) - result = scorer.score("hello", "bonjour") - for ws in result.word_scores: - assert len(ws.suggestions) <= 2 - - def test_forced_bos_prepended_to_labels(self): - """Verify that forced_bos_token_id is prepended to labels when not already present.""" - model, tokenizer = _make_mock_model_and_tokenizer() - scorer = TranslationScorer(model, tokenizer) - - # Capture the labels passed to the model - captured_labels = [] - original_call = model.__call__.side_effect - - def capturing_call(**kwargs): - captured_labels.append(kwargs.get("labels")) - return original_call(**kwargs) - - model.__call__.side_effect = capturing_call - scorer.score("hello", "bonjour") - - assert len(captured_labels) == 1 - labels = captured_labels[0] - # The first token should be the forced BOS (language code) - forced_bos = tokenizer.convert_tokens_to_ids("fra_Latn") - assert labels[0, 0].item() == forced_bos - - -class TestWordGrouping: - def test_sentencepiece_tokens_grouped_correctly(self): - """Test that SentencePiece tokens (▁ prefix) are grouped into words.""" - model, tokenizer = _make_mock_model_and_tokenizer() - scorer = TranslationScorer(model, tokenizer) - - # Mock _group_tokens_into_words directly to test grouping logic - tokens = ["▁hello", "▁world", "!"] - token_log_probs = [-1.0, -2.0, -0.5] - label_ids = [101, 102, 103] # fake IDs, not in special_token_ids - top_k_ids = [[1, 2, 3, 4, 5, 6]] * 3 - - # Temporarily clear special token IDs to prevent any skipping - scorer._special_token_ids = set() - words = scorer._group_tokens_into_words(tokens, token_log_probs, top_k_ids, label_ids) - - # "hello" is one word, "world" + "!" are two words - # "!" does not start with ▁, so it continues "world" → "world!" - assert len(words) == 2 - assert words[0].word == "hello" - assert words[1].word == "world!" - - def test_continuation_tokens_merged(self): - """Test that tokens without ▁ are merged with the previous word.""" - model, tokenizer = _make_mock_model_and_tokenizer() - scorer = TranslationScorer(model, tokenizer) - scorer._special_token_ids = set() - - tokens = ["▁walk", "ing"] - token_log_probs = [-0.5, -0.3] - label_ids = [101, 102] - top_k_ids = [[1, 2, 3, 4, 5, 6]] * 2 - - words = scorer._group_tokens_into_words(tokens, token_log_probs, top_k_ids, label_ids) - assert len(words) == 1 - assert words[0].word == "walking" - assert len(words[0].tokens) == 2 - - def test_special_tokens_skipped(self): - """Special tokens (EOS, BOS, lang codes) should not appear as words.""" - model, tokenizer = _make_mock_model_and_tokenizer() - scorer = TranslationScorer(model, tokenizer) - - # Get some real special token IDs - eos_id = tokenizer.eos_token_id - special_ids = {eos_id} - scorer._special_token_ids = special_ids - - tokens = ["▁hello", tokenizer.eos_token] - token_log_probs = [-1.0, 0.0] - label_ids = [101, eos_id] - top_k_ids = [[1, 2, 3, 4, 5, 6]] * 2 - - words = scorer._group_tokens_into_words(tokens, token_log_probs, top_k_ids, label_ids) - # EOS token should be skipped; only "hello" remains - assert len(words) == 1 - assert words[0].word == "hello" +from types import SimpleNamespace +from unittest.mock import patch + +from silnlp.nmt.translation_scorer import PhraseScore, SequenceScore, TokenScore, TranslationScorer, WordSpan + + +class FakeTokenizer: + all_special_ids = [0, 1, 2] + eos_token_id = 1 + bos_token_id = 2 + + +class FakeModel: + def __init__(self): + self.generation_config = SimpleNamespace(forced_bos_token_id=2) + self.config = SimpleNamespace(forced_bos_token_id=None, decoder_start_token_id=2) + + def eval(self): + return None + + +def build_sequence(text: str, word_log_probs: list[float]) -> SequenceScore: + words = text.split() if text else [] + return SequenceScore( + text=text, + label_ids=list(range(10, 10 + len(words))), + tokens=[f"▁{word}" for word in words], + token_log_probs=word_log_probs, + words=[WordSpan(word, index, index + 1) for index, word in enumerate(words)], + ) + + +def make_scorer(**kwargs) -> TranslationScorer: + return TranslationScorer(FakeModel(), FakeTokenizer(), min_suggestion_improvement=0.0, **kwargs) + + +def test_contextual_word_scoring_uses_right_context_delta(): + scorer = make_scorer(low_prob_threshold=-3.0, max_phrase_words=2) + scores = { + "bleu maison vite": build_sequence("bleu maison vite", [-1.0, -5.0, -0.5]), + "maison vite": build_sequence("maison vite", [-0.5, -0.5]), + "bleu vite": build_sequence("bleu vite", [-1.0, -0.2]), + "vite": build_sequence("vite", [-0.2]), + } + + with ( + patch.object(scorer, "_encode_source", return_value={}), + patch.object(scorer, "_score_translation_text", side_effect=lambda _ctx, text, _cache: scores[text]), + patch.object(scorer, "_generate_candidate_phrases", return_value=[]), + ): + result = scorer.score("blue house quickly", "bleu maison vite") + + bleu = next(word for word in result.word_scores if word.word == "bleu") + maison = next(word for word in result.word_scores if word.word == "maison") + + assert bleu.forward_log_prob == -1.0 + assert bleu.right_context_log_prob == -4.5 + assert bleu.contextual_log_prob == -5.5 + assert bleu.is_low_probability + assert maison.contextual_log_prob < maison.forward_log_prob + + +def test_phrase_scores_are_length_normalized_for_variable_length_phrases(): + scorer = make_scorer(low_prob_threshold=-3.0, max_phrase_words=3) + scores = { + "bleu maison vite": build_sequence("bleu maison vite", [-1.0, -5.0, -0.5]), + "maison vite": build_sequence("maison vite", [-0.5, -0.5]), + "bleu vite": build_sequence("bleu vite", [-1.0, -0.2]), + "vite": build_sequence("vite", [-0.2]), + "bleu maison": build_sequence("bleu maison", [-1.0, -0.3]), + } + + with ( + patch.object(scorer, "_encode_source", return_value={}), + patch.object(scorer, "_score_translation_text", side_effect=lambda _ctx, text, _cache: scores[text]), + patch.object(scorer, "_generate_candidate_phrases", return_value=[]), + ): + result = scorer.score("blue house quickly", "bleu maison vite") + + phrase = next(score for score in result.phrase_scores if score.phrase == "bleu maison") + assert isinstance(phrase, PhraseScore) + assert phrase.forward_log_prob == -6.0 + assert phrase.right_context_log_prob == -0.3 + assert phrase.contextual_log_prob == -6.3 + assert phrase.normalized_log_prob == -3.15 + assert phrase.is_low_probability + + +def test_replacement_suggestions_are_rescored_with_right_context_and_can_change_length(): + scorer = make_scorer(low_prob_threshold=-3.0, top_k_suggestions=3, max_phrase_words=2) + scores = { + "bleu maison vite": build_sequence("bleu maison vite", [-1.0, -5.0, -0.5]), + "maison vite": build_sequence("maison vite", [-0.5, -0.5]), + "blue maison vite": build_sequence("blue maison vite", [-0.8, -0.2, -0.5]), + "blue house maison vite": build_sequence("blue house maison vite", [-0.7, -0.3, -0.3, -0.4]), + "azure maison vite": build_sequence("azure maison vite", [-1.5, -0.6, -0.8]), + "bleu vite": build_sequence("bleu vite", [-1.0, -0.2]), + "blue vite": build_sequence("blue vite", [-0.7, -0.2]), + "blue house vite": build_sequence("blue house vite", [-0.7, -0.4, -0.1]), + "azure vite": build_sequence("azure vite", [-1.3, -0.3]), + "vite": build_sequence("vite", [-0.2]), + } + + with ( + patch.object(scorer, "_encode_source", return_value={}), + patch.object(scorer, "_score_translation_text", side_effect=lambda _ctx, text, _cache: scores[text]), + patch.object(scorer, "_generate_candidate_phrases", return_value=["blue", "blue house", "azure"]), + ): + result = scorer.score("blue house quickly", "bleu maison vite") + + bleu = next(word for word in result.word_scores if word.word == "bleu") + phrase = next(score for score in result.phrase_scores if score.phrase == "bleu maison") + + assert [suggestion.phrase for suggestion in bleu.suggestions] == ["blue house", "blue", "azure"] + assert bleu.suggestions[0].normalized_log_prob > bleu.suggestions[1].normalized_log_prob + assert bleu.suggestions[0].improvement > 0 + + assert [suggestion.phrase for suggestion in phrase.suggestions] == ["blue house", "blue", "azure"] + assert phrase.suggestions[0].phrase == "blue house" + assert phrase.suggestions[0].normalized_log_prob > phrase.normalized_log_prob + + +def test_sequence_log_probability_still_tracks_forward_sentence_score(): + scorer = make_scorer(low_prob_threshold=-10.0, max_phrase_words=1) + scores = { + "bleu maison": build_sequence("bleu maison", [-1.0, -2.0]), + "maison": build_sequence("maison", [-0.5]), + } + + with ( + patch.object(scorer, "_encode_source", return_value={}), + patch.object(scorer, "_score_translation_text", side_effect=lambda _ctx, text, _cache: scores[text]), + patch.object(scorer, "_generate_candidate_phrases", return_value=[]), + ): + result = scorer.score("blue house", "bleu maison") + + assert result.sequence_log_prob == -3.0 + assert all(isinstance(token, TokenScore) for word in result.word_scores for token in word.tokens) From 59314491be56bc009c5d14ccfb33baa43f86af08 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 28 May 2026 16:16:24 +0000 Subject: [PATCH 5/5] Add NLLB translation scorer prototype web app --- pyproject.toml | 2 +- silnlp/nmt/translation_scorer_webapp.py | 479 ++++++++++++++++++ .../test_translation_scorer_webapp.py | 21 + 3 files changed, 501 insertions(+), 1 deletion(-) create mode 100644 silnlp/nmt/translation_scorer_webapp.py create mode 100644 tests/smoke_tests/test_translation_scorer_webapp.py diff --git a/pyproject.toml b/pyproject.toml index bf0f4987..092000d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ silnlp-nmt-preprocess = "silnlp.nmt.preprocess:main" silnlp-nmt-train = "silnlp.nmt.train:main" silnlp-nmt-test = "silnlp.nmt.test:main" silnlp-nmt-translate = "silnlp.nmt.translate:main" +silnlp-nmt-score-webapp = "silnlp.nmt.translation_scorer_webapp:main" silnlp-alignment-preprocess = "silnlp.alignment.preprocess:main" silnlp-alignment-align = "silnlp.alignment.align:main" @@ -122,4 +123,3 @@ eflomal = ["eflomal"] name = "torch" url = "https://download.pytorch.org/whl/cu121" priority = "explicit" - diff --git a/silnlp/nmt/translation_scorer_webapp.py b/silnlp/nmt/translation_scorer_webapp.py new file mode 100644 index 00000000..faf45ade --- /dev/null +++ b/silnlp/nmt/translation_scorer_webapp.py @@ -0,0 +1,479 @@ +import argparse +import json +import logging +import re +import threading +from dataclasses import dataclass +from html import escape +from http import HTTPStatus +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import Any, Dict, List, Sequence + +from silnlp.common.iso_info import NLLB_TAGS + +LOGGER = logging.getLogger(__name__) + +DEFAULT_MODEL_NAME = "facebook/nllb-200-distilled-600M" +DEFAULT_HOST = "127.0.0.1" +DEFAULT_PORT = 8080 +DEFAULT_LOW_PROB_THRESHOLD = -3.0 +DEFAULT_TOP_K_SUGGESTIONS = 5 +DEFAULT_SOURCE_LANG = "eng_Latn" +DEFAULT_TARGET_LANG = "fra_Latn" +SUPPORTED_LANGUAGES = sorted(set(NLLB_TAGS)) + + +@dataclass(frozen=True) +class _FlagCandidate: + text: str + span_start: int + span_end: int + normalized_log_prob: float + suggestions: List[Dict[str, float | str]] + kind: str + + @property + def length(self) -> int: + return self.span_end - self.span_start + + +def _word_char_spans(text: str) -> List[tuple[int, int]]: + return [(match.start(), match.end()) for match in re.finditer(r"\S+", text)] + + +def _build_suggestions(score: Any) -> List[Dict[str, float | str]]: + return [ + { + "phrase": suggestion.phrase, + "normalized_log_prob": suggestion.normalized_log_prob, + "improvement": suggestion.improvement, + } + for suggestion in score.suggestions + ] + + +def _collect_flag_candidates(scored: Any) -> List[_FlagCandidate]: + candidates: List[_FlagCandidate] = [] + for score in scored.low_probability_phrases: + candidates.append( + _FlagCandidate( + text=score.phrase, + span_start=score.span_start, + span_end=score.span_end, + normalized_log_prob=score.normalized_log_prob, + suggestions=_build_suggestions(score), + kind="phrase", + ) + ) + + for score in scored.low_probability_words: + candidates.append( + _FlagCandidate( + text=score.word, + span_start=score.span_start, + span_end=score.span_end, + normalized_log_prob=score.normalized_log_prob, + suggestions=_build_suggestions(score), + kind="word", + ) + ) + + return candidates + + +def _select_non_overlapping_flags(candidates: Sequence[_FlagCandidate]) -> List[_FlagCandidate]: + selected: List[_FlagCandidate] = [] + for candidate in sorted(candidates, key=lambda c: (c.span_start, -c.length, c.normalized_log_prob)): + overlaps = any( + candidate.span_start < selected_candidate.span_end and selected_candidate.span_start < candidate.span_end + for selected_candidate in selected + ) + if not overlaps: + selected.append(candidate) + return selected + + +def _format_flags(scored: Any) -> List[Dict[str, Any]]: + char_spans = _word_char_spans(scored.translation) + flags: List[Dict[str, Any]] = [] + + for index, candidate in enumerate(_select_non_overlapping_flags(_collect_flag_candidates(scored))): + if candidate.span_end <= 0 or candidate.span_end > len(char_spans): + continue + char_start = char_spans[candidate.span_start][0] + char_end = char_spans[candidate.span_end - 1][1] + flags.append( + { + "id": f"flag-{index}", + "text": candidate.text, + "kind": candidate.kind, + "span_start": candidate.span_start, + "span_end": candidate.span_end, + "char_start": char_start, + "char_end": char_end, + "normalized_log_prob": candidate.normalized_log_prob, + "suggestions": candidate.suggestions, + } + ) + + return flags + + +class NllbScoringService: + def __init__( + self, + model_name: str = DEFAULT_MODEL_NAME, + low_prob_threshold: float = DEFAULT_LOW_PROB_THRESHOLD, + top_k_suggestions: int = DEFAULT_TOP_K_SUGGESTIONS, + ): + self._model_name = model_name + self._low_prob_threshold = low_prob_threshold + self._top_k_suggestions = top_k_suggestions + self._lock = threading.Lock() + self._tokenizer = None + self._model = None + + def _ensure_model(self) -> None: + if self._model is not None and self._tokenizer is not None: + return + import torch + from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + + LOGGER.info("Loading model %s", self._model_name) + tokenizer = AutoTokenizer.from_pretrained(self._model_name) + model = AutoModelForSeq2SeqLM.from_pretrained(self._model_name) + model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) + model.eval() + self._tokenizer = tokenizer + self._model = model + + def _configure_languages(self, source_lang: str, target_lang: str) -> None: + assert self._tokenizer is not None + assert self._model is not None + + if source_lang not in self._tokenizer.lang_code_to_id: + raise ValueError(f"Unsupported source language: {source_lang}") + if target_lang not in self._tokenizer.lang_code_to_id: + raise ValueError(f"Unsupported target language: {target_lang}") + + self._tokenizer.src_lang = source_lang + self._tokenizer.tgt_lang = target_lang + forced_bos_token_id = self._tokenizer.lang_code_to_id[target_lang] + self._model.config.forced_bos_token_id = forced_bos_token_id + if getattr(self._model, "generation_config", None) is not None: + self._model.generation_config.forced_bos_token_id = forced_bos_token_id + + def score(self, source: str, translation: str, source_lang: str, target_lang: str) -> Dict[str, Any]: + source = source.strip() + translation = translation.strip() + if source == "" or translation == "": + return {"source": source, "translation": translation, "flags": []} + + with self._lock: + self._ensure_model() + self._configure_languages(source_lang, target_lang) + assert self._model is not None + assert self._tokenizer is not None + from silnlp.nmt.translation_scorer import TranslationScorer + + scorer = TranslationScorer( + self._model, + self._tokenizer, + low_prob_threshold=self._low_prob_threshold, + top_k_suggestions=self._top_k_suggestions, + ) + scored = scorer.score(source, translation) + return { + "source": source, + "translation": scored.translation, + "flags": _format_flags(scored), + } + + +def _html_template() -> str: + languages_json = json.dumps(SUPPORTED_LANGUAGES) + default_source = escape(DEFAULT_SOURCE_LANG) + default_target = escape(DEFAULT_TARGET_LANG) + return f""" + + + + NLLB Translation Scorer Prototype + + + +

NLLB 600M Translation Scorer Prototype

+
+
+ + + + +
+
+ + + + + +
+
Click a highlighted phrase to view alternatives.
+
+
+
Waiting for input.
+ + + +""" + + +class TranslationScorerHttpHandler(BaseHTTPRequestHandler): + scorer_service: NllbScoringService = NllbScoringService() + + def _send_json(self, body: Dict[str, Any], status: int = HTTPStatus.OK) -> None: + response = json.dumps(body).encode("utf-8") + self.send_response(status) + self.send_header("Content-Type", "application/json; charset=utf-8") + self.send_header("Content-Length", str(len(response))) + self.end_headers() + self.wfile.write(response) + + def _send_html(self, body: str, status: int = HTTPStatus.OK) -> None: + response = body.encode("utf-8") + self.send_response(status) + self.send_header("Content-Type", "text/html; charset=utf-8") + self.send_header("Content-Length", str(len(response))) + self.end_headers() + self.wfile.write(response) + + def do_GET(self) -> None: # noqa: N802 + if self.path in ("/", "/index.html"): + self._send_html(_html_template()) + return + if self.path == "/api/languages": + self._send_json({"languages": SUPPORTED_LANGUAGES}) + return + self._send_json({"error": "Not found"}, status=HTTPStatus.NOT_FOUND) + + def do_POST(self) -> None: # noqa: N802 + if self.path != "/api/score": + self._send_json({"error": "Not found"}, status=HTTPStatus.NOT_FOUND) + return + + content_length = int(self.headers.get("Content-Length", "0")) + if content_length <= 0: + self._send_json({"error": "Request body is required"}, status=HTTPStatus.BAD_REQUEST) + return + + payload = json.loads(self.rfile.read(content_length).decode("utf-8")) + source = payload.get("source", "") + translation = payload.get("translation", "") + source_lang = payload.get("source_lang", DEFAULT_SOURCE_LANG) + target_lang = payload.get("target_lang", DEFAULT_TARGET_LANG) + + try: + scored = self.scorer_service.score(source, translation, source_lang, target_lang) + except ValueError as error: + self._send_json({"error": str(error)}, status=HTTPStatus.BAD_REQUEST) + return + except Exception: + LOGGER.exception("Failed to score translation") + self._send_json({"error": "Scoring failed"}, status=HTTPStatus.INTERNAL_SERVER_ERROR) + return + + self._send_json(scored) + + def log_message(self, format: str, *args: Any) -> None: + LOGGER.info("%s - %s", self.address_string(), format % args) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run a prototype web app for NLLB translation scoring.") + parser.add_argument("--host", type=str, default=DEFAULT_HOST, help=f"Host to bind (default: {DEFAULT_HOST})") + parser.add_argument("--port", type=int, default=DEFAULT_PORT, help=f"Port to bind (default: {DEFAULT_PORT})") + parser.add_argument( + "--low-prob-threshold", + type=float, + default=DEFAULT_LOW_PROB_THRESHOLD, + help=f"Contextual log-probability threshold for highlights (default: {DEFAULT_LOW_PROB_THRESHOLD})", + ) + parser.add_argument( + "--top-k-suggestions", + type=int, + default=DEFAULT_TOP_K_SUGGESTIONS, + help=f"Number of alternatives to return per highlight (default: {DEFAULT_TOP_K_SUGGESTIONS})", + ) + parser.add_argument( + "--model-name", + type=str, + default=DEFAULT_MODEL_NAME, + help=f"Model name or path (default: {DEFAULT_MODEL_NAME})", + ) + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + TranslationScorerHttpHandler.scorer_service = NllbScoringService( + model_name=args.model_name, + low_prob_threshold=args.low_prob_threshold, + top_k_suggestions=args.top_k_suggestions, + ) + + server = ThreadingHTTPServer((args.host, args.port), TranslationScorerHttpHandler) + LOGGER.info("Serving translation scorer prototype at http://%s:%s", args.host, args.port) + try: + server.serve_forever() + except KeyboardInterrupt: + LOGGER.info("Stopping server") + finally: + server.server_close() + + +if __name__ == "__main__": + main() diff --git a/tests/smoke_tests/test_translation_scorer_webapp.py b/tests/smoke_tests/test_translation_scorer_webapp.py new file mode 100644 index 00000000..940c3c38 --- /dev/null +++ b/tests/smoke_tests/test_translation_scorer_webapp.py @@ -0,0 +1,21 @@ +from silnlp.nmt.translation_scorer_webapp import _FlagCandidate, _select_non_overlapping_flags, _word_char_spans + + +def test_word_char_spans_returns_non_whitespace_ranges(): + assert _word_char_spans("bleu maison vite") == [(0, 4), (7, 13), (14, 18)] + + +def test_select_non_overlapping_flags_prefers_longer_phrases_at_same_start(): + candidates = [ + _FlagCandidate("bleu", 0, 1, -4.0, [], "word"), + _FlagCandidate("bleu maison", 0, 2, -3.8, [], "phrase"), + _FlagCandidate("maison", 1, 2, -3.6, [], "word"), + _FlagCandidate("vite", 2, 3, -4.5, [], "word"), + ] + + selected = _select_non_overlapping_flags(candidates) + + assert [(candidate.text, candidate.span_start, candidate.span_end) for candidate in selected] == [ + ("bleu maison", 0, 2), + ("vite", 2, 3), + ]