Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ silnlp-nmt-preprocess = "silnlp.nmt.preprocess:main"
silnlp-nmt-train = "silnlp.nmt.train:main"
silnlp-nmt-test = "silnlp.nmt.test:main"
silnlp-nmt-translate = "silnlp.nmt.translate:main"
silnlp-nmt-score-webapp = "silnlp.nmt.translation_scorer_webapp:main"

silnlp-alignment-preprocess = "silnlp.alignment.preprocess:main"
silnlp-alignment-align = "silnlp.alignment.align:main"
Expand Down Expand Up @@ -122,4 +123,3 @@ eflomal = ["eflomal"]
name = "torch"
url = "https://download.pytorch.org/whl/cu121"
priority = "explicit"

17 changes: 16 additions & 1 deletion silnlp/nmt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from enum import Enum, auto
from pathlib import Path
from statistics import mean, median, stdev
from typing import Dict, Generator, Iterable, List, Optional, Set, TextIO, Tuple, Union, cast
from typing import TYPE_CHECKING, Dict, Generator, Iterable, List, Optional, Set, TextIO, Tuple, Union, cast

if TYPE_CHECKING:
from .translation_scorer import ScoredTranslation

import pandas as pd
from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, get_books
Expand Down Expand Up @@ -99,6 +102,18 @@ def clear_cache(self) -> None: ...
@abstractmethod
def get_num_drafts(self) -> int: ...

@abstractmethod
def score_translation(
self,
source: str,
translation: str,
src_iso: str,
trg_iso: str,
ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST,
low_prob_threshold: float = -3.0,
top_k_suggestions: int = 5,
) -> "ScoredTranslation": ...


class Config(ABC):
def __init__(self, exp_dir: Path, config: dict) -> None:
Expand Down
51 changes: 50 additions & 1 deletion silnlp/nmt/hugging_face_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from itertools import repeat
from math import prod
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, TypeVar, Union, cast

if TYPE_CHECKING:
from .translation_scorer import ScoredTranslation

import datasets.utils.logging as datasets_logging
import evaluate
Expand Down Expand Up @@ -1406,6 +1409,52 @@ def clear_cache(self) -> None:
self._cached_inference_model = None
self._inference_model_params = None

def score_translation(
self,
source: str,
translation: str,
src_iso: str,
trg_iso: str,
ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST,
low_prob_threshold: float = -3.0,
top_k_suggestions: int = 5,
) -> "ScoredTranslation":
"""Score a translation using forced decoding.

Computes the conditional probability P(y_t | y_1, ..., y_{t-1}, x) for each
target token in the translation, groups subword tokens into words, flags
low-probability words, and provides top-k alternative suggestions for them.

Args:
source: The source sentence.
translation: The translation to score.
src_iso: Source language ISO code.
trg_iso: Target language ISO code.
ckpt: Checkpoint to use for the model.
low_prob_threshold: Log-probability threshold below which a word is flagged.
top_k_suggestions: Number of alternative suggestions per low-probability word.

Returns:
A ScoredTranslation with per-word scores and suggestions.
"""
from .translation_scorer import ScoredTranslation, TranslationScorer

src_lang = self._config.data["lang_codes"].get(src_iso, src_iso)
trg_lang = self._config.data["lang_codes"].get(trg_iso, trg_iso)
inference_model_params = InferenceModelParams(ckpt, src_lang, trg_lang)
tokenizer = self._config.get_tokenizer()
if self._inference_model_params == inference_model_params and self._cached_inference_model is not None:
model = self._cached_inference_model
else:
model = self._cached_inference_model = self._create_inference_model(ckpt, tokenizer, src_lang, trg_lang)
self._inference_model_params = inference_model_params

if isinstance(tokenizer, (NllbTokenizer, NllbTokenizerFast)):
tokenizer = PunctuationNormalizingTokenizer(tokenizer)

scorer = TranslationScorer(model, tokenizer, low_prob_threshold, top_k_suggestions)
return scorer.score(source, translation)

def _create_training_arguments(self) -> Seq2SeqTrainingArguments:
parser = HfArgumentParser(Seq2SeqTrainingArguments)
args: dict = {}
Expand Down
140 changes: 140 additions & 0 deletions silnlp/nmt/score_translations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import argparse
import logging
from typing import Optional

from .config_utils import load_config
from .translation_scorer import DEFAULT_LOW_PROB_THRESHOLD, DEFAULT_TOP_K_SUGGESTIONS, PhraseScore, ScoredTranslation

LOGGER = logging.getLogger(__name__)


def _format_suggestions(phrase_score: PhraseScore) -> str:
if not phrase_score.suggestions:
return ""
return "; ".join(
f"{suggestion.phrase} (ctx/word={suggestion.normalized_log_prob:.4f}, Δ={suggestion.improvement:.4f})"
for suggestion in phrase_score.suggestions
)


def format_scored_translation(scored: ScoredTranslation) -> str:
"""Format a ScoredTranslation as a human-readable string."""
lines = [
f"Source: {scored.source}",
f"Translation: {scored.translation}",
f"Sequence log-probability: {scored.sequence_log_prob:.4f}",
"",
"Word-level contextual scores:",
]

col_word = max((len(score.word) for score in scored.word_scores), default=4)
header = (
f" {'Word':<{col_word}} {'Forward':>10} {'Right Ctx':>10} {'Ctx/Word':>10} Suggestions"
)
lines.append(header)
lines.append(" " + "-" * (len(header) - 2))
for score in scored.word_scores:
flag = "* " if score.is_low_probability else " "
suggestion_text = ", ".join(suggestion.phrase for suggestion in score.suggestions)
lines.append(
f"{flag}{score.word:<{col_word}} {score.forward_log_prob:>10.4f} "
f"{score.right_context_log_prob:>10.4f} {score.normalized_log_prob:>10.4f} {suggestion_text}"
)

lines.append("")
lines.append("Low-probability phrases:")
low_phrases = scored.low_probability_phrases
if not low_phrases:
lines.append(" None")
else:
for score in low_phrases:
lines.append(
f" '{score.phrase}' [{score.span_start}:{score.span_end}] "
f"forward={score.forward_log_prob:.4f}, right_ctx={score.right_context_log_prob:.4f}, "
f"ctx/word={score.normalized_log_prob:.4f}"
)
suggestions = _format_suggestions(score)
if suggestions:
lines.append(f" Suggestions: {suggestions}")

return "\n".join(lines)


def score_translation(
experiment: str,
source: str,
translation: str,
src_iso: Optional[str],
trg_iso: Optional[str],
checkpoint: str = "last",
low_prob_threshold: float = DEFAULT_LOW_PROB_THRESHOLD,
top_k_suggestions: int = DEFAULT_TOP_K_SUGGESTIONS,
) -> ScoredTranslation:
config = load_config(experiment)
model = config.create_model()

effective_src_iso = src_iso or config.default_test_src_iso
effective_trg_iso = trg_iso or config.default_test_trg_iso

return model.score_translation(
source=source,
translation=translation,
src_iso=effective_src_iso,
trg_iso=effective_trg_iso,
ckpt=checkpoint,
low_prob_threshold=low_prob_threshold,
top_k_suggestions=top_k_suggestions,
)


def main() -> None:
parser = argparse.ArgumentParser(
description=(
"Score a translation by computing contextual phrase probabilities. "
"Low-probability words and phrases are flagged and paired with rescored replacement suggestions."
)
)
parser.add_argument("experiment", help="Experiment name")
parser.add_argument("--source", type=str, required=True, help="Source sentence to score against")
parser.add_argument("--translation", type=str, required=True, help="Translation to evaluate")
parser.add_argument("--src-iso", type=str, default=None, help="Source language ISO code")
parser.add_argument("--trg-iso", type=str, default=None, help="Target language ISO code")
parser.add_argument(
"--checkpoint",
type=str,
default="last",
help="Checkpoint to use: 'last', 'best', 'avg', or a checkpoint step number",
)
parser.add_argument(
"--low-prob-threshold",
type=float,
default=DEFAULT_LOW_PROB_THRESHOLD,
help=(
f"Threshold on contextual log-probability per word for flagging low-probability spans "
f"(default: {DEFAULT_LOW_PROB_THRESHOLD})"
),
)
parser.add_argument(
"--top-k-suggestions",
type=int,
default=DEFAULT_TOP_K_SUGGESTIONS,
help=f"Number of replacement suggestions per flagged span (default: {DEFAULT_TOP_K_SUGGESTIONS})",
)

args = parser.parse_args()

scored = score_translation(
experiment=args.experiment,
source=args.source,
translation=args.translation,
src_iso=args.src_iso,
trg_iso=args.trg_iso,
checkpoint=args.checkpoint,
low_prob_threshold=args.low_prob_threshold,
top_k_suggestions=args.top_k_suggestions,
)
print(format_scored_translation(scored))


if __name__ == "__main__":
main()
Loading