diff --git a/pyproject.toml b/pyproject.toml index bf0f4987..092000d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ silnlp-nmt-preprocess = "silnlp.nmt.preprocess:main" silnlp-nmt-train = "silnlp.nmt.train:main" silnlp-nmt-test = "silnlp.nmt.test:main" silnlp-nmt-translate = "silnlp.nmt.translate:main" +silnlp-nmt-score-webapp = "silnlp.nmt.translation_scorer_webapp:main" silnlp-alignment-preprocess = "silnlp.alignment.preprocess:main" silnlp-alignment-align = "silnlp.alignment.align:main" @@ -122,4 +123,3 @@ eflomal = ["eflomal"] name = "torch" url = "https://download.pytorch.org/whl/cu121" priority = "explicit" - diff --git a/silnlp/nmt/config.py b/silnlp/nmt/config.py index 751f19ff..19d440ed 100644 --- a/silnlp/nmt/config.py +++ b/silnlp/nmt/config.py @@ -8,7 +8,10 @@ from enum import Enum, auto from pathlib import Path from statistics import mean, median, stdev -from typing import Dict, Generator, Iterable, List, Optional, Set, TextIO, Tuple, Union, cast +from typing import TYPE_CHECKING, Dict, Generator, Iterable, List, Optional, Set, TextIO, Tuple, Union, cast + +if TYPE_CHECKING: + from .translation_scorer import ScoredTranslation import pandas as pd from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, get_books @@ -99,6 +102,18 @@ def clear_cache(self) -> None: ... @abstractmethod def get_num_drafts(self) -> int: ... + @abstractmethod + def score_translation( + self, + source: str, + translation: str, + src_iso: str, + trg_iso: str, + ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, + low_prob_threshold: float = -3.0, + top_k_suggestions: int = 5, + ) -> "ScoredTranslation": ... + class Config(ABC): def __init__(self, exp_dir: Path, config: dict) -> None: diff --git a/silnlp/nmt/hugging_face_config.py b/silnlp/nmt/hugging_face_config.py index 0399de9c..1457adfb 100644 --- a/silnlp/nmt/hugging_face_config.py +++ b/silnlp/nmt/hugging_face_config.py @@ -11,7 +11,10 @@ from itertools import repeat from math import prod from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, TypeVar, Union, cast + +if TYPE_CHECKING: + from .translation_scorer import ScoredTranslation import datasets.utils.logging as datasets_logging import evaluate @@ -1406,6 +1409,52 @@ def clear_cache(self) -> None: self._cached_inference_model = None self._inference_model_params = None + def score_translation( + self, + source: str, + translation: str, + src_iso: str, + trg_iso: str, + ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, + low_prob_threshold: float = -3.0, + top_k_suggestions: int = 5, + ) -> "ScoredTranslation": + """Score a translation using forced decoding. + + Computes the conditional probability P(y_t | y_1, ..., y_{t-1}, x) for each + target token in the translation, groups subword tokens into words, flags + low-probability words, and provides top-k alternative suggestions for them. + + Args: + source: The source sentence. + translation: The translation to score. + src_iso: Source language ISO code. + trg_iso: Target language ISO code. + ckpt: Checkpoint to use for the model. + low_prob_threshold: Log-probability threshold below which a word is flagged. + top_k_suggestions: Number of alternative suggestions per low-probability word. + + Returns: + A ScoredTranslation with per-word scores and suggestions. + """ + from .translation_scorer import ScoredTranslation, TranslationScorer + + src_lang = self._config.data["lang_codes"].get(src_iso, src_iso) + trg_lang = self._config.data["lang_codes"].get(trg_iso, trg_iso) + inference_model_params = InferenceModelParams(ckpt, src_lang, trg_lang) + tokenizer = self._config.get_tokenizer() + if self._inference_model_params == inference_model_params and self._cached_inference_model is not None: + model = self._cached_inference_model + else: + model = self._cached_inference_model = self._create_inference_model(ckpt, tokenizer, src_lang, trg_lang) + self._inference_model_params = inference_model_params + + if isinstance(tokenizer, (NllbTokenizer, NllbTokenizerFast)): + tokenizer = PunctuationNormalizingTokenizer(tokenizer) + + scorer = TranslationScorer(model, tokenizer, low_prob_threshold, top_k_suggestions) + return scorer.score(source, translation) + def _create_training_arguments(self) -> Seq2SeqTrainingArguments: parser = HfArgumentParser(Seq2SeqTrainingArguments) args: dict = {} diff --git a/silnlp/nmt/score_translations.py b/silnlp/nmt/score_translations.py new file mode 100644 index 00000000..419201ac --- /dev/null +++ b/silnlp/nmt/score_translations.py @@ -0,0 +1,140 @@ +import argparse +import logging +from typing import Optional + +from .config_utils import load_config +from .translation_scorer import DEFAULT_LOW_PROB_THRESHOLD, DEFAULT_TOP_K_SUGGESTIONS, PhraseScore, ScoredTranslation + +LOGGER = logging.getLogger(__name__) + + +def _format_suggestions(phrase_score: PhraseScore) -> str: + if not phrase_score.suggestions: + return "" + return "; ".join( + f"{suggestion.phrase} (ctx/word={suggestion.normalized_log_prob:.4f}, Δ={suggestion.improvement:.4f})" + for suggestion in phrase_score.suggestions + ) + + +def format_scored_translation(scored: ScoredTranslation) -> str: + """Format a ScoredTranslation as a human-readable string.""" + lines = [ + f"Source: {scored.source}", + f"Translation: {scored.translation}", + f"Sequence log-probability: {scored.sequence_log_prob:.4f}", + "", + "Word-level contextual scores:", + ] + + col_word = max((len(score.word) for score in scored.word_scores), default=4) + header = ( + f" {'Word':<{col_word}} {'Forward':>10} {'Right Ctx':>10} {'Ctx/Word':>10} Suggestions" + ) + lines.append(header) + lines.append(" " + "-" * (len(header) - 2)) + for score in scored.word_scores: + flag = "* " if score.is_low_probability else " " + suggestion_text = ", ".join(suggestion.phrase for suggestion in score.suggestions) + lines.append( + f"{flag}{score.word:<{col_word}} {score.forward_log_prob:>10.4f} " + f"{score.right_context_log_prob:>10.4f} {score.normalized_log_prob:>10.4f} {suggestion_text}" + ) + + lines.append("") + lines.append("Low-probability phrases:") + low_phrases = scored.low_probability_phrases + if not low_phrases: + lines.append(" None") + else: + for score in low_phrases: + lines.append( + f" '{score.phrase}' [{score.span_start}:{score.span_end}] " + f"forward={score.forward_log_prob:.4f}, right_ctx={score.right_context_log_prob:.4f}, " + f"ctx/word={score.normalized_log_prob:.4f}" + ) + suggestions = _format_suggestions(score) + if suggestions: + lines.append(f" Suggestions: {suggestions}") + + return "\n".join(lines) + + +def score_translation( + experiment: str, + source: str, + translation: str, + src_iso: Optional[str], + trg_iso: Optional[str], + checkpoint: str = "last", + low_prob_threshold: float = DEFAULT_LOW_PROB_THRESHOLD, + top_k_suggestions: int = DEFAULT_TOP_K_SUGGESTIONS, +) -> ScoredTranslation: + config = load_config(experiment) + model = config.create_model() + + effective_src_iso = src_iso or config.default_test_src_iso + effective_trg_iso = trg_iso or config.default_test_trg_iso + + return model.score_translation( + source=source, + translation=translation, + src_iso=effective_src_iso, + trg_iso=effective_trg_iso, + ckpt=checkpoint, + low_prob_threshold=low_prob_threshold, + top_k_suggestions=top_k_suggestions, + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description=( + "Score a translation by computing contextual phrase probabilities. " + "Low-probability words and phrases are flagged and paired with rescored replacement suggestions." + ) + ) + parser.add_argument("experiment", help="Experiment name") + parser.add_argument("--source", type=str, required=True, help="Source sentence to score against") + parser.add_argument("--translation", type=str, required=True, help="Translation to evaluate") + parser.add_argument("--src-iso", type=str, default=None, help="Source language ISO code") + parser.add_argument("--trg-iso", type=str, default=None, help="Target language ISO code") + parser.add_argument( + "--checkpoint", + type=str, + default="last", + help="Checkpoint to use: 'last', 'best', 'avg', or a checkpoint step number", + ) + parser.add_argument( + "--low-prob-threshold", + type=float, + default=DEFAULT_LOW_PROB_THRESHOLD, + help=( + f"Threshold on contextual log-probability per word for flagging low-probability spans " + f"(default: {DEFAULT_LOW_PROB_THRESHOLD})" + ), + ) + parser.add_argument( + "--top-k-suggestions", + type=int, + default=DEFAULT_TOP_K_SUGGESTIONS, + help=f"Number of replacement suggestions per flagged span (default: {DEFAULT_TOP_K_SUGGESTIONS})", + ) + + args = parser.parse_args() + + scored = score_translation( + experiment=args.experiment, + source=args.source, + translation=args.translation, + src_iso=args.src_iso, + trg_iso=args.trg_iso, + checkpoint=args.checkpoint, + low_prob_threshold=args.low_prob_threshold, + top_k_suggestions=args.top_k_suggestions, + ) + print(format_scored_translation(scored)) + + +if __name__ == "__main__": + main() diff --git a/silnlp/nmt/translation_scorer.py b/silnlp/nmt/translation_scorer.py new file mode 100644 index 00000000..06aca129 --- /dev/null +++ b/silnlp/nmt/translation_scorer.py @@ -0,0 +1,529 @@ +import logging +from dataclasses import dataclass, field +from math import exp +from typing import Dict, List, Optional, Set, Tuple + +import torch +import torch.nn.functional as F +from transformers import PreTrainedModel, PreTrainedTokenizer + +LOGGER = logging.getLogger(__name__) + +DEFAULT_LOW_PROB_THRESHOLD = -3.0 +DEFAULT_TOP_K_SUGGESTIONS = 5 +DEFAULT_MAX_PHRASE_WORDS = 3 +DEFAULT_MIN_SUGGESTION_IMPROVEMENT = 0.25 +DEFAULT_GENERATION_BEAMS = 8 +DEFAULT_MAX_CANDIDATE_TOKENS = 12 + + +@dataclass +class TokenScore: + """Score information for a single subword token.""" + + token: str + log_prob: float + + @property + def prob(self) -> float: + return exp(self.log_prob) + + +@dataclass +class PhraseSuggestion: + """A replacement phrase suggestion scored with both left and right context.""" + + phrase: str + forward_log_prob: float + right_context_log_prob: float + contextual_log_prob: float + improvement: float + + @property + def word_count(self) -> int: + return max(len(self.phrase.split()), 1) + + @property + def normalized_log_prob(self) -> float: + return self.contextual_log_prob / self.word_count + + +@dataclass +class WordScore: + """Context-aware score information for a single word.""" + + word: str + tokens: List[TokenScore] + forward_log_prob: float + right_context_log_prob: float + contextual_log_prob: float + span_start: int + span_end: int + suggestions: List[PhraseSuggestion] = field(default_factory=list) + low_prob_threshold: float = field(default=DEFAULT_LOW_PROB_THRESHOLD, compare=False, repr=False) + + @property + def word_count(self) -> int: + return 1 + + @property + def log_prob(self) -> float: + return self.contextual_log_prob + + @property + def normalized_log_prob(self) -> float: + return self.contextual_log_prob + + @property + def prob(self) -> float: + return exp(self.contextual_log_prob) + + @property + def is_low_probability(self) -> bool: + return self.normalized_log_prob < self.low_prob_threshold + + +@dataclass +class PhraseScore: + """Context-aware score for a variable-length phrase span.""" + + phrase: str + word_scores: List[WordScore] + forward_log_prob: float + right_context_log_prob: float + contextual_log_prob: float + suggestions: List[PhraseSuggestion] = field(default_factory=list) + low_prob_threshold: float = field(default=DEFAULT_LOW_PROB_THRESHOLD, compare=False, repr=False) + + @property + def span_start(self) -> int: + return self.word_scores[0].span_start + + @property + def span_end(self) -> int: + return self.word_scores[-1].span_end + + @property + def word_count(self) -> int: + return len(self.word_scores) + + @property + def normalized_log_prob(self) -> float: + return self.contextual_log_prob / max(self.word_count, 1) + + @property + def is_low_probability(self) -> bool: + return self.normalized_log_prob < self.low_prob_threshold + + +@dataclass +class ScoredTranslation: + """Result of scoring a translation against a source sentence.""" + + source: str + translation: str + word_scores: List[WordScore] + phrase_scores: List[PhraseScore] + + @property + def sequence_log_prob(self) -> float: + return sum(w.forward_log_prob for w in self.word_scores) + + @property + def low_probability_words(self) -> List[WordScore]: + return [w for w in self.word_scores if w.is_low_probability] + + @property + def low_probability_phrases(self) -> List[PhraseScore]: + return [p for p in self.phrase_scores if p.is_low_probability] + + +@dataclass +class WordSpan: + text: str + token_start: int + token_end: int + + +@dataclass +class SequenceScore: + text: str + label_ids: List[int] + tokens: List[str] + token_log_probs: List[float] + words: List[WordSpan] + + def __post_init__(self) -> None: + self._word_forward_log_probs = [ + sum(self.token_log_probs[word.token_start : word.token_end]) for word in self.words + ] + cumulative = [0.0] + total = 0.0 + for value in self._word_forward_log_probs: + total += value + cumulative.append(total) + self._cumulative_word_forward_log_probs = cumulative + + @property + def word_texts(self) -> List[str]: + return [word.text for word in self.words] + + def phrase_text(self, start: int, end: int) -> str: + return " ".join(self.word_texts[start:end]) + + def word_forward_log_prob(self, start: int, end: int) -> float: + return self._cumulative_word_forward_log_probs[end] - self._cumulative_word_forward_log_probs[start] + + +class TranslationScorer: + """Scores translations with contextual phrase rescoring and replacement suggestions. + + The left-to-right decoder supplies token log probabilities for the observed phrase. + To incorporate right-context evidence, each phrase is rescored against the observed + suffix using a contextual span objective: + + contextual(span) = log P(span | prefix, x) + + log P(suffix | prefix + span, x) + - log P(suffix | prefix, x) + + This acts like a pointwise contextual compatibility score. The same score, normalized + by phrase length in words, is used for both anomaly detection and for ranking + replacement candidates of different lengths. + """ + + def __init__( + self, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + low_prob_threshold: float = DEFAULT_LOW_PROB_THRESHOLD, + top_k_suggestions: int = DEFAULT_TOP_K_SUGGESTIONS, + max_phrase_words: int = DEFAULT_MAX_PHRASE_WORDS, + min_suggestion_improvement: float = DEFAULT_MIN_SUGGESTION_IMPROVEMENT, + ): + self._model = model + self._tokenizer = tokenizer + self._low_prob_threshold = low_prob_threshold + self._top_k_suggestions = top_k_suggestions + self._max_phrase_words = max_phrase_words + self._min_suggestion_improvement = min_suggestion_improvement + self._special_token_ids: Set[int] = set(tokenizer.all_special_ids) + + def score(self, source: str, translation: str) -> ScoredTranslation: + source_context = self._encode_source(source) + cache: Dict[str, SequenceScore] = {} + base = self._score_translation_text(source_context, translation, cache) + + phrase_scores = self._build_phrase_scores(source_context, base, cache) + word_scores = [self._to_word_score(phrase_score) for phrase_score in phrase_scores if phrase_score.word_count == 1] + multi_word_phrase_scores = [phrase_score for phrase_score in phrase_scores if phrase_score.word_count > 1] + + return ScoredTranslation( + source=source, + translation=translation, + word_scores=word_scores, + phrase_scores=multi_word_phrase_scores, + ) + + def _build_phrase_scores( + self, + source_context: Dict[str, torch.Tensor], + base: SequenceScore, + cache: Dict[str, SequenceScore], + ) -> List[PhraseScore]: + phrase_scores: List[PhraseScore] = [] + + for phrase_len in range(1, min(self._max_phrase_words, len(base.words)) + 1): + for start in range(0, len(base.words) - phrase_len + 1): + end = start + phrase_len + phrase_text = base.phrase_text(start, end) + forward_log_prob = base.word_forward_log_prob(start, end) + right_context_log_prob, ablated = self._score_right_context( + source_context, + base, + start, + end, + cache, + ) + contextual_log_prob = forward_log_prob + right_context_log_prob + span_word_scores = self._create_span_word_scores(base, start, end, right_context_log_prob) + phrase_score = PhraseScore( + phrase=phrase_text, + word_scores=span_word_scores, + forward_log_prob=forward_log_prob, + right_context_log_prob=right_context_log_prob, + contextual_log_prob=contextual_log_prob, + low_prob_threshold=self._low_prob_threshold, + ) + if phrase_score.is_low_probability: + phrase_score.suggestions = self._generate_scored_suggestions( + source_context, + base, + ablated, + phrase_score, + cache, + ) + phrase_scores.append(phrase_score) + # Reuse phrase-level suggestions for the singleton word scores. + for phrase_score in phrase_scores: + if phrase_score.word_count == 1: + continue + for word_score in phrase_score.word_scores: + if not word_score.suggestions and phrase_score.suggestions: + word_score.suggestions = phrase_score.suggestions + + return phrase_scores + + def _create_span_word_scores(self, base: SequenceScore, start: int, end: int, right_context_log_prob: float) -> List[WordScore]: + span = base.words[start:end] + span_forward = base.word_forward_log_prob(start, end) + span_contextual = span_forward + right_context_log_prob + word_count = max(end - start, 1) + span_share = right_context_log_prob / word_count + word_scores: List[WordScore] = [] + for index, word in enumerate(span, start): + tokens = [ + TokenScore(token=base.tokens[token_index], log_prob=base.token_log_probs[token_index]) + for token_index in range(word.token_start, word.token_end) + ] + word_forward = base.word_forward_log_prob(index, index + 1) + word_scores.append( + WordScore( + word=word.text, + tokens=tokens, + forward_log_prob=word_forward, + right_context_log_prob=span_share, + contextual_log_prob=(span_contextual / word_count if word_count > 1 else word_forward + right_context_log_prob), + span_start=index, + span_end=index + 1, + low_prob_threshold=self._low_prob_threshold, + ) + ) + return word_scores + + def _to_word_score(self, phrase_score: PhraseScore) -> WordScore: + word_score = phrase_score.word_scores[0] + word_score.suggestions = phrase_score.suggestions + return word_score + + def _score_right_context( + self, + source_context: Dict[str, torch.Tensor], + base: SequenceScore, + start: int, + end: int, + cache: Dict[str, SequenceScore], + ) -> Tuple[float, SequenceScore]: + suffix_with_phrase = base.word_forward_log_prob(end, len(base.words)) + ablated_words = base.word_texts[:start] + base.word_texts[end:] + ablated_text = self._join_words(ablated_words) + ablated = self._score_translation_text(source_context, ablated_text, cache) + if end >= len(base.words): + return 0.0, ablated + suffix_without_phrase = ablated.word_forward_log_prob(start, len(ablated.words)) + return suffix_with_phrase - suffix_without_phrase, ablated + + def _generate_scored_suggestions( + self, + source_context: Dict[str, torch.Tensor], + base: SequenceScore, + ablated: SequenceScore, + phrase_score: PhraseScore, + cache: Dict[str, SequenceScore], + ) -> List[PhraseSuggestion]: + prefix_words = base.word_texts[: phrase_score.span_start] + suffix_words = base.word_texts[phrase_score.span_end :] + prefix_text = self._join_words(prefix_words) + candidate_phrases = self._generate_candidate_phrases(source_context, prefix_text) + suggestions: List[PhraseSuggestion] = [] + seen: Set[str] = set() + + for candidate_phrase in candidate_phrases: + normalized_candidate = candidate_phrase.strip() + if normalized_candidate == "" or normalized_candidate == phrase_score.phrase or normalized_candidate in seen: + continue + seen.add(normalized_candidate) + candidate_words = normalized_candidate.split() + variant_words = prefix_words + candidate_words + suffix_words + variant_text = self._join_words(variant_words) + variant = self._score_translation_text(source_context, variant_text, cache) + candidate_end = phrase_score.span_start + len(candidate_words) + forward_log_prob = variant.word_forward_log_prob(phrase_score.span_start, candidate_end) + suffix_with_candidate = variant.word_forward_log_prob(candidate_end, len(variant.words)) + suffix_without_phrase = ablated.word_forward_log_prob(phrase_score.span_start, len(ablated.words)) + right_context_log_prob = suffix_with_candidate - suffix_without_phrase + contextual_log_prob = forward_log_prob + right_context_log_prob + suggestion = PhraseSuggestion( + phrase=normalized_candidate, + forward_log_prob=forward_log_prob, + right_context_log_prob=right_context_log_prob, + contextual_log_prob=contextual_log_prob, + improvement=(contextual_log_prob / max(len(candidate_words), 1)) - phrase_score.normalized_log_prob, + ) + if suggestion.improvement >= self._min_suggestion_improvement: + suggestions.append(suggestion) + + suggestions.sort(key=lambda suggestion: (suggestion.normalized_log_prob, suggestion.improvement), reverse=True) + return suggestions[: self._top_k_suggestions] + + def _generate_candidate_phrases(self, source_context: Dict[str, torch.Tensor], prefix_text: str) -> List[str]: + self._model.eval() + prefix_ids = self._build_decoder_prefix_ids(prefix_text).to(source_context["input_ids"].device) + num_beams = max(DEFAULT_GENERATION_BEAMS, self._top_k_suggestions * 2) + with torch.no_grad(): + outputs = self._model.generate( + input_ids=source_context["input_ids"], + attention_mask=source_context["attention_mask"], + decoder_input_ids=prefix_ids, + num_beams=num_beams, + num_return_sequences=num_beams, + max_new_tokens=DEFAULT_MAX_CANDIDATE_TOKENS, + early_stopping=True, + return_dict_in_generate=True, + forced_bos_token_id=None, + ) + + candidates: List[str] = [] + prefix_len = prefix_ids.shape[1] + for sequence in outputs.sequences: + continuation_ids = sequence[prefix_len:] + candidate_text = self._tokenizer.decode(continuation_ids, skip_special_tokens=True).strip() + if candidate_text == "": + continue + words = candidate_text.split() + max_words = min(self._max_phrase_words, len(words)) + for phrase_len in range(1, max_words + 1): + candidates.append(" ".join(words[:phrase_len])) + return candidates + + def _encode_source(self, source: str) -> Dict[str, torch.Tensor]: + source_encoding = self._tokenizer(source, return_tensors="pt", truncation=True) + device = next(self._model.parameters()).device + return { + "input_ids": source_encoding["input_ids"].to(device), + "attention_mask": source_encoding["attention_mask"].to(device), + } + + def _score_translation_text( + self, + source_context: Dict[str, torch.Tensor], + translation: str, + cache: Dict[str, SequenceScore], + ) -> SequenceScore: + cached = cache.get(translation) + if cached is not None: + return cached + + labels = self._prepare_target_labels(translation) + labels_on_device = labels.to(source_context["input_ids"].device) + + self._model.eval() + with torch.no_grad(): + outputs = self._model( + input_ids=source_context["input_ids"], + attention_mask=source_context["attention_mask"], + labels=labels_on_device, + ) + + logits = outputs.logits + log_probs = F.log_softmax(logits.float(), dim=-1) + label_ids = labels[0].tolist() + tokens = [self._tokenizer.convert_ids_to_tokens(token_id) for token_id in label_ids] + token_log_probs = [log_probs[0, index, token_id].item() for index, token_id in enumerate(label_ids)] + words = self._group_tokens_into_words(tokens, token_log_probs, label_ids) + sequence = SequenceScore( + text=translation, + label_ids=label_ids, + tokens=tokens, + token_log_probs=token_log_probs, + words=words, + ) + cache[translation] = sequence + return sequence + + def _prepare_target_labels(self, translation: str) -> torch.Tensor: + target_encoding = self._tokenizer(text_target=translation, return_tensors="pt", truncation=True) + labels = target_encoding["input_ids"] + forced_bos_token_id = self._get_forced_bos_token_id() + if forced_bos_token_id is not None and (labels.shape[1] == 0 or labels[0, 0].item() != forced_bos_token_id): + forced_bos = torch.tensor([[forced_bos_token_id]], dtype=labels.dtype) + labels = torch.cat([forced_bos, labels], dim=1) + return labels + + def _build_decoder_prefix_ids(self, prefix_text: str) -> torch.Tensor: + if prefix_text.strip() == "": + return torch.tensor([[self._get_initial_decoder_token_id()]], dtype=torch.long) + prefix_labels = self._prepare_target_labels(prefix_text) + if prefix_labels.shape[1] > 0 and prefix_labels[0, -1].item() == self._tokenizer.eos_token_id: + prefix_labels = prefix_labels[:, :-1] + return prefix_labels + + def _get_initial_decoder_token_id(self) -> int: + forced_bos_token_id = self._get_forced_bos_token_id() + if forced_bos_token_id is not None: + return forced_bos_token_id + decoder_start_token_id = getattr(self._model.config, "decoder_start_token_id", None) + if decoder_start_token_id is not None: + return decoder_start_token_id + bos_token_id = getattr(self._tokenizer, "bos_token_id", None) + if bos_token_id is not None: + return bos_token_id + raise RuntimeError("Unable to determine the initial decoder token id for suggestion generation.") + + def _get_forced_bos_token_id(self) -> Optional[int]: + if hasattr(self._model, "generation_config") and self._model.generation_config is not None: + forced_bos_token_id = getattr(self._model.generation_config, "forced_bos_token_id", None) + if forced_bos_token_id is not None: + return forced_bos_token_id + return getattr(self._model.config, "forced_bos_token_id", None) + + def _group_tokens_into_words( + self, + tokens: List[str], + token_log_probs: List[float], + label_ids: List[int], + ) -> List[WordSpan]: + words: List[WordSpan] = [] + current_word_chars = "" + current_token_start: Optional[int] = None + current_token_end: Optional[int] = None + + for token_index, (token, _log_prob, label_id) in enumerate(zip(tokens, token_log_probs, label_ids)): + if label_id in self._special_token_ids: + if current_token_start is not None and current_token_end is not None: + words.append(WordSpan(current_word_chars, current_token_start, current_token_end)) + current_word_chars = "" + current_token_start = None + current_token_end = None + continue + + starts_new_word = self._is_word_initial(token) + if starts_new_word and current_token_start is not None and current_token_end is not None: + words.append(WordSpan(current_word_chars, current_token_start, current_token_end)) + current_word_chars = "" + current_token_start = None + current_token_end = None + + if current_token_start is None: + current_token_start = token_index + current_token_end = token_index + 1 + current_word_chars += self._token_to_word_piece(token) + + if current_token_start is not None and current_token_end is not None: + words.append(WordSpan(current_word_chars, current_token_start, current_token_end)) + + return words + + def _is_word_initial(self, token: str) -> bool: + return token.startswith("▁") or token.startswith("Ġ") + + def _token_to_word_piece(self, token: str) -> str: + if token.startswith("▁"): + return token.removeprefix("▁") + if token.startswith("Ġ"): + return token.removeprefix("Ġ") + if token.startswith("##"): + return token.removeprefix("##") + return token + + def _join_words(self, words: List[str]) -> str: + return " ".join(word for word in words if word != "") diff --git a/silnlp/nmt/translation_scorer_webapp.py b/silnlp/nmt/translation_scorer_webapp.py new file mode 100644 index 00000000..faf45ade --- /dev/null +++ b/silnlp/nmt/translation_scorer_webapp.py @@ -0,0 +1,479 @@ +import argparse +import json +import logging +import re +import threading +from dataclasses import dataclass +from html import escape +from http import HTTPStatus +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import Any, Dict, List, Sequence + +from silnlp.common.iso_info import NLLB_TAGS + +LOGGER = logging.getLogger(__name__) + +DEFAULT_MODEL_NAME = "facebook/nllb-200-distilled-600M" +DEFAULT_HOST = "127.0.0.1" +DEFAULT_PORT = 8080 +DEFAULT_LOW_PROB_THRESHOLD = -3.0 +DEFAULT_TOP_K_SUGGESTIONS = 5 +DEFAULT_SOURCE_LANG = "eng_Latn" +DEFAULT_TARGET_LANG = "fra_Latn" +SUPPORTED_LANGUAGES = sorted(set(NLLB_TAGS)) + + +@dataclass(frozen=True) +class _FlagCandidate: + text: str + span_start: int + span_end: int + normalized_log_prob: float + suggestions: List[Dict[str, float | str]] + kind: str + + @property + def length(self) -> int: + return self.span_end - self.span_start + + +def _word_char_spans(text: str) -> List[tuple[int, int]]: + return [(match.start(), match.end()) for match in re.finditer(r"\S+", text)] + + +def _build_suggestions(score: Any) -> List[Dict[str, float | str]]: + return [ + { + "phrase": suggestion.phrase, + "normalized_log_prob": suggestion.normalized_log_prob, + "improvement": suggestion.improvement, + } + for suggestion in score.suggestions + ] + + +def _collect_flag_candidates(scored: Any) -> List[_FlagCandidate]: + candidates: List[_FlagCandidate] = [] + for score in scored.low_probability_phrases: + candidates.append( + _FlagCandidate( + text=score.phrase, + span_start=score.span_start, + span_end=score.span_end, + normalized_log_prob=score.normalized_log_prob, + suggestions=_build_suggestions(score), + kind="phrase", + ) + ) + + for score in scored.low_probability_words: + candidates.append( + _FlagCandidate( + text=score.word, + span_start=score.span_start, + span_end=score.span_end, + normalized_log_prob=score.normalized_log_prob, + suggestions=_build_suggestions(score), + kind="word", + ) + ) + + return candidates + + +def _select_non_overlapping_flags(candidates: Sequence[_FlagCandidate]) -> List[_FlagCandidate]: + selected: List[_FlagCandidate] = [] + for candidate in sorted(candidates, key=lambda c: (c.span_start, -c.length, c.normalized_log_prob)): + overlaps = any( + candidate.span_start < selected_candidate.span_end and selected_candidate.span_start < candidate.span_end + for selected_candidate in selected + ) + if not overlaps: + selected.append(candidate) + return selected + + +def _format_flags(scored: Any) -> List[Dict[str, Any]]: + char_spans = _word_char_spans(scored.translation) + flags: List[Dict[str, Any]] = [] + + for index, candidate in enumerate(_select_non_overlapping_flags(_collect_flag_candidates(scored))): + if candidate.span_end <= 0 or candidate.span_end > len(char_spans): + continue + char_start = char_spans[candidate.span_start][0] + char_end = char_spans[candidate.span_end - 1][1] + flags.append( + { + "id": f"flag-{index}", + "text": candidate.text, + "kind": candidate.kind, + "span_start": candidate.span_start, + "span_end": candidate.span_end, + "char_start": char_start, + "char_end": char_end, + "normalized_log_prob": candidate.normalized_log_prob, + "suggestions": candidate.suggestions, + } + ) + + return flags + + +class NllbScoringService: + def __init__( + self, + model_name: str = DEFAULT_MODEL_NAME, + low_prob_threshold: float = DEFAULT_LOW_PROB_THRESHOLD, + top_k_suggestions: int = DEFAULT_TOP_K_SUGGESTIONS, + ): + self._model_name = model_name + self._low_prob_threshold = low_prob_threshold + self._top_k_suggestions = top_k_suggestions + self._lock = threading.Lock() + self._tokenizer = None + self._model = None + + def _ensure_model(self) -> None: + if self._model is not None and self._tokenizer is not None: + return + import torch + from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + + LOGGER.info("Loading model %s", self._model_name) + tokenizer = AutoTokenizer.from_pretrained(self._model_name) + model = AutoModelForSeq2SeqLM.from_pretrained(self._model_name) + model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) + model.eval() + self._tokenizer = tokenizer + self._model = model + + def _configure_languages(self, source_lang: str, target_lang: str) -> None: + assert self._tokenizer is not None + assert self._model is not None + + if source_lang not in self._tokenizer.lang_code_to_id: + raise ValueError(f"Unsupported source language: {source_lang}") + if target_lang not in self._tokenizer.lang_code_to_id: + raise ValueError(f"Unsupported target language: {target_lang}") + + self._tokenizer.src_lang = source_lang + self._tokenizer.tgt_lang = target_lang + forced_bos_token_id = self._tokenizer.lang_code_to_id[target_lang] + self._model.config.forced_bos_token_id = forced_bos_token_id + if getattr(self._model, "generation_config", None) is not None: + self._model.generation_config.forced_bos_token_id = forced_bos_token_id + + def score(self, source: str, translation: str, source_lang: str, target_lang: str) -> Dict[str, Any]: + source = source.strip() + translation = translation.strip() + if source == "" or translation == "": + return {"source": source, "translation": translation, "flags": []} + + with self._lock: + self._ensure_model() + self._configure_languages(source_lang, target_lang) + assert self._model is not None + assert self._tokenizer is not None + from silnlp.nmt.translation_scorer import TranslationScorer + + scorer = TranslationScorer( + self._model, + self._tokenizer, + low_prob_threshold=self._low_prob_threshold, + top_k_suggestions=self._top_k_suggestions, + ) + scored = scorer.score(source, translation) + return { + "source": source, + "translation": scored.translation, + "flags": _format_flags(scored), + } + + +def _html_template() -> str: + languages_json = json.dumps(SUPPORTED_LANGUAGES) + default_source = escape(DEFAULT_SOURCE_LANG) + default_target = escape(DEFAULT_TARGET_LANG) + return f""" + +
+ +