From c492ac13eb540ab139556eeacad7aaa7a91aa13c Mon Sep 17 00:00:00 2001 From: Adithyakrishna Hanasoge Date: Sun, 26 Apr 2026 19:55:32 -0700 Subject: [PATCH 1/2] feat: add TokenAligner and cross-tokenizer projection utilities MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Foundational library code for cross-tokenizer distillation. No algorithm or training-loop integration yet — those follow in subsequent PRs. - nemo_rl/algorithms/x_token/tokenalign.py: TokenAligner(nn.Module) with Numba-accelerated DP alignment, projection-matrix loading (dense and sparse COO), and the project_token_likelihoods_instance forward path used by the cross-tokenizer loss. - nemo_rl/algorithms/x_token/__init__.py: package init. - nemo_rl/utils/x_token/{minimal_projection_generator, minimal_projection_via_multitoken,reapply_exact_map, sort_and_cut_projection_matrix}.py: standalone CLI scripts (argparse-driven, __main__ entrypoints) for one-time projection-matrix preparation. Not on the training import path. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Adithyakrishna Hanasoge --- nemo_rl/algorithms/x_token/__init__.py | 13 + nemo_rl/algorithms/x_token/tokenalign.py | 2194 +++++++++++++++++ nemo_rl/utils/x_token/__init__.py | 13 + .../x_token/minimal_projection_generator.py | 585 +++++ .../minimal_projection_via_multitoken.py | 942 +++++++ nemo_rl/utils/x_token/reapply_exact_map.py | 246 ++ .../x_token/sort_and_cut_projection_matrix.py | 451 ++++ 7 files changed, 4444 insertions(+) create mode 100644 nemo_rl/algorithms/x_token/__init__.py create mode 100644 nemo_rl/algorithms/x_token/tokenalign.py create mode 100644 nemo_rl/utils/x_token/__init__.py create mode 100644 nemo_rl/utils/x_token/minimal_projection_generator.py create mode 100644 nemo_rl/utils/x_token/minimal_projection_via_multitoken.py create mode 100644 nemo_rl/utils/x_token/reapply_exact_map.py create mode 100644 nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py diff --git a/nemo_rl/algorithms/x_token/__init__.py b/nemo_rl/algorithms/x_token/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/nemo_rl/algorithms/x_token/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_rl/algorithms/x_token/tokenalign.py b/nemo_rl/algorithms/x_token/tokenalign.py new file mode 100644 index 0000000000..1d55f858f7 --- /dev/null +++ b/nemo_rl/algorithms/x_token/tokenalign.py @@ -0,0 +1,2194 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoTokenizer + +try: + from numba import njit + _NUMBA_AVAILABLE = True +except ImportError: + _NUMBA_AVAILABLE = False + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +if _NUMBA_AVAILABLE: + @njit(cache=True) + def _dp_core_numba(ids1, ids2, joined1, joined2, n1, n2, + exact_match_score, gap_penalty, comb_mul, max_comb_len): + """Numba-accelerated DP core for token alignment. + + Uses the same algorithm as align_tokens_with_combinations_numpy but + with integer ID comparisons instead of Python string operations. + + Trace codes: 0=start, 1=diag, 2=up, 3=left, + 10+k = comb_s1_over_s2_k, 20+k = comb_s2_over_s1_k + """ + INVALID = np.int64(-1) + dp = np.zeros((n1 + 1, n2 + 1), dtype=np.float32) + trace = np.zeros((n1 + 1, n2 + 1), dtype=np.int32) + + for i in range(1, n1 + 1): + dp[i, 0] = dp[i - 1, 0] + gap_penalty + trace[i, 0] = 2 + for j in range(1, n2 + 1): + dp[0, j] = dp[0, j - 1] + gap_penalty + trace[0, j] = 3 + + for i in range(1, n1 + 1): + id_i = ids1[i - 1] + for j in range(1, n2 + 1): + id_j = ids2[j - 1] + + if id_i == id_j: + best = dp[i - 1, j - 1] + exact_match_score + else: + best = dp[i - 1, j - 1] - exact_match_score + best_m = np.int32(1) + + s = dp[i - 1, j] + gap_penalty + if s > best: + best = s + best_m = np.int32(2) + + s = dp[i, j - 1] + gap_penalty + if s > best: + best = s + best_m = np.int32(3) + + k_max_s2 = min(j, max_comb_len) + for k in range(2, k_max_s2 + 1): + jid = joined2[j, k] + if jid != INVALID and id_i == jid: + s = dp[i - 1, j - k] + comb_mul * np.float32(k) + if s > best: + best = s + best_m = np.int32(10 + k) + + k_max_s1 = min(i, max_comb_len) + for k in range(2, k_max_s1 + 1): + jid = joined1[i, k] + if jid != INVALID and id_j == jid: + s = dp[i - k, j - 1] + comb_mul * np.float32(k) + if s > best: + best = s + best_m = np.int32(20 + k) + + dp[i, j] = best + trace[i, j] = best_m + + return dp, trace +else: + _dp_core_numba = None + + +@dataclass(frozen=True) +class VocabPartition: + """Projection-matrix-derived vocab partition for the gold-loss path. + + Built once per (xtoken_loss, teacher_vocab_size) by + `TokenAligner.build_vocab_partition`. All tensors are long, 1-D, and + live on the aligner's projection-matrix device. + """ + + common_student_indices: torch.Tensor + common_teacher_indices: torch.Tensor + uncommon_student_indices: torch.Tensor + uncommon_teacher_indices: torch.Tensor + + +class TokenAligner(nn.Module): + def __init__(self, max_comb_len=4, teacher_tokenizer_name=None, student_tokenizer_name=None, init_hf_tokenizers=True, projection_matrix_multiplier=1.0, enable_scale_trick=None): + super().__init__() + self.teacher_tokenizer_name = teacher_tokenizer_name + self.student_tokenizer_name = student_tokenizer_name + self.projection_matrix_multiplier = projection_matrix_multiplier + self.enable_scale_trick = enable_scale_trick + + if init_hf_tokenizers: + self.teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_tokenizer_name) + self.student_tokenizer = AutoTokenizer.from_pretrained(student_tokenizer_name) + if self.teacher_tokenizer.pad_token is None: + self.teacher_tokenizer.pad_token = self.teacher_tokenizer.eos_token + if self.student_tokenizer.pad_token is None: + self.student_tokenizer.pad_token = self.student_tokenizer.eos_token + else: + self.teacher_tokenizer = None + self.student_tokenizer = None + + self.max_combination_len = max_comb_len + self.sparse_transformation_matrix = None + # Cached CSR for dense top-k projection (built from indices/values) to avoid scatter path + self._dense_proj_csr = None + self._dense_proj_csr_device = None + + # Precomputed canonical ID maps (built by precompute_canonical_maps) + self._student_canon_map = None + self._teacher_canon_map = None + self._canon_id_to_str = None + + # Cached gold-loss vocab partitions keyed by (xtoken_loss, teacher_vocab_size). + # Populated lazily by build_vocab_partition(); bypassed when learnable=True. + self._vocab_partition_cache: dict[tuple[bool, int], VocabPartition] = {} + + @torch.no_grad() + def build_vocab_partition( + self, xtoken_loss: bool, teacher_vocab_size: int + ) -> VocabPartition: + """Derive (and cache) the gold-loss vocab partition. + + Computes exactly the state that CrossTokenizerDistillationLossFn._compute_gold_loss + previously rebuilt every microbatch: the set of student/teacher tokens that have an + exact 1:1 projection (common) and the complement (uncommon). Cached by + (xtoken_loss, teacher_vocab_size) so repeat calls cost nothing. Caching is skipped + when the projection matrix is learnable, since the partition then depends on + gradient-updated values. + """ + if ( + not hasattr(self, "likelihood_projection_indices") + or self.likelihood_projection_indices is None + ): + raise ValueError( + "build_vocab_partition requires likelihood_projection_indices to be loaded" + ) + + learnable = bool(getattr(self, "learnable", False)) + key = (bool(xtoken_loss), int(teacher_vocab_size)) + if not learnable: + cached = self._vocab_partition_cache.get(key) + if cached is not None: + return cached + + projection_indices = self.likelihood_projection_indices + projection_matrix = ( + self.transform_learned_matrix_instance(self.likelihood_projection_matrix) + if learnable + else self.likelihood_projection_matrix + ) + device = projection_matrix.device + student_vocab_size = int(projection_matrix.shape[0]) + + sorted_values, sorted_indices_in_topk = torch.sort( + projection_matrix, dim=-1, descending=True + ) + + if xtoken_loss: + has_exact_map = sorted_values[:, 0] >= 0.6 + else: + has_exact_map = (sorted_values[:, 0] == 1.0) & (projection_indices[:, 1] == -1) + + student_indices_with_exact_map = torch.where(has_exact_map)[0] + teacher_indices_for_exact_map = projection_indices[ + student_indices_with_exact_map, + sorted_indices_in_topk[student_indices_with_exact_map, 0], + ] + + student_to_teacher_exact_map: dict[int, int] = {} + teacher_to_student_exact_map: dict[int, int] = {} + for s_idx, t_idx in zip( + student_indices_with_exact_map.tolist(), + teacher_indices_for_exact_map.tolist(), + ): + if 0 <= t_idx < teacher_vocab_size: + if t_idx not in teacher_to_student_exact_map or xtoken_loss: + if t_idx in teacher_to_student_exact_map: + prev_student_token = teacher_to_student_exact_map[t_idx] + if sorted_values[prev_student_token, 0] >= sorted_values[s_idx, 0]: + continue + del student_to_teacher_exact_map[prev_student_token] + student_to_teacher_exact_map[s_idx] = t_idx + teacher_to_student_exact_map[t_idx] = s_idx + + common_student = sorted(student_to_teacher_exact_map.keys()) + common_teacher = [student_to_teacher_exact_map[s] for s in common_student] + uncommon_student = sorted( + set(range(student_vocab_size)) - set(common_student) + ) + uncommon_teacher = sorted( + set(range(teacher_vocab_size)) - set(common_teacher) + ) + + partition = VocabPartition( + common_student_indices=torch.tensor( + common_student, dtype=torch.long, device=device + ), + common_teacher_indices=torch.tensor( + common_teacher, dtype=torch.long, device=device + ), + uncommon_student_indices=torch.tensor( + uncommon_student, dtype=torch.long, device=device + ), + uncommon_teacher_indices=torch.tensor( + uncommon_teacher, dtype=torch.long, device=device + ), + ) + + if not learnable: + self._vocab_partition_cache[key] = partition + return partition + + def precompute_canonical_maps(self): + """Build token_id → canonical_string lookup tables for both tokenizers. + + Call once at startup. After this, align_fast() can skip + convert_ids_to_tokens and _canonicalize_sequence entirely. + """ + import time as _time + _t0 = _time.time() + + canon_str_to_id: dict[str, int] = {} + next_id = [0] + + def _get_canon_id(s: str) -> int: + cid = canon_str_to_id.get(s) + if cid is None: + cid = next_id[0] + canon_str_to_id[s] = cid + next_id[0] += 1 + return cid + + student_vocab_size = len(self.student_tokenizer) + teacher_vocab_size = len(self.teacher_tokenizer) + + student_map = np.zeros(student_vocab_size, dtype=np.int64) + for tid in range(student_vocab_size): + tok = self.student_tokenizer.convert_ids_to_tokens(tid) + canon = self._canonical_token(tok) + student_map[tid] = _get_canon_id(canon) + + teacher_map = np.zeros(teacher_vocab_size, dtype=np.int64) + for tid in range(teacher_vocab_size): + tok = self.teacher_tokenizer.convert_ids_to_tokens(tid) + canon = self._canonical_token(tok) + teacher_map[tid] = _get_canon_id(canon) + + self._student_canon_map = student_map + self._teacher_canon_map = teacher_map + self._canon_id_to_str = {v: k for k, v in canon_str_to_id.items()} + + _t1 = _time.time() + print(f" [TokenAligner] Precomputed canonical maps in {_t1-_t0:.2f}s " + f"(student_vocab={student_vocab_size}, teacher_vocab={teacher_vocab_size}, " + f"unique_canonical={len(canon_str_to_id)})", flush=True) + + def align_fast(self, student_ids, teacher_ids, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + chunk_size=128, + post_process=True, + anchor_lengths=[3,], + ignore_leading_char_diff=False): + """Fast alignment using precomputed canonical ID maps. + + Skips convert_ids_to_tokens and _canonicalize_sequence by looking up + canonical strings directly from token IDs via precomputed numpy arrays. + Falls back to regular align() if precomputed maps are not available. + """ + if self._student_canon_map is None: + return self.align(student_ids, teacher_ids, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + chunk_size=chunk_size, + post_process=post_process, + anchor_lengths=anchor_lengths, + ignore_leading_char_diff=ignore_leading_char_diff) + + if isinstance(student_ids, torch.Tensor): + student_ids = student_ids.cpu().numpy() + if isinstance(teacher_ids, torch.Tensor): + teacher_ids = teacher_ids.cpu().numpy() + + if student_ids.ndim == 1: + student_ids = student_ids[np.newaxis, :] + teacher_ids = teacher_ids[np.newaxis, :] + + import time as _time + _t_lookup_total = 0.0 + _t_anchors_dp_total = 0.0 + _t_postprocess_total = 0.0 + _t_mask_total = 0.0 + + all_aligned_pairs = [] + for i in range(student_ids.shape[0]): + s_ids = student_ids[i] + t_ids = teacher_ids[i] + + _tl0 = _time.time() + s_canon_strs = [self._canon_id_to_str[self._student_canon_map[tid]] for tid in s_ids] + t_canon_strs = [self._canon_id_to_str[self._teacher_canon_map[tid]] for tid in t_ids] + _tl1 = _time.time() + _t_lookup_total += _tl1 - _tl0 + + align_kwargs = { + 'exact_match_score': exact_match_score, + 'combination_score_multiplier': combination_score_multiplier, + 'gap_penalty': gap_penalty, + 'max_combination_len': self.max_combination_len, + 'ignore_leading_char_diff': False, + 'chunk_size': chunk_size, + 'anchor_lengths': anchor_lengths, + } + + aligned_pairs, _ = self._align_with_anchors(s_canon_strs, t_canon_strs, **align_kwargs) + _tl2 = _time.time() + _t_anchors_dp_total += _tl2 - _tl1 + + if post_process: + aligned_pairs = self.post_process_alignment_optimized( + aligned_pairs, + ignore_leading_char_diff=ignore_leading_char_diff, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=self.max_combination_len + ) + _tl3 = _time.time() + _t_postprocess_total += _tl3 - _tl2 + + mask = self.get_alignment_mask(aligned_pairs, use_canonicalization=True, + ignore_leading_char_diff=ignore_leading_char_diff) + aligned_pairs = [ + (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, mask_value) + for (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end), mask_value + in zip(aligned_pairs, mask) + ] + _tl4 = _time.time() + _t_mask_total += _tl4 - _tl3 + + all_aligned_pairs.append(aligned_pairs) + + n = student_ids.shape[0] + _t_total = _t_lookup_total + _t_anchors_dp_total + _t_postprocess_total + _t_mask_total + if _t_total > 0.5 or n > 1: + print(f" [align_fast timing] lookup={_t_lookup_total:.3f}s, " + f"anchors+DP={_t_anchors_dp_total:.3f}s, " + f"postprocess={_t_postprocess_total:.3f}s, " + f"mask={_t_mask_total:.3f}s, " + f"total={_t_total:.3f}s (n={n})", flush=True) + + return all_aligned_pairs + + def _load_logits_projection_map( + self, + folder_location: str = "cross_tokenizer_data", + file_path: str = None, + top_k: int = 100, + device: str = "cuda", + use_sparse_format: bool = False, + learnable: bool = False, + ): + """ + Load projection map for cross-tokenizer likelihood projection. + Always creates student→teacher mapping. + + Args: + folder_location: Directory containing the projection files + file_path: Specific file path (overrides folder_location) + top_k: Number of top entries per row (only used for old format) + device: Device to load tensors on + use_sparse_format: If True, load sparse transformation matrix format (from multi-token mapping) + If False, load old dense indices/values format + learnable: If True, make the transformation matrix learnable + """ + self.learnable = learnable + if use_sparse_format: + # Load sparse transformation matrix format + if file_path is None: + file_path = f"{folder_location}/transformation_counts_via_multitoken.pt" + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Sparse transformation matrix file not found: {file_path}. Please generate it first.") + + # Load transformation counts dictionary + transformation_counts = torch.load(file_path, map_location='cpu', weights_only=False) + + # Get tokenizer vocab sizes + teacher_vocab_size = len(self.teacher_tokenizer) if self.teacher_tokenizer else 151669 # fallback + student_vocab_size = len(self.student_tokenizer) if self.student_tokenizer else 128256 # fallback + if 1: + # get vocab sizes from autoconfig + if "gemma" not in self.teacher_tokenizer_name.lower() and "qwen3.5" not in self.teacher_tokenizer_name.lower(): + teacher_vocab_size = AutoConfig.from_pretrained(self.teacher_tokenizer_name).vocab_size + else: + teacher_vocab_size = AutoConfig.from_pretrained(self.teacher_tokenizer_name).text_config.vocab_size + if "gemma" not in self.student_tokenizer_name.lower() and "qwen3.5" not in self.student_tokenizer_name.lower(): + student_vocab_size = AutoConfig.from_pretrained(self.student_tokenizer_name).vocab_size + else: + student_vocab_size = AutoConfig.from_pretrained(self.student_tokenizer_name).text_config.vocab_size + # teacher_vocab_size = AutoConfig.from_pretrained(self.teacher_tokenizer_name).vocab_size + # student_vocab_size = AutoConfig.from_pretrained(self.student_tokenizer_name).vocab_size + + + # Debug vocab sizes + print(f"Teacher vocab size: {teacher_vocab_size}, Student vocab size: {student_vocab_size}") + + # Convert dictionary to sparse tensor + if transformation_counts: + + + indices = list(transformation_counts.keys()) + values = list(transformation_counts.values()) + + student_indices = [idx[0] for idx in indices] + teacher_indices = [idx[1] for idx in indices] + + # Always create student→teacher mapping: rows = student vocab, cols = teacher vocab + indices_tensor = torch.LongTensor([student_indices, teacher_indices]) + values_tensor = torch.FloatTensor(values)/self.projection_matrix_multiplier + matrix_shape = (student_vocab_size, teacher_vocab_size) + + print(f"Creating sparse matrix: student→teacher ({student_vocab_size} x {teacher_vocab_size})") + + sparse_transformation_matrix = torch.sparse_coo_tensor( + indices_tensor, + values_tensor, + (student_vocab_size, teacher_vocab_size), # student_vocab × teacher_vocab + device=device, + dtype=torch.float32 + ) + + # Optionally make the sparse matrix learnable (values only) + if learnable: + self.sparse_transformation_matrix = nn.Parameter( + sparse_transformation_matrix.coalesce(), requires_grad=True + ) + else: + # Register as buffer for non-learnable parameters (ensures proper device handling) + self.register_buffer('sparse_transformation_matrix', + sparse_transformation_matrix.coalesce(), + persistent=True) + + # Store a flag for downstream code + self.is_sparse_learnable = learnable + print(f"Loaded sparse transformation matrix with {len(transformation_counts)} entries") + else: + # Empty transformation matrix (student→teacher) + matrix_shape = (student_vocab_size, teacher_vocab_size) + + empty_sparse = torch.sparse_coo_tensor( + torch.zeros(2, 0, dtype=torch.long), + torch.zeros(0, dtype=torch.float32), + matrix_shape, + device=device, + ) + + if learnable: + self.sparse_transformation_matrix = nn.Parameter(empty_sparse, requires_grad=True) + else: + # Register as buffer for non-learnable parameters + self.register_buffer('sparse_transformation_matrix', empty_sparse, persistent=True) + + self.is_sparse_learnable = learnable + print("Warning: Empty transformation matrix loaded") + else: + # Load old dense indices/values format + if file_path is None: + file_path = f"{folder_location}/projection_map_Llama-3.1_to_Qwen3_bidirectional_top_10.pt" + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Projection map file not found: {file_path}. Please generate it first.") + + projection_data = torch.load(file_path, map_location='cpu', weights_only=False) + # Always use B_to_A direction for student->teacher projection + # projection_data = projection_data["B_to_A"] + # projection_data = projection_data["A_to_B"] + + indices = projection_data["indices"] + likelihoods = projection_data["likelihoods"]/self.projection_matrix_multiplier + + # Register indices as buffer (always non-learnable) + self.register_buffer('likelihood_projection_indices', indices.to(device), persistent=True) + if learnable: + if 1: + likelihoods = (likelihoods+1e-10).log() + + # Use instance variable if set, otherwise use default (False) + # scale_trick_enabled = self.enable_scale_trick if self.enable_scale_trick is not None else False + + # if scale_trick_enabled: + # #trick with last column being multiplier - set to -4.0 + # likelihoods[:,-1] = likelihoods[:,-1]*0.0 - 4.0 + #lets introduce some noise to encourage training. will remove later. + if 0: + likelihoods = likelihoods + torch.randn_like(likelihoods) * 1e-1 + likelihoods = likelihoods/2.0 + + self.likelihood_projection_matrix = nn.Parameter(likelihoods.to(device), requires_grad=True) + # print(self.likelihood_projection_matrix[0]) + # print(self.likelihood_projection_matrix[:,-1]) + # exit() + #add small gaussian noise to the projection matrix + #use log form + else: + # Register as buffer for non-learnable parameters + self.register_buffer('likelihood_projection_matrix', likelihoods.to(device), persistent=True) + + + print(f"Loaded dense projection map with shape {indices.shape}") + # Invalidate cached CSR; will rebuild on first use + self._dense_proj_csr = None + self._dense_proj_csr_device = None + + def create_reverse_projection_matrix(self, device="cuda"): + """ + Create a reverse (transposed) projection matrix for teacher→student projection. + + For sparse format: Transposes the sparse_transformation_matrix from [student_vocab, teacher_vocab] + to [teacher_vocab, student_vocab] + For dense format: Builds a reverse index mapping from teacher tokens to student tokens + + This enables projecting teacher logits into student vocabulary space. + """ + if hasattr(self, 'sparse_transformation_matrix') and self.sparse_transformation_matrix is not None: + # Transpose sparse matrix + print("Creating reverse projection matrix (sparse format): teacher→student") + sparse_matrix = self.sparse_transformation_matrix.coalesce() + indices = sparse_matrix.indices() + values = sparse_matrix.values() + + # Swap student and teacher indices (transpose) + transposed_indices = torch.stack([indices[1], indices[0]], dim=0) # Swap rows: [teacher, student] + teacher_vocab_size, student_vocab_size = sparse_matrix.shape[1], sparse_matrix.shape[0] + + reverse_sparse = torch.sparse_coo_tensor( + transposed_indices, + values, + (teacher_vocab_size, student_vocab_size), + device=device, + dtype=torch.float32 + ).coalesce() + + # Store as buffer or parameter based on learnability + if self.is_sparse_learnable: + self.reverse_sparse_transformation_matrix = nn.Parameter(reverse_sparse, requires_grad=True) + else: + self.register_buffer('reverse_sparse_transformation_matrix', reverse_sparse, persistent=True) + + print(f"Created reverse sparse matrix: teacher→student ({teacher_vocab_size} x {student_vocab_size})") + print(f"Reverse matrix has {len(values)} non-zero entries") + + elif hasattr(self, 'likelihood_projection_indices') and self.likelihood_projection_indices is not None: + # Build reverse index for dense format + print("Creating reverse projection matrix (dense format): teacher→student") + + # Current: likelihood_projection_indices is [student_vocab, topk] + # We need to build: [teacher_vocab, variable_k] where variable_k depends on how many students map to each teacher token + + student_vocab_size = self.likelihood_projection_indices.shape[0] + topk = self.likelihood_projection_indices.shape[1] + + # Infer teacher vocab size from the max index + teacher_vocab_size = self.likelihood_projection_indices.max().item() + 1 + + # Build reverse mapping: for each teacher token, collect all (student_token, value) pairs + from collections import defaultdict + teacher_to_students = defaultdict(list) + + for student_idx in range(student_vocab_size): + for k in range(topk): + teacher_idx = self.likelihood_projection_indices[student_idx, k].item() + if hasattr(self, 'likelihood_projection_matrix'): + value = self.likelihood_projection_matrix[student_idx, k].item() + else: + value = 1.0 # Default value if no matrix + + # Check for valid entries: teacher_idx must be valid, and value must be finite (not -inf) + # If matrix is in log-space, valid log-probs are finite negative values + # Threshold at -20 to filter out padding values like -22.3197 + if teacher_idx >= 0 and value > -20.0: # Skip invalid or padding entries + teacher_to_students[teacher_idx].append((student_idx, value)) + + # Find max number of students mapping to any teacher token + raw_max_students = max([len(v) for v in teacher_to_students.values()]) if teacher_to_students else 1 + print(f"Max students mapping to any teacher token (before filtering): {raw_max_students}") + + # Limit to top-K students per teacher token to avoid explosion + # Keep only the top-K highest probability mappings per teacher + max_students_per_teacher = min(topk, raw_max_students) # Use same topk as forward direction + print(f"Limiting to top-{max_students_per_teacher} students per teacher token") + + # Sort each teacher's student list by value (descending) and keep only top-K + for teacher_idx in teacher_to_students: + student_list = teacher_to_students[teacher_idx] + # Sort by value (descending - higher log-prob = less negative) + student_list_sorted = sorted(student_list, key=lambda x: x[1], reverse=True) + teacher_to_students[teacher_idx] = student_list_sorted[:max_students_per_teacher] + + # Create dense reverse index [teacher_vocab, max_students_per_teacher] + # Use 0 instead of -1 for padding (valid index), with very negative values to nullify contribution + reverse_indices = torch.zeros((teacher_vocab_size, max_students_per_teacher), + dtype=torch.long, device=device) + # Initialize with very negative values (padding sentinel, similar to forward direction) + reverse_values = torch.full((teacher_vocab_size, max_students_per_teacher), -22.3197, + dtype=torch.float32, device=device) + + for teacher_idx, student_list in teacher_to_students.items(): + for k, (student_idx, value) in enumerate(student_list): + reverse_indices[teacher_idx, k] = student_idx + reverse_values[teacher_idx, k] = value + + print(f"Created reverse dense projection: teacher→student ({teacher_vocab_size} x {max_students_per_teacher})") + + # Store as buffer or parameter + self.register_buffer('reverse_likelihood_projection_indices', reverse_indices, persistent=True) + if self.learnable: + self.reverse_likelihood_projection_matrix = nn.Parameter(reverse_values, requires_grad=True) + else: + self.register_buffer('reverse_likelihood_projection_matrix', reverse_values, persistent=True) + + print(f"Created reverse dense projection: teacher→student ({teacher_vocab_size} x {max_students_per_teacher})") + else: + raise ValueError("No projection matrix loaded. Cannot create reverse projection.") + + def project_token_likelihoods_instance(self, input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device, use_sparse_format=False, sparse_matrix=None, use_vectorized=True, gpu_optimized_scatter=True, global_top_indices=None): + """ + Instance method wrapper for project_token_likelihoods that can access instance variables. + + Args: + global_top_indices: Optional tensor of shape (K,) containing indices of tokens to project to. + If provided, only projects to these K tokens instead of full target_vocab_size. + Results in (batch, seq, K) output instead of (batch, seq, target_vocab_size). + """ + if use_sparse_format: + if sparse_matrix is None: + raise ValueError("sparse_matrix must be provided when use_sparse_format=True") + + if global_top_indices is not None: + # For sparse format with global_top_indices, project to full vocab then slice + full_projection = TokenAligner.project_token_likelihoods_sparse(input_likelihoods, sparse_matrix*self.projection_matrix_multiplier, device) + return full_projection[:, :, global_top_indices] + else: + return TokenAligner.project_token_likelihoods_sparse(input_likelihoods, sparse_matrix*self.projection_matrix_multiplier, device) + else: + # If projection map is learnable, fall back to dense scatter path to preserve gradients + if getattr(projection_map_values, "requires_grad", False): + scale_trick_enabled = self.enable_scale_trick if self.enable_scale_trick is not None else False + return TokenAligner.project_token_likelihoods_dense( + input_likelihoods, + projection_map_indices, + projection_map_values * self.projection_matrix_multiplier, + target_vocab_size, + device, + use_vectorized=True, + gpu_optimized_scatter=gpu_optimized_scatter, + enable_scale_trick=scale_trick_enabled, + global_top_indices=global_top_indices, + ) + + # Otherwise, use stateless CSR matmul (no caching) for memory efficiency + vs = projection_map_indices.shape[0] + top_k = projection_map_indices.shape[1] + # Ensure device/dtype for indices/values + idx = projection_map_indices.to(device) + val = (projection_map_values * self.projection_matrix_multiplier).to(device) + if val.dtype != input_likelihoods.dtype: + val = val.to(input_likelihoods.dtype) + # Build CSR once per call outside autograd to keep checkpoint recomputation identical + with torch.no_grad(): + crow_indices = torch.arange(0, (vs + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = idx.reshape(-1) + values = val.reshape(-1) + # _exact_map_remapped matrices use -1 as a padding sentinel for missing + # entries. A -1 column index is illegal in a CSR tensor and causes a + # CUDA illegal-memory-access during the subsequent matmul. Clamp those + # entries to column 0 and zero their value so they contribute nothing. + pad_mask = col_indices < 0 + if pad_mask.any(): + col_indices = col_indices.clone() + col_indices[pad_mask] = 0 + values = values.clone() + values[pad_mask] = 0.0 + proj_csr = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(vs, target_vocab_size), device=device + ) + # Matmul: [B, S, Vs] -> [B*S, Vs] @ [Vs, Vt] -> [B*S, Vt] -> [B, S, Vt] + bsz, seqlen, vs_in = input_likelihoods.shape + if vs_in != vs: + # In case logits have extra vocab tail, slice to match + x = input_likelihoods[:, :, :vs] + else: + x = input_likelihoods + x2d = x.reshape(bsz * seqlen, vs) + out2d = torch.matmul(x2d.to(torch.float32), proj_csr.to(torch.float32)) + out = out2d.reshape(bsz, seqlen, target_vocab_size).to(input_likelihoods.dtype) + return out + + @staticmethod + def project_token_likelihoods_dense(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device, use_vectorized=True, gpu_optimized_scatter=True, enable_scale_trick=None, global_top_indices=None): + """ + Projects token likelihoods from a source to a target vocabulary using dense indices/values format. + + Args: + global_top_indices: Optional tensor of shape (K,) containing indices of target tokens to project to. + If provided, only projects to these K tokens instead of full target_vocab_size. + Results in (batch, seq, K) output instead of (batch, seq, target_vocab_size). + MAJOR SPEEDUP: Reduces both memory and compute significantly. + """ + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + if abs(source_vocab_size - projection_map_indices.shape[0]) > 1000: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches projection map size ({projection_map_indices.shape[0]})") + + top_k = projection_map_indices.shape[1] + input_likelihoods = input_likelihoods.to(device) + if projection_map_indices.device != device: + projection_map_indices = projection_map_indices.to(device) + if projection_map_values.device != device: + projection_map_values = projection_map_values.to(device) + #do for dtype + if projection_map_values.dtype != input_likelihoods.dtype: + projection_map_values = projection_map_values.to(input_likelihoods.dtype) + + # else: + # projection_map_values = projection_map_values.to(device) + + if use_vectorized: + # Solution 1: Efficient dense implementation using vectorized operations for small top_k + source_vocab_size_fixed = projection_map_indices.shape[0] + input_likelihoods_fixed = input_likelihoods[:, :, :source_vocab_size_fixed] + + # OPTIMIZATION: Use reduced vocabulary if global_top_indices provided + if global_top_indices is not None: + k_indices = len(global_top_indices) + global_top_indices = global_top_indices.to(device) + + # Create mapping from full target indices to reduced indices [0, 1, 2, ..., k-1] + full_to_reduced_map = torch.full((target_vocab_size,), -1, device=device, dtype=torch.long) + full_to_reduced_map[global_top_indices] = torch.arange(k_indices, device=device) + + # Initialize smaller output tensor - MAJOR MEMORY SAVINGS + projected_likelihoods = torch.zeros(batch_size, seq_len, k_indices, + device=device, dtype=input_likelihoods.dtype) + effective_vocab_size = k_indices + + # Filter projection matrices to only include mappings to global_top_indices + # This will be used in the scatter operations below + use_reduced_projection = True + else: + # Initialize full output tensor + projected_likelihoods = torch.zeros(batch_size, seq_len, target_vocab_size, + device=device, dtype=input_likelihoods.dtype) + effective_vocab_size = target_vocab_size + use_reduced_projection = False + + # Optimized chunked processing with multiple speedup techniques + # Use larger chunks for better amortization of fixed costs + max_memory_mb = 200 # Increased for better performance + # max_memory_mb = 500 # Increased for better performance + elements_per_chunk = max_memory_mb * 1024 * 1024 // 4 # 4 bytes per float32 + chunk_size = max(512, min(source_vocab_size_fixed, elements_per_chunk // (batch_size * seq_len))) + + + use_masking = False + # Process vocabulary in optimized chunks + for chunk_start in range(0, source_vocab_size_fixed, chunk_size): + chunk_end = min(chunk_start + chunk_size, source_vocab_size_fixed) + chunk_len = chunk_end - chunk_start + + + input_chunk = input_likelihoods_fixed[:, :, chunk_start:chunk_end] # (B, S, chunk_len) + indices_chunk = projection_map_indices[chunk_start:chunk_end, :] # (chunk_len, top_k) + values_chunk = projection_map_values[chunk_start:chunk_end, :] # (chunk_len, top_k) + + # Extract input chunk once per chunk (not per k) - major speedup + # Determine effective top_k (exclude last column if scale trick is enabled) + scale_trick_enabled = enable_scale_trick if enable_scale_trick is not None else False + effective_top_k = top_k - 1 if scale_trick_enabled else top_k + # effective_top_k = 1 + + if gpu_optimized_scatter: + if use_masking: + # Process one k at a time to reduce peak memory usage + for k in range(effective_top_k): + values_k = values_chunk[:, k] + valid_mask_k = values_k > 1e-4 + if not valid_mask_k.any(): + continue + + source_indices_k = torch.nonzero(valid_mask_k, as_tuple=True)[0] + + input_subset_k = input_chunk[:, :, source_indices_k] + values_subset_k = values_k[source_indices_k] + + indices_k = indices_chunk[:, k] + target_indices_subset_k = indices_k[source_indices_k] + + weighted_inputs_k = input_subset_k * values_subset_k.view(1, 1, -1) + expanded_target_indices_k = target_indices_subset_k.view(1, 1, -1).expand(batch_size, seq_len, -1) + + projected_likelihoods.scatter_add_(2, expanded_target_indices_k, weighted_inputs_k) + else: + # Compact, un-masked implementation + # Process only effective columns without creating intermediate tensors + input_expanded = input_chunk.unsqueeze(-1) # (B, S, chunk_len, 1) + + for k in range(effective_top_k): + values_k = values_chunk[:, k:k+1] # (chunk_len, 1) - view, no copy + indices_k = indices_chunk[:, k] # (chunk_len,) + + if use_reduced_projection: + # OPTIMIZATION: Only project to indices in global_top_indices + # Map full indices to reduced indices and filter out invalid ones + reduced_indices_k = full_to_reduced_map[indices_k] # (chunk_len,) + valid_mask = reduced_indices_k != -1 # Only keep indices in global_top_indices + + if not valid_mask.any(): + continue # Skip if no valid indices in this chunk + + # Filter to only valid entries - MAJOR COMPUTE SAVINGS + valid_indices = torch.nonzero(valid_mask, as_tuple=True)[0] + reduced_indices_filtered = reduced_indices_k[valid_indices] + values_filtered = values_k.squeeze(-1)[valid_indices] # (valid_count,) + input_filtered = input_chunk[:, :, valid_indices] # (B, S, valid_count) + + weighted_k = input_filtered * values_filtered.unsqueeze(0).unsqueeze(0) + indices_expanded = reduced_indices_filtered.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, -1) + projected_likelihoods.scatter_add_(2, indices_expanded, weighted_k) + else: + # Standard full projection + weighted_k = input_expanded * values_k.unsqueeze(0).unsqueeze(0) # (B, S, chunk_len, 1) + weighted_k = weighted_k.squeeze(-1) # (B, S, chunk_len) + + indices_expanded = indices_k.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, -1) + projected_likelihoods.scatter_add_(2, indices_expanded, weighted_k) + else: + # Original implementation with a loop over top_k + if True: # For small top_k, process all k together + # Broadcast input: (B, S, chunk_len, 1) * (1, 1, chunk_len, top_k) -> (B, S, chunk_len, top_k) + weighted_inputs = input_chunk.unsqueeze(-1) * values_chunk.unsqueeze(0).unsqueeze(0) + + # Process all k simultaneously using advanced indexing + for k in range(effective_top_k): + target_indices_k = indices_chunk[:, k] # (chunk_len,) + weighted_k = weighted_inputs[:, :, :, k] # (B, S, chunk_len) + + if use_reduced_projection: + # OPTIMIZATION: Only project to indices in global_top_indices + reduced_indices_k = full_to_reduced_map[target_indices_k] # (chunk_len,) + valid_mask = reduced_indices_k != -1 + + if not valid_mask.any(): + continue # Skip if no valid indices + + # Filter to only valid entries - MAJOR COMPUTE SAVINGS + valid_indices = torch.nonzero(valid_mask, as_tuple=True)[0] + reduced_indices_filtered = reduced_indices_k[valid_indices] + weighted_filtered = weighted_k[:, :, valid_indices] # (B, S, valid_count) + + target_expanded = reduced_indices_filtered.view(1, 1, -1).expand(batch_size, seq_len, len(valid_indices)) + projected_likelihoods.scatter_add_(2, target_expanded, weighted_filtered) + else: + # Use optimized scatter with pre-expanded indices (avoid .expand() in loop) + target_expanded = target_indices_k.view(1, 1, -1).expand(batch_size, seq_len, chunk_len) + projected_likelihoods.scatter_add_(2, target_expanded, weighted_k) + + # else: # For larger top_k, use optimized sequential processing + # for k in range(top_k): + # target_indices_k = indices_chunk[:, k] # (chunk_len,) + # target_values_k = values_chunk[:, k] # (chunk_len,) + + # # Skip projections marked with -1 + # valid_mask = target_values_k > -0.00001 + # if not valid_mask.any(): + # continue + + # # Only process valid projections + # valid_target_indices = target_indices_k[valid_mask] + # valid_target_values = target_values_k[valid_mask] + # valid_input = input_chunk[valid_mask] + + # weighted_input = valid_input * valid_target_values.view(-1, 1, 1) + + # # Direct scatter (simpler and often faster than index caching) + # target_expanded = valid_target_indices.view(1, 1, -1).expand(batch_size, seq_len, valid_target_indices.size(0)) + # projected_likelihoods.scatter_add_(2, target_expanded, weighted_input) + + return projected_likelihoods + else: + # Solution 2: Sparse matrix approach (original implementation) + source_vocab_size_fixed = projection_map_indices.shape[0] + + # Create sparse CSR matrix + crow_indices = torch.arange(0, (source_vocab_size_fixed + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = projection_map_indices.flatten() + values = projection_map_values.flatten() + + sparse_projection_matrix = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(source_vocab_size_fixed, target_vocab_size), device=device + ) + + # Apply sparse matrix multiplication + input_likelihoods_fixed = input_likelihoods[:, :, :source_vocab_size_fixed] + reshaped_input = input_likelihoods_fixed.reshape(batch_size * seq_len, source_vocab_size) + + projected_likelihoods_reshaped = torch.matmul(reshaped_input.to(torch.float32), sparse_projection_matrix.to(torch.float32)) + + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, target_vocab_size).to(input_likelihoods.dtype) + + @staticmethod + def project_token_likelihoods_sparse(input_likelihoods, sparse_matrix, device): + """Projects token likelihoods using a sparse transformation matrix.""" + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + + # Get dimensions from sparse matrix + matrix_input_size, matrix_output_size = sparse_matrix.shape + + if abs(source_vocab_size - matrix_input_size) > 1000: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches sparse matrix input size ({matrix_input_size})") + + # Move to correct device and dtype + # input_likelihoods = input_likelihoods.to(device) + # sparse_matrix = sparse_matrix.to(device) + + # Adjust input size to match matrix dimensions + # next 2 lines required when we used vocab length from tokenizer, now we use the size of logits + # source_vocab_size_fixed = min(source_vocab_size, matrix_input_size) + # input_likelihoods_fixed = input_likelihoods[:, :, :source_vocab_size_fixed] + input_likelihoods_fixed = input_likelihoods + + # Reshape for matrix multiplication + reshaped_input = input_likelihoods_fixed.reshape(batch_size * seq_len, source_vocab_size) + + # Project using sparse matrix multiplication + projected_likelihoods_reshaped = torch.matmul(reshaped_input.to(torch.float32), sparse_matrix.to(torch.float32)) + + # Reshape back to original format + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, matrix_output_size).to(input_likelihoods.dtype) + + def align(self, student_seq: Union[List[str], List[List[str]], List[int], List[List[int]]], + teacher_seq: Union[List[str], List[List[str]], List[int], List[List[int]]], + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + ignore_leading_char_diff=False, + chunk_size=128, + post_process=True, + convert_ids_to_tokens=True, + anchor_lengths=[3,], + _debug_timing=False): + """Align two sequences of tokens (or batches of sequences).""" + import time as _time + + seq1 = student_seq + seq2 = teacher_seq + + _t_convert = 0.0 + if isinstance(seq1, torch.Tensor): + seq1 = seq1.cpu().tolist() + seq2 = seq2.cpu().tolist() + if convert_ids_to_tokens: + _tc0 = _time.time() + seq1 = [self.student_tokenizer.convert_ids_to_tokens(seq1_single) for seq1_single in seq1] + seq2 = [self.teacher_tokenizer.convert_ids_to_tokens(seq2_single) for seq2_single in seq2] + _t_convert = _time.time() - _tc0 + + is_batched = isinstance(seq1, list) and len(seq1) > 0 and isinstance(seq1[0], list) + + _t_canon_total = 0.0 + _t_anchors_dp_total = 0.0 + _t_postprocess_total = 0.0 + _t_mask_total = 0.0 + + if is_batched: + if not (isinstance(seq2, list) and len(seq2) == len(seq1) and (len(seq2) == 0 or isinstance(seq2[0], list))): + raise ValueError("For batched input, seq1 and seq2 must be lists of lists with the same length.") + + all_aligned_pairs = [] + for s1, s2 in zip(seq1, seq2): + aligned_pairs, timings = self._align_single(s1, s2, exact_match_score, combination_score_multiplier, gap_penalty, ignore_leading_char_diff, chunk_size, post_process, anchor_lengths, _return_timings=True) + all_aligned_pairs.append(aligned_pairs) + _t_canon_total += timings.get("canon", 0) + _t_anchors_dp_total += timings.get("anchors_dp", 0) + _t_postprocess_total += timings.get("postprocess", 0) + _t_mask_total += timings.get("mask", 0) + else: + aligned_pairs, timings = self._align_single(seq1, seq2, exact_match_score, combination_score_multiplier, gap_penalty, ignore_leading_char_diff, chunk_size, post_process, anchor_lengths, _return_timings=True) + all_aligned_pairs = [aligned_pairs] + _t_canon_total += timings.get("canon", 0) + _t_anchors_dp_total += timings.get("anchors_dp", 0) + _t_postprocess_total += timings.get("postprocess", 0) + _t_mask_total += timings.get("mask", 0) + + if _debug_timing: + n = len(all_aligned_pairs) + print(f" [align timing] convert_ids={_t_convert:.3f}s, " + f"canonicalize={_t_canon_total:.3f}s, " + f"anchors+DP={_t_anchors_dp_total:.3f}s, " + f"postprocess={_t_postprocess_total:.3f}s, " + f"mask={_t_mask_total:.3f}s " + f"(n={n})", flush=True) + + return all_aligned_pairs + + def _align_single(self, seq1, seq2, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + ignore_leading_char_diff=True, + chunk_size=0, + post_process=True, + anchor_lengths=None, + _return_timings=False): + """Align two sequences of tokens.""" + import time as _time + + _tc0 = _time.time() + seq1_canon = TokenAligner._canonicalize_sequence(seq1) + seq2_canon = TokenAligner._canonicalize_sequence(seq2) + _tc1 = _time.time() + + align_kwargs = { + 'exact_match_score': exact_match_score, + 'combination_score_multiplier': combination_score_multiplier, + 'gap_penalty': gap_penalty, + 'max_combination_len': self.max_combination_len, + 'ignore_leading_char_diff': False, + 'chunk_size': chunk_size, + 'anchor_lengths': anchor_lengths, + } + + aligned_pairs, _ = self._align_with_anchors(seq1_canon, seq2_canon, **align_kwargs) + _tc2 = _time.time() + + if post_process: + aligned_pairs = self.post_process_alignment_optimized( + aligned_pairs, + ignore_leading_char_diff=ignore_leading_char_diff, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=self.max_combination_len + ) + _tc3 = _time.time() + + mask = self.get_alignment_mask(aligned_pairs, use_canonicalization=True, ignore_leading_char_diff=ignore_leading_char_diff) + aligned_pairs = [ + (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, mask_value) + for (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end), mask_value in zip(aligned_pairs, mask) + ] + _tc4 = _time.time() + + timings = { + "canon": _tc1 - _tc0, + "anchors_dp": _tc2 - _tc1, + "postprocess": _tc3 - _tc2, + "mask": _tc4 - _tc3, + } + + if _return_timings: + return aligned_pairs, timings + return aligned_pairs + + + def _align_with_anchors(self, seq1, seq2, anchor_lengths=[3,], **kwargs): + """ + Optimized alignment using unique 1-to-1 matches as anchors. + """ + # CRITICAL FIX: If anchor_lengths is empty, disable anchor optimization completely + if not anchor_lengths: + return self._perform_dp_alignment(seq1, seq2, **kwargs) + + if anchor_lengths is None: + anchor_lengths = [3, 2] # Default: check 3-token, then 2-token sequences + + # Debug output + debug = kwargs.get('debug', False) + + # 1. Find high-confidence anchor points using unique token matches. + s1_counts = {} + for i, t in enumerate(seq1): + if t not in s1_counts: s1_counts[t] = [] + s1_counts[t].append(i) + + s2_counts = {} + for i, t in enumerate(seq2): + if t not in s2_counts: s2_counts[t] = [] + s2_counts[t].append(i) + + # Find potential anchors using consecutive token sequences + potential_anchors = [] + + # FIXED: Don't break early - collect anchors from all lengths and then choose the best + all_potential_anchors = [] + + # Check for anchors of different lengths + for anchor_len in anchor_lengths: + anchors_for_this_len = [] + + if anchor_len == 1: + # Handle single token anchors + common_tokens = s1_counts.keys() & s2_counts.keys() + for token in common_tokens: + if len(s1_counts[token]) == 1 and len(s2_counts[token]) == 1: + i = s1_counts[token][0] + j = s2_counts[token][0] + anchors_for_this_len.append((i, j, anchor_len)) + else: + # Handle multi-token anchors + s1_ngram_counts = {} + for i in range(len(seq1) - anchor_len + 1): + ngram = tuple(seq1[i:i + anchor_len]) + if ngram not in s1_ngram_counts: + s1_ngram_counts[ngram] = [] + s1_ngram_counts[ngram].append(i) + + s2_ngram_counts = {} + for i in range(len(seq2) - anchor_len + 1): + ngram = tuple(seq2[i:i + anchor_len]) + if ngram not in s2_ngram_counts: + s2_ngram_counts[ngram] = [] + s2_ngram_counts[ngram].append(i) + + # Find n-grams that appear exactly once in both sequences + common_ngrams = s1_ngram_counts.keys() & s2_ngram_counts.keys() + for ngram in common_ngrams: + if len(s1_ngram_counts[ngram]) == 1 and len(s2_ngram_counts[ngram]) == 1: + i = s1_ngram_counts[ngram][0] + j = s2_ngram_counts[ngram][0] + # ADDED: Verify the anchor is actually correct + if (i + anchor_len <= len(seq1) and j + anchor_len <= len(seq2) and + seq1[i:i + anchor_len] == seq2[j:j + anchor_len]): + anchors_for_this_len.append((i, j, anchor_len)) + + all_potential_anchors.extend(anchors_for_this_len) + + # IMPROVED: Choose the best set of anchors + # Prefer longer anchors, but if shorter anchors give better coverage, use them + + # Sort by position and filter for monotonic ordering + all_potential_anchors.sort() + + # IMPROVED: Better anchor selection - use greedy approach to maximize coverage + selected_anchors = [] + used_positions_seq1 = set() + used_positions_seq2 = set() + + # Sort by anchor length (descending) then by position + all_potential_anchors.sort(key=lambda x: (-x[2], x[0], x[1])) + + for i, j, anchor_len in all_potential_anchors: + # Check if this anchor conflicts with already selected ones + seq1_range = set(range(i, i + anchor_len)) + seq2_range = set(range(j, j + anchor_len)) + + if not (seq1_range & used_positions_seq1) and not (seq2_range & used_positions_seq2): + # This anchor doesn't conflict - we can use it + selected_anchors.append((i, j, anchor_len)) + used_positions_seq1.update(seq1_range) + used_positions_seq2.update(seq2_range) + + # Re-sort selected anchors by position for processing + selected_anchors.sort() + + # IMPROVED: Additional validation of selected anchors + validated_anchors = [] + last_j = -1 + for i, j, anchor_len in selected_anchors: + # Ensure monotonic ordering and no overlaps + if j > last_j: + # Double-check the anchor is valid + if (i + anchor_len <= len(seq1) and j + anchor_len <= len(seq2) and + seq1[i:i + anchor_len] == seq2[j:j + anchor_len]): + validated_anchors.append((i, j, anchor_len)) + last_j = j + anchor_len - 1 + + anchors = validated_anchors + + if not anchors: + # If no anchors are found, fall back to the standard alignment. + return self._perform_dp_alignment(seq1, seq2, **kwargs) + + # 2. Align segments between anchors. + full_alignment = [] + last_i, last_j = 0, 0 + + for anchor_idx, (i, j, anchor_len) in enumerate(anchors): + + # Align segment before the current anchor. + seg1, seg2 = seq1[last_i:i], seq2[last_j:j] + + if seg1 or seg2: + aligned_segment, _ = self._perform_dp_alignment(seg1, seg2, **kwargs) + + # Adjust indices to be relative to the full sequence and split exact matches. + for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end in aligned_segment: + new_s1_start = s1_start + last_i if s1_start != -1 else -1 + new_s1_end = s1_end + last_i if s1_end != -1 else -1 + new_s2_start = s2_start + last_j if s2_start != -1 else -1 + new_s2_end = s2_end + last_j if s2_end != -1 else -1 + + # Split if both sides have the same tokens + if (len(s1_toks) > 1 and len(s2_toks) > 1 and + len(s1_toks) == len(s2_toks) and s1_toks == s2_toks): + # Split into individual 1-to-1 matches + for k in range(len(s1_toks)): + full_alignment.append(( + [s1_toks[k]], [s2_toks[k]], + new_s1_start + k, new_s1_start + k + 1, + new_s2_start + k, new_s2_start + k + 1 + )) + else: + full_alignment.append((s1_toks, s2_toks, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + + # Add the anchor itself (consecutive tokens), also split if needed. + anchor_seq1 = seq1[i:i + anchor_len] + anchor_seq2 = seq2[j:j + anchor_len] + + # Split anchor into individual matches since they should be identical + for k in range(anchor_len): + full_alignment.append(( + [anchor_seq1[k]], [anchor_seq2[k]], + i + k, i + k + 1, + j + k, j + k + 1 + )) + + last_i, last_j = i + anchor_len, j + anchor_len + + # 3. Align the final segment after the last anchor. + seg1, seg2 = seq1[last_i:], seq2[last_j:] + + if seg1 or seg2: + aligned_segment, _ = self._perform_dp_alignment(seg1, seg2, **kwargs) + + for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end in aligned_segment: + new_s1_start = s1_start + last_i if s1_start != -1 else -1 + new_s1_end = s1_end + last_i if s1_end != -1 else -1 + new_s2_start = s2_start + last_j if s2_start != -1 else -1 + new_s2_end = s2_end + last_j if s2_end != -1 else -1 + + # Split if both sides have the same tokens + if (len(s1_toks) > 1 and len(s2_toks) > 1 and + len(s1_toks) == len(s2_toks) and s1_toks == s2_toks): + # Split into individual 1-to-1 matches + for k in range(len(s1_toks)): + full_alignment.append(( + [s1_toks[k]], [s2_toks[k]], + new_s1_start + k, new_s1_start + k + 1, + new_s2_start + k, new_s2_start + k + 1 + )) + else: + full_alignment.append((s1_toks, s2_toks, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + + return full_alignment, 0 # Return 0 for score as it's not well-defined here + + def _perform_dp_alignment(self, seq1, seq2, **kwargs): + """ + Helper function to run the core DP-based alignment. + """ + chunk_size = kwargs.get('chunk_size', 0) + kwargs.pop('chunk_size', None) + kwargs.pop('anchor_lengths', None) + + if chunk_size > 0: + return self.align_tokens_combinations_chunked(seq1, seq2, chunk_size=chunk_size, **kwargs) + else: + return self.align_tokens_with_combinations_numpy_jit(seq1, seq2, **kwargs) + + @staticmethod + def _canonical_token(token: str) -> str: + """Return a canonical representation of a tokenizer token.""" + if not token: + return token + + # 1. Normalize space prefixes first + if token.startswith(' '): + token = 'Ġ' + token[1:] + elif token.startswith('_'): + token = 'Ġ' + token[1:] + elif token.startswith('▁'): # SentencePiece-style space prefix + token = 'Ġ' + token[1:] + + # 1.5. Normalize newline and whitespace representations + if token == 'Ċ': # GPT-style newline (used by Llama) + token = '\n' + elif token == '\\n': # Escaped newline representation + token = '\n' + elif token == 'ĉ': # Alternative newline representation + token = '\n' + elif token == 'Ġ\n': # Space + newline combination + token = '\n' + elif 'Ċ' in token: # Handle Ċ embedded in other tokens + token = token.replace('Ċ', '\n') + elif '\\n' in token: # Handle escaped newlines in compound tokens + token = token.replace('\\n', '\n') + + # 1.6. Handle space-separated punctuation normalization + if token == 'Ġ,': # Space + comma + token = ',' + elif token == 'Ġ.': # Space + period + token = '.' + elif token == 'Ġ;': # Space + semicolon + token = ';' + elif token == 'Ġ:': # Space + colon + token = ':' + + # 2. Handle SentencePiece byte fallback tokens like <0x20> + if token.startswith('<0x') and token.endswith('>') and len(token) == 6: + try: + byte_val = int(token[3:5], 16) + if 0 <= byte_val <= 255: + return chr(byte_val) + except ValueError: + pass + + # 3. Normalize common Unicode encoding issues + unicode_fixes = { + # Spanish + 'ñ': 'ñ', 'á': 'á', 'é': 'é', 'í': 'í', 'ó': 'ó', 'ú': 'ú', + 'Ã': 'À', 'â': 'â', 'ç': 'ç', + # French + 'ç': 'ç', 'è': 'è', 'é': 'é', 'ë': 'ë', 'î': 'î', 'ô': 'ô', + 'ù': 'ù', 'û': 'û', 'ÿ': 'ÿ', + # Chinese (common encoding artifacts) + 'ä¸Ń': '中', 'æĸĩ': '文', 'æĹ¥æľ¬': '日本', 'èªŀ': '語', + # Russian + 'ÐłÑĥÑģ': 'Рус', 'Ñģкий': 'ский', + # Arabic + 'اÙĦعربÙĬØ©': 'العربية', + # Hindi + 'ह': 'ह', 'िà¤Ĥ': 'हिं', 'दà¥Ģ': 'दी', + # Mathematical symbols (common artifacts) + 'âĪij': '∑', 'âĪı': '∏', 'âĪĤ': '∂', 'âĪĩ': '∇', + 'âĪŀ': '∞', 'âĪļ': '√', 'âĪ«': '∫', 'âīĪ': '≈', + 'âīł': '≠', 'âī¤': '≤', 'âī¥': '≥', + } + + # Apply Unicode fixes + for broken, fixed in unicode_fixes.items(): + if broken in token: + token = token.replace(broken, fixed) + + # 4. Normalize special tokens + special_token_map = { + '<|begin_of_text|>': '', # Llama-style BOS token + '': '', # Standard BOS token + '': '', # Padding tokens → empty (will be handled by alignment) + '': ' ', # End tokens + '': ' ', # End tokens + } + + if token in special_token_map: + return special_token_map[token] + + return token + + @staticmethod + def _canonicalize_sequence(seq: List[str]) -> List[str]: + """Canonicalize every token in a sequence (list of str).""" + # First, handle multi-token encoding artifacts (before individual canonicalization) + merged_artifacts = TokenAligner._merge_encoding_artifacts(seq) + + # Then, canonicalize individual tokens + canon_tokens = [TokenAligner._canonical_token(tok) for tok in merged_artifacts] + + # Finally, merge consecutive byte tokens into proper Unicode characters + return TokenAligner._merge_consecutive_bytes(canon_tokens) + + @staticmethod + def _merge_encoding_artifacts(tokens: List[str]) -> List[str]: + """Merge consecutive tokens that represent multi-token encoding artifacts.""" + if not tokens: + return tokens + + # Common multi-token encoding artifacts that should be merged + multi_token_fixes = [ + # Mathematical symbols split across tokens + (['ĠâĪ', 'ij'], ['Ġ∑']), # Sum symbol + (['âĪ', 'ij'], ['∑']), # Sum symbol (no space) + (['ĠâĪ', 'ı'], ['Ġ∏']), # Product symbol + (['âĪ', 'ı'], ['∏']), # Product symbol (no space) + (['ĠâĪ', 'Ĥ'], ['Ġ∂']), # Partial derivative + (['âĪ', 'Ĥ'], ['∂']), # Partial derivative (no space) + (['ĠâĪ', 'ĩ'], ['Ġ∇']), # Nabla/gradient + (['âĪ', 'ĩ'], ['∇']), # Nabla/gradient (no space) + (['ĠâĪ', 'ŀ'], ['Ġ∞']), # Infinity + (['âĪ', 'ŀ'], ['∞']), # Infinity (no space) + (['ĠâĪ', 'ļ'], ['Ġ√']), # Square root + (['âĪ', 'ļ'], ['√']), # Square root (no space) + (['ĠâĪ', '«'], ['Ġ∫']), # Integral + (['âĪ', '«'], ['∫']), # Integral (no space) + (['Ġâī', 'ł'], ['Ġ≠']), # Not equal + (['âī', 'ł'], ['≠']), # Not equal (no space) + # Other common multi-token artifacts + (['Ġä¸', 'Ń'], ['Ġ中']), # Chinese character + (['ä¸', 'Ń'], ['中']), # Chinese character (no space) + (['æĸ', 'ĩ'], ['文']), # Chinese character + (['Ġæĸ', 'ĩ'], ['Ġ文']), # Chinese character (with space) + ] + + result = [] + i = 0 + + while i < len(tokens): + # Check if current position matches any multi-token pattern + matched = False + + for pattern, replacement in multi_token_fixes: + pattern_len = len(pattern) + if i + pattern_len <= len(tokens): + # Check if the tokens match the pattern + if tokens[i:i+pattern_len] == pattern: + # Replace with the fixed version + result.extend(replacement) + i += pattern_len + matched = True + break + + if not matched: + # No pattern matched, keep the original token + result.append(tokens[i]) + i += 1 + + return result + + @staticmethod + def _merge_consecutive_bytes(tokens: List[str]) -> List[str]: + """Merge consecutive tokens that represent UTF-8 byte sequences.""" + if not tokens: + return tokens + + result = [] + byte_buffer = [] + + for token in tokens: + # Check if this token represents byte(s) + clean_token = token.lstrip('Ġ') + + # Check if all characters in the token are visual bytes + all_chars_are_bytes = True + if len(clean_token) == 0: + all_chars_are_bytes = False + else: + for char in clean_token: + if TokenAligner._get_byte_value(char) is None: + all_chars_are_bytes = False + break + + if all_chars_are_bytes: + byte_buffer.append(token) + else: + # Not a byte token, flush buffer first + if byte_buffer: + merged = TokenAligner._try_merge_byte_buffer(byte_buffer) + result.extend(merged) + byte_buffer = [] + result.append(token) + + # Flush any remaining bytes + if byte_buffer: + merged = TokenAligner._try_merge_byte_buffer(byte_buffer) + result.extend(merged) + + return result + + @staticmethod + def _try_merge_byte_buffer(byte_tokens: List[str]) -> List[str]: + """Try to merge a buffer of potential byte tokens into a Unicode character.""" + if not byte_tokens: + return [] + + # If only one token, just return it unless it's a multi-character byte token + if len(byte_tokens) == 1: + token = byte_tokens[0] + clean_token = token.lstrip('Ġ') + if len(clean_token) <= 1: + return byte_tokens + # Continue processing multi-character token + + # Extract space prefix from first token + first_token = byte_tokens[0] + space_prefix = 'Ġ' if first_token.startswith('Ġ') else '' + + # Extract raw bytes from all characters in all tokens + raw_bytes = [] + for token in byte_tokens: + clean_token = token.lstrip('Ġ') + for char in clean_token: + byte_value = TokenAligner._get_byte_value(char) + if byte_value is not None: + raw_bytes.append(byte_value) + else: + # If any character is not a byte, return original tokens + return byte_tokens + + # Only try to merge if we have 2-4 bytes (typical for emoji/multi-byte chars) + if len(raw_bytes) < 2 or len(raw_bytes) > 4: + return byte_tokens + + # Try to decode as UTF-8 + try: + decoded_text = bytes(raw_bytes).decode('utf-8') + # Only merge if the result is a single Unicode character (like an emoji) + if len(decoded_text) == 1 and ord(decoded_text) > 127: + return [space_prefix + decoded_text] + else: + # If it's not a single special character, keep original tokens + return byte_tokens + except UnicodeDecodeError: + # If decoding fails, return original tokens + return byte_tokens + + # Common visual byte representations used by some tokenizers (especially for emojis) + VISUAL_BYTE_MAP = { + # Common emoji byte range (240-255) + 'ð': 240, 'Ɩ': 241, 'Ɨ': 242, 'Ƙ': 243, 'ƙ': 244, 'ƚ': 245, 'ƛ': 246, 'Ɯ': 247, + 'Ɲ': 248, 'ƞ': 249, 'Ɵ': 250, 'Ơ': 251, 'ơ': 252, 'Ƣ': 253, 'ƣ': 254, 'Ƥ': 255, + # Other common byte representations (0-255 only) + 'Ł': 156, 'ł': 157, 'Ń': 158, 'ń': 159, 'ĺ': 149, 'Ļ': 150, 'ļ': 151, 'Ľ': 152, + 'ľ': 153, 'Ŀ': 154, 'ŀ': 155, 'Ĭ': 135, 'ĭ': 136, 'Į': 137, 'į': 138, 'İ': 139, + 'ı': 140, 'IJ': 141, 'ij': 142, 'Ĵ': 143, 'ĵ': 144, 'Ķ': 145, 'ķ': 146, 'ĸ': 147, + 'Ĺ': 148, 'ĥ': 128, 'Ħ': 129, 'ħ': 130, 'Ĩ': 131, 'ĩ': 132, 'Ī': 133, 'ī': 134, + 'Ģ': 162, 'ģ': 163, 'Ĝ': 28, 'ĝ': 29, 'Ğ': 30, 'ğ': 31, + } + + @staticmethod + def _get_byte_value(token_char: str) -> int: + """Get the byte value for a character, handling both direct bytes and visual representations.""" + if len(token_char) != 1: + return None + + char_ord = ord(token_char) + + # Direct byte (0-255) + if char_ord < 256: + return char_ord + + # Visual byte representation + if token_char in TokenAligner.VISUAL_BYTE_MAP: + return TokenAligner.VISUAL_BYTE_MAP[token_char] + + return None + + @staticmethod + def _strings_equal_flexible(s1, s2, ignore_leading_char_diff): + if not ignore_leading_char_diff: + return s1 == s2 + + # Use our comprehensive canonicalization for robust comparison + s1_canonical = TokenAligner._canonical_token(s1) + s2_canonical = TokenAligner._canonical_token(s2) + + return s1_canonical == s2_canonical + + def align_tokens_with_combinations_numpy(seq1, seq2, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + max_combination_len=4, + ignore_leading_char_diff=False): + n1, n2 = len(seq1), len(seq2) + dp = np.zeros((n1 + 1, n2 + 1), dtype=np.float32) + trace = np.full((n1 + 1, n2 + 1), '', dtype=object) + + # Initialize DP edges with gap penalties + for i in range(1, n1 + 1): + dp[i, 0] = dp[i - 1, 0] + gap_penalty + trace[i, 0] = 'up' + for j in range(1, n2 + 1): + dp[0, j] = dp[0, j - 1] + gap_penalty + trace[0, j] = 'left' + + # Precompute joined substrings for all valid k-length spans + joined_seq1 = {(i - k, i): ''.join(seq1[i - k:i]) + for i in range(n1 + 1) + for k in range(1, min(i, max_combination_len) + 1)} + joined_seq2 = {(j - k, j): ''.join(seq2[j - k:j]) + for j in range(n2 + 1) + for k in range(1, min(j, max_combination_len) + 1)} + + # Fill DP table + for i in range(1, n1 + 1): + for j in range(1, n2 + 1): + s1_val, s2_val = seq1[i - 1], seq2[j - 1] + match_score = exact_match_score if TokenAligner._strings_equal_flexible(s1_val, s2_val, ignore_leading_char_diff) else -exact_match_score + score_diag = dp[i - 1, j - 1] + match_score + score_up = dp[i - 1, j] + gap_penalty + score_left = dp[i, j - 1] + gap_penalty + + max_score = score_diag + best_move = 'diag' + if score_up > max_score: + max_score = score_up + best_move = 'up' + if score_left > max_score: + max_score = score_left + best_move = 'left' + + # Check for seq1[i-1] == join(seq2[j-k:j]) + for k in range(2, min(j + 1, max_combination_len + 1)): + if (j - k, j) in joined_seq2 and TokenAligner._strings_equal_flexible(s1_val, joined_seq2[(j - k, j)], ignore_leading_char_diff): + comb_score = dp[i - 1, j - k] + combination_score_multiplier * k + if comb_score > max_score: + max_score = comb_score + best_move = f'comb_s1_over_s2_{k}' + + # Check for seq2[j-1] vs seq1[i-k:i] + for k in range(2, min(i + 1, max_combination_len + 1)): + if (i - k, i) in joined_seq1 and TokenAligner._strings_equal_flexible(s2_val, joined_seq1[(i - k, i)], ignore_leading_char_diff): + comb_score = dp[i - k, j - 1] + combination_score_multiplier * k + if comb_score > max_score: + max_score = comb_score + best_move = f'comb_s2_over_s1_{k}' + + dp[i, j] = max_score + trace[i, j] = best_move + + # Backtrack to extract alignment + aligned = [] + i, j = n1, n2 + while i > 0 or j > 0: + move = trace[i, j] + if move == 'diag': + aligned.append(([seq1[i - 1]], [seq2[j - 1]], i - 1, i, j - 1, j)) + i -= 1 + j -= 1 + elif move == 'up': + aligned.append(([seq1[i - 1]], [], i - 1, i, -1, -1)) + i -= 1 + elif move == 'left': + aligned.append(([], [seq2[j - 1]], -1, -1, j - 1, j)) + j -= 1 + elif move.startswith('comb_s1_over_s2_'): + k = int(move.rsplit('_', 1)[-1]) + aligned.append(([seq1[i - 1]], seq2[j - k:j], i - 1, i, j - k, j)) + i -= 1 + j -= k + elif move.startswith('comb_s2_over_s1_'): + k = int(move.rsplit('_', 1)[-1]) + aligned.append((seq1[i - k:i], [seq2[j - 1]], i - k, i, j - 1, j)) + i -= k + j -= 1 + else: + break + + aligned.reverse() + return aligned, dp[n1, n2] + + @staticmethod + def align_tokens_with_combinations_numpy_jit( + seq1, seq2, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + max_combination_len=4, + ignore_leading_char_diff=False, + ): + """Numba-accelerated version of align_tokens_with_combinations_numpy. + + Pre-converts string tokens to integer IDs, runs the DP in a Numba + @njit kernel, then backtracks using the original string tokens. + Falls back to the pure-Python original when Numba is unavailable or + when ignore_leading_char_diff is True (requires Python string logic). + """ + if not _NUMBA_AVAILABLE or ignore_leading_char_diff: + return TokenAligner.align_tokens_with_combinations_numpy( + seq1, seq2, exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff, + ) + + n1, n2 = len(seq1), len(seq2) + if n1 == 0 and n2 == 0: + return [], 0.0 + if n1 == 0: + return [([], [seq2[j]], -1, -1, j, j + 1) for j in range(n2)], n2 * gap_penalty + if n2 == 0: + return [([seq1[i]], [], i, i + 1, -1, -1) for i in range(n1)], n1 * gap_penalty + + token_to_id: dict[str, int] = {} + _next_id = [0] + + def _get_id(s: str) -> int: + tid = token_to_id.get(s) + if tid is None: + tid = _next_id[0] + token_to_id[s] = tid + _next_id[0] += 1 + return tid + + ids1 = np.array([_get_id(t) for t in seq1], dtype=np.int64) + ids2 = np.array([_get_id(t) for t in seq2], dtype=np.int64) + + INVALID = np.int64(-1) + joined1 = np.full((n1 + 1, max_combination_len + 1), INVALID, dtype=np.int64) + for i in range(n1 + 1): + for k in range(2, min(i, max_combination_len) + 1): + joined1[i, k] = _get_id(''.join(seq1[i - k:i])) + + joined2 = np.full((n2 + 1, max_combination_len + 1), INVALID, dtype=np.int64) + for j in range(n2 + 1): + for k in range(2, min(j, max_combination_len) + 1): + joined2[j, k] = _get_id(''.join(seq2[j - k:j])) + + dp, trace = _dp_core_numba( + ids1, ids2, joined1, joined2, n1, n2, + np.float32(exact_match_score), + np.float32(gap_penalty), + np.float32(combination_score_multiplier), + max_combination_len, + ) + + aligned = [] + i, j = n1, n2 + while i > 0 or j > 0: + m = trace[i, j] + if m == 1: + aligned.append(([seq1[i - 1]], [seq2[j - 1]], i - 1, i, j - 1, j)) + i -= 1 + j -= 1 + elif m == 2: + aligned.append(([seq1[i - 1]], [], i - 1, i, -1, -1)) + i -= 1 + elif m == 3: + aligned.append(([], [seq2[j - 1]], -1, -1, j - 1, j)) + j -= 1 + elif 10 <= m < 20: + k = m - 10 + aligned.append(([seq1[i - 1]], seq2[j - k:j], i - 1, i, j - k, j)) + i -= 1 + j -= k + elif 20 <= m < 30: + k = m - 20 + aligned.append((seq1[i - k:i], [seq2[j - 1]], i - k, i, j - 1, j)) + i -= k + j -= 1 + else: + break + + aligned.reverse() + return aligned, float(dp[n1, n2]) + + @staticmethod + def align_tokens_combinations_chunked( + seq1: List[str], + seq2: List[str], + exact_match_score: float = 3.0, + combination_score_multiplier: float = 1.5, + gap_penalty: float = -1.5, + max_combination_len: int = 4, + ignore_leading_char_diff: bool = False, + chunk_size: int = 256, + ): + """ + Chunked processing for very large sequences. + """ + n1, n2 = len(seq1), len(seq2) + + # If sequences are small enough, use regular algorithm + if n1 <= chunk_size and n2 <= chunk_size: + return TokenAligner.align_tokens_with_combinations_numpy_jit( + seq1, seq2, exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff + ) + + # For very large sequences, use divide-and-conquer approach + if n1 > chunk_size or n2 > chunk_size: + # Find approximate midpoint alignment using simplified algorithm + mid1, mid2 = n1 // 2, n2 // 2 + + # Recursively align left and right parts + left_aligned, left_score = TokenAligner.align_tokens_combinations_chunked( + seq1[:mid1], seq2[:mid2], exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff, + chunk_size=chunk_size + ) + + right_aligned, right_score = TokenAligner.align_tokens_combinations_chunked( + seq1[mid1:], seq2[mid2:], exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff, + chunk_size=chunk_size + ) + + # Adjust indices for right part + adjusted_right = [] + for s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end in right_aligned: + new_s1_start = s1_start + mid1 if s1_start >= 0 else -1 + new_s1_end = s1_end + mid1 if s1_end >= 0 else -1 + new_s2_start = s2_start + mid2 if s2_start >= 0 else -1 + new_s2_end = s2_end + mid2 if s2_end >= 0 else -1 + adjusted_right.append((s1_tokens, s2_tokens, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + + # Combine results + combined_aligned = left_aligned + adjusted_right + combined_score = left_score + right_score + + return combined_aligned, combined_score + + # Fallback to regular algorithm + return TokenAligner.align_tokens_with_combinations_numpy_jit( + seq1, seq2, exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff + ) + + @staticmethod + def _combine_consecutive_misaligned_tokens( + aligned_pairs: List, + pair_strings: List, + end_mismatch_threshold: float = 0.2 + ) -> List: + """ + Combine consecutive misaligned tokens into single chunks to improve alignment. + + This addresses cases where multiple tokens are individually misaligned but + collectively represent the same content. Avoids combining tokens near the + end of sequences that might be misaligned due to length differences. + + Args: + aligned_pairs: List of alignment pairs + pair_strings: Precomputed string representations and match status + end_mismatch_threshold: Fraction of sequence from end to avoid chunking + + Returns: + Modified aligned_pairs with consecutive misaligned tokens combined + """ + if not aligned_pairs or len(aligned_pairs) < 2: + return aligned_pairs + + # Calculate the boundary for avoiding end mismatches + sequence_length = len(aligned_pairs) + end_boundary = int(sequence_length * (1 - end_mismatch_threshold)) + + processed_pairs = [] + i = 0 + + while i < len(aligned_pairs): + # Check if current pair is misaligned and not near the end + if (i < end_boundary and + not pair_strings[i][2] and # Current pair is misaligned + i + 1 < len(aligned_pairs)): # Not the last pair + + # Find consecutive misaligned pairs + consecutive_misaligned = [i] + j = i + 1 + + # Look ahead for more consecutive misaligned pairs (up to end boundary) + while (j < end_boundary and + j < len(aligned_pairs) and + not pair_strings[j][2]): # Next pair is also misaligned + consecutive_misaligned.append(j) + j += 1 + + # Only combine if we have multiple consecutive misaligned pairs + if len(consecutive_misaligned) >= 2: + # Combine all consecutive misaligned pairs into one chunk + combined_s1_tokens = [] + combined_s2_tokens = [] + s1_indices = [] + s2_indices = [] + + for idx in consecutive_misaligned: + s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, *rest = aligned_pairs[idx] + combined_s1_tokens.extend(s1_tokens) + combined_s2_tokens.extend(s2_tokens) + + if s1_tokens and s1_start != -1: + s1_indices.extend([s1_start, s1_end]) + if s2_tokens and s2_start != -1: + s2_indices.extend([s2_start, s2_end]) + + # Calculate combined indices + combined_s1_start = min(s1_indices[::2]) if s1_indices else -1 + combined_s1_end = max(s1_indices[1::2]) if s1_indices else -1 + combined_s2_start = min(s2_indices[::2]) if s2_indices else -1 + combined_s2_end = max(s2_indices[1::2]) if s2_indices else -1 + + # Create combined pair + combined_pair = ( + combined_s1_tokens, + combined_s2_tokens, + combined_s1_start, + combined_s1_end, + combined_s2_start, + combined_s2_end + ) + + processed_pairs.append(combined_pair) + i = j # Skip to after the combined region + else: + # Only one misaligned pair, keep as is + processed_pairs.append(aligned_pairs[i]) + i += 1 + else: + # Current pair is aligned or near the end, keep as is + processed_pairs.append(aligned_pairs[i]) + i += 1 + + return processed_pairs + + + @staticmethod + def post_process_alignment_optimized( + aligned_pairs: List, + ignore_leading_char_diff: bool = False, + exact_match_score: float = 3.0, + combination_score_multiplier: float = 1.5, + gap_penalty: float = -1.5, + max_combination_len: int = 4, + combine_misaligned_chunks: bool = True, + end_mismatch_threshold: float = 0.2 + ) -> List: + """ + Optimized version of post_process_alignment with better performance. + + Key optimizations: + 1. Precompute string concatenations to avoid repeated joins + 2. Early termination when no bad regions are found + 3. Cache alignment results for repeated chunk patterns + 4. Vectorized index calculations + 5. Reduced nested loop complexity + 6. Combine multiple consecutive misaligned tokens into single chunks + + Args: + combine_misaligned_chunks: If True, combine consecutive misaligned tokens into chunks + end_mismatch_threshold: Fraction of sequence length from end to avoid chunking (0.2 = last 20%) + """ + if not aligned_pairs: + return [] + + # Precompute joined strings for all pairs to avoid repeated concatenation + # Use canonicalization for robust comparison + pair_strings = [] + for i, (s1_tokens, s2_tokens, *rest) in enumerate(aligned_pairs): + # Canonicalize individual tokens before joining for better matching + s1_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s1_tokens] if s1_tokens else [] + s2_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s2_tokens] if s2_tokens else [] + s1_str = "".join(s1_canonical_tokens) + s2_str = "".join(s2_canonical_tokens) + is_match = TokenAligner._strings_equal_flexible(s1_str, s2_str, ignore_leading_char_diff) + pair_strings.append((s1_str, s2_str, is_match)) + + # Step 1: Handle consecutive misaligned chunks if enabled + if combine_misaligned_chunks: + aligned_pairs = TokenAligner._combine_consecutive_misaligned_tokens( + aligned_pairs, pair_strings, end_mismatch_threshold + ) + + # Recompute pair_strings after combining misaligned chunks + pair_strings = [] + for i, (s1_tokens, s2_tokens, *rest) in enumerate(aligned_pairs): + s1_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s1_tokens] if s1_tokens else [] + s2_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s2_tokens] if s2_tokens else [] + s1_str = "".join(s1_canonical_tokens) + s2_str = "".join(s2_canonical_tokens) + is_match = TokenAligner._strings_equal_flexible(s1_str, s2_str, ignore_leading_char_diff) + pair_strings.append((s1_str, s2_str, is_match)) + + processed_pairs = [] + alignment_cache = {} # Cache for repeated alignment patterns + i = 0 + + while i < len(aligned_pairs): + s1_tokens, s2_tokens, *_ = aligned_pairs[i] + + # Handle coarse alignments that can be split (optimized) + if len(s1_tokens) > 1 and len(s1_tokens) == len(s2_tokens) and s1_tokens == s2_tokens: + s1_start, s1_end, s2_start, s2_end = aligned_pairs[i][2:6] + # Vectorized creation of split pairs + for k in range(len(s1_tokens)): + processed_pairs.append( + ([s1_tokens[k]], [s2_tokens[k]], + s1_start + k, s1_start + k + 1, + s2_start + k, s2_start + k + 1) + ) + i += 1 + continue + + # Find bad regions more efficiently using precomputed strings + start_bad_region = -1 + for j in range(i, len(aligned_pairs)): + if not pair_strings[j][2]: # is_match is False + start_bad_region = j + break + + if start_bad_region == -1: + # No more bad regions - add remaining pairs and exit + processed_pairs.extend(aligned_pairs[i:]) + break + + # Add good pairs before bad region + processed_pairs.extend(aligned_pairs[i:start_bad_region]) + + # Optimized chunk processing with early termination + found_fix = False + max_chunk_size = min(10, len(aligned_pairs) - start_bad_region) # Limit search space + + for chunk_size in range(2, max_chunk_size + 1): + chunk = aligned_pairs[start_bad_region : start_bad_region + chunk_size] + + # Efficient token extraction using list comprehension + chunk_s1_tokens = [] + chunk_s2_tokens = [] + s1_indices = [] + s2_indices = [] + + for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end, *rest in chunk: + chunk_s1_tokens.extend(s1_toks) + chunk_s2_tokens.extend(s2_toks) + if s1_toks: + s1_indices.extend([s1_start, s1_end]) + if s2_toks: + s2_indices.extend([s2_start, s2_end]) + + # Quick string comparison using canonicalization + chunk_s1_canonical = [TokenAligner._canonical_token(tok) for tok in chunk_s1_tokens] + chunk_s2_canonical = [TokenAligner._canonical_token(tok) for tok in chunk_s2_tokens] + chunk_s1_str = "".join(chunk_s1_canonical) + chunk_s2_str = "".join(chunk_s2_canonical) + + if not TokenAligner._strings_equal_flexible(chunk_s1_str, chunk_s2_str, ignore_leading_char_diff): + continue + + # Create cache key for alignment + cache_key = (tuple(chunk_s1_tokens), tuple(chunk_s2_tokens)) + + if cache_key in alignment_cache: + sub_aligned_pairs, realign_is_perfect = alignment_cache[cache_key] + else: + # Perform alignment + sub_aligned_pairs, _ = TokenAligner.align_tokens_with_combinations_numpy( + chunk_s1_tokens, + chunk_s2_tokens, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=max_combination_len, + ignore_leading_char_diff=ignore_leading_char_diff + ) + + # Check if re-alignment was successful using canonicalization + realign_is_perfect = all( + TokenAligner._strings_equal_flexible( + "".join([TokenAligner._canonical_token(tok) for tok in p[0]]), + "".join([TokenAligner._canonical_token(tok) for tok in p[1]]), + ignore_leading_char_diff + ) + for p in sub_aligned_pairs + ) + + # Cache the result + alignment_cache[cache_key] = (sub_aligned_pairs, realign_is_perfect) + + # Vectorized index calculations + s1_chunk_start = min(s1_indices[::2]) if s1_indices else -1 + s2_chunk_start = min(s2_indices[::2]) if s2_indices else -1 + + if realign_is_perfect: + # Add granular aligned pairs + for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end, *_ in sub_aligned_pairs: + new_s1_start = s1_chunk_start + s1_start if s1_start != -1 else -1 + new_s1_end = s1_chunk_start + s1_end if s1_end != -1 else -1 + new_s2_start = s2_chunk_start + s2_start if s2_start != -1 else -1 + new_s2_end = s2_chunk_start + s2_end if s2_end != -1 else -1 + processed_pairs.append((s1_toks, s2_toks, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + else: + # Create merged pair + s1_chunk_end = max(s1_indices[1::2]) if s1_indices else -1 + s2_chunk_end = max(s2_indices[1::2]) if s2_indices else -1 + merged_pair = (chunk_s1_tokens, chunk_s2_tokens, s1_chunk_start, s1_chunk_end, s2_chunk_start, s2_chunk_end) + processed_pairs.append(merged_pair) + + i = start_bad_region + chunk_size + found_fix = True + break + + if not found_fix: + processed_pairs.append(aligned_pairs[start_bad_region]) + i = start_bad_region + 1 + + return processed_pairs + + @staticmethod + def get_alignment_mask(aligned_pairs: List, use_canonicalization: bool = True, + ignore_leading_char_diff: bool = False) -> List[bool]: + """ + Get a boolean mask indicating which alignments are correct. + """ + if not aligned_pairs: + return [] + + # Handle batch case - take first batch + if isinstance(aligned_pairs, list) and aligned_pairs and isinstance(aligned_pairs[0], list): + pairs_to_verify = aligned_pairs[0] + else: + pairs_to_verify = aligned_pairs + + mask = [] + for s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, *rest in pairs_to_verify: + # Concatenate tokens into strings + s1_str = "".join(s1_tokens) if s1_tokens else "" + s2_str = "".join(s2_tokens) if s2_tokens else "" + + # Apply canonicalization if requested + if use_canonicalization: + s1_canonical = "".join([TokenAligner._canonical_token(tok) for tok in s1_tokens]) if s1_tokens else "" + s2_canonical = "".join([TokenAligner._canonical_token(tok) for tok in s2_tokens]) if s2_tokens else "" + is_correct = TokenAligner._strings_equal_flexible(s1_canonical, s2_canonical, ignore_leading_char_diff) + else: + if ignore_leading_char_diff: + is_correct = TokenAligner._strings_equal_flexible(s1_str, s2_str, ignore_leading_char_diff) + else: + is_correct = s1_str == s2_str + + mask.append(is_correct) + + return mask + + + def transform_learned_matrix_instance(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Instance method version that uses instance variables. + """ + scale_trick_enabled = self.enable_scale_trick if self.enable_scale_trick is not None else False + return TokenAligner.transform_learned_matrix(x, dim, enable_scale_trick=scale_trick_enabled) + + @staticmethod + def transform_learned_matrix(x: torch.Tensor, dim: int = -1, enable_scale_trick=None) -> torch.Tensor: + """ + Compute Quite Attention over tensor x along specified dimension. + + Args: + x: Input tensor. + dim: Dimension to apply attention over (default: -1). + + Returns: + Tensor of same shape with quite attention applied. + """ + if 0: + exp_x = torch.exp(x) + denom = 1 + torch.sum(exp_x, dim=dim, keepdim=True) + return exp_x / denom + # write as a single lambda function + # return lambda x: torch.exp(x) / (1 + torch.sum(torch.exp(x), dim=dim, keepdim=True)) + else: + scale_trick_enabled = enable_scale_trick if enable_scale_trick is not None else False + if scale_trick_enabled: + #trick with last column being multiplier of 0..1, or try with c instead of 1 in qa. + scores = torch.nn.functional.softmax(x, dim=dim) + # Create a mask to zero out the last column while preserving gradients + # mask = torch.ones_like(scores) + # mask[:, -1] = 0.0 + # scores = scores * mask + # Alternative approach using sigmoid (commented out): + # scores = scores * torch.sigmoid(x[:, -1].unsqueeze(-1)) + return scores + else: + #normal softmax + return torch.nn.functional.softmax(x, dim=dim) + return torch.nn.functional.softmax(x, dim=dim) diff --git a/nemo_rl/utils/x_token/__init__.py b/nemo_rl/utils/x_token/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/nemo_rl/utils/x_token/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_rl/utils/x_token/minimal_projection_generator.py b/nemo_rl/utils/x_token/minimal_projection_generator.py new file mode 100644 index 0000000000..4c71dd1490 --- /dev/null +++ b/nemo_rl/utils/x_token/minimal_projection_generator.py @@ -0,0 +1,585 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import os +import argparse +from transformers import AutoTokenizer, AutoModel, AutoConfig +# from sentence_transformers import SentenceTransformer +from tqdm.auto import tqdm +import re +import pdb + + +##### verify KL and top5 with this matrix + + +###### use config vocab size, not tokenizer + +EXACT_MATCH_ONLY = False + +# --- Configuration and Setup --- +parser = argparse.ArgumentParser(description="Generate a sparse projection map between two tokenizers.") +parser.add_argument("--model_a_index", type=int, default=1, help="Index of the source model (Model A / Student).") +parser.add_argument("--model_b_index", type=int, default=0, help="Index of the target model (Model B / Teacher).") +parser.add_argument("--model_a_name", type=str, default=None, help="HuggingFace model name for source model (Model A / Student). If provided, overrides model_a_index.") +parser.add_argument("--model_b_name", type=str, default=None, help="HuggingFace model name for target model (Model B / Teacher). If provided, overrides model_b_index.") +parser.add_argument("--keep_top_tokens", type=int, default=-1, help="Number of top tokens to keep for each vocabulary. -1 means all.") +parser.add_argument("--data_dir", type=str, default="cross_tokenizer_data/", help="Directory for importance scores and cached data.") +parser.add_argument("--top_k", type=int, default=10, help="Number of top projections to keep for each token.") +parser.add_argument("--weight_threshold", type=float, default=0.0, help="Minimum weight threshold to keep a projection. Values below this will be filtered out.") +parser.add_argument("--force_recompute", action='store_true', help="Force recomputation of embeddings even if cached files exist.") +parser.add_argument("--skip_exact_enforcement", action='store_true', help="Skip enforcing exact matches between tokens.") +parser.add_argument("--use_canonicalization", action='store_true', help="Apply token canonicalization before generating embeddings to normalize different tokenizer representations (e.g., Ġ vs ▁ prefixes, Ċ vs \\n).") +args = parser.parse_args() + +args.skip_exact_enforcement = True + +MODEL_LIST = [ + "nvidia/Mistral-NeMo-Minitron-8B-Base", + "Qwen/Qwen3-8B-Base", + "meta-llama/Llama-3.2-1B", + "meta-llama/Llama-3.1-8B", + "google/gemma-3-4b-it", + "google/gemma-2b", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "openai/gpt-oss-20b", + "microsoft/phi-4", + "google/gemma-3-12b-pt", +] +EMBEDDING_MODEL_CHOICES = [ + {"name": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", "type": "sbert"}, + {"name": "sentence-transformers/all-mpnet-base-v2", "type": "sbert"}, + {"name": "sentence-transformers/all-MiniLM-L6-v2", "type": "sbert"}, + {"name": "Qwen/Qwen3-Embedding-4B", "type": "llm_first_layer"}, + {"name": "Qwen/Qwen3-Embedding-0.6B", "type": "llm_first_layer"}, +] + +MAX_SEQ_LENGTH_EMBEDDING = 64 +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def sinkhorn(A, n_iters=10): + for _ in range(n_iters): + if _ % 2 == 0: + # A = A / (A.sum(dim=0, keepdim=True) + 1e-6) + col_sums = A.sum(dim=0, keepdim=True) + safe_col_sums = torch.where(col_sums == 0, torch.ones_like(col_sums), col_sums) + A = A / safe_col_sums + else: + #0, 2, 4, 6 + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + + return A + +def sinkhorn_one_dim(A, n_iters=1): + for _ in range(n_iters): + + + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + + return A + +# --- Helper Functions --- + +def clean_model_name_for_filename(name: str) -> str: + """Removes parameter counts and common suffixes from model names for cleaner filenames.""" + # Removes patterns like -8B, -1.5B, -4b, -125m etc. + cleaned_name = re.sub(r'-?[0-9\.]+[bBmB]', '', name, flags=re.IGNORECASE) + # Remove common suffixes + cleaned_name = cleaned_name.replace('-Base', '').replace('-it', '').replace('-Instruct', '') + # Clean up any leading/trailing hyphens that might result + cleaned_name = cleaned_name.strip('-_') + if 'mini' in name: + cleaned_name += "_mini" + return cleaned_name + +def load_tokenizer(model_id_or_path): + """Loads a HuggingFace tokenizer, setting a pad token if necessary.""" + try: + tok = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok + except Exception as e: + print(f"Error loading tokenizer for model '{model_id_or_path}': {e}") + print(f"Available models in MODEL_LIST (indices 0-{len(MODEL_LIST)-1}):") + for i, model in enumerate(MODEL_LIST): + print(f" {i}: {model}") + raise + +def validate_model_selection(args): + """Validates that the model selection arguments are valid.""" + # Check if both name and index are provided for the same model + if args.model_a_name is not None and args.model_a_index != 1: # 1 is the default + print("Warning: Both --model_a_name and --model_a_index provided. Using --model_a_name.") + + if args.model_b_name is not None and args.model_b_index != 0: # 0 is the default + print("Warning: Both --model_b_name and --model_b_index provided. Using --model_b_name.") + + # Validate indices if names are not provided + if args.model_a_name is None: + if args.model_a_index < 0 or args.model_a_index >= len(MODEL_LIST): + raise ValueError(f"model_a_index {args.model_a_index} is out of range. Available models: 0-{len(MODEL_LIST)-1}") + + if args.model_b_name is None: + if args.model_b_index < 0 or args.model_b_index >= len(MODEL_LIST): + raise ValueError(f"model_b_index {args.model_b_index} is out of range. Available models: 0-{len(MODEL_LIST)-1}") + + # Check if the same model is selected for both A and B + model_a_id = args.model_a_name if args.model_a_name is not None else MODEL_LIST[args.model_a_index] + model_b_id = args.model_b_name if args.model_b_name is not None else MODEL_LIST[args.model_b_index] + + if model_a_id == model_b_id: + raise ValueError(f"Cannot use the same model for both A and B: {model_a_id}") + +def save_data(data, filename): + """Saves data to a torch file.""" + os.makedirs(os.path.dirname(filename), exist_ok=True) + torch.save(data.cpu(), filename) + print(f"Data saved to {filename}") + +def load_data(filename): + """Loads data from a torch file.""" + return torch.load(filename) + +def get_llm_first_layer_embeddings(decoded_tokens_list, llm_embedding_tokenizer, llm_embedding_model, max_seq_length_embedding, device, batch_size=32): + """Generates embeddings using the first layer of a given LLM.""" + all_embeddings = [] + llm_embedding_model.eval() + embedding_dim = llm_embedding_model.config.hidden_size + + for i in tqdm(range(0, len(decoded_tokens_list), batch_size), desc="Encoding tokens with LLM"): + batch_tokens = decoded_tokens_list[i:i + batch_size] + inputs = llm_embedding_tokenizer( + batch_tokens, return_tensors="pt", padding=True, truncation=True, + max_length=max_seq_length_embedding, add_special_tokens=False, + ).to(device) + + with torch.no_grad(): + outputs = llm_embedding_model(**inputs, output_hidden_states=True) + first_layer_output = outputs.hidden_states[0] + + for k in range(first_layer_output.shape[0]): + valid_token_mask = inputs['attention_mask'][k] == 1 + if valid_token_mask.sum() > 0: + pooled_embedding = first_layer_output[k, valid_token_mask].mean(dim=0) + all_embeddings.append(pooled_embedding) + else: + all_embeddings.append(torch.zeros(embedding_dim, device=device)) + + return torch.stack(all_embeddings).to(device) + + +def compute_chunked_projection_map(embeddings_query, embeddings_corpus, args, device, chunk_size=1000): + """Computes projection map in chunks to save memory.""" + num_queries = embeddings_query.shape[0] + target_vocab_size = embeddings_corpus.shape[0] + + # Pre-allocate result tensors + all_top_k_indices = torch.zeros((num_queries, args.top_k), dtype=torch.long) + all_top_k_likelihoods = torch.zeros((num_queries, args.top_k), dtype=torch.float32) + + # Normalize corpus embeddings once + embeddings_corpus_norm = torch.nn.functional.normalize(embeddings_corpus.to(device).float(), p=2, dim=1) + + for chunk_start in tqdm(range(0, num_queries, chunk_size), desc="Processing chunks"): + chunk_end = min(chunk_start + chunk_size, num_queries) + chunk_query = embeddings_query[chunk_start:chunk_end].to(device).float() + + with torch.no_grad(): + # Compute similarities for this chunk + chunk_query_norm = torch.nn.functional.normalize(chunk_query, p=2, dim=1) + similarities = torch.matmul(chunk_query_norm, embeddings_corpus_norm.t()) + + # Generate projection map for this chunk + chunk_top_k_indices, chunk_top_k_likelihoods = generate_projection_map_chunk(similarities, args) + + # Store results + all_top_k_indices[chunk_start:chunk_end] = chunk_top_k_indices.cpu() + all_top_k_likelihoods[chunk_start:chunk_end] = chunk_top_k_likelihoods.cpu() + + # Clear GPU memory + del similarities, chunk_query_norm, chunk_top_k_indices, chunk_top_k_likelihoods + torch.cuda.empty_cache() + + return all_top_k_indices, all_top_k_likelihoods + +def generate_projection_map_chunk(similarities, args): + """Calculates the sparse likelihood map from a similarity matrix chunk.""" + similarities = similarities.abs() + similarities[similarities > 0.999999999] = 1.0 + max_similarities = torch.max(similarities, dim=1, keepdim=True)[0] + sharpness = 10.0 * max_similarities + likelihood = similarities ** sharpness + + # Normalize rows + likelihood = sinkhorn_one_dim(likelihood) + + # Extract final top-k values from the normalized sparse likelihood matrix + top_k_likelihood, top_k_indices = likelihood.topk(args.top_k, dim=1) + + # Apply weight threshold filtering if specified + if args.weight_threshold > 0.0: + threshold_mask = top_k_likelihood >= args.weight_threshold + top_k_indices = top_k_indices.where(threshold_mask, torch.full_like(top_k_indices, -1)) + + return top_k_indices, top_k_likelihood + +def project_token_likelihoods(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device): + """Projects token likelihoods from a source to a target vocabulary using a sparse map.""" + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + if source_vocab_size != projection_map_indices.shape[0]: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches projection map size ({projection_map_indices.shape[0]})") + + top_k = projection_map_indices.shape[1] + input_likelihoods = input_likelihoods.to(device) + projection_map_indices = projection_map_indices.to(device) + projection_map_values = projection_map_values.to(device) + + crow_indices = torch.arange(0, (source_vocab_size + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = projection_map_indices.flatten() + values = projection_map_values.flatten() + + sparse_projection_matrix = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(source_vocab_size, target_vocab_size), device=device + ) + + reshaped_input = input_likelihoods.reshape(batch_size * seq_len, source_vocab_size) + projected_likelihoods_reshaped = torch.matmul(reshaped_input, sparse_projection_matrix) + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, target_vocab_size) + +def debug_projection_map(top_k_indices, top_k_likelihood, source_tokenizer, target_tokenizer, direction="", N=2000): + """Debug function to show first N rows with decoded tokens and weights.""" + N = min(N, top_k_indices.shape[0]) # Show first N rows or less + print(f"\n--- Debugging projection map {direction} (first {N} rows) ---") + + for row_idx in range(N): + # for row_idx in range(-N,-1): + # Decode source token + try: + token_id = row_idx if row_idx >= 0 else top_k_indices.shape[0] + row_idx + source_token = source_tokenizer.decode([token_id]) + # source_token = source_tokenizer.convert_ids_to_tokens([token_id])[0] + source_token_str = repr(source_token) # Use repr to show special chars + except: + source_token_str = f"" + + # Build the target tokens with weights string + row_indices = top_k_indices[row_idx].cpu().numpy() + row_weights = top_k_likelihood[row_idx].float().cpu().numpy() + + weight_total = 0 + target_parts = [] + + if row_weights.max() != row_weights[-1]: + continue + + + for target_idx, weight in zip(row_indices, row_weights): + try: + target_token = target_tokenizer.decode([target_idx]) + target_token_str = repr(target_token) + except: + target_token_str = f"" + + + target_parts.append(f"{target_token_str}({weight:.4f})") + weight_total += weight + + target_string = " ".join(target_parts) + # print(f"Weight total: {weight_total:.4f}") + print(f"{source_token_str} -> {target_string}") + +def generate_projection_map(similarities, args): + """Calculates the sparse likelihood map from a similarity matrix.""" + similarities = similarities.abs() + similarities[similarities > 0.999999999] = 1.0 + max_similarities = torch.max(similarities, dim=1, keepdim=True)[0] + sharpness = 10.0 * max_similarities + likelihood = similarities ** sharpness + + # Create a sparse representation by keeping only top-k values + # top_k_likelihood_pre_norm, _ = likelihood.topk(args.top_k, dim=1) + # likelihood = likelihood.where(likelihood >= top_k_likelihood_pre_norm[:, -1:], torch.zeros_like(likelihood)) + + # Normalize the row to sum to 1, handling rows that are all zero + # row_sums = likelihood.sum(dim=1, keepdim=True) + # safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + # likelihood = likelihood / safe_row_sums + # pdb.set_trace() + # likelihood = sinkhorn_one_dim(likelihood) + + # Get the final top-k values and their indices from the sparse, normalized likelihood matrix + top_k_likelihood, top_k_indices = likelihood.topk(args.top_k, dim=1) + + # Store top-k values before zeroing (to avoid losing them) + row_indices = torch.arange(likelihood.shape[0]).unsqueeze(1).expand(-1, args.top_k) + top_k_values = likelihood[row_indices, top_k_indices].clone() + + # Zero out entire likelihood matrix in-place, then restore only top-k elements + likelihood.zero_() + likelihood[row_indices, top_k_indices] = top_k_values + + # likelihood = sinkhorn(likelihood, n_iters=1) + # likelihood = sinkhorn(likelihood, n_iters=1) works the best + + + likelihood = sinkhorn_one_dim(likelihood) + + # Extract final top-k values from the normalized sparse likelihood matrix + top_k_likelihood, top_k_indices = likelihood.topk(args.top_k, dim=1) + + # Apply weight threshold filtering if specified + if args.weight_threshold > 0.0: + print(f"Applying weight threshold filter: {args.weight_threshold}") + # Create mask for values above threshold + # pdb.set_trace() + threshold_mask = top_k_likelihood >= args.weight_threshold + + #set indices to -1 where threshold is not met + top_k_indices = top_k_indices.where(threshold_mask, torch.full_like(top_k_indices, -1)) + + # # Count how many values per row are above threshold + # valid_counts = threshold_mask.sum(dim=1) + # total_filtered = (valid_counts == 0).sum().item() + # total_kept = threshold_mask.sum().item() + # total_possible = top_k_likelihood.numel() + + # print(f"Kept {total_kept}/{total_possible} ({100*total_kept/total_possible:.1f}%) projections above threshold") + + # if total_filtered > 0: + # print(f"Warning: {total_filtered} tokens have no projections above threshold {args.weight_threshold}") + + # # Zero out values below threshold + # filtered_likelihood = top_k_likelihood * threshold_mask.to(top_k_likelihood.dtype) + # filtered_indices = top_k_indices.clone() + + # # For rows with no values above threshold, keep the top value to avoid empty rows + # empty_rows = valid_counts == 0 + # if empty_rows.any(): + # print(f"Keeping top projection for {empty_rows.sum().item()} tokens with no values above threshold") + # filtered_likelihood[empty_rows, 0] = top_k_likelihood[empty_rows, 0] + + # top_k_likelihood = filtered_likelihood + # top_k_indices = filtered_indices + + # pdb.set_trace() + + return top_k_indices, top_k_likelihood + +# --- Main Execution --- +if __name__ == "__main__": + # Validate model selection arguments + validate_model_selection(args) + + # 1. Load Tokenizers and deterministically assign A and B + # Use model names if provided, otherwise use indices + if args.model_a_name is not None: + model_1 = {'id': args.model_a_name} + print(f"Using provided model A name: {args.model_a_name}") + else: + model_1 = {'id': MODEL_LIST[args.model_a_index]} + print(f"Using model A from index {args.model_a_index}: {model_1['id']}") + + model_1['name'] = model_1['id'].split("/")[-1] + print(f"Loading first tokenizer: {model_1['name']}") + model_1['tokenizer'] = load_tokenizer(model_1['id']) + + if args.model_b_name is not None: + model_2 = {'id': args.model_b_name} + print(f"Using provided model B name: {args.model_b_name}") + else: + model_2 = {'id': MODEL_LIST[args.model_b_index]} + print(f"Using model B from index {args.model_b_index}: {model_2['id']}") + + model_2['name'] = model_2['id'].split("/")[-1] + print(f"Loading second tokenizer: {model_2['name']}") + model_2['tokenizer'] = load_tokenizer(model_2['id']) + + # Deterministically assign model_A and model_B based on alphabetical order of names + if model_1['name'] > model_2['name']: + model_A, model_B = model_2, model_1 + else: + model_A, model_B = model_1, model_2 + + print(f"\nAssigned Source (A): {model_A['name']}") + print(f"Assigned Target (B): {model_B['name']}") + + source_vocab_size = model_A['tokenizer'].vocab_size + target_vocab_size = model_B['tokenizer'].vocab_size + # get the top k tokens from the source and target vocab from model config file + model_A_config = AutoConfig.from_pretrained(model_A['id'], trust_remote_code=True if 'nvidia' in model_A['id'] else False) + model_B_config = AutoConfig.from_pretrained(model_B['id'], trust_remote_code=True if 'nvidia' in model_B['id'] else False) + # pdb.set_trace() + source_vocab_size = model_A_config.vocab_size + if "gemma" not in model_B['id']: + target_vocab_size = model_B_config.vocab_size + else: + target_vocab_size = model_B_config.text_config.vocab_size + # print(f"Source top k tokens: {model_A_top_k_tokens}") + # print(f"Target top k tokens: {model_B_top_k_tokens}") + + print(f"Source vocab size (full): {source_vocab_size}") + print(f"Target vocab size (full): {target_vocab_size}") + # exit() + + + + if 0: + # just debugging learned projection map + # learned_projection_map = torch.load("models/runs/s4_l1q4b_lr0_kl1_ce0_k1_emb_top10_transformation_matrices/learned_projection_map_latest.pt") + # learned_projection_map = torch.load("cross_tokenizer_data/projection_map_Llama-3.2_to_Qwen3_multitoken_top_64_double.pt") + learned_projection_map = torch.load("cross_tokenizer_data/projection_matrix_learned_llama_qwen_top5.pt") + top_k_indices_A_to_B = learned_projection_map["indices"] + top_k_likelihood_A_to_B = learned_projection_map["likelihoods"] + debug_projection_map(top_k_indices_A_to_B, top_k_likelihood_A_to_B, model_A['tokenizer'], model_B['tokenizer'], "A -> B", N=150000) + exit() + + + + + + + + # 2. Select and Load Embedding Model + embedding_model_index = 3 # Default to a good LLM embedder + selected_model_info = EMBEDDING_MODEL_CHOICES[embedding_model_index] + embedding_model_name = selected_model_info["name"] + embedding_model_type = selected_model_info["type"] + print(f"\nUsing embedding model: {embedding_model_name} ({embedding_model_type})") + + # 3. Generate or Load Embeddings + canonicalization_suffix = "_canonical" if args.use_canonicalization else "_raw" + embeddings_path_A = os.path.join(args.data_dir, f"embeddings_{model_A['name']}_{embedding_model_name.replace('/', '_')}_full{canonicalization_suffix}.pt") + embeddings_path_B = os.path.join(args.data_dir, f"embeddings_{model_B['name']}_{embedding_model_name.replace('/', '_')}_full{canonicalization_suffix}.pt") + + if not args.force_recompute and os.path.exists(embeddings_path_A) and os.path.exists(embeddings_path_B): + print("Loading cached embeddings...") + model_A['embeddings'] = load_data(embeddings_path_A).to(DEVICE) + model_B['embeddings'] = load_data(embeddings_path_B).to(DEVICE) + else: + print("Generating new embeddings...") + + # Generate raw decoded tokens + raw_tokens_A = [model_A['tokenizer'].decode([idx]) for idx in range(model_A['tokenizer'].vocab_size)] + raw_tokens_B = [model_B['tokenizer'].decode([idx]) for idx in range(model_B['tokenizer'].vocab_size)] + + # Apply canonicalization if requested + if args.use_canonicalization: + # Import canonicalization function + import sys + sys.path.append('.') + from tokenalign import TokenAligner + + print("Applying token canonicalization before embedding generation...") + decoded_tokens_A = [TokenAligner._canonical_token(token) for token in raw_tokens_A] + decoded_tokens_B = [TokenAligner._canonical_token(token) for token in raw_tokens_B] + + # Show some examples of canonicalization + print("Canonicalization examples:") + for i in range(min(10, len(raw_tokens_A))): + if raw_tokens_A[i] != decoded_tokens_A[i]: + print(f" Model A: '{raw_tokens_A[i]}' -> '{decoded_tokens_A[i]}'") + for i in range(min(10, len(raw_tokens_B))): + if raw_tokens_B[i] != decoded_tokens_B[i]: + print(f" Model B: '{raw_tokens_B[i]}' -> '{decoded_tokens_B[i]}'") + + print(f"Applied canonicalization to {len(decoded_tokens_A)} tokens for model A and {len(decoded_tokens_B)} tokens for model B") + else: + print("Using raw decoded tokens without canonicalization") + decoded_tokens_A = raw_tokens_A + decoded_tokens_B = raw_tokens_B + + if embedding_model_type == "sbert": + sbert_model = SentenceTransformer(embedding_model_name, device=DEVICE) + model_A['embeddings'] = sbert_model.encode(decoded_tokens_A, convert_to_tensor=True, show_progress_bar=True) + model_B['embeddings'] = sbert_model.encode(decoded_tokens_B, convert_to_tensor=True, show_progress_bar=True) + elif embedding_model_type == "llm_first_layer": + llm_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, trust_remote_code=True) + if llm_tokenizer.pad_token is None: llm_tokenizer.pad_token = llm_tokenizer.eos_token + llm_model = AutoModel.from_pretrained( + embedding_model_name, torch_dtype=torch.bfloat16, trust_remote_code=True + ).to(DEVICE) + model_A['embeddings'] = get_llm_first_layer_embeddings(decoded_tokens_A, llm_tokenizer, llm_model, MAX_SEQ_LENGTH_EMBEDDING, DEVICE) + model_B['embeddings'] = get_llm_first_layer_embeddings(decoded_tokens_B, llm_tokenizer, llm_model, MAX_SEQ_LENGTH_EMBEDDING, DEVICE) + + save_data(model_A['embeddings'], embeddings_path_A) + save_data(model_B['embeddings'], embeddings_path_B) + + # 4. Compute Similarity and Generate Projection Maps (chunked to save memory) + print("\nComputing projection map in chunks to save memory...") + chunk_size = 500 # Process 500 tokens at a time to avoid OOM + top_k_indices_A_to_B, top_k_likelihood_A_to_B = compute_chunked_projection_map( + model_A['embeddings'], model_B['embeddings'], args, DEVICE, chunk_size=chunk_size) + + # Note: Exact match enforcement is skipped in chunked mode for simplicity + # The chunked approach processes similarities in small batches to avoid OOM + if 0: + debug_projection_map(top_k_indices_A_to_B, top_k_likelihood_A_to_B, model_A['tokenizer'], model_B['tokenizer'], "A -> B") + + # print("Generating B -> A projection map...") + # top_k_indices_B_to_A, top_k_likelihood_B_to_A = generate_projection_map(similarities.T, args) + # debug_projection_map(top_k_indices_B_to_A, top_k_likelihood_B_to_A, model_B['tokenizer'], model_A['tokenizer'], "B -> A") + + # 5. Save the Combined Projection Map + print("\nSaving combined projection map...") + model_a_clean_name = clean_model_name_for_filename(model_A['name']) + model_b_clean_name = clean_model_name_for_filename(model_B['name']) + # output_filename = f"temp_projection_map_{model_a_clean_name}_to_{model_b_clean_name}_bidirectional_top_{args.top_k}.pt" + output_filename = f"temp_projection_map_{model_a_clean_name}_to_{model_b_clean_name}_top_{args.top_k}" + # if args.skip_exact_enforcement: + # output_filename += "_no_exact" + output_filename += f".pt" + if args.weight_threshold > 0.0: + output_filename = output_filename.replace(".pt", f"_thresh_{args.weight_threshold:.3f}.pt") + output_path = os.path.join(args.data_dir, output_filename) + + torch.save({ + "indices": top_k_indices_A_to_B.cpu(), + "likelihoods": top_k_likelihood_A_to_B.cpu(), + "model_A_id": model_A['id'], + "model_B_id": model_B['id'], + }, output_path) + + # torch.save({ + # "A_to_B": { + # "indices": top_k_indices_A_to_B.cpu(), + # "likelihoods": top_k_likelihood_A_to_B.cpu() + # }, + # "B_to_A": { + # "indices": top_k_indices_B_to_A.cpu(), + # "likelihoods": top_k_likelihood_B_to_A.cpu() + # }, + # "model_A_id": model_A['id'], + # "model_B_id": model_B['id'], + # }, output_path) + print(f"Saved combined projection map to: {output_path}") + + # 6. Example Usage of the Projection Function + print("\n--- Testing projection function (A -> B) ---") + # Create a dummy likelihood tensor: [BATCH, SEQ, vocab_size_A] + source_vocab_size_A = model_A['embeddings'].shape[0] + target_vocab_size_B = model_B['embeddings'].shape[0] + dummy_tensor = torch.randn(1, 4096, source_vocab_size_A, device=DEVICE, dtype=torch.bfloat16) + + # Transform this tensor using the projection map (convert to float32 for compatibility) + projected_tensor = project_token_likelihoods(dummy_tensor.float(), top_k_indices_A_to_B, top_k_likelihood_A_to_B, target_vocab_size_B, DEVICE) + print(f"Input tensor shape: {dummy_tensor.shape}") + print(f"Projected tensor shape: {projected_tensor.shape}") + print("Projection test successful.") diff --git a/nemo_rl/utils/x_token/minimal_projection_via_multitoken.py b/nemo_rl/utils/x_token/minimal_projection_via_multitoken.py new file mode 100644 index 0000000000..ec67e4cc57 --- /dev/null +++ b/nemo_rl/utils/x_token/minimal_projection_via_multitoken.py @@ -0,0 +1,942 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +import torch.nn as nn +from nemo_rl.algorithms.x_token.tokenalign import TokenAligner +import gc +from collections import defaultdict +from datasets import load_dataset, get_dataset_config_names +import random +import time +import numpy as np +import tqdm +import pdb +import difflib +import re +import argparse +import os + + + +###### save as dense format and set indices to -1 where not used + + +#remove all special tokens that start with <| and end with |> + + +# compare 3 ways to estimate likelihood matrix: +# 1. using embeddings from another model, like was done in minimal_projection_generator.py +# 2. using text analysis like in tokenalign_likelihood_estimate.py +# 3. use one token to multiple and assign those as transformation matrix + +# this file implements 3rd way + +def sinkhorn_one_dim(A, n_iters=1): + for _ in range(n_iters): + + + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + + return A + +def apply_canonicalization_if_enabled(token_str, use_canonicalization): + """Apply canonicalization to token string if enabled.""" + if use_canonicalization: + return TokenAligner._canonical_token(token_str) + return token_str + +def create_weight_distribution(num_tokens): + """Create weight distribution for multi-token mappings""" + # if num_tokens == 1: + # return [1.0] + # elif num_tokens == 2: + # return [0.7, 0.3] + # elif num_tokens == 3: + # return [0.6, 0.3, 0.1] + # else: + if 1: + # For more tokens, use exponential decay + weights = [] + base = 0.9 + for i in range(num_tokens): + if i == 0: + weights.append(base) + else: + weights.append(base * (0.1 ** i)) + + # Normalize to sum to 1 + total = sum(weights) + weights = [w / total for w in weights] + return weights + +def find_similar_special_tokens(tokenizer_a, tokenizer_b, similarity_threshold=0.4, top_k_matches=3): + """Find similar special tokens between two tokenizers using string similarity.""" + + def is_special_token(token_str): + """Check if a token looks like a special token""" + return (token_str.startswith('<|') and token_str.endswith('|>')) or \ + (token_str.startswith('<') and token_str.endswith('>')) or \ + token_str in ['', '', '', '', '', ''] + + def extract_special_tokens(tokenizer): + """Extract all special tokens from a tokenizer with their IDs""" + special_tokens = {} + vocab = tokenizer.get_vocab() + for token_str, token_id in vocab.items(): + if is_special_token(token_str): + special_tokens[token_id] = token_str + return special_tokens + + def calculate_similarity(token_a, token_b): + """Calculate similarity between two token strings""" + # Use difflib for sequence similarity + seq_similarity = difflib.SequenceMatcher(None, token_a, token_b).ratio() + + # Extract key words from special tokens for semantic matching + def extract_keywords(token): + # Remove special token markers and split by common separators + cleaned = re.sub(r'[<>|_]', ' ', token.lower()) + words = [w for w in cleaned.split() if len(w) > 2] # Filter short words + return set(words) + + keywords_a = extract_keywords(token_a) + keywords_b = extract_keywords(token_b) + + # Jaccard similarity for keywords + if keywords_a or keywords_b: + keyword_similarity = len(keywords_a.intersection(keywords_b)) / len(keywords_a.union(keywords_b)) + else: + keyword_similarity = 0.0 + + # Combined similarity (weighted average) + return 0.6 * seq_similarity + 0.4 * keyword_similarity + + print("Extracting special tokens...") + special_tokens_a = extract_special_tokens(tokenizer_a) # student + special_tokens_b = extract_special_tokens(tokenizer_b) # teacher + + print(f"Found {len(special_tokens_a)} special tokens in student tokenizer") + print(f"Found {len(special_tokens_b)} special tokens in teacher tokenizer") + + # Find matches + special_token_mappings = [] + + print("Finding similar special tokens...") + for token_id_a, token_str_a in special_tokens_a.items(): + similarities = [] + for token_id_b, token_str_b in special_tokens_b.items(): + similarity = calculate_similarity(token_str_a, token_str_b) + if similarity >= similarity_threshold: + similarities.append((token_id_b, token_str_b, similarity)) + + # Sort by similarity and take top-k + similarities.sort(key=lambda x: x[2], reverse=True) + for token_id_b, token_str_b, similarity in similarities[:top_k_matches]: + special_token_mappings.append({ + 'student_id': token_id_a, + 'student_token': token_str_a, + 'teacher_id': token_id_b, + 'teacher_token': token_str_b, + 'similarity': similarity + }) + + return special_token_mappings + + +def parse_arguments(): + """Parse command line arguments for the multi-token projection script.""" + parser = argparse.ArgumentParser( + description="Generate multi-token projection mappings between tokenizers", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + # Model selection arguments + parser.add_argument( + "--student-model", + type=str, + default="meta-llama/Llama-3.2-1B", + help="Student model name or path" + ) + parser.add_argument( + "--teacher-model", + type=str, + default="microsoft/phi-4", + help="Teacher model name or path" + ) + + # Boolean flags + parser.add_argument( + "--enable-scale-trick", + action="store_true", + default=True, + help="Enable scale trick (set last column likelihood to 0.2)" + ) + parser.add_argument( + "--disable-scale-trick", + action="store_false", + dest="enable_scale_trick", + help="Disable scale trick" + ) + parser.add_argument( + "--enable-reverse-pass", + action="store_true", + default=True, + help="Enable second pass: student tokens -> teacher tokens" + ) + parser.add_argument( + "--disable-reverse-pass", + action="store_false", + dest="enable_reverse_pass", + help="Disable reverse pass" + ) + parser.add_argument( + "--enable-exact-match", + action="store_true", + default=False, + help="Enable exact match enforcement for identical tokens" + ) + parser.add_argument( + "--use-raw-tokens", + action="store_true", + default=False, + help="Use convert_ids_to_tokens instead of decode, should be False" + ) + parser.add_argument( + "--enable-special-token-mapping", + action="store_true", + default=True, + help="Enable mapping of similar special tokens" + ) + parser.add_argument( + "--disable-special-token-mapping", + action="store_false", + dest="enable_special_token_mapping", + help="Disable special token mapping" + ) + parser.add_argument( + "--use-canonicalization", + action="store_true", + default=False, + help="Apply token canonicalization before processing to normalize different tokenizer representations (e.g., Ġ vs ▁ prefixes, Ċ vs \\n)" + ) + + # Numeric parameters + parser.add_argument( + "--tokens-to-cut", + type=int, + default=4, + help="Maximum number of tokens to consider for multi-token mappings" + ) + parser.add_argument( + "--top-k", + type=int, + default=32, + help="Number of top projections to keep for each token" + ) + parser.add_argument( + "--special-token-similarity-threshold", + type=float, + default=0.3, + help="Minimum similarity threshold for special token matching" + ) + parser.add_argument( + "--special-token-top-k", + type=int, + default=None, + help="Top K matches for each special token (defaults to --top-k value)" + ) + + # File paths + parser.add_argument( + "--initial-projection-path", + type=str, + default=None, + help="Path to initial projection map to load and extend" + ) + parser.add_argument( + "--output-dir", + type=str, + default="cross_tokenizer_data", + help="Output directory for saving projection maps" + ) + + return parser.parse_args() + + +if __name__ == "__main__": + # Parse command line arguments + args = parse_arguments() + + # Configuration from arguments + ENABLE_SCALE_TRICK = args.enable_scale_trick + ENABLE_REVERSE_PASS = args.enable_reverse_pass + ENABLE_EXACT_MATCH = args.enable_exact_match + + TOKENS_TO_CUT = args.tokens_to_cut + TOP_K = args.top_k + USE_RAW_TOKENS = args.use_raw_tokens + INITIAL_PROJECTION_PATH = args.initial_projection_path + ENABLE_SPECIAL_TOKEN_MAPPING = args.enable_special_token_mapping + SPECIAL_TOKEN_SIMILARITY_THRESHOLD = args.special_token_similarity_threshold + SPECIAL_TOKEN_TOP_K = args.special_token_top_k if args.special_token_top_k is not None else TOP_K + USE_CANONICALIZATION = args.use_canonicalization + + # Model names from arguments + teacher_model_name = args.teacher_model + student_model_name = args.student_model + + # Print configuration + print("=== Configuration ===") + print(f"Student model: {student_model_name}") + print(f"Teacher model: {teacher_model_name}") + print(f"Enable scale trick: {ENABLE_SCALE_TRICK}") + print(f"Enable reverse pass: {ENABLE_REVERSE_PASS}") + print(f"Enable exact match: {ENABLE_EXACT_MATCH}") + print(f"Use raw tokens: {USE_RAW_TOKENS}") + print(f"Use canonicalization: {USE_CANONICALIZATION}") + print(f"Tokens to cut: {TOKENS_TO_CUT}") + print(f"Top K: {TOP_K}") + print(f"Enable special token mapping: {ENABLE_SPECIAL_TOKEN_MAPPING}") + if ENABLE_SPECIAL_TOKEN_MAPPING: + print(f"Special token similarity threshold: {SPECIAL_TOKEN_SIMILARITY_THRESHOLD}") + print(f"Special token top K: {SPECIAL_TOKEN_TOP_K}") + print(f"Initial projection path: {INITIAL_PROJECTION_PATH}") + print(f"Output directory: {args.output_dir}") + print("=" * 25) + + tokenizer_student = AutoTokenizer.from_pretrained(student_model_name) + tokenizer_teacher = AutoTokenizer.from_pretrained(teacher_model_name) + + tokenizer_student_total_vocab_size = len(tokenizer_student) + tokenizer_teacher_total_vocab_size = len(tokenizer_teacher) + model_A_config = AutoConfig.from_pretrained(student_model_name) + model_B_config = AutoConfig.from_pretrained(teacher_model_name) + # pdb.set_trace() + if "gemma" not in student_model_name.lower(): + source_vocab_size = model_A_config.vocab_size + else: + source_vocab_size = model_A_config.text_config.vocab_size + + if "gemma" not in teacher_model_name.lower(): + target_vocab_size = model_B_config.vocab_size + else: + target_vocab_size = model_B_config.text_config.vocab_size + + tokenizer_student_total_vocab_size = source_vocab_size + tokenizer_teacher_total_vocab_size = target_vocab_size + # print(f"Source top k tokens: {model_A_top_k_tokens}") + # print(f"Target top k tokens: {model_B_top_k_tokens}") + + print(f"Student tokenizer total vocab size: {tokenizer_student_total_vocab_size}") + print(f"Teacher tokenizer total vocab size: {tokenizer_teacher_total_vocab_size}") + + # Print token processing mode + if USE_RAW_TOKENS: + print("Using raw token representation (convert_ids_to_tokens)") + else: + print("Using decoded token representation (decode)") + + + transformation_counts = defaultdict(float) + import os + if INITIAL_PROJECTION_PATH and os.path.exists(INITIAL_PROJECTION_PATH): + print(f"Loading initial projection from: {INITIAL_PROJECTION_PATH}") + initial_projection_map = torch.load(INITIAL_PROJECTION_PATH, map_location='cpu') + + if isinstance(initial_projection_map, dict) and 'indices' in initial_projection_map and 'likelihoods' in initial_projection_map: + print("Loading from sparse top-k format and converting to transformation_counts.") + indices = initial_projection_map['indices'] + likelihoods = initial_projection_map['likelihoods'] + + loaded_student_model = initial_projection_map.get('model_A_id') + loaded_teacher_model = initial_projection_map.get('model_B_id') + + if loaded_student_model and loaded_student_model != student_model_name: + print(f"Warning: Student model mismatch. Loaded: {loaded_student_model}, Current: {student_model_name}") + if loaded_teacher_model and loaded_teacher_model != teacher_model_name: + print(f"Warning: Teacher model mismatch. Loaded: {loaded_teacher_model}, Current: {teacher_model_name}") + + num_student_tokens = indices.shape[0] + top_k = indices.shape[1] + + for student_id in tqdm.tqdm(range(num_student_tokens), desc="Converting initial projection to counts"): + for k in range(top_k): + teacher_id = indices[student_id, k].item() + if teacher_id != -1: + likelihood = likelihoods[student_id, k].item() + if likelihood > 0: + transformation_counts[(student_id, teacher_id)] = likelihood + + elif torch.is_tensor(initial_projection_map): + if initial_projection_map.is_sparse: + print("Loading from sparse tensor and converting to transformation_counts.") + sparse_matrix = initial_projection_map.coalesce() + map_indices = sparse_matrix.indices() + map_values = sparse_matrix.values() + for i in tqdm.tqdm(range(map_indices.shape[1]), desc="Converting sparse tensor to counts"): + student_id = map_indices[0, i].item() + teacher_id = map_indices[1, i].item() + weight = map_values[i].item() + if weight > 0: + transformation_counts[(student_id, teacher_id)] = weight + else: + print("Loading from dense matrix and converting to transformation_counts.") + dense_matrix = initial_projection_map + non_zero_indices = torch.nonzero(dense_matrix, as_tuple=False) + for idx in tqdm.tqdm(range(non_zero_indices.shape[0]), desc="Converting dense projection to counts"): + student_id = non_zero_indices[idx, 0].item() + teacher_id = non_zero_indices[idx, 1].item() + weight = dense_matrix[student_id, teacher_id].item() + if weight > 0: + transformation_counts[(student_id, teacher_id)] = weight + else: + print(f"Warning: Unrecognized format for initial projection map at {INITIAL_PROJECTION_PATH}. Skipping.") + + print(f"Initialized transformation_counts with {len(transformation_counts)} entries.") + + # pdb.set_trace() + + ignore_tokens = ['<|endoftext|>', '', ] + ignore_student_ids = {tokenizer_student.convert_tokens_to_ids(token) for token in ignore_tokens if token in tokenizer_student.get_vocab()} + ignore_teacher_ids = {tokenizer_teacher.convert_tokens_to_ids(token) for token in ignore_tokens if token in tokenizer_teacher.get_vocab()} + + # Get all teacher tokens and decode them + teacher_vocab = tokenizer_teacher.get_vocab() + teacher_tokens_decoded = {} + + print("Decoding teacher tokens...") + for token_id in tqdm.tqdm(range(tokenizer_teacher_total_vocab_size), desc="Decoding teacher tokens"): + if token_id in ignore_teacher_ids: + continue + try: + # Get token representation based on configuration + if USE_RAW_TOKENS: + decoded = tokenizer_teacher.convert_ids_to_tokens([token_id])[0] + else: + decoded = tokenizer_teacher.decode([token_id]) + + # Apply canonicalization if enabled + decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION) + teacher_tokens_decoded[token_id] = decoded + except: + # Skip tokens that can't be processed + continue + + print(f"Successfully decoded {len(teacher_tokens_decoded)} teacher tokens") + + # Find multi-token mappings + multi_token_examples = [] + + + print("=== FIRST PASS: Teacher tokens -> Student tokens ===") + print("Finding multi-token mappings...") + + + # First pass: Teacher tokens -> Student tokens (reverse direction) + if 1: + print("\n=== First PASS: Student tokens -> Teacher tokens ===") + + # Get all student tokens and decode them + student_vocab = tokenizer_student.get_vocab() + + student_tokens_decoded = {} + + print("Decoding student tokens...") + for token_id in tqdm.tqdm(range(tokenizer_student_total_vocab_size), desc="Decoding student tokens"): + if token_id in ignore_student_ids: + continue + try: + # Get token representation based on configuration + if USE_RAW_TOKENS: + decoded = tokenizer_student.convert_ids_to_tokens([token_id])[0] + else: + decoded = tokenizer_student.decode([token_id]) + + if decoded.startswith("<|") and decoded.endswith("|>"): + print(f"Skipping special token: {decoded}") + continue + + # Apply canonicalization if enabled + decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION) + student_tokens_decoded[token_id] = decoded + except: + # Skip tokens that can't be processed + continue + + print(f"Successfully decoded {len(student_tokens_decoded)} student tokens") + + reverse_multi_token_examples = [] + print("Finding reverse multi-token mappings...") + for student_token_id, student_token_str in tqdm.tqdm(student_tokens_decoded.items(), desc="Processing student tokens"): + # Tokenize the student token string using teacher tokenizer + teacher_encoding = tokenizer_teacher(student_token_str, add_special_tokens=False, return_attention_mask=False) + teacher_token_ids = teacher_encoding['input_ids'] + + # Skip if any teacher token is in ignore list + if any(tid in ignore_teacher_ids for tid in teacher_token_ids): + continue + + # Cut to only first 4 tokens + teacher_token_ids = teacher_token_ids[:TOKENS_TO_CUT] + + # Get weight distribution based on number of teacher tokens + weights = create_weight_distribution(len(teacher_token_ids)) + + # Add to transformation matrix (reverse direction: teacher_token_id -> student_token_id) + if 1: + for teacher_token_id, weight in zip(teacher_token_ids, weights): + transformation_counts[(student_token_id, teacher_token_id)] += weight + + # Collect examples for analysis + if len(teacher_token_ids) >= 2: + teacher_tokens_decoded_reverse = [tokenizer_teacher.decode([tid]) for tid in teacher_token_ids] + reverse_multi_token_examples.append({ + 'student_token': student_token_str, + 'student_id': student_token_id, + 'teacher_tokens': teacher_tokens_decoded_reverse, + 'teacher_ids': teacher_token_ids, + 'weights': weights + }) + + # second pass: Teacher tokens -> Student tokens (opposite direction) + if ENABLE_REVERSE_PASS: + print("\n=== secod PASS: Teacher tokens -> Student tokens ===") + + # Get all teacher tokens and decode them + teacher_vocab = tokenizer_teacher.get_vocab() + teacher_tokens_decoded = {} + + print("Decoding teacher tokens...") + for token_id in tqdm.tqdm(range(tokenizer_teacher_total_vocab_size), desc="Decoding teacher tokens"): + if token_id in ignore_teacher_ids: + continue + try: + # Get token representation based on configuration + if USE_RAW_TOKENS: + decoded = tokenizer_teacher.convert_ids_to_tokens([token_id])[0] + else: + decoded = tokenizer_teacher.decode([token_id]) + + if decoded.startswith("<|") and decoded.endswith("|>"): + print(f"Skipping special token: {decoded}") + continue + + # Apply canonicalization if enabled + decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION) + teacher_tokens_decoded[token_id] = decoded + except: + # Skip tokens that can't be processed + continue + + print(f"Successfully decoded {len(teacher_tokens_decoded)} teacher tokens") + + teacher_to_student_multi_token_examples = [] + print("Finding teacher->student multi-token mappings...") + for teacher_token_id, teacher_token_str in tqdm.tqdm(teacher_tokens_decoded.items(), desc="Processing teacher tokens"): + # Tokenize the teacher token string using student tokenizer + student_encoding = tokenizer_student(teacher_token_str, add_special_tokens=False, return_attention_mask=False) + student_token_ids = student_encoding['input_ids'] + + # Skip if any student token is in ignore list + if any(sid in ignore_student_ids for sid in student_token_ids): + continue + + # Cut to only first 4 tokens + student_token_ids = student_token_ids[:TOKENS_TO_CUT] + + # Get weight distribution based on number of student tokens + weights = create_weight_distribution(len(student_token_ids)) + + # Add to transformation matrix (student_token_id -> teacher_token_id mapping) + if 1: + for student_token_id, weight in zip(student_token_ids, weights): + transformation_counts[(student_token_id, teacher_token_id)] += weight + + # Collect examples for analysis + if len(student_token_ids) >= 2: + student_tokens_decoded_reverse = [tokenizer_student.decode([sid]) for sid in student_token_ids] + teacher_to_student_multi_token_examples.append({ + 'teacher_token': teacher_token_str, + 'teacher_id': teacher_token_id, + 'student_tokens': student_tokens_decoded_reverse, + 'student_ids': student_token_ids, + 'weights': weights + }) + + print(f"\n=== ADDING SPECIAL TOKEN MAPPINGS ===") + + # Find and add special token mappings (if enabled) + special_token_mappings = [] + if ENABLE_SPECIAL_TOKEN_MAPPING: + special_token_mappings = find_similar_special_tokens( + tokenizer_student, + tokenizer_teacher, + similarity_threshold=SPECIAL_TOKEN_SIMILARITY_THRESHOLD, + top_k_matches=SPECIAL_TOKEN_TOP_K + ) + else: + print("Special token mapping disabled") + + if special_token_mappings: + print(f"\nFound {len(special_token_mappings)} special token mappings:") + initial_transformation_count = len(transformation_counts) + + # Add ALL mappings to transformation matrix + for mapping in special_token_mappings: + student_id = mapping['student_id'] + teacher_id = mapping['teacher_id'] + similarity = mapping['similarity'] + + # Add mapping with weight based on similarity + weight = similarity * 0.8 # Scale similarity to reasonable weight + transformation_counts[(student_id, teacher_id)] += weight + + # Group mappings by student token and show top 2 matches per student token + from collections import defaultdict + student_mappings = defaultdict(list) + for mapping in special_token_mappings: + student_mappings[mapping['student_id']].append(mapping) + + # Sort each student's mappings by similarity and show top 2 + print("Top 2 matches per student special token:") + shown_count = 0 + for student_id, mappings in student_mappings.items(): + # Sort by similarity (highest first) + sorted_mappings = sorted(mappings, key=lambda x: x['similarity'], reverse=True) + + # Show top 2 for this student token + student_token = sorted_mappings[0]['student_token'] # Get student token name + print(f" {student_token}:") + + for mapping in sorted_mappings[:2]: + similarity = mapping['similarity'] + weight = similarity * 0.8 + print(f" -> '{mapping['teacher_token']}' (similarity: {similarity:.3f}, weight: {weight:.3f})") + shown_count += 1 + + if len(sorted_mappings) > 2: + print(f" ... and {len(sorted_mappings) - 2} more matches") + + total_hidden = len(special_token_mappings) - shown_count + if total_hidden > 0: + print(f"Total mappings not shown: {total_hidden}") + + added_count = len(transformation_counts) - initial_transformation_count + print(f"Added {added_count} new special token transformation entries") + else: + print("No similar special tokens found") + + print(f"\n=== SUMMARY ===") + print(f"Found {len(multi_token_examples)} teacher tokens that map to multiple student tokens") + # exit() + # Show some examples + if multi_token_examples: + print("\nExamples of multi-token mappings:") + for i, example in enumerate(multi_token_examples[:10]): + print(f" Teacher '{example['teacher_token']}' -> Student {example['student_tokens']} (weights: {example['weights']})") + if len(multi_token_examples) > 10: + print(f" ... and {len(multi_token_examples) - 10} more.") + + if ENABLE_REVERSE_PASS: + print(f"\nReverse pass enabled - added bidirectional mappings") + + print(f"\nTotal transformation entries: {len(transformation_counts)}") + + if ENABLE_EXACT_MATCH: + + print("Checking for exact token matches and setting exact mappings...") + # check exact match between student and teacher tokens and set those as perfect 1-to-1 mappings + # Convert all tokens to strings at once for vectorized comparison + # pdb.set_trace() + tokens_student = [apply_canonicalization_if_enabled(tokenizer_student.convert_ids_to_tokens([i])[0], USE_CANONICALIZATION) for i in range(tokenizer_student_total_vocab_size)] + tokens_teacher = [apply_canonicalization_if_enabled(tokenizer_teacher.convert_ids_to_tokens([j])[0], USE_CANONICALIZATION) for j in range(tokenizer_teacher_total_vocab_size)] + + map_teacher_token_to_idx = {token: j for j, token in enumerate(tokens_teacher)} + + # Find indices in student and teacher where the tokens are identical + match_indices_student = [] + match_indices_teacher = [] + for i, token_student in enumerate(tokens_student): + if token_student in map_teacher_token_to_idx: + j = map_teacher_token_to_idx[token_student] + match_indices_student.append(i) + match_indices_teacher.append(j) + + if match_indices_student: + print(f"Found {len(match_indices_student)} exact matches. Setting perfect 1-to-1 mappings.") + + # For tokens that match exactly, we want their mapping to be 1.0 + # and they should not be mapped to any other token. + # First, remove all existing mappings for these student tokens + match_indices_student_set = set(match_indices_student) + keys_to_remove = [] + for key in transformation_counts.keys(): + student_id, teacher_id = key + if student_id in match_indices_student_set: + keys_to_remove.append(key) + for key in keys_to_remove: + del transformation_counts[key] + # Then, set the perfect 1-to-1 mappings for exact matches + for student_id, teacher_id in zip(match_indices_student, match_indices_teacher): + transformation_counts[(student_id, teacher_id)] = 1.0 + + + def debug_projection_map(transformation_counts, source_tokenizer, target_tokenizer, direction="", N=50): + """Debug function to show projection mappings with decoded tokens and weights.""" + print(f"\n--- Debugging projection map {direction} (showing {N} examples) ---") + + # Group transformation_counts by source token (student token) + source_to_targets = defaultdict(list) + for (source_id, target_id), weight in transformation_counts.items(): + source_to_targets[source_id].append((target_id, weight)) + + # Sort by source token ID and take first N + # sorted_sources = sorted(source_to_targets.keys())[:N] + sorted_sources = sorted(source_to_targets.keys())[-N:] + + for source_id in sorted_sources: + # Decode source token + try: + if USE_RAW_TOKENS: + source_token = source_tokenizer.convert_ids_to_tokens([source_id])[0] + else: + source_token = source_tokenizer.decode([source_id]) + source_token = apply_canonicalization_if_enabled(source_token, USE_CANONICALIZATION) + source_token_str = repr(source_token) # Use repr to show special chars + except: + source_token_str = f"" + + # Sort targets by weight (descending) and build target string + targets_weights = sorted(source_to_targets[source_id], key=lambda x: x[1], reverse=True) + + target_parts = [] + for target_id, weight in targets_weights: + try: + if USE_RAW_TOKENS: + target_token = target_tokenizer.convert_ids_to_tokens([target_id])[0] + else: + target_token = target_tokenizer.decode([target_id]) + target_token = apply_canonicalization_if_enabled(target_token, USE_CANONICALIZATION) + target_token_str = repr(target_token) + except: + target_token_str = f"" + target_parts.append(f"{target_token_str}({weight:.4f})") + + target_string = " ".join(target_parts) + print(f"{source_token_str} -> {target_string}") + + # debug_projection_map(transformation_counts, tokenizer_student, tokenizer_teacher, + # direction="student->teacher", N=1000) + + # Create transformation matrix (student -> teacher projection) + indices = list(transformation_counts.keys()) + values = list(transformation_counts.values()) + + teacher_indices = [idx[1] for idx in indices] + student_indices = [idx[0] for idx in indices] + + # Create sparse tensor with student tokens as rows, teacher tokens as columns + # This creates a student -> teacher projection matrix + indices_tensor = torch.LongTensor([student_indices, teacher_indices]) + values_tensor = torch.FloatTensor(values) + + transformation_matrix_sparse = torch.sparse_coo_tensor( + indices_tensor, + values_tensor, + (tokenizer_student_total_vocab_size, tokenizer_teacher_total_vocab_size), + device="cuda" if torch.cuda.is_available() else "cpu", + dtype=torch.bfloat16 + ) + + # indices, values = torch.topk(transformation_matrix_sparse, k=1000, dim=1) + + print(f"Created sparse student->teacher projection matrix with shape: {transformation_matrix_sparse.shape}") + print(f"Non-zero elements: {transformation_matrix_sparse._nnz()}") + + if 0: + # cant fit to the memory + # Calculate mapping statistics from sparse matrix + print("\nCalculating mapping statistics from projection matrix...") + + # Count non-zero elements in each row (each row = student token) + dense_matrix = transformation_matrix_sparse.to_dense() + non_zero_counts_per_row = (dense_matrix != 0).sum(dim=1) # Count non-zeros per row + + # Create statistics + mapping_stats = defaultdict(int) + for count in non_zero_counts_per_row: + mapping_stats[count.item()] += 1 + + # Print mapping statistics + print("\nMapping statistics (student tokens -> teacher tokens):") + for i in range(1, 5): # 1, 2, 3, 4 teacher tokens + count = mapping_stats.get(i, 0) + print(f"Student tokens mapping to {i} teacher tokens: {count}") + + total_mapped = sum(mapping_stats.values()) + print(f"Total student tokens mapped: {total_mapped}") + + # Convert sparse matrix to same format as minimal_projection_generator.py + os.makedirs(args.output_dir, exist_ok=True) + + # Convert defaultdict to regular dict for saving + transformation_counts_dict = dict(transformation_counts) + + # Show some examples of the projection mappings + debug_projection_map(transformation_counts_dict, tokenizer_student, tokenizer_teacher, + direction="student->teacher", N=1000) + + # exit() + + print(f"\nConverting sparse matrix to top-{TOP_K} dense format...") + + # Convert sparse matrix to dense and get top-k values per row + print("Converting to dense matrix on CPU to avoid memory issues...") + dense_matrix = transformation_matrix_sparse.cpu().to_dense() # Move to CPU to handle memory + print(f"Dense matrix shape: {dense_matrix.shape}") + + # Get top-k values and indices for each row (each source token) + print(f"Extracting top-{TOP_K} values per token...") + + # Apply sinkhorn normalization on CPU + if 1: + print("Applying Sinkhorn normalization on CPU...") + dense_matrix = sinkhorn_one_dim(dense_matrix, n_iters=1) + + # Extract top-k on CPU + top_k_likelihoods, top_k_indices = torch.topk(dense_matrix, k=min(TOP_K, dense_matrix.shape[1]), dim=1) + # exit() + # Handle case where vocabulary has fewer tokens than TOP_K + actual_k = top_k_indices.shape[1] + if actual_k < TOP_K: + print(f"Warning: Target vocabulary size ({dense_matrix.shape[1]}) is smaller than TOP_K ({TOP_K}). Using k={actual_k}") + # Pad with -1 indices and 0.0 likelihoods to maintain consistent shape + pad_size = TOP_K - actual_k + top_k_indices = torch.cat([top_k_indices, torch.full((top_k_indices.shape[0], pad_size), -1, dtype=top_k_indices.dtype)], dim=1) + top_k_likelihoods = torch.cat([top_k_likelihoods, torch.zeros((top_k_likelihoods.shape[0], pad_size), dtype=top_k_likelihoods.dtype)], dim=1) + + if 0: + threshold_mask = top_k_likelihoods >= 0.0000000000000000001 + top_k_indices = top_k_indices.where(threshold_mask, torch.full_like(top_k_indices, -1)) + + # Apply SCALE_TRICK: set last column to -4 if enabled + if ENABLE_SCALE_TRICK: + print("ENABLE_SCALE_TRICK is True: Setting last column of likelihoods to -4.0") + top_k_likelihoods[:, -1] = 0.2 + if ENABLE_EXACT_MATCH: + for indices in match_indices_student: + top_k_likelihoods[indices, -1] = 0.0 + print(f"Set last column of likelihoods to 0.0 for {len(match_indices_student)} exact matches as exact match is enabled") + # Apply sinkhorn normalization on CPU + if 1: + print("Applying Sinkhorn normalization on CPU...") + top_k_likelihoods = sinkhorn_one_dim(top_k_likelihoods, n_iters=1) + + # pdb.set_trace() + #set indices to -1 where likelihood is 0 + + # Create filename in same format as minimal_projection_generator.py + def clean_model_name_for_filename(name: str) -> str: + """Removes parameter counts and common suffixes from model names for cleaner filenames.""" + import re + # Removes patterns like -8B, -1.5B, -4b, -125m etc. + cleaned_name = re.sub(r'-?[0-9\.]+[bBmB]', '', name, flags=re.IGNORECASE) + # Remove common suffixes + cleaned_name = cleaned_name.replace('-Base', '').replace('-it', '').replace('-Instruct', '') + # Clean up any leading/trailing hyphens that might result + cleaned_name = cleaned_name.strip('-_') + return cleaned_name + + student_clean_name = clean_model_name_for_filename(student_model_name.split("/")[-1]) + teacher_clean_name = clean_model_name_for_filename(teacher_model_name.split("/")[-1]) + + output_filename = f"projection_map_{student_clean_name}_to_{teacher_clean_name}_multitoken_top_{TOP_K}_double" + # if USE_RAW_TOKENS: + # output_filename += "_raw_tokens" + if ENABLE_SPECIAL_TOKEN_MAPPING: + output_filename += "_special" + output_filename += ".pt" + # if ENABLE_REVERSE_PASS: + # output_filename = output_filename.replace(".pt", "_bidirectional.pt") + output_path = os.path.join(args.output_dir, output_filename) + + # Save in same format as minimal_projection_generator.py + torch.save({ + "indices": top_k_indices, + "likelihoods": top_k_likelihoods, + "model_A_id": student_model_name, # source model (student) + "model_B_id": teacher_model_name, # target model (teacher) + }, output_path) + + print(f"Saved projection map to: {output_path}") + print(f"Format: indices shape {top_k_indices.shape}, likelihoods shape {top_k_likelihoods.shape}") + print(f"Compatible with minimal_projection_generator.py format") + print(f"Token processing mode: {'Raw tokens (convert_ids_to_tokens)' if USE_RAW_TOKENS else 'Decoded tokens (decode)'}") + if ENABLE_REVERSE_PASS: + print("File includes bidirectional mappings (teacher->student and student->teacher)") + if ENABLE_SPECIAL_TOKEN_MAPPING: + print(f"File includes special token mappings (similarity_threshold={SPECIAL_TOKEN_SIMILARITY_THRESHOLD}, top_k={SPECIAL_TOKEN_TOP_K})") + # exit() + + + + + # Test projection function compatibility (same as minimal_projection_generator.py) + print("\n--- Testing projection function compatibility ---") + + def project_token_likelihoods(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device): + """Projects token likelihoods from a source to a target vocabulary using a sparse map.""" + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + if source_vocab_size != projection_map_indices.shape[0]: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches projection map size ({projection_map_indices.shape[0]})") + + top_k = projection_map_indices.shape[1] + input_likelihoods = input_likelihoods.to(device) + projection_map_indices = projection_map_indices.to(device) + projection_map_values = projection_map_values.to(device) + + crow_indices = torch.arange(0, (source_vocab_size + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = projection_map_indices.flatten() + values = projection_map_values.flatten() + + sparse_projection_matrix = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(source_vocab_size, target_vocab_size), device=device + ) + + reshaped_input = input_likelihoods.reshape(batch_size * seq_len, source_vocab_size) + projected_likelihoods_reshaped = torch.matmul(reshaped_input, sparse_projection_matrix) + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, target_vocab_size) + + # Create a dummy likelihood tensor: [BATCH, SEQ, source_vocab_size] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dummy_tensor = torch.randn(1, 4096, tokenizer_student_total_vocab_size, device=device, dtype=torch.bfloat16) + + # Transform this tensor using the projection map + projected_tensor = project_token_likelihoods( + dummy_tensor, + top_k_indices.to(device), + top_k_likelihoods.to(device), + tokenizer_teacher_total_vocab_size, + device + ) + print(f"Input tensor shape: {dummy_tensor.shape}") + print(f"Projected tensor shape: {projected_tensor.shape}") + print("Projection test successful - format is fully compatible!") + + # pdb.set_trace() + \ No newline at end of file diff --git a/nemo_rl/utils/x_token/reapply_exact_map.py b/nemo_rl/utils/x_token/reapply_exact_map.py new file mode 100644 index 0000000000..b37f994466 --- /dev/null +++ b/nemo_rl/utils/x_token/reapply_exact_map.py @@ -0,0 +1,246 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +import torch.nn as nn +from tokenalign import TokenAligner +import gc +from collections import defaultdict +from datasets import load_dataset, get_dataset_config_names +import random +import time +import numpy as np +import tqdm +import pdb +import difflib +import re +import argparse +import os + +def apply_canonicalization_if_enabled(token_str, use_canonicalization): + """Apply canonicalization to token string if enabled.""" + if use_canonicalization: + return TokenAligner._canonical_token(token_str) + return token_str + +def parse_arguments(): + """Parse command line arguments for the multi-token projection script.""" + parser = argparse.ArgumentParser( + description="Generate multi-token projection mappings between tokenizers", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + # Model selection arguments + parser.add_argument( + "--student-model", + type=str, + default="meta-llama/Llama-3.2-1B", + help="Student model name or path" + ) + parser.add_argument( + "--teacher-model", + type=str, + default="microsoft/phi-4", + help="Teacher model name or path" + ) + + # Boolean flags + parser.add_argument( + "--enable-scale-trick", + action="store_true", + default=True, + help="Enable scale trick (set last column likelihood to 0.2)" + ) + parser.add_argument( + "--disable-scale-trick", + action="store_false", + dest="enable_scale_trick", + help="Disable scale trick" + ) + parser.add_argument( + "--enable-reverse-pass", + action="store_true", + default=True, + help="Enable second pass: student tokens -> teacher tokens" + ) + parser.add_argument( + "--disable-reverse-pass", + action="store_false", + dest="enable_reverse_pass", + help="Disable reverse pass" + ) + parser.add_argument( + "--enable-exact-match", + action="store_true", + default=False, + help="Enable exact match enforcement for identical tokens" + ) + parser.add_argument( + "--use-raw-tokens", + action="store_true", + default=False, + help="Use convert_ids_to_tokens instead of decode, should be False" + ) + parser.add_argument( + "--enable-special-token-mapping", + action="store_true", + default=True, + help="Enable mapping of similar special tokens" + ) + parser.add_argument( + "--disable-special-token-mapping", + action="store_false", + dest="enable_special_token_mapping", + help="Disable special token mapping" + ) + parser.add_argument( + "--use-canonicalization", + action="store_true", + default=False, + help="Apply token canonicalization before processing to normalize different tokenizer representations (e.g., Ġ vs ▁ prefixes, Ċ vs \\n)" + ) + + # Numeric parameters + parser.add_argument( + "--tokens-to-cut", + type=int, + default=4, + help="Maximum number of tokens to consider for multi-token mappings" + ) + parser.add_argument( + "--top-k", + type=int, + default=32, + help="Number of top projections to keep for each token" + ) + parser.add_argument( + "--special-token-similarity-threshold", + type=float, + default=0.3, + help="Minimum similarity threshold for special token matching" + ) + parser.add_argument( + "--special-token-top-k", + type=int, + default=None, + help="Top K matches for each special token (defaults to --top-k value)" + ) + + # File paths + parser.add_argument( + "--initial-projection-path", + type=str, + default=None, + help="Path to initial projection map to load and extend" + ) + parser.add_argument( + "--output-dir", + type=str, + default="cross_tokenizer_data", + help="Output directory for saving projection maps" + ) + + return parser.parse_args() + + +if __name__ == "__main__": + # Parse command line arguments + args = parse_arguments() + # Model names from arguments + teacher_model_name = args.teacher_model + student_model_name = args.student_model + USE_CANONICALIZATION = args.use_canonicalization + + + tokenizer_student = AutoTokenizer.from_pretrained(student_model_name) + tokenizer_teacher = AutoTokenizer.from_pretrained(teacher_model_name) + + tokenizer_student_total_vocab_size = len(tokenizer_student) + tokenizer_teacher_total_vocab_size = len(tokenizer_teacher) + model_A_config = AutoConfig.from_pretrained(student_model_name) + model_B_config = AutoConfig.from_pretrained(teacher_model_name) + + tokens_student = [apply_canonicalization_if_enabled(tokenizer_student.convert_ids_to_tokens([i])[0], USE_CANONICALIZATION) for i in range(tokenizer_student_total_vocab_size)] + tokens_teacher = [apply_canonicalization_if_enabled(tokenizer_teacher.convert_ids_to_tokens([j])[0], USE_CANONICALIZATION) for j in range(tokenizer_teacher_total_vocab_size)] + + map_teacher_token_to_idx = {token: j for j, token in enumerate(tokens_teacher)} + + # Find indices in student and teacher where the tokens are identical + match_indices_student = [] + match_indices_teacher = [] + for i, token_student in enumerate(tokens_student): + if token_student in map_teacher_token_to_idx: + j = map_teacher_token_to_idx[token_student] + match_indices_student.append(i) + match_indices_teacher.append(j) + + if match_indices_student: + print(f"Found {len(match_indices_student)} exact matches. Setting perfect 1-to-1 mappings.") + + # load intial projection map + initial_projection_path = args.initial_projection_path + if initial_projection_path is not None: + initial_projection_map = torch.load(initial_projection_path) + else: + initial_projection_map = None + + # go through token in projection map. For each token present in match_indices_student, set it's likelihoods and incices to 1.0 and the exact match teacher token + non_exact_map_tokens = list(range(len(initial_projection_map["likelihoods"]))) + all_student_token_ids = list(range(len(initial_projection_map["likelihoods"]))) + + show_remapping = 5 + if show_remapping > 0: + print(f"Showing remapping for the last {show_remapping} exact matches.") + else: + print(f"Not showing remapping.") + + for i, exact_token_student in enumerate(match_indices_student): + exact_token_teacher = match_indices_teacher[i] + + index_ = all_student_token_ids.index(exact_token_student) + likelihoods = initial_projection_map["likelihoods"][index_] + indices = initial_projection_map["indices"][index_] + + if len(match_indices_student) - i <= show_remapping: + print(f"prior to remapping: likelihoods {likelihoods} indices {indices}") + + topk = indices.shape[0] + + remapped_indices = torch.ones_like(indices) * -1 + remapped_likelihoods = torch.zeros_like(likelihoods) + + remapped_likelihoods[0] = 1.0 + remapped_indices[0] = exact_token_teacher + + # if exact_token_student == 5159: + # import pdb + # pdb.set_trace() + + initial_projection_map["likelihoods"][index_] = remapped_likelihoods + initial_projection_map["indices"][index_] = remapped_indices + + + if len(match_indices_student) - i <= show_remapping: + print(f'after remapping {tokens_student[exact_token_student]}:{exact_token_student} -> {tokens_teacher[exact_token_teacher]}:{exact_token_teacher}: likelihoods {initial_projection_map["likelihoods"][index_]} indices {initial_projection_map["indices"][index_]}') + non_exact_map_tokens.remove(index_) + + + # import pdb + # pdb.set_trace() + # print(f"non exact map tokens: {non_exact_map_tokens}") + # pdb.set_trace() + save_path = args.initial_projection_path.split(".")[0] + "_exact_map_remapped.pt" + torch.save(initial_projection_map, save_path) + print(f"Saved remapped projection map to: {save_path}") + print(f"remapped {len(match_indices_student)} tokens. Retained remaining {len(non_exact_map_tokens)} tokens as is.") \ No newline at end of file diff --git a/nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py b/nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py new file mode 100644 index 0000000000..caba5b2738 --- /dev/null +++ b/nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py @@ -0,0 +1,451 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import os +import argparse +import tqdm +from transformers import AutoTokenizer, AutoConfig +import pdb; + +def sinkhorn_one_dim(A, n_iters=1): + """Apply Sinkhorn normalization to make each row sum to 1.""" + for _ in range(n_iters): + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + return A + +def clean_model_name_for_filename(name: str) -> str: + """Removes parameter counts and common suffixes from model names for cleaner filenames.""" + import re + # Removes patterns like -8B, -1.5B, -4b, -125m etc. + cleaned_name = re.sub(r'-?[0-9\.]+[bBmB]', '', name, flags=re.IGNORECASE) + # Remove common suffixes + cleaned_name = cleaned_name.replace('-Base', '').replace('-it', '').replace('-Instruct', '') + # Clean up any leading/trailing hyphens that might result + cleaned_name = cleaned_name.strip('-_') + return cleaned_name + +def sort_and_cut_projection_matrix(input_path, output_path, new_top_k, preserve_last=False, verbose=True): + """ + Load a projection matrix, sort each row by weight values, and save with new top_k cutoff. + + Args: + input_path: Path to input projection matrix file + output_path: Path to save the new projection matrix + new_top_k: New top_k value for cutoff + preserve_last: If True, always preserve the last column as the final element + verbose: Whether to print progress information + """ + if verbose: + print(f"Loading projection matrix from: {input_path}") + + # Load the projection matrix + projection_data = torch.load(input_path, map_location='cpu', weights_only=False) + + if not isinstance(projection_data, dict) or 'indices' not in projection_data or 'likelihoods' not in projection_data: + raise ValueError("Input file must contain a dictionary with 'indices' and 'likelihoods' keys") + + original_indices = projection_data['indices'] # Shape: [vocab_size, original_top_k] + original_likelihoods = projection_data['likelihoods'] # Shape: [vocab_size, original_top_k] + + vocab_size, original_top_k = original_indices.shape + + if verbose: + print(f"Original matrix shape: {original_indices.shape}") + print(f"Original top_k: {original_top_k}") + print(f"New top_k: {new_top_k}") + print(f"Preserve last column: {preserve_last}") + # pdb.set_trace() + + if new_top_k > original_top_k: + print(f"Warning: New top_k ({new_top_k}) is larger than original top_k ({original_top_k})") + print(f"Will pad with -1 indices and 0.0 likelihoods") + effective_top_k = original_top_k + else: + effective_top_k = new_top_k + + # Initialize new tensors + new_indices = torch.full((vocab_size, new_top_k), -1, dtype=original_indices.dtype) + new_likelihoods = torch.zeros((vocab_size, new_top_k), dtype=original_likelihoods.dtype) + + # Statistics tracking + rows_with_order_change = 0 + significant_components_count = [0] * min(new_top_k, 10) # Track up to 10 components + threshold_for_significance = 0.2 # Threshold for considering a component "significant" + # Track position of maximum element in original ordering + max_element_positions = {} # position -> count + # Track preserve_last statistics + rows_with_preserved_last = 0 + # Track specifically when max element is in the last column + rows_with_max_in_last_column = 0 + # Track position of maximum element in final sorted and trimmed matrix + final_max_element_positions = {} # position -> count + + # threshold_for_significance = 0.05 # Threshold for considering a component "significant" + # threshold_for_significance = 0.05 # Threshold for considering a component "significant" + + if verbose: + print("Sorting and cutting each row...") + + # Process each row (each source token) + last_element_trick_count = 0 + for row_idx in tqdm.tqdm(range(vocab_size), desc="Processing rows", disable=not verbose): + row_indices = original_indices[row_idx] # [original_top_k] + row_likelihoods = original_likelihoods[row_idx] # [original_top_k] + + # Filter out invalid indices (-1) and zero likelihoods + valid_mask = (row_indices != -1) & (row_likelihoods > 0) + + if valid_mask.any(): + valid_indices = row_indices[valid_mask] + valid_likelihoods = row_likelihoods[valid_mask] + + # Track position of maximum element in original ordering + max_pos = torch.argmax(valid_likelihoods).item() + if max_pos not in max_element_positions: + max_element_positions[max_pos] = 0 + max_element_positions[max_pos] += 1 + + # Check if max element is specifically in the last column + # Find the actual maximum value in the original row (including invalid entries) + original_max_pos = torch.argmax(row_likelihoods).item() + if original_max_pos == original_top_k - 1: + # Only count if the last position actually has valid data + last_index = row_indices[original_top_k - 1] + last_likelihood = row_likelihoods[original_top_k - 1] + if last_index != -1 and last_likelihood > 0: + rows_with_max_in_last_column += 1 + + if preserve_last and new_top_k >= 1: + # Handle preserve_last case + last_index = original_indices[row_idx, original_top_k - 1] + last_likelihood = original_likelihoods[row_idx, original_top_k - 1] + + if new_top_k == 1: + # Special case: only keep the last element + if last_index != -1 and last_likelihood > 0: + new_indices[row_idx, 0] = last_index + new_likelihoods[row_idx, 0] = last_likelihood + rows_with_preserved_last += 1 + + # Count significant components + if last_likelihood >= threshold_for_significance: + significant_components_count[0] += 1 + else: + # General case: sort first (original_top_k-1) elements, then add last element + elements_to_sort = min(len(valid_likelihoods), original_top_k - 1) + if elements_to_sort > 0: + # Get elements excluding the last position in original matrix + sort_mask = torch.arange(len(valid_likelihoods)) < elements_to_sort + if sort_mask.any(): + sortable_indices = valid_indices[sort_mask] + sortable_likelihoods = valid_likelihoods[sort_mask] + + # Sort the non-last elements + sorted_likelihoods, sort_order = torch.sort(sortable_likelihoods, descending=True) + sorted_indices = sortable_indices[sort_order] + + # Check if order changed in the sortable portion + original_order = torch.arange(len(sortable_likelihoods)) + if not torch.equal(sort_order, original_order): + rows_with_order_change += 1 + + # Take top (new_top_k - 1) elements from sorted portion + num_from_sorted = min(len(sorted_indices), new_top_k - 1) + + new_indices[row_idx, :num_from_sorted] = sorted_indices[:num_from_sorted] + new_likelihoods[row_idx, :num_from_sorted] = sorted_likelihoods[:num_from_sorted] + + # Count significant components from sorted portion + for comp_idx in range(min(num_from_sorted, len(significant_components_count) - 1)): + if sorted_likelihoods[comp_idx] >= threshold_for_significance: + significant_components_count[comp_idx] += 1 + + # Always put the last element at the end (if valid) + + if last_index != -1 and last_likelihood > 0: + last_element_trick_count += 1 + new_indices[row_idx, new_top_k - 1] = last_index + new_likelihoods[row_idx, new_top_k - 1] = last_likelihood + rows_with_preserved_last += 1 + + # Count significant component for the preserved last element + if new_top_k - 1 < len(significant_components_count): + if last_likelihood >= threshold_for_significance: + significant_components_count[new_top_k - 1] += 1 + + else: + # Original logic: sort all elements normally + # Check if order changed by comparing original vs sorted order + original_order = torch.arange(len(valid_likelihoods)) + sorted_likelihoods, sort_order = torch.sort(valid_likelihoods, descending=True) + + # Check if the order changed (not just sorted, but actually different) + if not torch.equal(sort_order, original_order): + rows_with_order_change += 1 + + sorted_indices = valid_indices[sort_order] + + # Take top effective_top_k elements + num_to_take = min(len(sorted_indices), effective_top_k) + + new_indices[row_idx, :num_to_take] = sorted_indices[:num_to_take] + new_likelihoods[row_idx, :num_to_take] = sorted_likelihoods[:num_to_take] + # pdb.set_trace() + # Count significant components (components above threshold) + for comp_idx in range(min(num_to_take, len(significant_components_count))): + if sorted_likelihoods[comp_idx] >= threshold_for_significance: + significant_components_count[comp_idx] += 1 + # if significant_components_count[1] > 0.0: + # pdb.set_trace() + + # If new_top_k > original_top_k, the tensors are already padded with -1 and 0.0 + + # Apply Sinkhorn normalization to the final matrix + print(f"last element trick count: {last_element_trick_count}") + if verbose: + print("Applying Sinkhorn normalization...") + + # Apply normalization only to non-zero values to preserve sparsity structure + normalized_likelihoods = sinkhorn_one_dim(new_likelihoods.clone(), n_iters=1) + + # Calculate final maximum element position statistics after sorting and normalization + if verbose: + print("Calculating final maximum element position statistics...") + + for row_idx in range(vocab_size): + row_likelihoods = normalized_likelihoods[row_idx] + # Filter out zero likelihoods + valid_mask = row_likelihoods > 0 + if valid_mask.any(): + valid_likelihoods = row_likelihoods[valid_mask] + # Find position of maximum element in the final matrix + max_pos_in_valid = torch.argmax(valid_likelihoods).item() + # Convert back to original position in the row + valid_positions = torch.nonzero(valid_mask).squeeze(-1) + actual_max_pos = valid_positions[max_pos_in_valid].item() + + if actual_max_pos not in final_max_element_positions: + final_max_element_positions[actual_max_pos] = 0 + final_max_element_positions[actual_max_pos] += 1 + + # Create output dictionary with same format as input + output_data = { + 'indices': new_indices, + 'likelihoods': normalized_likelihoods, + } + + # Copy over any additional metadata + for key in projection_data: + if key not in ['indices', 'likelihoods']: + output_data[key] = projection_data[key] + + # Save the new projection matrix + torch.save(output_data, output_path) + + if verbose: + print(f"Saved sorted and cut projection matrix to: {output_path}") + print(f"New matrix shape: {new_indices.shape}") + + # Show basic statistics + non_zero_counts = (new_likelihoods > 0).sum(dim=1) + avg_non_zero = non_zero_counts.float().mean().item() + print(f"Average non-zero entries per row: {avg_non_zero:.2f}") + print(f"Rows with max entries ({new_top_k}): {(non_zero_counts == new_top_k).sum().item()}") + + # Show ordering statistics + print(f"\n=== Ordering Statistics ===") + print(f"Rows with changed order after sorting: {rows_with_order_change:,} / {vocab_size:,} ({100*rows_with_order_change/vocab_size:.1f}%)") + if preserve_last: + print(f"Rows with preserved last element: {rows_with_preserved_last:,} / {vocab_size:,} ({100*rows_with_preserved_last/vocab_size:.1f}%)") + + # Show last column maximum element statistics + print(f"\n=== Last Column Maximum Element Statistics ===") + total_rows_with_data = sum(max_element_positions.values()) + if total_rows_with_data > 0: + percentage_last_max = 100 * rows_with_max_in_last_column / total_rows_with_data + print(f"Rows with maximum element in LAST column: {rows_with_max_in_last_column:,} / {total_rows_with_data:,} ({percentage_last_max:.1f}%)") + print(f"Rows with maximum element in NON-LAST columns: {total_rows_with_data - rows_with_max_in_last_column:,} / {total_rows_with_data:,} ({100 - percentage_last_max:.1f}%)") + else: + print(f"No valid data found to analyze last column statistics") + + # Show maximum element position distribution + print(f"\n=== Maximum Element Position Distribution (Original Ordering) ===") + total_rows_with_data = sum(max_element_positions.values()) + print(f"Total rows with valid data: {total_rows_with_data:,}") + + # Sort positions for ordered display + sorted_positions = sorted(max_element_positions.keys()) + for pos in sorted_positions[:20]: # Show up to first 20 positions + count = max_element_positions[pos] + percentage = 100 * count / total_rows_with_data if total_rows_with_data > 0 else 0 + ordinal = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"][pos] if pos < 10 else f"{pos+1}th" + print(f"Rows with max element in {ordinal} position: {count:,} / {total_rows_with_data:,} ({percentage:.1f}%)") + + if len(sorted_positions) > 20: + remaining_count = sum(max_element_positions[pos] for pos in sorted_positions[20:]) + remaining_percentage = 100 * remaining_count / total_rows_with_data if total_rows_with_data > 0 else 0 + print(f"Rows with max element in positions 21+: {remaining_count:,} / {total_rows_with_data:,} ({remaining_percentage:.1f}%)") + + # Show final maximum element position distribution (after sorting and normalization) + print(f"\n=== Maximum Element Position Distribution (Final Sorted & Normalized Matrix) ===") + total_final_rows_with_data = sum(final_max_element_positions.values()) + print(f"Total rows with valid data: {total_final_rows_with_data:,}") + + if total_final_rows_with_data > 0: + # Sort positions for ordered display + sorted_final_positions = sorted(final_max_element_positions.keys()) + for pos in sorted_final_positions[:min(new_top_k, 20)]: # Show up to new_top_k or 20 positions + count = final_max_element_positions[pos] + percentage = 100 * count / total_final_rows_with_data + ordinal = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"][pos] if pos < 10 else f"{pos+1}th" + print(f"Rows with max element in {ordinal} position: {count:,} / {total_final_rows_with_data:,} ({percentage:.1f}%)") + + if len(sorted_final_positions) > min(new_top_k, 20): + remaining_count = sum(final_max_element_positions[pos] for pos in sorted_final_positions[min(new_top_k, 20):]) + remaining_percentage = 100 * remaining_count / total_final_rows_with_data + print(f"Rows with max element in positions {min(new_top_k, 20)+1}+: {remaining_count:,} / {total_final_rows_with_data:,} ({remaining_percentage:.1f}%)") + + # Show significant components statistics + print(f"\n=== Significant Components Statistics (threshold >= {threshold_for_significance}) ===") + component_names = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"] + for i, count in enumerate(significant_components_count): + percentage = 100 * count / vocab_size if vocab_size > 0 else 0 + print(f"Rows with significant {component_names[i]} component: {count:,} / {vocab_size:,} ({percentage:.1f}%)") + + # Additional analysis: distribution of likelihood values (after normalization) + all_likelihoods = normalized_likelihoods[normalized_likelihoods > 0] + if len(all_likelihoods) > 0: + print(f"\n=== Likelihood Distribution ===") + print(f"Total non-zero likelihoods: {len(all_likelihoods):,}") + print(f"Mean likelihood: {all_likelihoods.mean().item():.4f}") + print(f"Median likelihood: {all_likelihoods.median().item():.4f}") + print(f"Min likelihood: {all_likelihoods.min().item():.4f}") + print(f"Max likelihood: {all_likelihoods.max().item():.4f}") + + # Show percentiles - convert to float for quantile calculation + percentiles = [90, 95, 99] + all_likelihoods_float = all_likelihoods.float() + for p in percentiles: + val = torch.quantile(all_likelihoods_float, p/100.0).item() + print(f"{p}th percentile: {val:.4f}") + + # Show how many rows have multiple significant components + print(f"\n=== Multi-Component Analysis ===") + rows_with_multiple_significant = 0 + for row_idx in range(vocab_size): + significant_in_row = (normalized_likelihoods[row_idx] >= threshold_for_significance).sum().item() + if significant_in_row >= 2: + rows_with_multiple_significant += 1 + + percentage_multi = 100 * rows_with_multiple_significant / vocab_size if vocab_size > 0 else 0 + print(f"Rows with 2+ significant components: {rows_with_multiple_significant:,} / {vocab_size:,} ({percentage_multi:.1f}%)") + + # Show normalization effect + print(f"\n=== Normalization Effect ===") + # Calculate row sums for ALL rows (including zero rows) + all_row_sums = normalized_likelihoods.sum(dim=1) + non_zero_rows = (normalized_likelihoods > 0).any(dim=1) + zero_rows = ~non_zero_rows + + print(f"Total rows: {vocab_size:,}") + print(f"Rows with non-zero entries: {non_zero_rows.sum().item():,}") + print(f"Rows with all zeros: {zero_rows.sum().item():,}") + + if non_zero_rows.any(): + row_sums_nonzero = all_row_sums[non_zero_rows] + print(f"\nNon-zero rows statistics:") + print(f" Mean sum: {row_sums_nonzero.mean().item():.6f}") + print(f" Std sum: {row_sums_nonzero.std().item():.6f}") + print(f" Min sum: {row_sums_nonzero.min().item():.6f}") + print(f" Max sum: {row_sums_nonzero.max().item():.6f}") + + # Check how many rows don't sum to 1 (with different tolerance levels) + tolerances = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2] + for tol in tolerances: + perfect_rows = (torch.abs(row_sums_nonzero - 1.0) < tol).sum().item() + imperfect_rows = len(row_sums_nonzero) - perfect_rows + percentage_imperfect = 100 * imperfect_rows / len(row_sums_nonzero) + print(f" Rows NOT summing to 1.0 (tol={tol}): {imperfect_rows:,}/{len(row_sums_nonzero):,} ({percentage_imperfect:.2f}%)") + + # Show distribution of row sums that deviate from 1.0 + if non_zero_rows.any(): + row_sums_nonzero = all_row_sums[non_zero_rows] + deviations = torch.abs(row_sums_nonzero - 1.0) + significant_deviations = deviations > 1e-3 + + if significant_deviations.any(): + print(f"\nRows with significant deviations from 1.0 (>0.001): {significant_deviations.sum().item():,}") + worst_deviations = deviations[significant_deviations] + print(f" Mean deviation: {worst_deviations.mean().item():.6f}") + print(f" Max deviation: {worst_deviations.max().item():.6f}") + + # Show some examples of problematic rows + worst_indices = torch.topk(deviations, k=min(5, len(deviations)))[1] + print(f" Worst {min(5, len(worst_indices))} row examples:") + for i, idx in enumerate(worst_indices): + actual_row_idx = torch.nonzero(non_zero_rows)[idx].item() + sum_val = row_sums_nonzero[idx].item() + deviation = deviations[idx].item() + non_zero_count = (normalized_likelihoods[actual_row_idx] > 0).sum().item() + print(f" Row {actual_row_idx}: sum={sum_val:.6f}, deviation={deviation:.6f}, non_zeros={non_zero_count}") + else: + print(f"\nAll non-zero rows sum very close to 1.0 (deviation < 0.001)") + +def main(): + parser = argparse.ArgumentParser(description="Sort and cut projection matrix by top_k") + parser.add_argument("input_path", help="Path to input projection matrix file") + parser.add_argument("--top_k", type=int, required=True, help="New top_k value for cutoff") + parser.add_argument("--output_path", help="Output path (auto-generated if not specified)") + parser.add_argument("--preserve_last", action="store_true", help="Always preserve the last column as the final element") + parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output") + # python sort_and_cut_projection_matrix.py /lustre/fsw/portfolios/nvr/projects/nvr_lpr_llm/users/pmolchanov/xtoken/models/runs/s4_l1q4b_lr0_kl1_ce0_k1_emb_top10_3_learn_qa2_transformation_matrices/learned_projection_map_latest.pt --top_k 8 --output_path cross_tokenizer_data/projection_matrix_learned_llama_qwen_top8.pt --preserve_last + #s4_l1q4b_lr0_kl1_ce0_k1_emb_top10_3_learn_qa2_transformation_matrices + args = parser.parse_args() + + # Auto-generate output path if not specified + if args.output_path is None: + input_dir = os.path.dirname(args.input_path) + input_filename = os.path.basename(args.input_path) + + # Extract base name and extension + base_name, ext = os.path.splitext(input_filename) + + # Remove old top_k info if present + import re + base_name = re.sub(r'_top_\d+', '', base_name) + + # Add new top_k info and preserve_last flag + suffix = "_sorted" + if args.preserve_last: + suffix += "_preservelast" + output_filename = f"{base_name}_top_{args.top_k}{suffix}{ext}" + args.output_path = os.path.join(input_dir, output_filename) + + # Ensure output directory exists + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + + # Process the matrix + sort_and_cut_projection_matrix( + args.input_path, + args.output_path, + args.top_k, + preserve_last=args.preserve_last, + verbose=not args.quiet + ) + +if __name__ == "__main__": + main() \ No newline at end of file From dcfda98be716a59a19bcbec57b628c772cb65eea Mon Sep 17 00:00:00 2001 From: Adithyakrishna Hanasoge Date: Mon, 27 Apr 2026 19:05:57 -0700 Subject: [PATCH 2/2] refactor(x_token): drop align_fast and numba code paths from TokenAligner MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit tokenalign.py: 2194 → 1877 lines (-317). align_fast removal - Committing to use_align_fast=false everywhere; the fast-lookup path is no longer needed. Callers (CrossTokenizerCollator, off_policy_distillation setup) will be updated on the dependent branches in this stack to take the unconditional aligner.align(...) path. - TokenAligner: dropped align_fast(), precompute_canonical_maps(), and the _student_canon_map / _teacher_canon_map / _canon_id_to_str attrs that backed them. Behavior note: align_fast skipped _canonicalize_sequence's _merge_encoding_artifacts / _merge_consecutive_bytes by reusing per-token canonical maps. align() runs the full sequence-level canonicalization, so alignment may differ slightly for sequences that contain encoding-artifact or byte tokens. For pairs without those, the result is identical. The cost is a small per-batch CPU bump in the collator workers since we no longer cache canonical strings per id. numba removal - The training container ships without numba, so at runtime _NUMBA_AVAILABLE was always False, _dp_core_numba was never defined, and align_tokens_with_combinations_numpy_jit was a one-line forwarder to align_tokens_with_combinations_numpy on its first if-not-numba branch. - Dropped the top-level numba try/except + _NUMBA_AVAILABLE flag and the @njit-decorated _dp_core_numba kernel. - Deleted align_tokens_with_combinations_numpy_jit and retargeted callers at align_tokens_with_combinations_numpy directly: - _perform_dp_alignment (chunk_size==0 branch) - align_tokens_combinations_chunked (small-sequence base case + divide-and-conquer fallback) - Behavior: zero runtime change for this container; the numpy DP kernel is what was always executing. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Adithyakrishna Hanasoge --- nemo_rl/algorithms/x_token/tokenalign.py | 323 +---------------------- 1 file changed, 3 insertions(+), 320 deletions(-) diff --git a/nemo_rl/algorithms/x_token/tokenalign.py b/nemo_rl/algorithms/x_token/tokenalign.py index 1d55f858f7..0fdbdb4a79 100644 --- a/nemo_rl/algorithms/x_token/tokenalign.py +++ b/nemo_rl/algorithms/x_token/tokenalign.py @@ -20,85 +20,9 @@ import torch.nn as nn from transformers import AutoConfig, AutoTokenizer -try: - from numba import njit - _NUMBA_AVAILABLE = True -except ImportError: - _NUMBA_AVAILABLE = False - os.environ["TOKENIZERS_PARALLELISM"] = "false" -if _NUMBA_AVAILABLE: - @njit(cache=True) - def _dp_core_numba(ids1, ids2, joined1, joined2, n1, n2, - exact_match_score, gap_penalty, comb_mul, max_comb_len): - """Numba-accelerated DP core for token alignment. - - Uses the same algorithm as align_tokens_with_combinations_numpy but - with integer ID comparisons instead of Python string operations. - - Trace codes: 0=start, 1=diag, 2=up, 3=left, - 10+k = comb_s1_over_s2_k, 20+k = comb_s2_over_s1_k - """ - INVALID = np.int64(-1) - dp = np.zeros((n1 + 1, n2 + 1), dtype=np.float32) - trace = np.zeros((n1 + 1, n2 + 1), dtype=np.int32) - - for i in range(1, n1 + 1): - dp[i, 0] = dp[i - 1, 0] + gap_penalty - trace[i, 0] = 2 - for j in range(1, n2 + 1): - dp[0, j] = dp[0, j - 1] + gap_penalty - trace[0, j] = 3 - - for i in range(1, n1 + 1): - id_i = ids1[i - 1] - for j in range(1, n2 + 1): - id_j = ids2[j - 1] - - if id_i == id_j: - best = dp[i - 1, j - 1] + exact_match_score - else: - best = dp[i - 1, j - 1] - exact_match_score - best_m = np.int32(1) - - s = dp[i - 1, j] + gap_penalty - if s > best: - best = s - best_m = np.int32(2) - - s = dp[i, j - 1] + gap_penalty - if s > best: - best = s - best_m = np.int32(3) - - k_max_s2 = min(j, max_comb_len) - for k in range(2, k_max_s2 + 1): - jid = joined2[j, k] - if jid != INVALID and id_i == jid: - s = dp[i - 1, j - k] + comb_mul * np.float32(k) - if s > best: - best = s - best_m = np.int32(10 + k) - - k_max_s1 = min(i, max_comb_len) - for k in range(2, k_max_s1 + 1): - jid = joined1[i, k] - if jid != INVALID and id_j == jid: - s = dp[i - k, j - 1] + comb_mul * np.float32(k) - if s > best: - best = s - best_m = np.int32(20 + k) - - dp[i, j] = best - trace[i, j] = best_m - - return dp, trace -else: - _dp_core_numba = None - - @dataclass(frozen=True) class VocabPartition: """Projection-matrix-derived vocab partition for the gold-loss path. @@ -139,11 +63,6 @@ def __init__(self, max_comb_len=4, teacher_tokenizer_name=None, student_tokenize self._dense_proj_csr = None self._dense_proj_csr_device = None - # Precomputed canonical ID maps (built by precompute_canonical_maps) - self._student_canon_map = None - self._teacher_canon_map = None - self._canon_id_to_str = None - # Cached gold-loss vocab partitions keyed by (xtoken_loss, teacher_vocab_size). # Populated lazily by build_vocab_partition(); bypassed when learnable=True. self._vocab_partition_cache: dict[tuple[bool, int], VocabPartition] = {} @@ -244,149 +163,6 @@ def build_vocab_partition( self._vocab_partition_cache[key] = partition return partition - def precompute_canonical_maps(self): - """Build token_id → canonical_string lookup tables for both tokenizers. - - Call once at startup. After this, align_fast() can skip - convert_ids_to_tokens and _canonicalize_sequence entirely. - """ - import time as _time - _t0 = _time.time() - - canon_str_to_id: dict[str, int] = {} - next_id = [0] - - def _get_canon_id(s: str) -> int: - cid = canon_str_to_id.get(s) - if cid is None: - cid = next_id[0] - canon_str_to_id[s] = cid - next_id[0] += 1 - return cid - - student_vocab_size = len(self.student_tokenizer) - teacher_vocab_size = len(self.teacher_tokenizer) - - student_map = np.zeros(student_vocab_size, dtype=np.int64) - for tid in range(student_vocab_size): - tok = self.student_tokenizer.convert_ids_to_tokens(tid) - canon = self._canonical_token(tok) - student_map[tid] = _get_canon_id(canon) - - teacher_map = np.zeros(teacher_vocab_size, dtype=np.int64) - for tid in range(teacher_vocab_size): - tok = self.teacher_tokenizer.convert_ids_to_tokens(tid) - canon = self._canonical_token(tok) - teacher_map[tid] = _get_canon_id(canon) - - self._student_canon_map = student_map - self._teacher_canon_map = teacher_map - self._canon_id_to_str = {v: k for k, v in canon_str_to_id.items()} - - _t1 = _time.time() - print(f" [TokenAligner] Precomputed canonical maps in {_t1-_t0:.2f}s " - f"(student_vocab={student_vocab_size}, teacher_vocab={teacher_vocab_size}, " - f"unique_canonical={len(canon_str_to_id)})", flush=True) - - def align_fast(self, student_ids, teacher_ids, - exact_match_score=3, - combination_score_multiplier=1.5, - gap_penalty=-1.5, - chunk_size=128, - post_process=True, - anchor_lengths=[3,], - ignore_leading_char_diff=False): - """Fast alignment using precomputed canonical ID maps. - - Skips convert_ids_to_tokens and _canonicalize_sequence by looking up - canonical strings directly from token IDs via precomputed numpy arrays. - Falls back to regular align() if precomputed maps are not available. - """ - if self._student_canon_map is None: - return self.align(student_ids, teacher_ids, - exact_match_score=exact_match_score, - combination_score_multiplier=combination_score_multiplier, - gap_penalty=gap_penalty, - chunk_size=chunk_size, - post_process=post_process, - anchor_lengths=anchor_lengths, - ignore_leading_char_diff=ignore_leading_char_diff) - - if isinstance(student_ids, torch.Tensor): - student_ids = student_ids.cpu().numpy() - if isinstance(teacher_ids, torch.Tensor): - teacher_ids = teacher_ids.cpu().numpy() - - if student_ids.ndim == 1: - student_ids = student_ids[np.newaxis, :] - teacher_ids = teacher_ids[np.newaxis, :] - - import time as _time - _t_lookup_total = 0.0 - _t_anchors_dp_total = 0.0 - _t_postprocess_total = 0.0 - _t_mask_total = 0.0 - - all_aligned_pairs = [] - for i in range(student_ids.shape[0]): - s_ids = student_ids[i] - t_ids = teacher_ids[i] - - _tl0 = _time.time() - s_canon_strs = [self._canon_id_to_str[self._student_canon_map[tid]] for tid in s_ids] - t_canon_strs = [self._canon_id_to_str[self._teacher_canon_map[tid]] for tid in t_ids] - _tl1 = _time.time() - _t_lookup_total += _tl1 - _tl0 - - align_kwargs = { - 'exact_match_score': exact_match_score, - 'combination_score_multiplier': combination_score_multiplier, - 'gap_penalty': gap_penalty, - 'max_combination_len': self.max_combination_len, - 'ignore_leading_char_diff': False, - 'chunk_size': chunk_size, - 'anchor_lengths': anchor_lengths, - } - - aligned_pairs, _ = self._align_with_anchors(s_canon_strs, t_canon_strs, **align_kwargs) - _tl2 = _time.time() - _t_anchors_dp_total += _tl2 - _tl1 - - if post_process: - aligned_pairs = self.post_process_alignment_optimized( - aligned_pairs, - ignore_leading_char_diff=ignore_leading_char_diff, - exact_match_score=exact_match_score, - combination_score_multiplier=combination_score_multiplier, - gap_penalty=gap_penalty, - max_combination_len=self.max_combination_len - ) - _tl3 = _time.time() - _t_postprocess_total += _tl3 - _tl2 - - mask = self.get_alignment_mask(aligned_pairs, use_canonicalization=True, - ignore_leading_char_diff=ignore_leading_char_diff) - aligned_pairs = [ - (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, mask_value) - for (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end), mask_value - in zip(aligned_pairs, mask) - ] - _tl4 = _time.time() - _t_mask_total += _tl4 - _tl3 - - all_aligned_pairs.append(aligned_pairs) - - n = student_ids.shape[0] - _t_total = _t_lookup_total + _t_anchors_dp_total + _t_postprocess_total + _t_mask_total - if _t_total > 0.5 or n > 1: - print(f" [align_fast timing] lookup={_t_lookup_total:.3f}s, " - f"anchors+DP={_t_anchors_dp_total:.3f}s, " - f"postprocess={_t_postprocess_total:.3f}s, " - f"mask={_t_mask_total:.3f}s, " - f"total={_t_total:.3f}s (n={n})", flush=True) - - return all_aligned_pairs - def _load_logits_projection_map( self, folder_location: str = "cross_tokenizer_data", @@ -1300,7 +1076,7 @@ def _perform_dp_alignment(self, seq1, seq2, **kwargs): if chunk_size > 0: return self.align_tokens_combinations_chunked(seq1, seq2, chunk_size=chunk_size, **kwargs) else: - return self.align_tokens_with_combinations_numpy_jit(seq1, seq2, **kwargs) + return self.align_tokens_with_combinations_numpy(seq1, seq2, **kwargs) @staticmethod def _canonical_token(token: str) -> str: @@ -1682,99 +1458,6 @@ def align_tokens_with_combinations_numpy(seq1, seq2, aligned.reverse() return aligned, dp[n1, n2] - @staticmethod - def align_tokens_with_combinations_numpy_jit( - seq1, seq2, - exact_match_score=3, - combination_score_multiplier=1.5, - gap_penalty=-1.5, - max_combination_len=4, - ignore_leading_char_diff=False, - ): - """Numba-accelerated version of align_tokens_with_combinations_numpy. - - Pre-converts string tokens to integer IDs, runs the DP in a Numba - @njit kernel, then backtracks using the original string tokens. - Falls back to the pure-Python original when Numba is unavailable or - when ignore_leading_char_diff is True (requires Python string logic). - """ - if not _NUMBA_AVAILABLE or ignore_leading_char_diff: - return TokenAligner.align_tokens_with_combinations_numpy( - seq1, seq2, exact_match_score, combination_score_multiplier, - gap_penalty, max_combination_len, ignore_leading_char_diff, - ) - - n1, n2 = len(seq1), len(seq2) - if n1 == 0 and n2 == 0: - return [], 0.0 - if n1 == 0: - return [([], [seq2[j]], -1, -1, j, j + 1) for j in range(n2)], n2 * gap_penalty - if n2 == 0: - return [([seq1[i]], [], i, i + 1, -1, -1) for i in range(n1)], n1 * gap_penalty - - token_to_id: dict[str, int] = {} - _next_id = [0] - - def _get_id(s: str) -> int: - tid = token_to_id.get(s) - if tid is None: - tid = _next_id[0] - token_to_id[s] = tid - _next_id[0] += 1 - return tid - - ids1 = np.array([_get_id(t) for t in seq1], dtype=np.int64) - ids2 = np.array([_get_id(t) for t in seq2], dtype=np.int64) - - INVALID = np.int64(-1) - joined1 = np.full((n1 + 1, max_combination_len + 1), INVALID, dtype=np.int64) - for i in range(n1 + 1): - for k in range(2, min(i, max_combination_len) + 1): - joined1[i, k] = _get_id(''.join(seq1[i - k:i])) - - joined2 = np.full((n2 + 1, max_combination_len + 1), INVALID, dtype=np.int64) - for j in range(n2 + 1): - for k in range(2, min(j, max_combination_len) + 1): - joined2[j, k] = _get_id(''.join(seq2[j - k:j])) - - dp, trace = _dp_core_numba( - ids1, ids2, joined1, joined2, n1, n2, - np.float32(exact_match_score), - np.float32(gap_penalty), - np.float32(combination_score_multiplier), - max_combination_len, - ) - - aligned = [] - i, j = n1, n2 - while i > 0 or j > 0: - m = trace[i, j] - if m == 1: - aligned.append(([seq1[i - 1]], [seq2[j - 1]], i - 1, i, j - 1, j)) - i -= 1 - j -= 1 - elif m == 2: - aligned.append(([seq1[i - 1]], [], i - 1, i, -1, -1)) - i -= 1 - elif m == 3: - aligned.append(([], [seq2[j - 1]], -1, -1, j - 1, j)) - j -= 1 - elif 10 <= m < 20: - k = m - 10 - aligned.append(([seq1[i - 1]], seq2[j - k:j], i - 1, i, j - k, j)) - i -= 1 - j -= k - elif 20 <= m < 30: - k = m - 20 - aligned.append((seq1[i - k:i], [seq2[j - 1]], i - k, i, j - 1, j)) - i -= k - j -= 1 - else: - break - - aligned.reverse() - return aligned, float(dp[n1, n2]) - @staticmethod def align_tokens_combinations_chunked( seq1: List[str], @@ -1793,7 +1476,7 @@ def align_tokens_combinations_chunked( # If sequences are small enough, use regular algorithm if n1 <= chunk_size and n2 <= chunk_size: - return TokenAligner.align_tokens_with_combinations_numpy_jit( + return TokenAligner.align_tokens_with_combinations_numpy( seq1, seq2, exact_match_score, combination_score_multiplier, gap_penalty, max_combination_len, ignore_leading_char_diff ) @@ -1832,7 +1515,7 @@ def align_tokens_combinations_chunked( return combined_aligned, combined_score # Fallback to regular algorithm - return TokenAligner.align_tokens_with_combinations_numpy_jit( + return TokenAligner.align_tokens_with_combinations_numpy( seq1, seq2, exact_match_score, combination_score_multiplier, gap_penalty, max_combination_len, ignore_leading_char_diff )