From c492ac13eb540ab139556eeacad7aaa7a91aa13c Mon Sep 17 00:00:00 2001 From: Adithyakrishna Hanasoge Date: Sun, 26 Apr 2026 19:55:32 -0700 Subject: [PATCH 1/4] feat: add TokenAligner and cross-tokenizer projection utilities MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Foundational library code for cross-tokenizer distillation. No algorithm or training-loop integration yet — those follow in subsequent PRs. - nemo_rl/algorithms/x_token/tokenalign.py: TokenAligner(nn.Module) with Numba-accelerated DP alignment, projection-matrix loading (dense and sparse COO), and the project_token_likelihoods_instance forward path used by the cross-tokenizer loss. - nemo_rl/algorithms/x_token/__init__.py: package init. - nemo_rl/utils/x_token/{minimal_projection_generator, minimal_projection_via_multitoken,reapply_exact_map, sort_and_cut_projection_matrix}.py: standalone CLI scripts (argparse-driven, __main__ entrypoints) for one-time projection-matrix preparation. Not on the training import path. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Adithyakrishna Hanasoge --- nemo_rl/algorithms/x_token/__init__.py | 13 + nemo_rl/algorithms/x_token/tokenalign.py | 2194 +++++++++++++++++ nemo_rl/utils/x_token/__init__.py | 13 + .../x_token/minimal_projection_generator.py | 585 +++++ .../minimal_projection_via_multitoken.py | 942 +++++++ nemo_rl/utils/x_token/reapply_exact_map.py | 246 ++ .../x_token/sort_and_cut_projection_matrix.py | 451 ++++ 7 files changed, 4444 insertions(+) create mode 100644 nemo_rl/algorithms/x_token/__init__.py create mode 100644 nemo_rl/algorithms/x_token/tokenalign.py create mode 100644 nemo_rl/utils/x_token/__init__.py create mode 100644 nemo_rl/utils/x_token/minimal_projection_generator.py create mode 100644 nemo_rl/utils/x_token/minimal_projection_via_multitoken.py create mode 100644 nemo_rl/utils/x_token/reapply_exact_map.py create mode 100644 nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py diff --git a/nemo_rl/algorithms/x_token/__init__.py b/nemo_rl/algorithms/x_token/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/nemo_rl/algorithms/x_token/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_rl/algorithms/x_token/tokenalign.py b/nemo_rl/algorithms/x_token/tokenalign.py new file mode 100644 index 0000000000..1d55f858f7 --- /dev/null +++ b/nemo_rl/algorithms/x_token/tokenalign.py @@ -0,0 +1,2194 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoTokenizer + +try: + from numba import njit + _NUMBA_AVAILABLE = True +except ImportError: + _NUMBA_AVAILABLE = False + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +if _NUMBA_AVAILABLE: + @njit(cache=True) + def _dp_core_numba(ids1, ids2, joined1, joined2, n1, n2, + exact_match_score, gap_penalty, comb_mul, max_comb_len): + """Numba-accelerated DP core for token alignment. + + Uses the same algorithm as align_tokens_with_combinations_numpy but + with integer ID comparisons instead of Python string operations. + + Trace codes: 0=start, 1=diag, 2=up, 3=left, + 10+k = comb_s1_over_s2_k, 20+k = comb_s2_over_s1_k + """ + INVALID = np.int64(-1) + dp = np.zeros((n1 + 1, n2 + 1), dtype=np.float32) + trace = np.zeros((n1 + 1, n2 + 1), dtype=np.int32) + + for i in range(1, n1 + 1): + dp[i, 0] = dp[i - 1, 0] + gap_penalty + trace[i, 0] = 2 + for j in range(1, n2 + 1): + dp[0, j] = dp[0, j - 1] + gap_penalty + trace[0, j] = 3 + + for i in range(1, n1 + 1): + id_i = ids1[i - 1] + for j in range(1, n2 + 1): + id_j = ids2[j - 1] + + if id_i == id_j: + best = dp[i - 1, j - 1] + exact_match_score + else: + best = dp[i - 1, j - 1] - exact_match_score + best_m = np.int32(1) + + s = dp[i - 1, j] + gap_penalty + if s > best: + best = s + best_m = np.int32(2) + + s = dp[i, j - 1] + gap_penalty + if s > best: + best = s + best_m = np.int32(3) + + k_max_s2 = min(j, max_comb_len) + for k in range(2, k_max_s2 + 1): + jid = joined2[j, k] + if jid != INVALID and id_i == jid: + s = dp[i - 1, j - k] + comb_mul * np.float32(k) + if s > best: + best = s + best_m = np.int32(10 + k) + + k_max_s1 = min(i, max_comb_len) + for k in range(2, k_max_s1 + 1): + jid = joined1[i, k] + if jid != INVALID and id_j == jid: + s = dp[i - k, j - 1] + comb_mul * np.float32(k) + if s > best: + best = s + best_m = np.int32(20 + k) + + dp[i, j] = best + trace[i, j] = best_m + + return dp, trace +else: + _dp_core_numba = None + + +@dataclass(frozen=True) +class VocabPartition: + """Projection-matrix-derived vocab partition for the gold-loss path. + + Built once per (xtoken_loss, teacher_vocab_size) by + `TokenAligner.build_vocab_partition`. All tensors are long, 1-D, and + live on the aligner's projection-matrix device. + """ + + common_student_indices: torch.Tensor + common_teacher_indices: torch.Tensor + uncommon_student_indices: torch.Tensor + uncommon_teacher_indices: torch.Tensor + + +class TokenAligner(nn.Module): + def __init__(self, max_comb_len=4, teacher_tokenizer_name=None, student_tokenizer_name=None, init_hf_tokenizers=True, projection_matrix_multiplier=1.0, enable_scale_trick=None): + super().__init__() + self.teacher_tokenizer_name = teacher_tokenizer_name + self.student_tokenizer_name = student_tokenizer_name + self.projection_matrix_multiplier = projection_matrix_multiplier + self.enable_scale_trick = enable_scale_trick + + if init_hf_tokenizers: + self.teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_tokenizer_name) + self.student_tokenizer = AutoTokenizer.from_pretrained(student_tokenizer_name) + if self.teacher_tokenizer.pad_token is None: + self.teacher_tokenizer.pad_token = self.teacher_tokenizer.eos_token + if self.student_tokenizer.pad_token is None: + self.student_tokenizer.pad_token = self.student_tokenizer.eos_token + else: + self.teacher_tokenizer = None + self.student_tokenizer = None + + self.max_combination_len = max_comb_len + self.sparse_transformation_matrix = None + # Cached CSR for dense top-k projection (built from indices/values) to avoid scatter path + self._dense_proj_csr = None + self._dense_proj_csr_device = None + + # Precomputed canonical ID maps (built by precompute_canonical_maps) + self._student_canon_map = None + self._teacher_canon_map = None + self._canon_id_to_str = None + + # Cached gold-loss vocab partitions keyed by (xtoken_loss, teacher_vocab_size). + # Populated lazily by build_vocab_partition(); bypassed when learnable=True. + self._vocab_partition_cache: dict[tuple[bool, int], VocabPartition] = {} + + @torch.no_grad() + def build_vocab_partition( + self, xtoken_loss: bool, teacher_vocab_size: int + ) -> VocabPartition: + """Derive (and cache) the gold-loss vocab partition. + + Computes exactly the state that CrossTokenizerDistillationLossFn._compute_gold_loss + previously rebuilt every microbatch: the set of student/teacher tokens that have an + exact 1:1 projection (common) and the complement (uncommon). Cached by + (xtoken_loss, teacher_vocab_size) so repeat calls cost nothing. Caching is skipped + when the projection matrix is learnable, since the partition then depends on + gradient-updated values. + """ + if ( + not hasattr(self, "likelihood_projection_indices") + or self.likelihood_projection_indices is None + ): + raise ValueError( + "build_vocab_partition requires likelihood_projection_indices to be loaded" + ) + + learnable = bool(getattr(self, "learnable", False)) + key = (bool(xtoken_loss), int(teacher_vocab_size)) + if not learnable: + cached = self._vocab_partition_cache.get(key) + if cached is not None: + return cached + + projection_indices = self.likelihood_projection_indices + projection_matrix = ( + self.transform_learned_matrix_instance(self.likelihood_projection_matrix) + if learnable + else self.likelihood_projection_matrix + ) + device = projection_matrix.device + student_vocab_size = int(projection_matrix.shape[0]) + + sorted_values, sorted_indices_in_topk = torch.sort( + projection_matrix, dim=-1, descending=True + ) + + if xtoken_loss: + has_exact_map = sorted_values[:, 0] >= 0.6 + else: + has_exact_map = (sorted_values[:, 0] == 1.0) & (projection_indices[:, 1] == -1) + + student_indices_with_exact_map = torch.where(has_exact_map)[0] + teacher_indices_for_exact_map = projection_indices[ + student_indices_with_exact_map, + sorted_indices_in_topk[student_indices_with_exact_map, 0], + ] + + student_to_teacher_exact_map: dict[int, int] = {} + teacher_to_student_exact_map: dict[int, int] = {} + for s_idx, t_idx in zip( + student_indices_with_exact_map.tolist(), + teacher_indices_for_exact_map.tolist(), + ): + if 0 <= t_idx < teacher_vocab_size: + if t_idx not in teacher_to_student_exact_map or xtoken_loss: + if t_idx in teacher_to_student_exact_map: + prev_student_token = teacher_to_student_exact_map[t_idx] + if sorted_values[prev_student_token, 0] >= sorted_values[s_idx, 0]: + continue + del student_to_teacher_exact_map[prev_student_token] + student_to_teacher_exact_map[s_idx] = t_idx + teacher_to_student_exact_map[t_idx] = s_idx + + common_student = sorted(student_to_teacher_exact_map.keys()) + common_teacher = [student_to_teacher_exact_map[s] for s in common_student] + uncommon_student = sorted( + set(range(student_vocab_size)) - set(common_student) + ) + uncommon_teacher = sorted( + set(range(teacher_vocab_size)) - set(common_teacher) + ) + + partition = VocabPartition( + common_student_indices=torch.tensor( + common_student, dtype=torch.long, device=device + ), + common_teacher_indices=torch.tensor( + common_teacher, dtype=torch.long, device=device + ), + uncommon_student_indices=torch.tensor( + uncommon_student, dtype=torch.long, device=device + ), + uncommon_teacher_indices=torch.tensor( + uncommon_teacher, dtype=torch.long, device=device + ), + ) + + if not learnable: + self._vocab_partition_cache[key] = partition + return partition + + def precompute_canonical_maps(self): + """Build token_id → canonical_string lookup tables for both tokenizers. + + Call once at startup. After this, align_fast() can skip + convert_ids_to_tokens and _canonicalize_sequence entirely. + """ + import time as _time + _t0 = _time.time() + + canon_str_to_id: dict[str, int] = {} + next_id = [0] + + def _get_canon_id(s: str) -> int: + cid = canon_str_to_id.get(s) + if cid is None: + cid = next_id[0] + canon_str_to_id[s] = cid + next_id[0] += 1 + return cid + + student_vocab_size = len(self.student_tokenizer) + teacher_vocab_size = len(self.teacher_tokenizer) + + student_map = np.zeros(student_vocab_size, dtype=np.int64) + for tid in range(student_vocab_size): + tok = self.student_tokenizer.convert_ids_to_tokens(tid) + canon = self._canonical_token(tok) + student_map[tid] = _get_canon_id(canon) + + teacher_map = np.zeros(teacher_vocab_size, dtype=np.int64) + for tid in range(teacher_vocab_size): + tok = self.teacher_tokenizer.convert_ids_to_tokens(tid) + canon = self._canonical_token(tok) + teacher_map[tid] = _get_canon_id(canon) + + self._student_canon_map = student_map + self._teacher_canon_map = teacher_map + self._canon_id_to_str = {v: k for k, v in canon_str_to_id.items()} + + _t1 = _time.time() + print(f" [TokenAligner] Precomputed canonical maps in {_t1-_t0:.2f}s " + f"(student_vocab={student_vocab_size}, teacher_vocab={teacher_vocab_size}, " + f"unique_canonical={len(canon_str_to_id)})", flush=True) + + def align_fast(self, student_ids, teacher_ids, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + chunk_size=128, + post_process=True, + anchor_lengths=[3,], + ignore_leading_char_diff=False): + """Fast alignment using precomputed canonical ID maps. + + Skips convert_ids_to_tokens and _canonicalize_sequence by looking up + canonical strings directly from token IDs via precomputed numpy arrays. + Falls back to regular align() if precomputed maps are not available. + """ + if self._student_canon_map is None: + return self.align(student_ids, teacher_ids, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + chunk_size=chunk_size, + post_process=post_process, + anchor_lengths=anchor_lengths, + ignore_leading_char_diff=ignore_leading_char_diff) + + if isinstance(student_ids, torch.Tensor): + student_ids = student_ids.cpu().numpy() + if isinstance(teacher_ids, torch.Tensor): + teacher_ids = teacher_ids.cpu().numpy() + + if student_ids.ndim == 1: + student_ids = student_ids[np.newaxis, :] + teacher_ids = teacher_ids[np.newaxis, :] + + import time as _time + _t_lookup_total = 0.0 + _t_anchors_dp_total = 0.0 + _t_postprocess_total = 0.0 + _t_mask_total = 0.0 + + all_aligned_pairs = [] + for i in range(student_ids.shape[0]): + s_ids = student_ids[i] + t_ids = teacher_ids[i] + + _tl0 = _time.time() + s_canon_strs = [self._canon_id_to_str[self._student_canon_map[tid]] for tid in s_ids] + t_canon_strs = [self._canon_id_to_str[self._teacher_canon_map[tid]] for tid in t_ids] + _tl1 = _time.time() + _t_lookup_total += _tl1 - _tl0 + + align_kwargs = { + 'exact_match_score': exact_match_score, + 'combination_score_multiplier': combination_score_multiplier, + 'gap_penalty': gap_penalty, + 'max_combination_len': self.max_combination_len, + 'ignore_leading_char_diff': False, + 'chunk_size': chunk_size, + 'anchor_lengths': anchor_lengths, + } + + aligned_pairs, _ = self._align_with_anchors(s_canon_strs, t_canon_strs, **align_kwargs) + _tl2 = _time.time() + _t_anchors_dp_total += _tl2 - _tl1 + + if post_process: + aligned_pairs = self.post_process_alignment_optimized( + aligned_pairs, + ignore_leading_char_diff=ignore_leading_char_diff, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=self.max_combination_len + ) + _tl3 = _time.time() + _t_postprocess_total += _tl3 - _tl2 + + mask = self.get_alignment_mask(aligned_pairs, use_canonicalization=True, + ignore_leading_char_diff=ignore_leading_char_diff) + aligned_pairs = [ + (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, mask_value) + for (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end), mask_value + in zip(aligned_pairs, mask) + ] + _tl4 = _time.time() + _t_mask_total += _tl4 - _tl3 + + all_aligned_pairs.append(aligned_pairs) + + n = student_ids.shape[0] + _t_total = _t_lookup_total + _t_anchors_dp_total + _t_postprocess_total + _t_mask_total + if _t_total > 0.5 or n > 1: + print(f" [align_fast timing] lookup={_t_lookup_total:.3f}s, " + f"anchors+DP={_t_anchors_dp_total:.3f}s, " + f"postprocess={_t_postprocess_total:.3f}s, " + f"mask={_t_mask_total:.3f}s, " + f"total={_t_total:.3f}s (n={n})", flush=True) + + return all_aligned_pairs + + def _load_logits_projection_map( + self, + folder_location: str = "cross_tokenizer_data", + file_path: str = None, + top_k: int = 100, + device: str = "cuda", + use_sparse_format: bool = False, + learnable: bool = False, + ): + """ + Load projection map for cross-tokenizer likelihood projection. + Always creates student→teacher mapping. + + Args: + folder_location: Directory containing the projection files + file_path: Specific file path (overrides folder_location) + top_k: Number of top entries per row (only used for old format) + device: Device to load tensors on + use_sparse_format: If True, load sparse transformation matrix format (from multi-token mapping) + If False, load old dense indices/values format + learnable: If True, make the transformation matrix learnable + """ + self.learnable = learnable + if use_sparse_format: + # Load sparse transformation matrix format + if file_path is None: + file_path = f"{folder_location}/transformation_counts_via_multitoken.pt" + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Sparse transformation matrix file not found: {file_path}. Please generate it first.") + + # Load transformation counts dictionary + transformation_counts = torch.load(file_path, map_location='cpu', weights_only=False) + + # Get tokenizer vocab sizes + teacher_vocab_size = len(self.teacher_tokenizer) if self.teacher_tokenizer else 151669 # fallback + student_vocab_size = len(self.student_tokenizer) if self.student_tokenizer else 128256 # fallback + if 1: + # get vocab sizes from autoconfig + if "gemma" not in self.teacher_tokenizer_name.lower() and "qwen3.5" not in self.teacher_tokenizer_name.lower(): + teacher_vocab_size = AutoConfig.from_pretrained(self.teacher_tokenizer_name).vocab_size + else: + teacher_vocab_size = AutoConfig.from_pretrained(self.teacher_tokenizer_name).text_config.vocab_size + if "gemma" not in self.student_tokenizer_name.lower() and "qwen3.5" not in self.student_tokenizer_name.lower(): + student_vocab_size = AutoConfig.from_pretrained(self.student_tokenizer_name).vocab_size + else: + student_vocab_size = AutoConfig.from_pretrained(self.student_tokenizer_name).text_config.vocab_size + # teacher_vocab_size = AutoConfig.from_pretrained(self.teacher_tokenizer_name).vocab_size + # student_vocab_size = AutoConfig.from_pretrained(self.student_tokenizer_name).vocab_size + + + # Debug vocab sizes + print(f"Teacher vocab size: {teacher_vocab_size}, Student vocab size: {student_vocab_size}") + + # Convert dictionary to sparse tensor + if transformation_counts: + + + indices = list(transformation_counts.keys()) + values = list(transformation_counts.values()) + + student_indices = [idx[0] for idx in indices] + teacher_indices = [idx[1] for idx in indices] + + # Always create student→teacher mapping: rows = student vocab, cols = teacher vocab + indices_tensor = torch.LongTensor([student_indices, teacher_indices]) + values_tensor = torch.FloatTensor(values)/self.projection_matrix_multiplier + matrix_shape = (student_vocab_size, teacher_vocab_size) + + print(f"Creating sparse matrix: student→teacher ({student_vocab_size} x {teacher_vocab_size})") + + sparse_transformation_matrix = torch.sparse_coo_tensor( + indices_tensor, + values_tensor, + (student_vocab_size, teacher_vocab_size), # student_vocab × teacher_vocab + device=device, + dtype=torch.float32 + ) + + # Optionally make the sparse matrix learnable (values only) + if learnable: + self.sparse_transformation_matrix = nn.Parameter( + sparse_transformation_matrix.coalesce(), requires_grad=True + ) + else: + # Register as buffer for non-learnable parameters (ensures proper device handling) + self.register_buffer('sparse_transformation_matrix', + sparse_transformation_matrix.coalesce(), + persistent=True) + + # Store a flag for downstream code + self.is_sparse_learnable = learnable + print(f"Loaded sparse transformation matrix with {len(transformation_counts)} entries") + else: + # Empty transformation matrix (student→teacher) + matrix_shape = (student_vocab_size, teacher_vocab_size) + + empty_sparse = torch.sparse_coo_tensor( + torch.zeros(2, 0, dtype=torch.long), + torch.zeros(0, dtype=torch.float32), + matrix_shape, + device=device, + ) + + if learnable: + self.sparse_transformation_matrix = nn.Parameter(empty_sparse, requires_grad=True) + else: + # Register as buffer for non-learnable parameters + self.register_buffer('sparse_transformation_matrix', empty_sparse, persistent=True) + + self.is_sparse_learnable = learnable + print("Warning: Empty transformation matrix loaded") + else: + # Load old dense indices/values format + if file_path is None: + file_path = f"{folder_location}/projection_map_Llama-3.1_to_Qwen3_bidirectional_top_10.pt" + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Projection map file not found: {file_path}. Please generate it first.") + + projection_data = torch.load(file_path, map_location='cpu', weights_only=False) + # Always use B_to_A direction for student->teacher projection + # projection_data = projection_data["B_to_A"] + # projection_data = projection_data["A_to_B"] + + indices = projection_data["indices"] + likelihoods = projection_data["likelihoods"]/self.projection_matrix_multiplier + + # Register indices as buffer (always non-learnable) + self.register_buffer('likelihood_projection_indices', indices.to(device), persistent=True) + if learnable: + if 1: + likelihoods = (likelihoods+1e-10).log() + + # Use instance variable if set, otherwise use default (False) + # scale_trick_enabled = self.enable_scale_trick if self.enable_scale_trick is not None else False + + # if scale_trick_enabled: + # #trick with last column being multiplier - set to -4.0 + # likelihoods[:,-1] = likelihoods[:,-1]*0.0 - 4.0 + #lets introduce some noise to encourage training. will remove later. + if 0: + likelihoods = likelihoods + torch.randn_like(likelihoods) * 1e-1 + likelihoods = likelihoods/2.0 + + self.likelihood_projection_matrix = nn.Parameter(likelihoods.to(device), requires_grad=True) + # print(self.likelihood_projection_matrix[0]) + # print(self.likelihood_projection_matrix[:,-1]) + # exit() + #add small gaussian noise to the projection matrix + #use log form + else: + # Register as buffer for non-learnable parameters + self.register_buffer('likelihood_projection_matrix', likelihoods.to(device), persistent=True) + + + print(f"Loaded dense projection map with shape {indices.shape}") + # Invalidate cached CSR; will rebuild on first use + self._dense_proj_csr = None + self._dense_proj_csr_device = None + + def create_reverse_projection_matrix(self, device="cuda"): + """ + Create a reverse (transposed) projection matrix for teacher→student projection. + + For sparse format: Transposes the sparse_transformation_matrix from [student_vocab, teacher_vocab] + to [teacher_vocab, student_vocab] + For dense format: Builds a reverse index mapping from teacher tokens to student tokens + + This enables projecting teacher logits into student vocabulary space. + """ + if hasattr(self, 'sparse_transformation_matrix') and self.sparse_transformation_matrix is not None: + # Transpose sparse matrix + print("Creating reverse projection matrix (sparse format): teacher→student") + sparse_matrix = self.sparse_transformation_matrix.coalesce() + indices = sparse_matrix.indices() + values = sparse_matrix.values() + + # Swap student and teacher indices (transpose) + transposed_indices = torch.stack([indices[1], indices[0]], dim=0) # Swap rows: [teacher, student] + teacher_vocab_size, student_vocab_size = sparse_matrix.shape[1], sparse_matrix.shape[0] + + reverse_sparse = torch.sparse_coo_tensor( + transposed_indices, + values, + (teacher_vocab_size, student_vocab_size), + device=device, + dtype=torch.float32 + ).coalesce() + + # Store as buffer or parameter based on learnability + if self.is_sparse_learnable: + self.reverse_sparse_transformation_matrix = nn.Parameter(reverse_sparse, requires_grad=True) + else: + self.register_buffer('reverse_sparse_transformation_matrix', reverse_sparse, persistent=True) + + print(f"Created reverse sparse matrix: teacher→student ({teacher_vocab_size} x {student_vocab_size})") + print(f"Reverse matrix has {len(values)} non-zero entries") + + elif hasattr(self, 'likelihood_projection_indices') and self.likelihood_projection_indices is not None: + # Build reverse index for dense format + print("Creating reverse projection matrix (dense format): teacher→student") + + # Current: likelihood_projection_indices is [student_vocab, topk] + # We need to build: [teacher_vocab, variable_k] where variable_k depends on how many students map to each teacher token + + student_vocab_size = self.likelihood_projection_indices.shape[0] + topk = self.likelihood_projection_indices.shape[1] + + # Infer teacher vocab size from the max index + teacher_vocab_size = self.likelihood_projection_indices.max().item() + 1 + + # Build reverse mapping: for each teacher token, collect all (student_token, value) pairs + from collections import defaultdict + teacher_to_students = defaultdict(list) + + for student_idx in range(student_vocab_size): + for k in range(topk): + teacher_idx = self.likelihood_projection_indices[student_idx, k].item() + if hasattr(self, 'likelihood_projection_matrix'): + value = self.likelihood_projection_matrix[student_idx, k].item() + else: + value = 1.0 # Default value if no matrix + + # Check for valid entries: teacher_idx must be valid, and value must be finite (not -inf) + # If matrix is in log-space, valid log-probs are finite negative values + # Threshold at -20 to filter out padding values like -22.3197 + if teacher_idx >= 0 and value > -20.0: # Skip invalid or padding entries + teacher_to_students[teacher_idx].append((student_idx, value)) + + # Find max number of students mapping to any teacher token + raw_max_students = max([len(v) for v in teacher_to_students.values()]) if teacher_to_students else 1 + print(f"Max students mapping to any teacher token (before filtering): {raw_max_students}") + + # Limit to top-K students per teacher token to avoid explosion + # Keep only the top-K highest probability mappings per teacher + max_students_per_teacher = min(topk, raw_max_students) # Use same topk as forward direction + print(f"Limiting to top-{max_students_per_teacher} students per teacher token") + + # Sort each teacher's student list by value (descending) and keep only top-K + for teacher_idx in teacher_to_students: + student_list = teacher_to_students[teacher_idx] + # Sort by value (descending - higher log-prob = less negative) + student_list_sorted = sorted(student_list, key=lambda x: x[1], reverse=True) + teacher_to_students[teacher_idx] = student_list_sorted[:max_students_per_teacher] + + # Create dense reverse index [teacher_vocab, max_students_per_teacher] + # Use 0 instead of -1 for padding (valid index), with very negative values to nullify contribution + reverse_indices = torch.zeros((teacher_vocab_size, max_students_per_teacher), + dtype=torch.long, device=device) + # Initialize with very negative values (padding sentinel, similar to forward direction) + reverse_values = torch.full((teacher_vocab_size, max_students_per_teacher), -22.3197, + dtype=torch.float32, device=device) + + for teacher_idx, student_list in teacher_to_students.items(): + for k, (student_idx, value) in enumerate(student_list): + reverse_indices[teacher_idx, k] = student_idx + reverse_values[teacher_idx, k] = value + + print(f"Created reverse dense projection: teacher→student ({teacher_vocab_size} x {max_students_per_teacher})") + + # Store as buffer or parameter + self.register_buffer('reverse_likelihood_projection_indices', reverse_indices, persistent=True) + if self.learnable: + self.reverse_likelihood_projection_matrix = nn.Parameter(reverse_values, requires_grad=True) + else: + self.register_buffer('reverse_likelihood_projection_matrix', reverse_values, persistent=True) + + print(f"Created reverse dense projection: teacher→student ({teacher_vocab_size} x {max_students_per_teacher})") + else: + raise ValueError("No projection matrix loaded. Cannot create reverse projection.") + + def project_token_likelihoods_instance(self, input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device, use_sparse_format=False, sparse_matrix=None, use_vectorized=True, gpu_optimized_scatter=True, global_top_indices=None): + """ + Instance method wrapper for project_token_likelihoods that can access instance variables. + + Args: + global_top_indices: Optional tensor of shape (K,) containing indices of tokens to project to. + If provided, only projects to these K tokens instead of full target_vocab_size. + Results in (batch, seq, K) output instead of (batch, seq, target_vocab_size). + """ + if use_sparse_format: + if sparse_matrix is None: + raise ValueError("sparse_matrix must be provided when use_sparse_format=True") + + if global_top_indices is not None: + # For sparse format with global_top_indices, project to full vocab then slice + full_projection = TokenAligner.project_token_likelihoods_sparse(input_likelihoods, sparse_matrix*self.projection_matrix_multiplier, device) + return full_projection[:, :, global_top_indices] + else: + return TokenAligner.project_token_likelihoods_sparse(input_likelihoods, sparse_matrix*self.projection_matrix_multiplier, device) + else: + # If projection map is learnable, fall back to dense scatter path to preserve gradients + if getattr(projection_map_values, "requires_grad", False): + scale_trick_enabled = self.enable_scale_trick if self.enable_scale_trick is not None else False + return TokenAligner.project_token_likelihoods_dense( + input_likelihoods, + projection_map_indices, + projection_map_values * self.projection_matrix_multiplier, + target_vocab_size, + device, + use_vectorized=True, + gpu_optimized_scatter=gpu_optimized_scatter, + enable_scale_trick=scale_trick_enabled, + global_top_indices=global_top_indices, + ) + + # Otherwise, use stateless CSR matmul (no caching) for memory efficiency + vs = projection_map_indices.shape[0] + top_k = projection_map_indices.shape[1] + # Ensure device/dtype for indices/values + idx = projection_map_indices.to(device) + val = (projection_map_values * self.projection_matrix_multiplier).to(device) + if val.dtype != input_likelihoods.dtype: + val = val.to(input_likelihoods.dtype) + # Build CSR once per call outside autograd to keep checkpoint recomputation identical + with torch.no_grad(): + crow_indices = torch.arange(0, (vs + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = idx.reshape(-1) + values = val.reshape(-1) + # _exact_map_remapped matrices use -1 as a padding sentinel for missing + # entries. A -1 column index is illegal in a CSR tensor and causes a + # CUDA illegal-memory-access during the subsequent matmul. Clamp those + # entries to column 0 and zero their value so they contribute nothing. + pad_mask = col_indices < 0 + if pad_mask.any(): + col_indices = col_indices.clone() + col_indices[pad_mask] = 0 + values = values.clone() + values[pad_mask] = 0.0 + proj_csr = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(vs, target_vocab_size), device=device + ) + # Matmul: [B, S, Vs] -> [B*S, Vs] @ [Vs, Vt] -> [B*S, Vt] -> [B, S, Vt] + bsz, seqlen, vs_in = input_likelihoods.shape + if vs_in != vs: + # In case logits have extra vocab tail, slice to match + x = input_likelihoods[:, :, :vs] + else: + x = input_likelihoods + x2d = x.reshape(bsz * seqlen, vs) + out2d = torch.matmul(x2d.to(torch.float32), proj_csr.to(torch.float32)) + out = out2d.reshape(bsz, seqlen, target_vocab_size).to(input_likelihoods.dtype) + return out + + @staticmethod + def project_token_likelihoods_dense(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device, use_vectorized=True, gpu_optimized_scatter=True, enable_scale_trick=None, global_top_indices=None): + """ + Projects token likelihoods from a source to a target vocabulary using dense indices/values format. + + Args: + global_top_indices: Optional tensor of shape (K,) containing indices of target tokens to project to. + If provided, only projects to these K tokens instead of full target_vocab_size. + Results in (batch, seq, K) output instead of (batch, seq, target_vocab_size). + MAJOR SPEEDUP: Reduces both memory and compute significantly. + """ + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + if abs(source_vocab_size - projection_map_indices.shape[0]) > 1000: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches projection map size ({projection_map_indices.shape[0]})") + + top_k = projection_map_indices.shape[1] + input_likelihoods = input_likelihoods.to(device) + if projection_map_indices.device != device: + projection_map_indices = projection_map_indices.to(device) + if projection_map_values.device != device: + projection_map_values = projection_map_values.to(device) + #do for dtype + if projection_map_values.dtype != input_likelihoods.dtype: + projection_map_values = projection_map_values.to(input_likelihoods.dtype) + + # else: + # projection_map_values = projection_map_values.to(device) + + if use_vectorized: + # Solution 1: Efficient dense implementation using vectorized operations for small top_k + source_vocab_size_fixed = projection_map_indices.shape[0] + input_likelihoods_fixed = input_likelihoods[:, :, :source_vocab_size_fixed] + + # OPTIMIZATION: Use reduced vocabulary if global_top_indices provided + if global_top_indices is not None: + k_indices = len(global_top_indices) + global_top_indices = global_top_indices.to(device) + + # Create mapping from full target indices to reduced indices [0, 1, 2, ..., k-1] + full_to_reduced_map = torch.full((target_vocab_size,), -1, device=device, dtype=torch.long) + full_to_reduced_map[global_top_indices] = torch.arange(k_indices, device=device) + + # Initialize smaller output tensor - MAJOR MEMORY SAVINGS + projected_likelihoods = torch.zeros(batch_size, seq_len, k_indices, + device=device, dtype=input_likelihoods.dtype) + effective_vocab_size = k_indices + + # Filter projection matrices to only include mappings to global_top_indices + # This will be used in the scatter operations below + use_reduced_projection = True + else: + # Initialize full output tensor + projected_likelihoods = torch.zeros(batch_size, seq_len, target_vocab_size, + device=device, dtype=input_likelihoods.dtype) + effective_vocab_size = target_vocab_size + use_reduced_projection = False + + # Optimized chunked processing with multiple speedup techniques + # Use larger chunks for better amortization of fixed costs + max_memory_mb = 200 # Increased for better performance + # max_memory_mb = 500 # Increased for better performance + elements_per_chunk = max_memory_mb * 1024 * 1024 // 4 # 4 bytes per float32 + chunk_size = max(512, min(source_vocab_size_fixed, elements_per_chunk // (batch_size * seq_len))) + + + use_masking = False + # Process vocabulary in optimized chunks + for chunk_start in range(0, source_vocab_size_fixed, chunk_size): + chunk_end = min(chunk_start + chunk_size, source_vocab_size_fixed) + chunk_len = chunk_end - chunk_start + + + input_chunk = input_likelihoods_fixed[:, :, chunk_start:chunk_end] # (B, S, chunk_len) + indices_chunk = projection_map_indices[chunk_start:chunk_end, :] # (chunk_len, top_k) + values_chunk = projection_map_values[chunk_start:chunk_end, :] # (chunk_len, top_k) + + # Extract input chunk once per chunk (not per k) - major speedup + # Determine effective top_k (exclude last column if scale trick is enabled) + scale_trick_enabled = enable_scale_trick if enable_scale_trick is not None else False + effective_top_k = top_k - 1 if scale_trick_enabled else top_k + # effective_top_k = 1 + + if gpu_optimized_scatter: + if use_masking: + # Process one k at a time to reduce peak memory usage + for k in range(effective_top_k): + values_k = values_chunk[:, k] + valid_mask_k = values_k > 1e-4 + if not valid_mask_k.any(): + continue + + source_indices_k = torch.nonzero(valid_mask_k, as_tuple=True)[0] + + input_subset_k = input_chunk[:, :, source_indices_k] + values_subset_k = values_k[source_indices_k] + + indices_k = indices_chunk[:, k] + target_indices_subset_k = indices_k[source_indices_k] + + weighted_inputs_k = input_subset_k * values_subset_k.view(1, 1, -1) + expanded_target_indices_k = target_indices_subset_k.view(1, 1, -1).expand(batch_size, seq_len, -1) + + projected_likelihoods.scatter_add_(2, expanded_target_indices_k, weighted_inputs_k) + else: + # Compact, un-masked implementation + # Process only effective columns without creating intermediate tensors + input_expanded = input_chunk.unsqueeze(-1) # (B, S, chunk_len, 1) + + for k in range(effective_top_k): + values_k = values_chunk[:, k:k+1] # (chunk_len, 1) - view, no copy + indices_k = indices_chunk[:, k] # (chunk_len,) + + if use_reduced_projection: + # OPTIMIZATION: Only project to indices in global_top_indices + # Map full indices to reduced indices and filter out invalid ones + reduced_indices_k = full_to_reduced_map[indices_k] # (chunk_len,) + valid_mask = reduced_indices_k != -1 # Only keep indices in global_top_indices + + if not valid_mask.any(): + continue # Skip if no valid indices in this chunk + + # Filter to only valid entries - MAJOR COMPUTE SAVINGS + valid_indices = torch.nonzero(valid_mask, as_tuple=True)[0] + reduced_indices_filtered = reduced_indices_k[valid_indices] + values_filtered = values_k.squeeze(-1)[valid_indices] # (valid_count,) + input_filtered = input_chunk[:, :, valid_indices] # (B, S, valid_count) + + weighted_k = input_filtered * values_filtered.unsqueeze(0).unsqueeze(0) + indices_expanded = reduced_indices_filtered.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, -1) + projected_likelihoods.scatter_add_(2, indices_expanded, weighted_k) + else: + # Standard full projection + weighted_k = input_expanded * values_k.unsqueeze(0).unsqueeze(0) # (B, S, chunk_len, 1) + weighted_k = weighted_k.squeeze(-1) # (B, S, chunk_len) + + indices_expanded = indices_k.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, -1) + projected_likelihoods.scatter_add_(2, indices_expanded, weighted_k) + else: + # Original implementation with a loop over top_k + if True: # For small top_k, process all k together + # Broadcast input: (B, S, chunk_len, 1) * (1, 1, chunk_len, top_k) -> (B, S, chunk_len, top_k) + weighted_inputs = input_chunk.unsqueeze(-1) * values_chunk.unsqueeze(0).unsqueeze(0) + + # Process all k simultaneously using advanced indexing + for k in range(effective_top_k): + target_indices_k = indices_chunk[:, k] # (chunk_len,) + weighted_k = weighted_inputs[:, :, :, k] # (B, S, chunk_len) + + if use_reduced_projection: + # OPTIMIZATION: Only project to indices in global_top_indices + reduced_indices_k = full_to_reduced_map[target_indices_k] # (chunk_len,) + valid_mask = reduced_indices_k != -1 + + if not valid_mask.any(): + continue # Skip if no valid indices + + # Filter to only valid entries - MAJOR COMPUTE SAVINGS + valid_indices = torch.nonzero(valid_mask, as_tuple=True)[0] + reduced_indices_filtered = reduced_indices_k[valid_indices] + weighted_filtered = weighted_k[:, :, valid_indices] # (B, S, valid_count) + + target_expanded = reduced_indices_filtered.view(1, 1, -1).expand(batch_size, seq_len, len(valid_indices)) + projected_likelihoods.scatter_add_(2, target_expanded, weighted_filtered) + else: + # Use optimized scatter with pre-expanded indices (avoid .expand() in loop) + target_expanded = target_indices_k.view(1, 1, -1).expand(batch_size, seq_len, chunk_len) + projected_likelihoods.scatter_add_(2, target_expanded, weighted_k) + + # else: # For larger top_k, use optimized sequential processing + # for k in range(top_k): + # target_indices_k = indices_chunk[:, k] # (chunk_len,) + # target_values_k = values_chunk[:, k] # (chunk_len,) + + # # Skip projections marked with -1 + # valid_mask = target_values_k > -0.00001 + # if not valid_mask.any(): + # continue + + # # Only process valid projections + # valid_target_indices = target_indices_k[valid_mask] + # valid_target_values = target_values_k[valid_mask] + # valid_input = input_chunk[valid_mask] + + # weighted_input = valid_input * valid_target_values.view(-1, 1, 1) + + # # Direct scatter (simpler and often faster than index caching) + # target_expanded = valid_target_indices.view(1, 1, -1).expand(batch_size, seq_len, valid_target_indices.size(0)) + # projected_likelihoods.scatter_add_(2, target_expanded, weighted_input) + + return projected_likelihoods + else: + # Solution 2: Sparse matrix approach (original implementation) + source_vocab_size_fixed = projection_map_indices.shape[0] + + # Create sparse CSR matrix + crow_indices = torch.arange(0, (source_vocab_size_fixed + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = projection_map_indices.flatten() + values = projection_map_values.flatten() + + sparse_projection_matrix = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(source_vocab_size_fixed, target_vocab_size), device=device + ) + + # Apply sparse matrix multiplication + input_likelihoods_fixed = input_likelihoods[:, :, :source_vocab_size_fixed] + reshaped_input = input_likelihoods_fixed.reshape(batch_size * seq_len, source_vocab_size) + + projected_likelihoods_reshaped = torch.matmul(reshaped_input.to(torch.float32), sparse_projection_matrix.to(torch.float32)) + + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, target_vocab_size).to(input_likelihoods.dtype) + + @staticmethod + def project_token_likelihoods_sparse(input_likelihoods, sparse_matrix, device): + """Projects token likelihoods using a sparse transformation matrix.""" + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + + # Get dimensions from sparse matrix + matrix_input_size, matrix_output_size = sparse_matrix.shape + + if abs(source_vocab_size - matrix_input_size) > 1000: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches sparse matrix input size ({matrix_input_size})") + + # Move to correct device and dtype + # input_likelihoods = input_likelihoods.to(device) + # sparse_matrix = sparse_matrix.to(device) + + # Adjust input size to match matrix dimensions + # next 2 lines required when we used vocab length from tokenizer, now we use the size of logits + # source_vocab_size_fixed = min(source_vocab_size, matrix_input_size) + # input_likelihoods_fixed = input_likelihoods[:, :, :source_vocab_size_fixed] + input_likelihoods_fixed = input_likelihoods + + # Reshape for matrix multiplication + reshaped_input = input_likelihoods_fixed.reshape(batch_size * seq_len, source_vocab_size) + + # Project using sparse matrix multiplication + projected_likelihoods_reshaped = torch.matmul(reshaped_input.to(torch.float32), sparse_matrix.to(torch.float32)) + + # Reshape back to original format + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, matrix_output_size).to(input_likelihoods.dtype) + + def align(self, student_seq: Union[List[str], List[List[str]], List[int], List[List[int]]], + teacher_seq: Union[List[str], List[List[str]], List[int], List[List[int]]], + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + ignore_leading_char_diff=False, + chunk_size=128, + post_process=True, + convert_ids_to_tokens=True, + anchor_lengths=[3,], + _debug_timing=False): + """Align two sequences of tokens (or batches of sequences).""" + import time as _time + + seq1 = student_seq + seq2 = teacher_seq + + _t_convert = 0.0 + if isinstance(seq1, torch.Tensor): + seq1 = seq1.cpu().tolist() + seq2 = seq2.cpu().tolist() + if convert_ids_to_tokens: + _tc0 = _time.time() + seq1 = [self.student_tokenizer.convert_ids_to_tokens(seq1_single) for seq1_single in seq1] + seq2 = [self.teacher_tokenizer.convert_ids_to_tokens(seq2_single) for seq2_single in seq2] + _t_convert = _time.time() - _tc0 + + is_batched = isinstance(seq1, list) and len(seq1) > 0 and isinstance(seq1[0], list) + + _t_canon_total = 0.0 + _t_anchors_dp_total = 0.0 + _t_postprocess_total = 0.0 + _t_mask_total = 0.0 + + if is_batched: + if not (isinstance(seq2, list) and len(seq2) == len(seq1) and (len(seq2) == 0 or isinstance(seq2[0], list))): + raise ValueError("For batched input, seq1 and seq2 must be lists of lists with the same length.") + + all_aligned_pairs = [] + for s1, s2 in zip(seq1, seq2): + aligned_pairs, timings = self._align_single(s1, s2, exact_match_score, combination_score_multiplier, gap_penalty, ignore_leading_char_diff, chunk_size, post_process, anchor_lengths, _return_timings=True) + all_aligned_pairs.append(aligned_pairs) + _t_canon_total += timings.get("canon", 0) + _t_anchors_dp_total += timings.get("anchors_dp", 0) + _t_postprocess_total += timings.get("postprocess", 0) + _t_mask_total += timings.get("mask", 0) + else: + aligned_pairs, timings = self._align_single(seq1, seq2, exact_match_score, combination_score_multiplier, gap_penalty, ignore_leading_char_diff, chunk_size, post_process, anchor_lengths, _return_timings=True) + all_aligned_pairs = [aligned_pairs] + _t_canon_total += timings.get("canon", 0) + _t_anchors_dp_total += timings.get("anchors_dp", 0) + _t_postprocess_total += timings.get("postprocess", 0) + _t_mask_total += timings.get("mask", 0) + + if _debug_timing: + n = len(all_aligned_pairs) + print(f" [align timing] convert_ids={_t_convert:.3f}s, " + f"canonicalize={_t_canon_total:.3f}s, " + f"anchors+DP={_t_anchors_dp_total:.3f}s, " + f"postprocess={_t_postprocess_total:.3f}s, " + f"mask={_t_mask_total:.3f}s " + f"(n={n})", flush=True) + + return all_aligned_pairs + + def _align_single(self, seq1, seq2, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + ignore_leading_char_diff=True, + chunk_size=0, + post_process=True, + anchor_lengths=None, + _return_timings=False): + """Align two sequences of tokens.""" + import time as _time + + _tc0 = _time.time() + seq1_canon = TokenAligner._canonicalize_sequence(seq1) + seq2_canon = TokenAligner._canonicalize_sequence(seq2) + _tc1 = _time.time() + + align_kwargs = { + 'exact_match_score': exact_match_score, + 'combination_score_multiplier': combination_score_multiplier, + 'gap_penalty': gap_penalty, + 'max_combination_len': self.max_combination_len, + 'ignore_leading_char_diff': False, + 'chunk_size': chunk_size, + 'anchor_lengths': anchor_lengths, + } + + aligned_pairs, _ = self._align_with_anchors(seq1_canon, seq2_canon, **align_kwargs) + _tc2 = _time.time() + + if post_process: + aligned_pairs = self.post_process_alignment_optimized( + aligned_pairs, + ignore_leading_char_diff=ignore_leading_char_diff, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=self.max_combination_len + ) + _tc3 = _time.time() + + mask = self.get_alignment_mask(aligned_pairs, use_canonicalization=True, ignore_leading_char_diff=ignore_leading_char_diff) + aligned_pairs = [ + (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, mask_value) + for (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end), mask_value in zip(aligned_pairs, mask) + ] + _tc4 = _time.time() + + timings = { + "canon": _tc1 - _tc0, + "anchors_dp": _tc2 - _tc1, + "postprocess": _tc3 - _tc2, + "mask": _tc4 - _tc3, + } + + if _return_timings: + return aligned_pairs, timings + return aligned_pairs + + + def _align_with_anchors(self, seq1, seq2, anchor_lengths=[3,], **kwargs): + """ + Optimized alignment using unique 1-to-1 matches as anchors. + """ + # CRITICAL FIX: If anchor_lengths is empty, disable anchor optimization completely + if not anchor_lengths: + return self._perform_dp_alignment(seq1, seq2, **kwargs) + + if anchor_lengths is None: + anchor_lengths = [3, 2] # Default: check 3-token, then 2-token sequences + + # Debug output + debug = kwargs.get('debug', False) + + # 1. Find high-confidence anchor points using unique token matches. + s1_counts = {} + for i, t in enumerate(seq1): + if t not in s1_counts: s1_counts[t] = [] + s1_counts[t].append(i) + + s2_counts = {} + for i, t in enumerate(seq2): + if t not in s2_counts: s2_counts[t] = [] + s2_counts[t].append(i) + + # Find potential anchors using consecutive token sequences + potential_anchors = [] + + # FIXED: Don't break early - collect anchors from all lengths and then choose the best + all_potential_anchors = [] + + # Check for anchors of different lengths + for anchor_len in anchor_lengths: + anchors_for_this_len = [] + + if anchor_len == 1: + # Handle single token anchors + common_tokens = s1_counts.keys() & s2_counts.keys() + for token in common_tokens: + if len(s1_counts[token]) == 1 and len(s2_counts[token]) == 1: + i = s1_counts[token][0] + j = s2_counts[token][0] + anchors_for_this_len.append((i, j, anchor_len)) + else: + # Handle multi-token anchors + s1_ngram_counts = {} + for i in range(len(seq1) - anchor_len + 1): + ngram = tuple(seq1[i:i + anchor_len]) + if ngram not in s1_ngram_counts: + s1_ngram_counts[ngram] = [] + s1_ngram_counts[ngram].append(i) + + s2_ngram_counts = {} + for i in range(len(seq2) - anchor_len + 1): + ngram = tuple(seq2[i:i + anchor_len]) + if ngram not in s2_ngram_counts: + s2_ngram_counts[ngram] = [] + s2_ngram_counts[ngram].append(i) + + # Find n-grams that appear exactly once in both sequences + common_ngrams = s1_ngram_counts.keys() & s2_ngram_counts.keys() + for ngram in common_ngrams: + if len(s1_ngram_counts[ngram]) == 1 and len(s2_ngram_counts[ngram]) == 1: + i = s1_ngram_counts[ngram][0] + j = s2_ngram_counts[ngram][0] + # ADDED: Verify the anchor is actually correct + if (i + anchor_len <= len(seq1) and j + anchor_len <= len(seq2) and + seq1[i:i + anchor_len] == seq2[j:j + anchor_len]): + anchors_for_this_len.append((i, j, anchor_len)) + + all_potential_anchors.extend(anchors_for_this_len) + + # IMPROVED: Choose the best set of anchors + # Prefer longer anchors, but if shorter anchors give better coverage, use them + + # Sort by position and filter for monotonic ordering + all_potential_anchors.sort() + + # IMPROVED: Better anchor selection - use greedy approach to maximize coverage + selected_anchors = [] + used_positions_seq1 = set() + used_positions_seq2 = set() + + # Sort by anchor length (descending) then by position + all_potential_anchors.sort(key=lambda x: (-x[2], x[0], x[1])) + + for i, j, anchor_len in all_potential_anchors: + # Check if this anchor conflicts with already selected ones + seq1_range = set(range(i, i + anchor_len)) + seq2_range = set(range(j, j + anchor_len)) + + if not (seq1_range & used_positions_seq1) and not (seq2_range & used_positions_seq2): + # This anchor doesn't conflict - we can use it + selected_anchors.append((i, j, anchor_len)) + used_positions_seq1.update(seq1_range) + used_positions_seq2.update(seq2_range) + + # Re-sort selected anchors by position for processing + selected_anchors.sort() + + # IMPROVED: Additional validation of selected anchors + validated_anchors = [] + last_j = -1 + for i, j, anchor_len in selected_anchors: + # Ensure monotonic ordering and no overlaps + if j > last_j: + # Double-check the anchor is valid + if (i + anchor_len <= len(seq1) and j + anchor_len <= len(seq2) and + seq1[i:i + anchor_len] == seq2[j:j + anchor_len]): + validated_anchors.append((i, j, anchor_len)) + last_j = j + anchor_len - 1 + + anchors = validated_anchors + + if not anchors: + # If no anchors are found, fall back to the standard alignment. + return self._perform_dp_alignment(seq1, seq2, **kwargs) + + # 2. Align segments between anchors. + full_alignment = [] + last_i, last_j = 0, 0 + + for anchor_idx, (i, j, anchor_len) in enumerate(anchors): + + # Align segment before the current anchor. + seg1, seg2 = seq1[last_i:i], seq2[last_j:j] + + if seg1 or seg2: + aligned_segment, _ = self._perform_dp_alignment(seg1, seg2, **kwargs) + + # Adjust indices to be relative to the full sequence and split exact matches. + for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end in aligned_segment: + new_s1_start = s1_start + last_i if s1_start != -1 else -1 + new_s1_end = s1_end + last_i if s1_end != -1 else -1 + new_s2_start = s2_start + last_j if s2_start != -1 else -1 + new_s2_end = s2_end + last_j if s2_end != -1 else -1 + + # Split if both sides have the same tokens + if (len(s1_toks) > 1 and len(s2_toks) > 1 and + len(s1_toks) == len(s2_toks) and s1_toks == s2_toks): + # Split into individual 1-to-1 matches + for k in range(len(s1_toks)): + full_alignment.append(( + [s1_toks[k]], [s2_toks[k]], + new_s1_start + k, new_s1_start + k + 1, + new_s2_start + k, new_s2_start + k + 1 + )) + else: + full_alignment.append((s1_toks, s2_toks, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + + # Add the anchor itself (consecutive tokens), also split if needed. + anchor_seq1 = seq1[i:i + anchor_len] + anchor_seq2 = seq2[j:j + anchor_len] + + # Split anchor into individual matches since they should be identical + for k in range(anchor_len): + full_alignment.append(( + [anchor_seq1[k]], [anchor_seq2[k]], + i + k, i + k + 1, + j + k, j + k + 1 + )) + + last_i, last_j = i + anchor_len, j + anchor_len + + # 3. Align the final segment after the last anchor. + seg1, seg2 = seq1[last_i:], seq2[last_j:] + + if seg1 or seg2: + aligned_segment, _ = self._perform_dp_alignment(seg1, seg2, **kwargs) + + for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end in aligned_segment: + new_s1_start = s1_start + last_i if s1_start != -1 else -1 + new_s1_end = s1_end + last_i if s1_end != -1 else -1 + new_s2_start = s2_start + last_j if s2_start != -1 else -1 + new_s2_end = s2_end + last_j if s2_end != -1 else -1 + + # Split if both sides have the same tokens + if (len(s1_toks) > 1 and len(s2_toks) > 1 and + len(s1_toks) == len(s2_toks) and s1_toks == s2_toks): + # Split into individual 1-to-1 matches + for k in range(len(s1_toks)): + full_alignment.append(( + [s1_toks[k]], [s2_toks[k]], + new_s1_start + k, new_s1_start + k + 1, + new_s2_start + k, new_s2_start + k + 1 + )) + else: + full_alignment.append((s1_toks, s2_toks, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + + return full_alignment, 0 # Return 0 for score as it's not well-defined here + + def _perform_dp_alignment(self, seq1, seq2, **kwargs): + """ + Helper function to run the core DP-based alignment. + """ + chunk_size = kwargs.get('chunk_size', 0) + kwargs.pop('chunk_size', None) + kwargs.pop('anchor_lengths', None) + + if chunk_size > 0: + return self.align_tokens_combinations_chunked(seq1, seq2, chunk_size=chunk_size, **kwargs) + else: + return self.align_tokens_with_combinations_numpy_jit(seq1, seq2, **kwargs) + + @staticmethod + def _canonical_token(token: str) -> str: + """Return a canonical representation of a tokenizer token.""" + if not token: + return token + + # 1. Normalize space prefixes first + if token.startswith(' '): + token = 'Ġ' + token[1:] + elif token.startswith('_'): + token = 'Ġ' + token[1:] + elif token.startswith('▁'): # SentencePiece-style space prefix + token = 'Ġ' + token[1:] + + # 1.5. Normalize newline and whitespace representations + if token == 'Ċ': # GPT-style newline (used by Llama) + token = '\n' + elif token == '\\n': # Escaped newline representation + token = '\n' + elif token == 'ĉ': # Alternative newline representation + token = '\n' + elif token == 'Ġ\n': # Space + newline combination + token = '\n' + elif 'Ċ' in token: # Handle Ċ embedded in other tokens + token = token.replace('Ċ', '\n') + elif '\\n' in token: # Handle escaped newlines in compound tokens + token = token.replace('\\n', '\n') + + # 1.6. Handle space-separated punctuation normalization + if token == 'Ġ,': # Space + comma + token = ',' + elif token == 'Ġ.': # Space + period + token = '.' + elif token == 'Ġ;': # Space + semicolon + token = ';' + elif token == 'Ġ:': # Space + colon + token = ':' + + # 2. Handle SentencePiece byte fallback tokens like <0x20> + if token.startswith('<0x') and token.endswith('>') and len(token) == 6: + try: + byte_val = int(token[3:5], 16) + if 0 <= byte_val <= 255: + return chr(byte_val) + except ValueError: + pass + + # 3. Normalize common Unicode encoding issues + unicode_fixes = { + # Spanish + 'ñ': 'ñ', 'á': 'á', 'é': 'é', 'í': 'í', 'ó': 'ó', 'ú': 'ú', + 'Ã': 'À', 'â': 'â', 'ç': 'ç', + # French + 'ç': 'ç', 'è': 'è', 'é': 'é', 'ë': 'ë', 'î': 'î', 'ô': 'ô', + 'ù': 'ù', 'û': 'û', 'ÿ': 'ÿ', + # Chinese (common encoding artifacts) + 'ä¸Ń': '中', 'æĸĩ': '文', 'æĹ¥æľ¬': '日本', 'èªŀ': '語', + # Russian + 'ÐłÑĥÑģ': 'Рус', 'Ñģкий': 'ский', + # Arabic + 'اÙĦعربÙĬØ©': 'العربية', + # Hindi + 'ह': 'ह', 'िà¤Ĥ': 'हिं', 'दà¥Ģ': 'दी', + # Mathematical symbols (common artifacts) + 'âĪij': '∑', 'âĪı': '∏', 'âĪĤ': '∂', 'âĪĩ': '∇', + 'âĪŀ': '∞', 'âĪļ': '√', 'âĪ«': '∫', 'âīĪ': '≈', + 'âīł': '≠', 'âī¤': '≤', 'âī¥': '≥', + } + + # Apply Unicode fixes + for broken, fixed in unicode_fixes.items(): + if broken in token: + token = token.replace(broken, fixed) + + # 4. Normalize special tokens + special_token_map = { + '<|begin_of_text|>': '', # Llama-style BOS token + '': '', # Standard BOS token + '': '', # Padding tokens → empty (will be handled by alignment) + '': ' ', # End tokens + '': ' ', # End tokens + } + + if token in special_token_map: + return special_token_map[token] + + return token + + @staticmethod + def _canonicalize_sequence(seq: List[str]) -> List[str]: + """Canonicalize every token in a sequence (list of str).""" + # First, handle multi-token encoding artifacts (before individual canonicalization) + merged_artifacts = TokenAligner._merge_encoding_artifacts(seq) + + # Then, canonicalize individual tokens + canon_tokens = [TokenAligner._canonical_token(tok) for tok in merged_artifacts] + + # Finally, merge consecutive byte tokens into proper Unicode characters + return TokenAligner._merge_consecutive_bytes(canon_tokens) + + @staticmethod + def _merge_encoding_artifacts(tokens: List[str]) -> List[str]: + """Merge consecutive tokens that represent multi-token encoding artifacts.""" + if not tokens: + return tokens + + # Common multi-token encoding artifacts that should be merged + multi_token_fixes = [ + # Mathematical symbols split across tokens + (['ĠâĪ', 'ij'], ['Ġ∑']), # Sum symbol + (['âĪ', 'ij'], ['∑']), # Sum symbol (no space) + (['ĠâĪ', 'ı'], ['Ġ∏']), # Product symbol + (['âĪ', 'ı'], ['∏']), # Product symbol (no space) + (['ĠâĪ', 'Ĥ'], ['Ġ∂']), # Partial derivative + (['âĪ', 'Ĥ'], ['∂']), # Partial derivative (no space) + (['ĠâĪ', 'ĩ'], ['Ġ∇']), # Nabla/gradient + (['âĪ', 'ĩ'], ['∇']), # Nabla/gradient (no space) + (['ĠâĪ', 'ŀ'], ['Ġ∞']), # Infinity + (['âĪ', 'ŀ'], ['∞']), # Infinity (no space) + (['ĠâĪ', 'ļ'], ['Ġ√']), # Square root + (['âĪ', 'ļ'], ['√']), # Square root (no space) + (['ĠâĪ', '«'], ['Ġ∫']), # Integral + (['âĪ', '«'], ['∫']), # Integral (no space) + (['Ġâī', 'ł'], ['Ġ≠']), # Not equal + (['âī', 'ł'], ['≠']), # Not equal (no space) + # Other common multi-token artifacts + (['Ġä¸', 'Ń'], ['Ġ中']), # Chinese character + (['ä¸', 'Ń'], ['中']), # Chinese character (no space) + (['æĸ', 'ĩ'], ['文']), # Chinese character + (['Ġæĸ', 'ĩ'], ['Ġ文']), # Chinese character (with space) + ] + + result = [] + i = 0 + + while i < len(tokens): + # Check if current position matches any multi-token pattern + matched = False + + for pattern, replacement in multi_token_fixes: + pattern_len = len(pattern) + if i + pattern_len <= len(tokens): + # Check if the tokens match the pattern + if tokens[i:i+pattern_len] == pattern: + # Replace with the fixed version + result.extend(replacement) + i += pattern_len + matched = True + break + + if not matched: + # No pattern matched, keep the original token + result.append(tokens[i]) + i += 1 + + return result + + @staticmethod + def _merge_consecutive_bytes(tokens: List[str]) -> List[str]: + """Merge consecutive tokens that represent UTF-8 byte sequences.""" + if not tokens: + return tokens + + result = [] + byte_buffer = [] + + for token in tokens: + # Check if this token represents byte(s) + clean_token = token.lstrip('Ġ') + + # Check if all characters in the token are visual bytes + all_chars_are_bytes = True + if len(clean_token) == 0: + all_chars_are_bytes = False + else: + for char in clean_token: + if TokenAligner._get_byte_value(char) is None: + all_chars_are_bytes = False + break + + if all_chars_are_bytes: + byte_buffer.append(token) + else: + # Not a byte token, flush buffer first + if byte_buffer: + merged = TokenAligner._try_merge_byte_buffer(byte_buffer) + result.extend(merged) + byte_buffer = [] + result.append(token) + + # Flush any remaining bytes + if byte_buffer: + merged = TokenAligner._try_merge_byte_buffer(byte_buffer) + result.extend(merged) + + return result + + @staticmethod + def _try_merge_byte_buffer(byte_tokens: List[str]) -> List[str]: + """Try to merge a buffer of potential byte tokens into a Unicode character.""" + if not byte_tokens: + return [] + + # If only one token, just return it unless it's a multi-character byte token + if len(byte_tokens) == 1: + token = byte_tokens[0] + clean_token = token.lstrip('Ġ') + if len(clean_token) <= 1: + return byte_tokens + # Continue processing multi-character token + + # Extract space prefix from first token + first_token = byte_tokens[0] + space_prefix = 'Ġ' if first_token.startswith('Ġ') else '' + + # Extract raw bytes from all characters in all tokens + raw_bytes = [] + for token in byte_tokens: + clean_token = token.lstrip('Ġ') + for char in clean_token: + byte_value = TokenAligner._get_byte_value(char) + if byte_value is not None: + raw_bytes.append(byte_value) + else: + # If any character is not a byte, return original tokens + return byte_tokens + + # Only try to merge if we have 2-4 bytes (typical for emoji/multi-byte chars) + if len(raw_bytes) < 2 or len(raw_bytes) > 4: + return byte_tokens + + # Try to decode as UTF-8 + try: + decoded_text = bytes(raw_bytes).decode('utf-8') + # Only merge if the result is a single Unicode character (like an emoji) + if len(decoded_text) == 1 and ord(decoded_text) > 127: + return [space_prefix + decoded_text] + else: + # If it's not a single special character, keep original tokens + return byte_tokens + except UnicodeDecodeError: + # If decoding fails, return original tokens + return byte_tokens + + # Common visual byte representations used by some tokenizers (especially for emojis) + VISUAL_BYTE_MAP = { + # Common emoji byte range (240-255) + 'ð': 240, 'Ɩ': 241, 'Ɨ': 242, 'Ƙ': 243, 'ƙ': 244, 'ƚ': 245, 'ƛ': 246, 'Ɯ': 247, + 'Ɲ': 248, 'ƞ': 249, 'Ɵ': 250, 'Ơ': 251, 'ơ': 252, 'Ƣ': 253, 'ƣ': 254, 'Ƥ': 255, + # Other common byte representations (0-255 only) + 'Ł': 156, 'ł': 157, 'Ń': 158, 'ń': 159, 'ĺ': 149, 'Ļ': 150, 'ļ': 151, 'Ľ': 152, + 'ľ': 153, 'Ŀ': 154, 'ŀ': 155, 'Ĭ': 135, 'ĭ': 136, 'Į': 137, 'į': 138, 'İ': 139, + 'ı': 140, 'IJ': 141, 'ij': 142, 'Ĵ': 143, 'ĵ': 144, 'Ķ': 145, 'ķ': 146, 'ĸ': 147, + 'Ĺ': 148, 'ĥ': 128, 'Ħ': 129, 'ħ': 130, 'Ĩ': 131, 'ĩ': 132, 'Ī': 133, 'ī': 134, + 'Ģ': 162, 'ģ': 163, 'Ĝ': 28, 'ĝ': 29, 'Ğ': 30, 'ğ': 31, + } + + @staticmethod + def _get_byte_value(token_char: str) -> int: + """Get the byte value for a character, handling both direct bytes and visual representations.""" + if len(token_char) != 1: + return None + + char_ord = ord(token_char) + + # Direct byte (0-255) + if char_ord < 256: + return char_ord + + # Visual byte representation + if token_char in TokenAligner.VISUAL_BYTE_MAP: + return TokenAligner.VISUAL_BYTE_MAP[token_char] + + return None + + @staticmethod + def _strings_equal_flexible(s1, s2, ignore_leading_char_diff): + if not ignore_leading_char_diff: + return s1 == s2 + + # Use our comprehensive canonicalization for robust comparison + s1_canonical = TokenAligner._canonical_token(s1) + s2_canonical = TokenAligner._canonical_token(s2) + + return s1_canonical == s2_canonical + + def align_tokens_with_combinations_numpy(seq1, seq2, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + max_combination_len=4, + ignore_leading_char_diff=False): + n1, n2 = len(seq1), len(seq2) + dp = np.zeros((n1 + 1, n2 + 1), dtype=np.float32) + trace = np.full((n1 + 1, n2 + 1), '', dtype=object) + + # Initialize DP edges with gap penalties + for i in range(1, n1 + 1): + dp[i, 0] = dp[i - 1, 0] + gap_penalty + trace[i, 0] = 'up' + for j in range(1, n2 + 1): + dp[0, j] = dp[0, j - 1] + gap_penalty + trace[0, j] = 'left' + + # Precompute joined substrings for all valid k-length spans + joined_seq1 = {(i - k, i): ''.join(seq1[i - k:i]) + for i in range(n1 + 1) + for k in range(1, min(i, max_combination_len) + 1)} + joined_seq2 = {(j - k, j): ''.join(seq2[j - k:j]) + for j in range(n2 + 1) + for k in range(1, min(j, max_combination_len) + 1)} + + # Fill DP table + for i in range(1, n1 + 1): + for j in range(1, n2 + 1): + s1_val, s2_val = seq1[i - 1], seq2[j - 1] + match_score = exact_match_score if TokenAligner._strings_equal_flexible(s1_val, s2_val, ignore_leading_char_diff) else -exact_match_score + score_diag = dp[i - 1, j - 1] + match_score + score_up = dp[i - 1, j] + gap_penalty + score_left = dp[i, j - 1] + gap_penalty + + max_score = score_diag + best_move = 'diag' + if score_up > max_score: + max_score = score_up + best_move = 'up' + if score_left > max_score: + max_score = score_left + best_move = 'left' + + # Check for seq1[i-1] == join(seq2[j-k:j]) + for k in range(2, min(j + 1, max_combination_len + 1)): + if (j - k, j) in joined_seq2 and TokenAligner._strings_equal_flexible(s1_val, joined_seq2[(j - k, j)], ignore_leading_char_diff): + comb_score = dp[i - 1, j - k] + combination_score_multiplier * k + if comb_score > max_score: + max_score = comb_score + best_move = f'comb_s1_over_s2_{k}' + + # Check for seq2[j-1] vs seq1[i-k:i] + for k in range(2, min(i + 1, max_combination_len + 1)): + if (i - k, i) in joined_seq1 and TokenAligner._strings_equal_flexible(s2_val, joined_seq1[(i - k, i)], ignore_leading_char_diff): + comb_score = dp[i - k, j - 1] + combination_score_multiplier * k + if comb_score > max_score: + max_score = comb_score + best_move = f'comb_s2_over_s1_{k}' + + dp[i, j] = max_score + trace[i, j] = best_move + + # Backtrack to extract alignment + aligned = [] + i, j = n1, n2 + while i > 0 or j > 0: + move = trace[i, j] + if move == 'diag': + aligned.append(([seq1[i - 1]], [seq2[j - 1]], i - 1, i, j - 1, j)) + i -= 1 + j -= 1 + elif move == 'up': + aligned.append(([seq1[i - 1]], [], i - 1, i, -1, -1)) + i -= 1 + elif move == 'left': + aligned.append(([], [seq2[j - 1]], -1, -1, j - 1, j)) + j -= 1 + elif move.startswith('comb_s1_over_s2_'): + k = int(move.rsplit('_', 1)[-1]) + aligned.append(([seq1[i - 1]], seq2[j - k:j], i - 1, i, j - k, j)) + i -= 1 + j -= k + elif move.startswith('comb_s2_over_s1_'): + k = int(move.rsplit('_', 1)[-1]) + aligned.append((seq1[i - k:i], [seq2[j - 1]], i - k, i, j - 1, j)) + i -= k + j -= 1 + else: + break + + aligned.reverse() + return aligned, dp[n1, n2] + + @staticmethod + def align_tokens_with_combinations_numpy_jit( + seq1, seq2, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + max_combination_len=4, + ignore_leading_char_diff=False, + ): + """Numba-accelerated version of align_tokens_with_combinations_numpy. + + Pre-converts string tokens to integer IDs, runs the DP in a Numba + @njit kernel, then backtracks using the original string tokens. + Falls back to the pure-Python original when Numba is unavailable or + when ignore_leading_char_diff is True (requires Python string logic). + """ + if not _NUMBA_AVAILABLE or ignore_leading_char_diff: + return TokenAligner.align_tokens_with_combinations_numpy( + seq1, seq2, exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff, + ) + + n1, n2 = len(seq1), len(seq2) + if n1 == 0 and n2 == 0: + return [], 0.0 + if n1 == 0: + return [([], [seq2[j]], -1, -1, j, j + 1) for j in range(n2)], n2 * gap_penalty + if n2 == 0: + return [([seq1[i]], [], i, i + 1, -1, -1) for i in range(n1)], n1 * gap_penalty + + token_to_id: dict[str, int] = {} + _next_id = [0] + + def _get_id(s: str) -> int: + tid = token_to_id.get(s) + if tid is None: + tid = _next_id[0] + token_to_id[s] = tid + _next_id[0] += 1 + return tid + + ids1 = np.array([_get_id(t) for t in seq1], dtype=np.int64) + ids2 = np.array([_get_id(t) for t in seq2], dtype=np.int64) + + INVALID = np.int64(-1) + joined1 = np.full((n1 + 1, max_combination_len + 1), INVALID, dtype=np.int64) + for i in range(n1 + 1): + for k in range(2, min(i, max_combination_len) + 1): + joined1[i, k] = _get_id(''.join(seq1[i - k:i])) + + joined2 = np.full((n2 + 1, max_combination_len + 1), INVALID, dtype=np.int64) + for j in range(n2 + 1): + for k in range(2, min(j, max_combination_len) + 1): + joined2[j, k] = _get_id(''.join(seq2[j - k:j])) + + dp, trace = _dp_core_numba( + ids1, ids2, joined1, joined2, n1, n2, + np.float32(exact_match_score), + np.float32(gap_penalty), + np.float32(combination_score_multiplier), + max_combination_len, + ) + + aligned = [] + i, j = n1, n2 + while i > 0 or j > 0: + m = trace[i, j] + if m == 1: + aligned.append(([seq1[i - 1]], [seq2[j - 1]], i - 1, i, j - 1, j)) + i -= 1 + j -= 1 + elif m == 2: + aligned.append(([seq1[i - 1]], [], i - 1, i, -1, -1)) + i -= 1 + elif m == 3: + aligned.append(([], [seq2[j - 1]], -1, -1, j - 1, j)) + j -= 1 + elif 10 <= m < 20: + k = m - 10 + aligned.append(([seq1[i - 1]], seq2[j - k:j], i - 1, i, j - k, j)) + i -= 1 + j -= k + elif 20 <= m < 30: + k = m - 20 + aligned.append((seq1[i - k:i], [seq2[j - 1]], i - k, i, j - 1, j)) + i -= k + j -= 1 + else: + break + + aligned.reverse() + return aligned, float(dp[n1, n2]) + + @staticmethod + def align_tokens_combinations_chunked( + seq1: List[str], + seq2: List[str], + exact_match_score: float = 3.0, + combination_score_multiplier: float = 1.5, + gap_penalty: float = -1.5, + max_combination_len: int = 4, + ignore_leading_char_diff: bool = False, + chunk_size: int = 256, + ): + """ + Chunked processing for very large sequences. + """ + n1, n2 = len(seq1), len(seq2) + + # If sequences are small enough, use regular algorithm + if n1 <= chunk_size and n2 <= chunk_size: + return TokenAligner.align_tokens_with_combinations_numpy_jit( + seq1, seq2, exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff + ) + + # For very large sequences, use divide-and-conquer approach + if n1 > chunk_size or n2 > chunk_size: + # Find approximate midpoint alignment using simplified algorithm + mid1, mid2 = n1 // 2, n2 // 2 + + # Recursively align left and right parts + left_aligned, left_score = TokenAligner.align_tokens_combinations_chunked( + seq1[:mid1], seq2[:mid2], exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff, + chunk_size=chunk_size + ) + + right_aligned, right_score = TokenAligner.align_tokens_combinations_chunked( + seq1[mid1:], seq2[mid2:], exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff, + chunk_size=chunk_size + ) + + # Adjust indices for right part + adjusted_right = [] + for s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end in right_aligned: + new_s1_start = s1_start + mid1 if s1_start >= 0 else -1 + new_s1_end = s1_end + mid1 if s1_end >= 0 else -1 + new_s2_start = s2_start + mid2 if s2_start >= 0 else -1 + new_s2_end = s2_end + mid2 if s2_end >= 0 else -1 + adjusted_right.append((s1_tokens, s2_tokens, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + + # Combine results + combined_aligned = left_aligned + adjusted_right + combined_score = left_score + right_score + + return combined_aligned, combined_score + + # Fallback to regular algorithm + return TokenAligner.align_tokens_with_combinations_numpy_jit( + seq1, seq2, exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff + ) + + @staticmethod + def _combine_consecutive_misaligned_tokens( + aligned_pairs: List, + pair_strings: List, + end_mismatch_threshold: float = 0.2 + ) -> List: + """ + Combine consecutive misaligned tokens into single chunks to improve alignment. + + This addresses cases where multiple tokens are individually misaligned but + collectively represent the same content. Avoids combining tokens near the + end of sequences that might be misaligned due to length differences. + + Args: + aligned_pairs: List of alignment pairs + pair_strings: Precomputed string representations and match status + end_mismatch_threshold: Fraction of sequence from end to avoid chunking + + Returns: + Modified aligned_pairs with consecutive misaligned tokens combined + """ + if not aligned_pairs or len(aligned_pairs) < 2: + return aligned_pairs + + # Calculate the boundary for avoiding end mismatches + sequence_length = len(aligned_pairs) + end_boundary = int(sequence_length * (1 - end_mismatch_threshold)) + + processed_pairs = [] + i = 0 + + while i < len(aligned_pairs): + # Check if current pair is misaligned and not near the end + if (i < end_boundary and + not pair_strings[i][2] and # Current pair is misaligned + i + 1 < len(aligned_pairs)): # Not the last pair + + # Find consecutive misaligned pairs + consecutive_misaligned = [i] + j = i + 1 + + # Look ahead for more consecutive misaligned pairs (up to end boundary) + while (j < end_boundary and + j < len(aligned_pairs) and + not pair_strings[j][2]): # Next pair is also misaligned + consecutive_misaligned.append(j) + j += 1 + + # Only combine if we have multiple consecutive misaligned pairs + if len(consecutive_misaligned) >= 2: + # Combine all consecutive misaligned pairs into one chunk + combined_s1_tokens = [] + combined_s2_tokens = [] + s1_indices = [] + s2_indices = [] + + for idx in consecutive_misaligned: + s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, *rest = aligned_pairs[idx] + combined_s1_tokens.extend(s1_tokens) + combined_s2_tokens.extend(s2_tokens) + + if s1_tokens and s1_start != -1: + s1_indices.extend([s1_start, s1_end]) + if s2_tokens and s2_start != -1: + s2_indices.extend([s2_start, s2_end]) + + # Calculate combined indices + combined_s1_start = min(s1_indices[::2]) if s1_indices else -1 + combined_s1_end = max(s1_indices[1::2]) if s1_indices else -1 + combined_s2_start = min(s2_indices[::2]) if s2_indices else -1 + combined_s2_end = max(s2_indices[1::2]) if s2_indices else -1 + + # Create combined pair + combined_pair = ( + combined_s1_tokens, + combined_s2_tokens, + combined_s1_start, + combined_s1_end, + combined_s2_start, + combined_s2_end + ) + + processed_pairs.append(combined_pair) + i = j # Skip to after the combined region + else: + # Only one misaligned pair, keep as is + processed_pairs.append(aligned_pairs[i]) + i += 1 + else: + # Current pair is aligned or near the end, keep as is + processed_pairs.append(aligned_pairs[i]) + i += 1 + + return processed_pairs + + + @staticmethod + def post_process_alignment_optimized( + aligned_pairs: List, + ignore_leading_char_diff: bool = False, + exact_match_score: float = 3.0, + combination_score_multiplier: float = 1.5, + gap_penalty: float = -1.5, + max_combination_len: int = 4, + combine_misaligned_chunks: bool = True, + end_mismatch_threshold: float = 0.2 + ) -> List: + """ + Optimized version of post_process_alignment with better performance. + + Key optimizations: + 1. Precompute string concatenations to avoid repeated joins + 2. Early termination when no bad regions are found + 3. Cache alignment results for repeated chunk patterns + 4. Vectorized index calculations + 5. Reduced nested loop complexity + 6. Combine multiple consecutive misaligned tokens into single chunks + + Args: + combine_misaligned_chunks: If True, combine consecutive misaligned tokens into chunks + end_mismatch_threshold: Fraction of sequence length from end to avoid chunking (0.2 = last 20%) + """ + if not aligned_pairs: + return [] + + # Precompute joined strings for all pairs to avoid repeated concatenation + # Use canonicalization for robust comparison + pair_strings = [] + for i, (s1_tokens, s2_tokens, *rest) in enumerate(aligned_pairs): + # Canonicalize individual tokens before joining for better matching + s1_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s1_tokens] if s1_tokens else [] + s2_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s2_tokens] if s2_tokens else [] + s1_str = "".join(s1_canonical_tokens) + s2_str = "".join(s2_canonical_tokens) + is_match = TokenAligner._strings_equal_flexible(s1_str, s2_str, ignore_leading_char_diff) + pair_strings.append((s1_str, s2_str, is_match)) + + # Step 1: Handle consecutive misaligned chunks if enabled + if combine_misaligned_chunks: + aligned_pairs = TokenAligner._combine_consecutive_misaligned_tokens( + aligned_pairs, pair_strings, end_mismatch_threshold + ) + + # Recompute pair_strings after combining misaligned chunks + pair_strings = [] + for i, (s1_tokens, s2_tokens, *rest) in enumerate(aligned_pairs): + s1_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s1_tokens] if s1_tokens else [] + s2_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s2_tokens] if s2_tokens else [] + s1_str = "".join(s1_canonical_tokens) + s2_str = "".join(s2_canonical_tokens) + is_match = TokenAligner._strings_equal_flexible(s1_str, s2_str, ignore_leading_char_diff) + pair_strings.append((s1_str, s2_str, is_match)) + + processed_pairs = [] + alignment_cache = {} # Cache for repeated alignment patterns + i = 0 + + while i < len(aligned_pairs): + s1_tokens, s2_tokens, *_ = aligned_pairs[i] + + # Handle coarse alignments that can be split (optimized) + if len(s1_tokens) > 1 and len(s1_tokens) == len(s2_tokens) and s1_tokens == s2_tokens: + s1_start, s1_end, s2_start, s2_end = aligned_pairs[i][2:6] + # Vectorized creation of split pairs + for k in range(len(s1_tokens)): + processed_pairs.append( + ([s1_tokens[k]], [s2_tokens[k]], + s1_start + k, s1_start + k + 1, + s2_start + k, s2_start + k + 1) + ) + i += 1 + continue + + # Find bad regions more efficiently using precomputed strings + start_bad_region = -1 + for j in range(i, len(aligned_pairs)): + if not pair_strings[j][2]: # is_match is False + start_bad_region = j + break + + if start_bad_region == -1: + # No more bad regions - add remaining pairs and exit + processed_pairs.extend(aligned_pairs[i:]) + break + + # Add good pairs before bad region + processed_pairs.extend(aligned_pairs[i:start_bad_region]) + + # Optimized chunk processing with early termination + found_fix = False + max_chunk_size = min(10, len(aligned_pairs) - start_bad_region) # Limit search space + + for chunk_size in range(2, max_chunk_size + 1): + chunk = aligned_pairs[start_bad_region : start_bad_region + chunk_size] + + # Efficient token extraction using list comprehension + chunk_s1_tokens = [] + chunk_s2_tokens = [] + s1_indices = [] + s2_indices = [] + + for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end, *rest in chunk: + chunk_s1_tokens.extend(s1_toks) + chunk_s2_tokens.extend(s2_toks) + if s1_toks: + s1_indices.extend([s1_start, s1_end]) + if s2_toks: + s2_indices.extend([s2_start, s2_end]) + + # Quick string comparison using canonicalization + chunk_s1_canonical = [TokenAligner._canonical_token(tok) for tok in chunk_s1_tokens] + chunk_s2_canonical = [TokenAligner._canonical_token(tok) for tok in chunk_s2_tokens] + chunk_s1_str = "".join(chunk_s1_canonical) + chunk_s2_str = "".join(chunk_s2_canonical) + + if not TokenAligner._strings_equal_flexible(chunk_s1_str, chunk_s2_str, ignore_leading_char_diff): + continue + + # Create cache key for alignment + cache_key = (tuple(chunk_s1_tokens), tuple(chunk_s2_tokens)) + + if cache_key in alignment_cache: + sub_aligned_pairs, realign_is_perfect = alignment_cache[cache_key] + else: + # Perform alignment + sub_aligned_pairs, _ = TokenAligner.align_tokens_with_combinations_numpy( + chunk_s1_tokens, + chunk_s2_tokens, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=max_combination_len, + ignore_leading_char_diff=ignore_leading_char_diff + ) + + # Check if re-alignment was successful using canonicalization + realign_is_perfect = all( + TokenAligner._strings_equal_flexible( + "".join([TokenAligner._canonical_token(tok) for tok in p[0]]), + "".join([TokenAligner._canonical_token(tok) for tok in p[1]]), + ignore_leading_char_diff + ) + for p in sub_aligned_pairs + ) + + # Cache the result + alignment_cache[cache_key] = (sub_aligned_pairs, realign_is_perfect) + + # Vectorized index calculations + s1_chunk_start = min(s1_indices[::2]) if s1_indices else -1 + s2_chunk_start = min(s2_indices[::2]) if s2_indices else -1 + + if realign_is_perfect: + # Add granular aligned pairs + for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end, *_ in sub_aligned_pairs: + new_s1_start = s1_chunk_start + s1_start if s1_start != -1 else -1 + new_s1_end = s1_chunk_start + s1_end if s1_end != -1 else -1 + new_s2_start = s2_chunk_start + s2_start if s2_start != -1 else -1 + new_s2_end = s2_chunk_start + s2_end if s2_end != -1 else -1 + processed_pairs.append((s1_toks, s2_toks, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + else: + # Create merged pair + s1_chunk_end = max(s1_indices[1::2]) if s1_indices else -1 + s2_chunk_end = max(s2_indices[1::2]) if s2_indices else -1 + merged_pair = (chunk_s1_tokens, chunk_s2_tokens, s1_chunk_start, s1_chunk_end, s2_chunk_start, s2_chunk_end) + processed_pairs.append(merged_pair) + + i = start_bad_region + chunk_size + found_fix = True + break + + if not found_fix: + processed_pairs.append(aligned_pairs[start_bad_region]) + i = start_bad_region + 1 + + return processed_pairs + + @staticmethod + def get_alignment_mask(aligned_pairs: List, use_canonicalization: bool = True, + ignore_leading_char_diff: bool = False) -> List[bool]: + """ + Get a boolean mask indicating which alignments are correct. + """ + if not aligned_pairs: + return [] + + # Handle batch case - take first batch + if isinstance(aligned_pairs, list) and aligned_pairs and isinstance(aligned_pairs[0], list): + pairs_to_verify = aligned_pairs[0] + else: + pairs_to_verify = aligned_pairs + + mask = [] + for s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, *rest in pairs_to_verify: + # Concatenate tokens into strings + s1_str = "".join(s1_tokens) if s1_tokens else "" + s2_str = "".join(s2_tokens) if s2_tokens else "" + + # Apply canonicalization if requested + if use_canonicalization: + s1_canonical = "".join([TokenAligner._canonical_token(tok) for tok in s1_tokens]) if s1_tokens else "" + s2_canonical = "".join([TokenAligner._canonical_token(tok) for tok in s2_tokens]) if s2_tokens else "" + is_correct = TokenAligner._strings_equal_flexible(s1_canonical, s2_canonical, ignore_leading_char_diff) + else: + if ignore_leading_char_diff: + is_correct = TokenAligner._strings_equal_flexible(s1_str, s2_str, ignore_leading_char_diff) + else: + is_correct = s1_str == s2_str + + mask.append(is_correct) + + return mask + + + def transform_learned_matrix_instance(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Instance method version that uses instance variables. + """ + scale_trick_enabled = self.enable_scale_trick if self.enable_scale_trick is not None else False + return TokenAligner.transform_learned_matrix(x, dim, enable_scale_trick=scale_trick_enabled) + + @staticmethod + def transform_learned_matrix(x: torch.Tensor, dim: int = -1, enable_scale_trick=None) -> torch.Tensor: + """ + Compute Quite Attention over tensor x along specified dimension. + + Args: + x: Input tensor. + dim: Dimension to apply attention over (default: -1). + + Returns: + Tensor of same shape with quite attention applied. + """ + if 0: + exp_x = torch.exp(x) + denom = 1 + torch.sum(exp_x, dim=dim, keepdim=True) + return exp_x / denom + # write as a single lambda function + # return lambda x: torch.exp(x) / (1 + torch.sum(torch.exp(x), dim=dim, keepdim=True)) + else: + scale_trick_enabled = enable_scale_trick if enable_scale_trick is not None else False + if scale_trick_enabled: + #trick with last column being multiplier of 0..1, or try with c instead of 1 in qa. + scores = torch.nn.functional.softmax(x, dim=dim) + # Create a mask to zero out the last column while preserving gradients + # mask = torch.ones_like(scores) + # mask[:, -1] = 0.0 + # scores = scores * mask + # Alternative approach using sigmoid (commented out): + # scores = scores * torch.sigmoid(x[:, -1].unsqueeze(-1)) + return scores + else: + #normal softmax + return torch.nn.functional.softmax(x, dim=dim) + return torch.nn.functional.softmax(x, dim=dim) diff --git a/nemo_rl/utils/x_token/__init__.py b/nemo_rl/utils/x_token/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/nemo_rl/utils/x_token/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_rl/utils/x_token/minimal_projection_generator.py b/nemo_rl/utils/x_token/minimal_projection_generator.py new file mode 100644 index 0000000000..4c71dd1490 --- /dev/null +++ b/nemo_rl/utils/x_token/minimal_projection_generator.py @@ -0,0 +1,585 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import os +import argparse +from transformers import AutoTokenizer, AutoModel, AutoConfig +# from sentence_transformers import SentenceTransformer +from tqdm.auto import tqdm +import re +import pdb + + +##### verify KL and top5 with this matrix + + +###### use config vocab size, not tokenizer + +EXACT_MATCH_ONLY = False + +# --- Configuration and Setup --- +parser = argparse.ArgumentParser(description="Generate a sparse projection map between two tokenizers.") +parser.add_argument("--model_a_index", type=int, default=1, help="Index of the source model (Model A / Student).") +parser.add_argument("--model_b_index", type=int, default=0, help="Index of the target model (Model B / Teacher).") +parser.add_argument("--model_a_name", type=str, default=None, help="HuggingFace model name for source model (Model A / Student). If provided, overrides model_a_index.") +parser.add_argument("--model_b_name", type=str, default=None, help="HuggingFace model name for target model (Model B / Teacher). If provided, overrides model_b_index.") +parser.add_argument("--keep_top_tokens", type=int, default=-1, help="Number of top tokens to keep for each vocabulary. -1 means all.") +parser.add_argument("--data_dir", type=str, default="cross_tokenizer_data/", help="Directory for importance scores and cached data.") +parser.add_argument("--top_k", type=int, default=10, help="Number of top projections to keep for each token.") +parser.add_argument("--weight_threshold", type=float, default=0.0, help="Minimum weight threshold to keep a projection. Values below this will be filtered out.") +parser.add_argument("--force_recompute", action='store_true', help="Force recomputation of embeddings even if cached files exist.") +parser.add_argument("--skip_exact_enforcement", action='store_true', help="Skip enforcing exact matches between tokens.") +parser.add_argument("--use_canonicalization", action='store_true', help="Apply token canonicalization before generating embeddings to normalize different tokenizer representations (e.g., Ġ vs ▁ prefixes, Ċ vs \\n).") +args = parser.parse_args() + +args.skip_exact_enforcement = True + +MODEL_LIST = [ + "nvidia/Mistral-NeMo-Minitron-8B-Base", + "Qwen/Qwen3-8B-Base", + "meta-llama/Llama-3.2-1B", + "meta-llama/Llama-3.1-8B", + "google/gemma-3-4b-it", + "google/gemma-2b", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "openai/gpt-oss-20b", + "microsoft/phi-4", + "google/gemma-3-12b-pt", +] +EMBEDDING_MODEL_CHOICES = [ + {"name": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", "type": "sbert"}, + {"name": "sentence-transformers/all-mpnet-base-v2", "type": "sbert"}, + {"name": "sentence-transformers/all-MiniLM-L6-v2", "type": "sbert"}, + {"name": "Qwen/Qwen3-Embedding-4B", "type": "llm_first_layer"}, + {"name": "Qwen/Qwen3-Embedding-0.6B", "type": "llm_first_layer"}, +] + +MAX_SEQ_LENGTH_EMBEDDING = 64 +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def sinkhorn(A, n_iters=10): + for _ in range(n_iters): + if _ % 2 == 0: + # A = A / (A.sum(dim=0, keepdim=True) + 1e-6) + col_sums = A.sum(dim=0, keepdim=True) + safe_col_sums = torch.where(col_sums == 0, torch.ones_like(col_sums), col_sums) + A = A / safe_col_sums + else: + #0, 2, 4, 6 + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + + return A + +def sinkhorn_one_dim(A, n_iters=1): + for _ in range(n_iters): + + + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + + return A + +# --- Helper Functions --- + +def clean_model_name_for_filename(name: str) -> str: + """Removes parameter counts and common suffixes from model names for cleaner filenames.""" + # Removes patterns like -8B, -1.5B, -4b, -125m etc. + cleaned_name = re.sub(r'-?[0-9\.]+[bBmB]', '', name, flags=re.IGNORECASE) + # Remove common suffixes + cleaned_name = cleaned_name.replace('-Base', '').replace('-it', '').replace('-Instruct', '') + # Clean up any leading/trailing hyphens that might result + cleaned_name = cleaned_name.strip('-_') + if 'mini' in name: + cleaned_name += "_mini" + return cleaned_name + +def load_tokenizer(model_id_or_path): + """Loads a HuggingFace tokenizer, setting a pad token if necessary.""" + try: + tok = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok + except Exception as e: + print(f"Error loading tokenizer for model '{model_id_or_path}': {e}") + print(f"Available models in MODEL_LIST (indices 0-{len(MODEL_LIST)-1}):") + for i, model in enumerate(MODEL_LIST): + print(f" {i}: {model}") + raise + +def validate_model_selection(args): + """Validates that the model selection arguments are valid.""" + # Check if both name and index are provided for the same model + if args.model_a_name is not None and args.model_a_index != 1: # 1 is the default + print("Warning: Both --model_a_name and --model_a_index provided. Using --model_a_name.") + + if args.model_b_name is not None and args.model_b_index != 0: # 0 is the default + print("Warning: Both --model_b_name and --model_b_index provided. Using --model_b_name.") + + # Validate indices if names are not provided + if args.model_a_name is None: + if args.model_a_index < 0 or args.model_a_index >= len(MODEL_LIST): + raise ValueError(f"model_a_index {args.model_a_index} is out of range. Available models: 0-{len(MODEL_LIST)-1}") + + if args.model_b_name is None: + if args.model_b_index < 0 or args.model_b_index >= len(MODEL_LIST): + raise ValueError(f"model_b_index {args.model_b_index} is out of range. Available models: 0-{len(MODEL_LIST)-1}") + + # Check if the same model is selected for both A and B + model_a_id = args.model_a_name if args.model_a_name is not None else MODEL_LIST[args.model_a_index] + model_b_id = args.model_b_name if args.model_b_name is not None else MODEL_LIST[args.model_b_index] + + if model_a_id == model_b_id: + raise ValueError(f"Cannot use the same model for both A and B: {model_a_id}") + +def save_data(data, filename): + """Saves data to a torch file.""" + os.makedirs(os.path.dirname(filename), exist_ok=True) + torch.save(data.cpu(), filename) + print(f"Data saved to {filename}") + +def load_data(filename): + """Loads data from a torch file.""" + return torch.load(filename) + +def get_llm_first_layer_embeddings(decoded_tokens_list, llm_embedding_tokenizer, llm_embedding_model, max_seq_length_embedding, device, batch_size=32): + """Generates embeddings using the first layer of a given LLM.""" + all_embeddings = [] + llm_embedding_model.eval() + embedding_dim = llm_embedding_model.config.hidden_size + + for i in tqdm(range(0, len(decoded_tokens_list), batch_size), desc="Encoding tokens with LLM"): + batch_tokens = decoded_tokens_list[i:i + batch_size] + inputs = llm_embedding_tokenizer( + batch_tokens, return_tensors="pt", padding=True, truncation=True, + max_length=max_seq_length_embedding, add_special_tokens=False, + ).to(device) + + with torch.no_grad(): + outputs = llm_embedding_model(**inputs, output_hidden_states=True) + first_layer_output = outputs.hidden_states[0] + + for k in range(first_layer_output.shape[0]): + valid_token_mask = inputs['attention_mask'][k] == 1 + if valid_token_mask.sum() > 0: + pooled_embedding = first_layer_output[k, valid_token_mask].mean(dim=0) + all_embeddings.append(pooled_embedding) + else: + all_embeddings.append(torch.zeros(embedding_dim, device=device)) + + return torch.stack(all_embeddings).to(device) + + +def compute_chunked_projection_map(embeddings_query, embeddings_corpus, args, device, chunk_size=1000): + """Computes projection map in chunks to save memory.""" + num_queries = embeddings_query.shape[0] + target_vocab_size = embeddings_corpus.shape[0] + + # Pre-allocate result tensors + all_top_k_indices = torch.zeros((num_queries, args.top_k), dtype=torch.long) + all_top_k_likelihoods = torch.zeros((num_queries, args.top_k), dtype=torch.float32) + + # Normalize corpus embeddings once + embeddings_corpus_norm = torch.nn.functional.normalize(embeddings_corpus.to(device).float(), p=2, dim=1) + + for chunk_start in tqdm(range(0, num_queries, chunk_size), desc="Processing chunks"): + chunk_end = min(chunk_start + chunk_size, num_queries) + chunk_query = embeddings_query[chunk_start:chunk_end].to(device).float() + + with torch.no_grad(): + # Compute similarities for this chunk + chunk_query_norm = torch.nn.functional.normalize(chunk_query, p=2, dim=1) + similarities = torch.matmul(chunk_query_norm, embeddings_corpus_norm.t()) + + # Generate projection map for this chunk + chunk_top_k_indices, chunk_top_k_likelihoods = generate_projection_map_chunk(similarities, args) + + # Store results + all_top_k_indices[chunk_start:chunk_end] = chunk_top_k_indices.cpu() + all_top_k_likelihoods[chunk_start:chunk_end] = chunk_top_k_likelihoods.cpu() + + # Clear GPU memory + del similarities, chunk_query_norm, chunk_top_k_indices, chunk_top_k_likelihoods + torch.cuda.empty_cache() + + return all_top_k_indices, all_top_k_likelihoods + +def generate_projection_map_chunk(similarities, args): + """Calculates the sparse likelihood map from a similarity matrix chunk.""" + similarities = similarities.abs() + similarities[similarities > 0.999999999] = 1.0 + max_similarities = torch.max(similarities, dim=1, keepdim=True)[0] + sharpness = 10.0 * max_similarities + likelihood = similarities ** sharpness + + # Normalize rows + likelihood = sinkhorn_one_dim(likelihood) + + # Extract final top-k values from the normalized sparse likelihood matrix + top_k_likelihood, top_k_indices = likelihood.topk(args.top_k, dim=1) + + # Apply weight threshold filtering if specified + if args.weight_threshold > 0.0: + threshold_mask = top_k_likelihood >= args.weight_threshold + top_k_indices = top_k_indices.where(threshold_mask, torch.full_like(top_k_indices, -1)) + + return top_k_indices, top_k_likelihood + +def project_token_likelihoods(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device): + """Projects token likelihoods from a source to a target vocabulary using a sparse map.""" + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + if source_vocab_size != projection_map_indices.shape[0]: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches projection map size ({projection_map_indices.shape[0]})") + + top_k = projection_map_indices.shape[1] + input_likelihoods = input_likelihoods.to(device) + projection_map_indices = projection_map_indices.to(device) + projection_map_values = projection_map_values.to(device) + + crow_indices = torch.arange(0, (source_vocab_size + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = projection_map_indices.flatten() + values = projection_map_values.flatten() + + sparse_projection_matrix = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(source_vocab_size, target_vocab_size), device=device + ) + + reshaped_input = input_likelihoods.reshape(batch_size * seq_len, source_vocab_size) + projected_likelihoods_reshaped = torch.matmul(reshaped_input, sparse_projection_matrix) + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, target_vocab_size) + +def debug_projection_map(top_k_indices, top_k_likelihood, source_tokenizer, target_tokenizer, direction="", N=2000): + """Debug function to show first N rows with decoded tokens and weights.""" + N = min(N, top_k_indices.shape[0]) # Show first N rows or less + print(f"\n--- Debugging projection map {direction} (first {N} rows) ---") + + for row_idx in range(N): + # for row_idx in range(-N,-1): + # Decode source token + try: + token_id = row_idx if row_idx >= 0 else top_k_indices.shape[0] + row_idx + source_token = source_tokenizer.decode([token_id]) + # source_token = source_tokenizer.convert_ids_to_tokens([token_id])[0] + source_token_str = repr(source_token) # Use repr to show special chars + except: + source_token_str = f"" + + # Build the target tokens with weights string + row_indices = top_k_indices[row_idx].cpu().numpy() + row_weights = top_k_likelihood[row_idx].float().cpu().numpy() + + weight_total = 0 + target_parts = [] + + if row_weights.max() != row_weights[-1]: + continue + + + for target_idx, weight in zip(row_indices, row_weights): + try: + target_token = target_tokenizer.decode([target_idx]) + target_token_str = repr(target_token) + except: + target_token_str = f"" + + + target_parts.append(f"{target_token_str}({weight:.4f})") + weight_total += weight + + target_string = " ".join(target_parts) + # print(f"Weight total: {weight_total:.4f}") + print(f"{source_token_str} -> {target_string}") + +def generate_projection_map(similarities, args): + """Calculates the sparse likelihood map from a similarity matrix.""" + similarities = similarities.abs() + similarities[similarities > 0.999999999] = 1.0 + max_similarities = torch.max(similarities, dim=1, keepdim=True)[0] + sharpness = 10.0 * max_similarities + likelihood = similarities ** sharpness + + # Create a sparse representation by keeping only top-k values + # top_k_likelihood_pre_norm, _ = likelihood.topk(args.top_k, dim=1) + # likelihood = likelihood.where(likelihood >= top_k_likelihood_pre_norm[:, -1:], torch.zeros_like(likelihood)) + + # Normalize the row to sum to 1, handling rows that are all zero + # row_sums = likelihood.sum(dim=1, keepdim=True) + # safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + # likelihood = likelihood / safe_row_sums + # pdb.set_trace() + # likelihood = sinkhorn_one_dim(likelihood) + + # Get the final top-k values and their indices from the sparse, normalized likelihood matrix + top_k_likelihood, top_k_indices = likelihood.topk(args.top_k, dim=1) + + # Store top-k values before zeroing (to avoid losing them) + row_indices = torch.arange(likelihood.shape[0]).unsqueeze(1).expand(-1, args.top_k) + top_k_values = likelihood[row_indices, top_k_indices].clone() + + # Zero out entire likelihood matrix in-place, then restore only top-k elements + likelihood.zero_() + likelihood[row_indices, top_k_indices] = top_k_values + + # likelihood = sinkhorn(likelihood, n_iters=1) + # likelihood = sinkhorn(likelihood, n_iters=1) works the best + + + likelihood = sinkhorn_one_dim(likelihood) + + # Extract final top-k values from the normalized sparse likelihood matrix + top_k_likelihood, top_k_indices = likelihood.topk(args.top_k, dim=1) + + # Apply weight threshold filtering if specified + if args.weight_threshold > 0.0: + print(f"Applying weight threshold filter: {args.weight_threshold}") + # Create mask for values above threshold + # pdb.set_trace() + threshold_mask = top_k_likelihood >= args.weight_threshold + + #set indices to -1 where threshold is not met + top_k_indices = top_k_indices.where(threshold_mask, torch.full_like(top_k_indices, -1)) + + # # Count how many values per row are above threshold + # valid_counts = threshold_mask.sum(dim=1) + # total_filtered = (valid_counts == 0).sum().item() + # total_kept = threshold_mask.sum().item() + # total_possible = top_k_likelihood.numel() + + # print(f"Kept {total_kept}/{total_possible} ({100*total_kept/total_possible:.1f}%) projections above threshold") + + # if total_filtered > 0: + # print(f"Warning: {total_filtered} tokens have no projections above threshold {args.weight_threshold}") + + # # Zero out values below threshold + # filtered_likelihood = top_k_likelihood * threshold_mask.to(top_k_likelihood.dtype) + # filtered_indices = top_k_indices.clone() + + # # For rows with no values above threshold, keep the top value to avoid empty rows + # empty_rows = valid_counts == 0 + # if empty_rows.any(): + # print(f"Keeping top projection for {empty_rows.sum().item()} tokens with no values above threshold") + # filtered_likelihood[empty_rows, 0] = top_k_likelihood[empty_rows, 0] + + # top_k_likelihood = filtered_likelihood + # top_k_indices = filtered_indices + + # pdb.set_trace() + + return top_k_indices, top_k_likelihood + +# --- Main Execution --- +if __name__ == "__main__": + # Validate model selection arguments + validate_model_selection(args) + + # 1. Load Tokenizers and deterministically assign A and B + # Use model names if provided, otherwise use indices + if args.model_a_name is not None: + model_1 = {'id': args.model_a_name} + print(f"Using provided model A name: {args.model_a_name}") + else: + model_1 = {'id': MODEL_LIST[args.model_a_index]} + print(f"Using model A from index {args.model_a_index}: {model_1['id']}") + + model_1['name'] = model_1['id'].split("/")[-1] + print(f"Loading first tokenizer: {model_1['name']}") + model_1['tokenizer'] = load_tokenizer(model_1['id']) + + if args.model_b_name is not None: + model_2 = {'id': args.model_b_name} + print(f"Using provided model B name: {args.model_b_name}") + else: + model_2 = {'id': MODEL_LIST[args.model_b_index]} + print(f"Using model B from index {args.model_b_index}: {model_2['id']}") + + model_2['name'] = model_2['id'].split("/")[-1] + print(f"Loading second tokenizer: {model_2['name']}") + model_2['tokenizer'] = load_tokenizer(model_2['id']) + + # Deterministically assign model_A and model_B based on alphabetical order of names + if model_1['name'] > model_2['name']: + model_A, model_B = model_2, model_1 + else: + model_A, model_B = model_1, model_2 + + print(f"\nAssigned Source (A): {model_A['name']}") + print(f"Assigned Target (B): {model_B['name']}") + + source_vocab_size = model_A['tokenizer'].vocab_size + target_vocab_size = model_B['tokenizer'].vocab_size + # get the top k tokens from the source and target vocab from model config file + model_A_config = AutoConfig.from_pretrained(model_A['id'], trust_remote_code=True if 'nvidia' in model_A['id'] else False) + model_B_config = AutoConfig.from_pretrained(model_B['id'], trust_remote_code=True if 'nvidia' in model_B['id'] else False) + # pdb.set_trace() + source_vocab_size = model_A_config.vocab_size + if "gemma" not in model_B['id']: + target_vocab_size = model_B_config.vocab_size + else: + target_vocab_size = model_B_config.text_config.vocab_size + # print(f"Source top k tokens: {model_A_top_k_tokens}") + # print(f"Target top k tokens: {model_B_top_k_tokens}") + + print(f"Source vocab size (full): {source_vocab_size}") + print(f"Target vocab size (full): {target_vocab_size}") + # exit() + + + + if 0: + # just debugging learned projection map + # learned_projection_map = torch.load("models/runs/s4_l1q4b_lr0_kl1_ce0_k1_emb_top10_transformation_matrices/learned_projection_map_latest.pt") + # learned_projection_map = torch.load("cross_tokenizer_data/projection_map_Llama-3.2_to_Qwen3_multitoken_top_64_double.pt") + learned_projection_map = torch.load("cross_tokenizer_data/projection_matrix_learned_llama_qwen_top5.pt") + top_k_indices_A_to_B = learned_projection_map["indices"] + top_k_likelihood_A_to_B = learned_projection_map["likelihoods"] + debug_projection_map(top_k_indices_A_to_B, top_k_likelihood_A_to_B, model_A['tokenizer'], model_B['tokenizer'], "A -> B", N=150000) + exit() + + + + + + + + # 2. Select and Load Embedding Model + embedding_model_index = 3 # Default to a good LLM embedder + selected_model_info = EMBEDDING_MODEL_CHOICES[embedding_model_index] + embedding_model_name = selected_model_info["name"] + embedding_model_type = selected_model_info["type"] + print(f"\nUsing embedding model: {embedding_model_name} ({embedding_model_type})") + + # 3. Generate or Load Embeddings + canonicalization_suffix = "_canonical" if args.use_canonicalization else "_raw" + embeddings_path_A = os.path.join(args.data_dir, f"embeddings_{model_A['name']}_{embedding_model_name.replace('/', '_')}_full{canonicalization_suffix}.pt") + embeddings_path_B = os.path.join(args.data_dir, f"embeddings_{model_B['name']}_{embedding_model_name.replace('/', '_')}_full{canonicalization_suffix}.pt") + + if not args.force_recompute and os.path.exists(embeddings_path_A) and os.path.exists(embeddings_path_B): + print("Loading cached embeddings...") + model_A['embeddings'] = load_data(embeddings_path_A).to(DEVICE) + model_B['embeddings'] = load_data(embeddings_path_B).to(DEVICE) + else: + print("Generating new embeddings...") + + # Generate raw decoded tokens + raw_tokens_A = [model_A['tokenizer'].decode([idx]) for idx in range(model_A['tokenizer'].vocab_size)] + raw_tokens_B = [model_B['tokenizer'].decode([idx]) for idx in range(model_B['tokenizer'].vocab_size)] + + # Apply canonicalization if requested + if args.use_canonicalization: + # Import canonicalization function + import sys + sys.path.append('.') + from tokenalign import TokenAligner + + print("Applying token canonicalization before embedding generation...") + decoded_tokens_A = [TokenAligner._canonical_token(token) for token in raw_tokens_A] + decoded_tokens_B = [TokenAligner._canonical_token(token) for token in raw_tokens_B] + + # Show some examples of canonicalization + print("Canonicalization examples:") + for i in range(min(10, len(raw_tokens_A))): + if raw_tokens_A[i] != decoded_tokens_A[i]: + print(f" Model A: '{raw_tokens_A[i]}' -> '{decoded_tokens_A[i]}'") + for i in range(min(10, len(raw_tokens_B))): + if raw_tokens_B[i] != decoded_tokens_B[i]: + print(f" Model B: '{raw_tokens_B[i]}' -> '{decoded_tokens_B[i]}'") + + print(f"Applied canonicalization to {len(decoded_tokens_A)} tokens for model A and {len(decoded_tokens_B)} tokens for model B") + else: + print("Using raw decoded tokens without canonicalization") + decoded_tokens_A = raw_tokens_A + decoded_tokens_B = raw_tokens_B + + if embedding_model_type == "sbert": + sbert_model = SentenceTransformer(embedding_model_name, device=DEVICE) + model_A['embeddings'] = sbert_model.encode(decoded_tokens_A, convert_to_tensor=True, show_progress_bar=True) + model_B['embeddings'] = sbert_model.encode(decoded_tokens_B, convert_to_tensor=True, show_progress_bar=True) + elif embedding_model_type == "llm_first_layer": + llm_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, trust_remote_code=True) + if llm_tokenizer.pad_token is None: llm_tokenizer.pad_token = llm_tokenizer.eos_token + llm_model = AutoModel.from_pretrained( + embedding_model_name, torch_dtype=torch.bfloat16, trust_remote_code=True + ).to(DEVICE) + model_A['embeddings'] = get_llm_first_layer_embeddings(decoded_tokens_A, llm_tokenizer, llm_model, MAX_SEQ_LENGTH_EMBEDDING, DEVICE) + model_B['embeddings'] = get_llm_first_layer_embeddings(decoded_tokens_B, llm_tokenizer, llm_model, MAX_SEQ_LENGTH_EMBEDDING, DEVICE) + + save_data(model_A['embeddings'], embeddings_path_A) + save_data(model_B['embeddings'], embeddings_path_B) + + # 4. Compute Similarity and Generate Projection Maps (chunked to save memory) + print("\nComputing projection map in chunks to save memory...") + chunk_size = 500 # Process 500 tokens at a time to avoid OOM + top_k_indices_A_to_B, top_k_likelihood_A_to_B = compute_chunked_projection_map( + model_A['embeddings'], model_B['embeddings'], args, DEVICE, chunk_size=chunk_size) + + # Note: Exact match enforcement is skipped in chunked mode for simplicity + # The chunked approach processes similarities in small batches to avoid OOM + if 0: + debug_projection_map(top_k_indices_A_to_B, top_k_likelihood_A_to_B, model_A['tokenizer'], model_B['tokenizer'], "A -> B") + + # print("Generating B -> A projection map...") + # top_k_indices_B_to_A, top_k_likelihood_B_to_A = generate_projection_map(similarities.T, args) + # debug_projection_map(top_k_indices_B_to_A, top_k_likelihood_B_to_A, model_B['tokenizer'], model_A['tokenizer'], "B -> A") + + # 5. Save the Combined Projection Map + print("\nSaving combined projection map...") + model_a_clean_name = clean_model_name_for_filename(model_A['name']) + model_b_clean_name = clean_model_name_for_filename(model_B['name']) + # output_filename = f"temp_projection_map_{model_a_clean_name}_to_{model_b_clean_name}_bidirectional_top_{args.top_k}.pt" + output_filename = f"temp_projection_map_{model_a_clean_name}_to_{model_b_clean_name}_top_{args.top_k}" + # if args.skip_exact_enforcement: + # output_filename += "_no_exact" + output_filename += f".pt" + if args.weight_threshold > 0.0: + output_filename = output_filename.replace(".pt", f"_thresh_{args.weight_threshold:.3f}.pt") + output_path = os.path.join(args.data_dir, output_filename) + + torch.save({ + "indices": top_k_indices_A_to_B.cpu(), + "likelihoods": top_k_likelihood_A_to_B.cpu(), + "model_A_id": model_A['id'], + "model_B_id": model_B['id'], + }, output_path) + + # torch.save({ + # "A_to_B": { + # "indices": top_k_indices_A_to_B.cpu(), + # "likelihoods": top_k_likelihood_A_to_B.cpu() + # }, + # "B_to_A": { + # "indices": top_k_indices_B_to_A.cpu(), + # "likelihoods": top_k_likelihood_B_to_A.cpu() + # }, + # "model_A_id": model_A['id'], + # "model_B_id": model_B['id'], + # }, output_path) + print(f"Saved combined projection map to: {output_path}") + + # 6. Example Usage of the Projection Function + print("\n--- Testing projection function (A -> B) ---") + # Create a dummy likelihood tensor: [BATCH, SEQ, vocab_size_A] + source_vocab_size_A = model_A['embeddings'].shape[0] + target_vocab_size_B = model_B['embeddings'].shape[0] + dummy_tensor = torch.randn(1, 4096, source_vocab_size_A, device=DEVICE, dtype=torch.bfloat16) + + # Transform this tensor using the projection map (convert to float32 for compatibility) + projected_tensor = project_token_likelihoods(dummy_tensor.float(), top_k_indices_A_to_B, top_k_likelihood_A_to_B, target_vocab_size_B, DEVICE) + print(f"Input tensor shape: {dummy_tensor.shape}") + print(f"Projected tensor shape: {projected_tensor.shape}") + print("Projection test successful.") diff --git a/nemo_rl/utils/x_token/minimal_projection_via_multitoken.py b/nemo_rl/utils/x_token/minimal_projection_via_multitoken.py new file mode 100644 index 0000000000..ec67e4cc57 --- /dev/null +++ b/nemo_rl/utils/x_token/minimal_projection_via_multitoken.py @@ -0,0 +1,942 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +import torch.nn as nn +from nemo_rl.algorithms.x_token.tokenalign import TokenAligner +import gc +from collections import defaultdict +from datasets import load_dataset, get_dataset_config_names +import random +import time +import numpy as np +import tqdm +import pdb +import difflib +import re +import argparse +import os + + + +###### save as dense format and set indices to -1 where not used + + +#remove all special tokens that start with <| and end with |> + + +# compare 3 ways to estimate likelihood matrix: +# 1. using embeddings from another model, like was done in minimal_projection_generator.py +# 2. using text analysis like in tokenalign_likelihood_estimate.py +# 3. use one token to multiple and assign those as transformation matrix + +# this file implements 3rd way + +def sinkhorn_one_dim(A, n_iters=1): + for _ in range(n_iters): + + + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + + return A + +def apply_canonicalization_if_enabled(token_str, use_canonicalization): + """Apply canonicalization to token string if enabled.""" + if use_canonicalization: + return TokenAligner._canonical_token(token_str) + return token_str + +def create_weight_distribution(num_tokens): + """Create weight distribution for multi-token mappings""" + # if num_tokens == 1: + # return [1.0] + # elif num_tokens == 2: + # return [0.7, 0.3] + # elif num_tokens == 3: + # return [0.6, 0.3, 0.1] + # else: + if 1: + # For more tokens, use exponential decay + weights = [] + base = 0.9 + for i in range(num_tokens): + if i == 0: + weights.append(base) + else: + weights.append(base * (0.1 ** i)) + + # Normalize to sum to 1 + total = sum(weights) + weights = [w / total for w in weights] + return weights + +def find_similar_special_tokens(tokenizer_a, tokenizer_b, similarity_threshold=0.4, top_k_matches=3): + """Find similar special tokens between two tokenizers using string similarity.""" + + def is_special_token(token_str): + """Check if a token looks like a special token""" + return (token_str.startswith('<|') and token_str.endswith('|>')) or \ + (token_str.startswith('<') and token_str.endswith('>')) or \ + token_str in ['', '', '', '', '', ''] + + def extract_special_tokens(tokenizer): + """Extract all special tokens from a tokenizer with their IDs""" + special_tokens = {} + vocab = tokenizer.get_vocab() + for token_str, token_id in vocab.items(): + if is_special_token(token_str): + special_tokens[token_id] = token_str + return special_tokens + + def calculate_similarity(token_a, token_b): + """Calculate similarity between two token strings""" + # Use difflib for sequence similarity + seq_similarity = difflib.SequenceMatcher(None, token_a, token_b).ratio() + + # Extract key words from special tokens for semantic matching + def extract_keywords(token): + # Remove special token markers and split by common separators + cleaned = re.sub(r'[<>|_]', ' ', token.lower()) + words = [w for w in cleaned.split() if len(w) > 2] # Filter short words + return set(words) + + keywords_a = extract_keywords(token_a) + keywords_b = extract_keywords(token_b) + + # Jaccard similarity for keywords + if keywords_a or keywords_b: + keyword_similarity = len(keywords_a.intersection(keywords_b)) / len(keywords_a.union(keywords_b)) + else: + keyword_similarity = 0.0 + + # Combined similarity (weighted average) + return 0.6 * seq_similarity + 0.4 * keyword_similarity + + print("Extracting special tokens...") + special_tokens_a = extract_special_tokens(tokenizer_a) # student + special_tokens_b = extract_special_tokens(tokenizer_b) # teacher + + print(f"Found {len(special_tokens_a)} special tokens in student tokenizer") + print(f"Found {len(special_tokens_b)} special tokens in teacher tokenizer") + + # Find matches + special_token_mappings = [] + + print("Finding similar special tokens...") + for token_id_a, token_str_a in special_tokens_a.items(): + similarities = [] + for token_id_b, token_str_b in special_tokens_b.items(): + similarity = calculate_similarity(token_str_a, token_str_b) + if similarity >= similarity_threshold: + similarities.append((token_id_b, token_str_b, similarity)) + + # Sort by similarity and take top-k + similarities.sort(key=lambda x: x[2], reverse=True) + for token_id_b, token_str_b, similarity in similarities[:top_k_matches]: + special_token_mappings.append({ + 'student_id': token_id_a, + 'student_token': token_str_a, + 'teacher_id': token_id_b, + 'teacher_token': token_str_b, + 'similarity': similarity + }) + + return special_token_mappings + + +def parse_arguments(): + """Parse command line arguments for the multi-token projection script.""" + parser = argparse.ArgumentParser( + description="Generate multi-token projection mappings between tokenizers", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + # Model selection arguments + parser.add_argument( + "--student-model", + type=str, + default="meta-llama/Llama-3.2-1B", + help="Student model name or path" + ) + parser.add_argument( + "--teacher-model", + type=str, + default="microsoft/phi-4", + help="Teacher model name or path" + ) + + # Boolean flags + parser.add_argument( + "--enable-scale-trick", + action="store_true", + default=True, + help="Enable scale trick (set last column likelihood to 0.2)" + ) + parser.add_argument( + "--disable-scale-trick", + action="store_false", + dest="enable_scale_trick", + help="Disable scale trick" + ) + parser.add_argument( + "--enable-reverse-pass", + action="store_true", + default=True, + help="Enable second pass: student tokens -> teacher tokens" + ) + parser.add_argument( + "--disable-reverse-pass", + action="store_false", + dest="enable_reverse_pass", + help="Disable reverse pass" + ) + parser.add_argument( + "--enable-exact-match", + action="store_true", + default=False, + help="Enable exact match enforcement for identical tokens" + ) + parser.add_argument( + "--use-raw-tokens", + action="store_true", + default=False, + help="Use convert_ids_to_tokens instead of decode, should be False" + ) + parser.add_argument( + "--enable-special-token-mapping", + action="store_true", + default=True, + help="Enable mapping of similar special tokens" + ) + parser.add_argument( + "--disable-special-token-mapping", + action="store_false", + dest="enable_special_token_mapping", + help="Disable special token mapping" + ) + parser.add_argument( + "--use-canonicalization", + action="store_true", + default=False, + help="Apply token canonicalization before processing to normalize different tokenizer representations (e.g., Ġ vs ▁ prefixes, Ċ vs \\n)" + ) + + # Numeric parameters + parser.add_argument( + "--tokens-to-cut", + type=int, + default=4, + help="Maximum number of tokens to consider for multi-token mappings" + ) + parser.add_argument( + "--top-k", + type=int, + default=32, + help="Number of top projections to keep for each token" + ) + parser.add_argument( + "--special-token-similarity-threshold", + type=float, + default=0.3, + help="Minimum similarity threshold for special token matching" + ) + parser.add_argument( + "--special-token-top-k", + type=int, + default=None, + help="Top K matches for each special token (defaults to --top-k value)" + ) + + # File paths + parser.add_argument( + "--initial-projection-path", + type=str, + default=None, + help="Path to initial projection map to load and extend" + ) + parser.add_argument( + "--output-dir", + type=str, + default="cross_tokenizer_data", + help="Output directory for saving projection maps" + ) + + return parser.parse_args() + + +if __name__ == "__main__": + # Parse command line arguments + args = parse_arguments() + + # Configuration from arguments + ENABLE_SCALE_TRICK = args.enable_scale_trick + ENABLE_REVERSE_PASS = args.enable_reverse_pass + ENABLE_EXACT_MATCH = args.enable_exact_match + + TOKENS_TO_CUT = args.tokens_to_cut + TOP_K = args.top_k + USE_RAW_TOKENS = args.use_raw_tokens + INITIAL_PROJECTION_PATH = args.initial_projection_path + ENABLE_SPECIAL_TOKEN_MAPPING = args.enable_special_token_mapping + SPECIAL_TOKEN_SIMILARITY_THRESHOLD = args.special_token_similarity_threshold + SPECIAL_TOKEN_TOP_K = args.special_token_top_k if args.special_token_top_k is not None else TOP_K + USE_CANONICALIZATION = args.use_canonicalization + + # Model names from arguments + teacher_model_name = args.teacher_model + student_model_name = args.student_model + + # Print configuration + print("=== Configuration ===") + print(f"Student model: {student_model_name}") + print(f"Teacher model: {teacher_model_name}") + print(f"Enable scale trick: {ENABLE_SCALE_TRICK}") + print(f"Enable reverse pass: {ENABLE_REVERSE_PASS}") + print(f"Enable exact match: {ENABLE_EXACT_MATCH}") + print(f"Use raw tokens: {USE_RAW_TOKENS}") + print(f"Use canonicalization: {USE_CANONICALIZATION}") + print(f"Tokens to cut: {TOKENS_TO_CUT}") + print(f"Top K: {TOP_K}") + print(f"Enable special token mapping: {ENABLE_SPECIAL_TOKEN_MAPPING}") + if ENABLE_SPECIAL_TOKEN_MAPPING: + print(f"Special token similarity threshold: {SPECIAL_TOKEN_SIMILARITY_THRESHOLD}") + print(f"Special token top K: {SPECIAL_TOKEN_TOP_K}") + print(f"Initial projection path: {INITIAL_PROJECTION_PATH}") + print(f"Output directory: {args.output_dir}") + print("=" * 25) + + tokenizer_student = AutoTokenizer.from_pretrained(student_model_name) + tokenizer_teacher = AutoTokenizer.from_pretrained(teacher_model_name) + + tokenizer_student_total_vocab_size = len(tokenizer_student) + tokenizer_teacher_total_vocab_size = len(tokenizer_teacher) + model_A_config = AutoConfig.from_pretrained(student_model_name) + model_B_config = AutoConfig.from_pretrained(teacher_model_name) + # pdb.set_trace() + if "gemma" not in student_model_name.lower(): + source_vocab_size = model_A_config.vocab_size + else: + source_vocab_size = model_A_config.text_config.vocab_size + + if "gemma" not in teacher_model_name.lower(): + target_vocab_size = model_B_config.vocab_size + else: + target_vocab_size = model_B_config.text_config.vocab_size + + tokenizer_student_total_vocab_size = source_vocab_size + tokenizer_teacher_total_vocab_size = target_vocab_size + # print(f"Source top k tokens: {model_A_top_k_tokens}") + # print(f"Target top k tokens: {model_B_top_k_tokens}") + + print(f"Student tokenizer total vocab size: {tokenizer_student_total_vocab_size}") + print(f"Teacher tokenizer total vocab size: {tokenizer_teacher_total_vocab_size}") + + # Print token processing mode + if USE_RAW_TOKENS: + print("Using raw token representation (convert_ids_to_tokens)") + else: + print("Using decoded token representation (decode)") + + + transformation_counts = defaultdict(float) + import os + if INITIAL_PROJECTION_PATH and os.path.exists(INITIAL_PROJECTION_PATH): + print(f"Loading initial projection from: {INITIAL_PROJECTION_PATH}") + initial_projection_map = torch.load(INITIAL_PROJECTION_PATH, map_location='cpu') + + if isinstance(initial_projection_map, dict) and 'indices' in initial_projection_map and 'likelihoods' in initial_projection_map: + print("Loading from sparse top-k format and converting to transformation_counts.") + indices = initial_projection_map['indices'] + likelihoods = initial_projection_map['likelihoods'] + + loaded_student_model = initial_projection_map.get('model_A_id') + loaded_teacher_model = initial_projection_map.get('model_B_id') + + if loaded_student_model and loaded_student_model != student_model_name: + print(f"Warning: Student model mismatch. Loaded: {loaded_student_model}, Current: {student_model_name}") + if loaded_teacher_model and loaded_teacher_model != teacher_model_name: + print(f"Warning: Teacher model mismatch. Loaded: {loaded_teacher_model}, Current: {teacher_model_name}") + + num_student_tokens = indices.shape[0] + top_k = indices.shape[1] + + for student_id in tqdm.tqdm(range(num_student_tokens), desc="Converting initial projection to counts"): + for k in range(top_k): + teacher_id = indices[student_id, k].item() + if teacher_id != -1: + likelihood = likelihoods[student_id, k].item() + if likelihood > 0: + transformation_counts[(student_id, teacher_id)] = likelihood + + elif torch.is_tensor(initial_projection_map): + if initial_projection_map.is_sparse: + print("Loading from sparse tensor and converting to transformation_counts.") + sparse_matrix = initial_projection_map.coalesce() + map_indices = sparse_matrix.indices() + map_values = sparse_matrix.values() + for i in tqdm.tqdm(range(map_indices.shape[1]), desc="Converting sparse tensor to counts"): + student_id = map_indices[0, i].item() + teacher_id = map_indices[1, i].item() + weight = map_values[i].item() + if weight > 0: + transformation_counts[(student_id, teacher_id)] = weight + else: + print("Loading from dense matrix and converting to transformation_counts.") + dense_matrix = initial_projection_map + non_zero_indices = torch.nonzero(dense_matrix, as_tuple=False) + for idx in tqdm.tqdm(range(non_zero_indices.shape[0]), desc="Converting dense projection to counts"): + student_id = non_zero_indices[idx, 0].item() + teacher_id = non_zero_indices[idx, 1].item() + weight = dense_matrix[student_id, teacher_id].item() + if weight > 0: + transformation_counts[(student_id, teacher_id)] = weight + else: + print(f"Warning: Unrecognized format for initial projection map at {INITIAL_PROJECTION_PATH}. Skipping.") + + print(f"Initialized transformation_counts with {len(transformation_counts)} entries.") + + # pdb.set_trace() + + ignore_tokens = ['<|endoftext|>', '', ] + ignore_student_ids = {tokenizer_student.convert_tokens_to_ids(token) for token in ignore_tokens if token in tokenizer_student.get_vocab()} + ignore_teacher_ids = {tokenizer_teacher.convert_tokens_to_ids(token) for token in ignore_tokens if token in tokenizer_teacher.get_vocab()} + + # Get all teacher tokens and decode them + teacher_vocab = tokenizer_teacher.get_vocab() + teacher_tokens_decoded = {} + + print("Decoding teacher tokens...") + for token_id in tqdm.tqdm(range(tokenizer_teacher_total_vocab_size), desc="Decoding teacher tokens"): + if token_id in ignore_teacher_ids: + continue + try: + # Get token representation based on configuration + if USE_RAW_TOKENS: + decoded = tokenizer_teacher.convert_ids_to_tokens([token_id])[0] + else: + decoded = tokenizer_teacher.decode([token_id]) + + # Apply canonicalization if enabled + decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION) + teacher_tokens_decoded[token_id] = decoded + except: + # Skip tokens that can't be processed + continue + + print(f"Successfully decoded {len(teacher_tokens_decoded)} teacher tokens") + + # Find multi-token mappings + multi_token_examples = [] + + + print("=== FIRST PASS: Teacher tokens -> Student tokens ===") + print("Finding multi-token mappings...") + + + # First pass: Teacher tokens -> Student tokens (reverse direction) + if 1: + print("\n=== First PASS: Student tokens -> Teacher tokens ===") + + # Get all student tokens and decode them + student_vocab = tokenizer_student.get_vocab() + + student_tokens_decoded = {} + + print("Decoding student tokens...") + for token_id in tqdm.tqdm(range(tokenizer_student_total_vocab_size), desc="Decoding student tokens"): + if token_id in ignore_student_ids: + continue + try: + # Get token representation based on configuration + if USE_RAW_TOKENS: + decoded = tokenizer_student.convert_ids_to_tokens([token_id])[0] + else: + decoded = tokenizer_student.decode([token_id]) + + if decoded.startswith("<|") and decoded.endswith("|>"): + print(f"Skipping special token: {decoded}") + continue + + # Apply canonicalization if enabled + decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION) + student_tokens_decoded[token_id] = decoded + except: + # Skip tokens that can't be processed + continue + + print(f"Successfully decoded {len(student_tokens_decoded)} student tokens") + + reverse_multi_token_examples = [] + print("Finding reverse multi-token mappings...") + for student_token_id, student_token_str in tqdm.tqdm(student_tokens_decoded.items(), desc="Processing student tokens"): + # Tokenize the student token string using teacher tokenizer + teacher_encoding = tokenizer_teacher(student_token_str, add_special_tokens=False, return_attention_mask=False) + teacher_token_ids = teacher_encoding['input_ids'] + + # Skip if any teacher token is in ignore list + if any(tid in ignore_teacher_ids for tid in teacher_token_ids): + continue + + # Cut to only first 4 tokens + teacher_token_ids = teacher_token_ids[:TOKENS_TO_CUT] + + # Get weight distribution based on number of teacher tokens + weights = create_weight_distribution(len(teacher_token_ids)) + + # Add to transformation matrix (reverse direction: teacher_token_id -> student_token_id) + if 1: + for teacher_token_id, weight in zip(teacher_token_ids, weights): + transformation_counts[(student_token_id, teacher_token_id)] += weight + + # Collect examples for analysis + if len(teacher_token_ids) >= 2: + teacher_tokens_decoded_reverse = [tokenizer_teacher.decode([tid]) for tid in teacher_token_ids] + reverse_multi_token_examples.append({ + 'student_token': student_token_str, + 'student_id': student_token_id, + 'teacher_tokens': teacher_tokens_decoded_reverse, + 'teacher_ids': teacher_token_ids, + 'weights': weights + }) + + # second pass: Teacher tokens -> Student tokens (opposite direction) + if ENABLE_REVERSE_PASS: + print("\n=== secod PASS: Teacher tokens -> Student tokens ===") + + # Get all teacher tokens and decode them + teacher_vocab = tokenizer_teacher.get_vocab() + teacher_tokens_decoded = {} + + print("Decoding teacher tokens...") + for token_id in tqdm.tqdm(range(tokenizer_teacher_total_vocab_size), desc="Decoding teacher tokens"): + if token_id in ignore_teacher_ids: + continue + try: + # Get token representation based on configuration + if USE_RAW_TOKENS: + decoded = tokenizer_teacher.convert_ids_to_tokens([token_id])[0] + else: + decoded = tokenizer_teacher.decode([token_id]) + + if decoded.startswith("<|") and decoded.endswith("|>"): + print(f"Skipping special token: {decoded}") + continue + + # Apply canonicalization if enabled + decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION) + teacher_tokens_decoded[token_id] = decoded + except: + # Skip tokens that can't be processed + continue + + print(f"Successfully decoded {len(teacher_tokens_decoded)} teacher tokens") + + teacher_to_student_multi_token_examples = [] + print("Finding teacher->student multi-token mappings...") + for teacher_token_id, teacher_token_str in tqdm.tqdm(teacher_tokens_decoded.items(), desc="Processing teacher tokens"): + # Tokenize the teacher token string using student tokenizer + student_encoding = tokenizer_student(teacher_token_str, add_special_tokens=False, return_attention_mask=False) + student_token_ids = student_encoding['input_ids'] + + # Skip if any student token is in ignore list + if any(sid in ignore_student_ids for sid in student_token_ids): + continue + + # Cut to only first 4 tokens + student_token_ids = student_token_ids[:TOKENS_TO_CUT] + + # Get weight distribution based on number of student tokens + weights = create_weight_distribution(len(student_token_ids)) + + # Add to transformation matrix (student_token_id -> teacher_token_id mapping) + if 1: + for student_token_id, weight in zip(student_token_ids, weights): + transformation_counts[(student_token_id, teacher_token_id)] += weight + + # Collect examples for analysis + if len(student_token_ids) >= 2: + student_tokens_decoded_reverse = [tokenizer_student.decode([sid]) for sid in student_token_ids] + teacher_to_student_multi_token_examples.append({ + 'teacher_token': teacher_token_str, + 'teacher_id': teacher_token_id, + 'student_tokens': student_tokens_decoded_reverse, + 'student_ids': student_token_ids, + 'weights': weights + }) + + print(f"\n=== ADDING SPECIAL TOKEN MAPPINGS ===") + + # Find and add special token mappings (if enabled) + special_token_mappings = [] + if ENABLE_SPECIAL_TOKEN_MAPPING: + special_token_mappings = find_similar_special_tokens( + tokenizer_student, + tokenizer_teacher, + similarity_threshold=SPECIAL_TOKEN_SIMILARITY_THRESHOLD, + top_k_matches=SPECIAL_TOKEN_TOP_K + ) + else: + print("Special token mapping disabled") + + if special_token_mappings: + print(f"\nFound {len(special_token_mappings)} special token mappings:") + initial_transformation_count = len(transformation_counts) + + # Add ALL mappings to transformation matrix + for mapping in special_token_mappings: + student_id = mapping['student_id'] + teacher_id = mapping['teacher_id'] + similarity = mapping['similarity'] + + # Add mapping with weight based on similarity + weight = similarity * 0.8 # Scale similarity to reasonable weight + transformation_counts[(student_id, teacher_id)] += weight + + # Group mappings by student token and show top 2 matches per student token + from collections import defaultdict + student_mappings = defaultdict(list) + for mapping in special_token_mappings: + student_mappings[mapping['student_id']].append(mapping) + + # Sort each student's mappings by similarity and show top 2 + print("Top 2 matches per student special token:") + shown_count = 0 + for student_id, mappings in student_mappings.items(): + # Sort by similarity (highest first) + sorted_mappings = sorted(mappings, key=lambda x: x['similarity'], reverse=True) + + # Show top 2 for this student token + student_token = sorted_mappings[0]['student_token'] # Get student token name + print(f" {student_token}:") + + for mapping in sorted_mappings[:2]: + similarity = mapping['similarity'] + weight = similarity * 0.8 + print(f" -> '{mapping['teacher_token']}' (similarity: {similarity:.3f}, weight: {weight:.3f})") + shown_count += 1 + + if len(sorted_mappings) > 2: + print(f" ... and {len(sorted_mappings) - 2} more matches") + + total_hidden = len(special_token_mappings) - shown_count + if total_hidden > 0: + print(f"Total mappings not shown: {total_hidden}") + + added_count = len(transformation_counts) - initial_transformation_count + print(f"Added {added_count} new special token transformation entries") + else: + print("No similar special tokens found") + + print(f"\n=== SUMMARY ===") + print(f"Found {len(multi_token_examples)} teacher tokens that map to multiple student tokens") + # exit() + # Show some examples + if multi_token_examples: + print("\nExamples of multi-token mappings:") + for i, example in enumerate(multi_token_examples[:10]): + print(f" Teacher '{example['teacher_token']}' -> Student {example['student_tokens']} (weights: {example['weights']})") + if len(multi_token_examples) > 10: + print(f" ... and {len(multi_token_examples) - 10} more.") + + if ENABLE_REVERSE_PASS: + print(f"\nReverse pass enabled - added bidirectional mappings") + + print(f"\nTotal transformation entries: {len(transformation_counts)}") + + if ENABLE_EXACT_MATCH: + + print("Checking for exact token matches and setting exact mappings...") + # check exact match between student and teacher tokens and set those as perfect 1-to-1 mappings + # Convert all tokens to strings at once for vectorized comparison + # pdb.set_trace() + tokens_student = [apply_canonicalization_if_enabled(tokenizer_student.convert_ids_to_tokens([i])[0], USE_CANONICALIZATION) for i in range(tokenizer_student_total_vocab_size)] + tokens_teacher = [apply_canonicalization_if_enabled(tokenizer_teacher.convert_ids_to_tokens([j])[0], USE_CANONICALIZATION) for j in range(tokenizer_teacher_total_vocab_size)] + + map_teacher_token_to_idx = {token: j for j, token in enumerate(tokens_teacher)} + + # Find indices in student and teacher where the tokens are identical + match_indices_student = [] + match_indices_teacher = [] + for i, token_student in enumerate(tokens_student): + if token_student in map_teacher_token_to_idx: + j = map_teacher_token_to_idx[token_student] + match_indices_student.append(i) + match_indices_teacher.append(j) + + if match_indices_student: + print(f"Found {len(match_indices_student)} exact matches. Setting perfect 1-to-1 mappings.") + + # For tokens that match exactly, we want their mapping to be 1.0 + # and they should not be mapped to any other token. + # First, remove all existing mappings for these student tokens + match_indices_student_set = set(match_indices_student) + keys_to_remove = [] + for key in transformation_counts.keys(): + student_id, teacher_id = key + if student_id in match_indices_student_set: + keys_to_remove.append(key) + for key in keys_to_remove: + del transformation_counts[key] + # Then, set the perfect 1-to-1 mappings for exact matches + for student_id, teacher_id in zip(match_indices_student, match_indices_teacher): + transformation_counts[(student_id, teacher_id)] = 1.0 + + + def debug_projection_map(transformation_counts, source_tokenizer, target_tokenizer, direction="", N=50): + """Debug function to show projection mappings with decoded tokens and weights.""" + print(f"\n--- Debugging projection map {direction} (showing {N} examples) ---") + + # Group transformation_counts by source token (student token) + source_to_targets = defaultdict(list) + for (source_id, target_id), weight in transformation_counts.items(): + source_to_targets[source_id].append((target_id, weight)) + + # Sort by source token ID and take first N + # sorted_sources = sorted(source_to_targets.keys())[:N] + sorted_sources = sorted(source_to_targets.keys())[-N:] + + for source_id in sorted_sources: + # Decode source token + try: + if USE_RAW_TOKENS: + source_token = source_tokenizer.convert_ids_to_tokens([source_id])[0] + else: + source_token = source_tokenizer.decode([source_id]) + source_token = apply_canonicalization_if_enabled(source_token, USE_CANONICALIZATION) + source_token_str = repr(source_token) # Use repr to show special chars + except: + source_token_str = f"" + + # Sort targets by weight (descending) and build target string + targets_weights = sorted(source_to_targets[source_id], key=lambda x: x[1], reverse=True) + + target_parts = [] + for target_id, weight in targets_weights: + try: + if USE_RAW_TOKENS: + target_token = target_tokenizer.convert_ids_to_tokens([target_id])[0] + else: + target_token = target_tokenizer.decode([target_id]) + target_token = apply_canonicalization_if_enabled(target_token, USE_CANONICALIZATION) + target_token_str = repr(target_token) + except: + target_token_str = f"" + target_parts.append(f"{target_token_str}({weight:.4f})") + + target_string = " ".join(target_parts) + print(f"{source_token_str} -> {target_string}") + + # debug_projection_map(transformation_counts, tokenizer_student, tokenizer_teacher, + # direction="student->teacher", N=1000) + + # Create transformation matrix (student -> teacher projection) + indices = list(transformation_counts.keys()) + values = list(transformation_counts.values()) + + teacher_indices = [idx[1] for idx in indices] + student_indices = [idx[0] for idx in indices] + + # Create sparse tensor with student tokens as rows, teacher tokens as columns + # This creates a student -> teacher projection matrix + indices_tensor = torch.LongTensor([student_indices, teacher_indices]) + values_tensor = torch.FloatTensor(values) + + transformation_matrix_sparse = torch.sparse_coo_tensor( + indices_tensor, + values_tensor, + (tokenizer_student_total_vocab_size, tokenizer_teacher_total_vocab_size), + device="cuda" if torch.cuda.is_available() else "cpu", + dtype=torch.bfloat16 + ) + + # indices, values = torch.topk(transformation_matrix_sparse, k=1000, dim=1) + + print(f"Created sparse student->teacher projection matrix with shape: {transformation_matrix_sparse.shape}") + print(f"Non-zero elements: {transformation_matrix_sparse._nnz()}") + + if 0: + # cant fit to the memory + # Calculate mapping statistics from sparse matrix + print("\nCalculating mapping statistics from projection matrix...") + + # Count non-zero elements in each row (each row = student token) + dense_matrix = transformation_matrix_sparse.to_dense() + non_zero_counts_per_row = (dense_matrix != 0).sum(dim=1) # Count non-zeros per row + + # Create statistics + mapping_stats = defaultdict(int) + for count in non_zero_counts_per_row: + mapping_stats[count.item()] += 1 + + # Print mapping statistics + print("\nMapping statistics (student tokens -> teacher tokens):") + for i in range(1, 5): # 1, 2, 3, 4 teacher tokens + count = mapping_stats.get(i, 0) + print(f"Student tokens mapping to {i} teacher tokens: {count}") + + total_mapped = sum(mapping_stats.values()) + print(f"Total student tokens mapped: {total_mapped}") + + # Convert sparse matrix to same format as minimal_projection_generator.py + os.makedirs(args.output_dir, exist_ok=True) + + # Convert defaultdict to regular dict for saving + transformation_counts_dict = dict(transformation_counts) + + # Show some examples of the projection mappings + debug_projection_map(transformation_counts_dict, tokenizer_student, tokenizer_teacher, + direction="student->teacher", N=1000) + + # exit() + + print(f"\nConverting sparse matrix to top-{TOP_K} dense format...") + + # Convert sparse matrix to dense and get top-k values per row + print("Converting to dense matrix on CPU to avoid memory issues...") + dense_matrix = transformation_matrix_sparse.cpu().to_dense() # Move to CPU to handle memory + print(f"Dense matrix shape: {dense_matrix.shape}") + + # Get top-k values and indices for each row (each source token) + print(f"Extracting top-{TOP_K} values per token...") + + # Apply sinkhorn normalization on CPU + if 1: + print("Applying Sinkhorn normalization on CPU...") + dense_matrix = sinkhorn_one_dim(dense_matrix, n_iters=1) + + # Extract top-k on CPU + top_k_likelihoods, top_k_indices = torch.topk(dense_matrix, k=min(TOP_K, dense_matrix.shape[1]), dim=1) + # exit() + # Handle case where vocabulary has fewer tokens than TOP_K + actual_k = top_k_indices.shape[1] + if actual_k < TOP_K: + print(f"Warning: Target vocabulary size ({dense_matrix.shape[1]}) is smaller than TOP_K ({TOP_K}). Using k={actual_k}") + # Pad with -1 indices and 0.0 likelihoods to maintain consistent shape + pad_size = TOP_K - actual_k + top_k_indices = torch.cat([top_k_indices, torch.full((top_k_indices.shape[0], pad_size), -1, dtype=top_k_indices.dtype)], dim=1) + top_k_likelihoods = torch.cat([top_k_likelihoods, torch.zeros((top_k_likelihoods.shape[0], pad_size), dtype=top_k_likelihoods.dtype)], dim=1) + + if 0: + threshold_mask = top_k_likelihoods >= 0.0000000000000000001 + top_k_indices = top_k_indices.where(threshold_mask, torch.full_like(top_k_indices, -1)) + + # Apply SCALE_TRICK: set last column to -4 if enabled + if ENABLE_SCALE_TRICK: + print("ENABLE_SCALE_TRICK is True: Setting last column of likelihoods to -4.0") + top_k_likelihoods[:, -1] = 0.2 + if ENABLE_EXACT_MATCH: + for indices in match_indices_student: + top_k_likelihoods[indices, -1] = 0.0 + print(f"Set last column of likelihoods to 0.0 for {len(match_indices_student)} exact matches as exact match is enabled") + # Apply sinkhorn normalization on CPU + if 1: + print("Applying Sinkhorn normalization on CPU...") + top_k_likelihoods = sinkhorn_one_dim(top_k_likelihoods, n_iters=1) + + # pdb.set_trace() + #set indices to -1 where likelihood is 0 + + # Create filename in same format as minimal_projection_generator.py + def clean_model_name_for_filename(name: str) -> str: + """Removes parameter counts and common suffixes from model names for cleaner filenames.""" + import re + # Removes patterns like -8B, -1.5B, -4b, -125m etc. + cleaned_name = re.sub(r'-?[0-9\.]+[bBmB]', '', name, flags=re.IGNORECASE) + # Remove common suffixes + cleaned_name = cleaned_name.replace('-Base', '').replace('-it', '').replace('-Instruct', '') + # Clean up any leading/trailing hyphens that might result + cleaned_name = cleaned_name.strip('-_') + return cleaned_name + + student_clean_name = clean_model_name_for_filename(student_model_name.split("/")[-1]) + teacher_clean_name = clean_model_name_for_filename(teacher_model_name.split("/")[-1]) + + output_filename = f"projection_map_{student_clean_name}_to_{teacher_clean_name}_multitoken_top_{TOP_K}_double" + # if USE_RAW_TOKENS: + # output_filename += "_raw_tokens" + if ENABLE_SPECIAL_TOKEN_MAPPING: + output_filename += "_special" + output_filename += ".pt" + # if ENABLE_REVERSE_PASS: + # output_filename = output_filename.replace(".pt", "_bidirectional.pt") + output_path = os.path.join(args.output_dir, output_filename) + + # Save in same format as minimal_projection_generator.py + torch.save({ + "indices": top_k_indices, + "likelihoods": top_k_likelihoods, + "model_A_id": student_model_name, # source model (student) + "model_B_id": teacher_model_name, # target model (teacher) + }, output_path) + + print(f"Saved projection map to: {output_path}") + print(f"Format: indices shape {top_k_indices.shape}, likelihoods shape {top_k_likelihoods.shape}") + print(f"Compatible with minimal_projection_generator.py format") + print(f"Token processing mode: {'Raw tokens (convert_ids_to_tokens)' if USE_RAW_TOKENS else 'Decoded tokens (decode)'}") + if ENABLE_REVERSE_PASS: + print("File includes bidirectional mappings (teacher->student and student->teacher)") + if ENABLE_SPECIAL_TOKEN_MAPPING: + print(f"File includes special token mappings (similarity_threshold={SPECIAL_TOKEN_SIMILARITY_THRESHOLD}, top_k={SPECIAL_TOKEN_TOP_K})") + # exit() + + + + + # Test projection function compatibility (same as minimal_projection_generator.py) + print("\n--- Testing projection function compatibility ---") + + def project_token_likelihoods(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device): + """Projects token likelihoods from a source to a target vocabulary using a sparse map.""" + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + if source_vocab_size != projection_map_indices.shape[0]: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches projection map size ({projection_map_indices.shape[0]})") + + top_k = projection_map_indices.shape[1] + input_likelihoods = input_likelihoods.to(device) + projection_map_indices = projection_map_indices.to(device) + projection_map_values = projection_map_values.to(device) + + crow_indices = torch.arange(0, (source_vocab_size + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = projection_map_indices.flatten() + values = projection_map_values.flatten() + + sparse_projection_matrix = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(source_vocab_size, target_vocab_size), device=device + ) + + reshaped_input = input_likelihoods.reshape(batch_size * seq_len, source_vocab_size) + projected_likelihoods_reshaped = torch.matmul(reshaped_input, sparse_projection_matrix) + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, target_vocab_size) + + # Create a dummy likelihood tensor: [BATCH, SEQ, source_vocab_size] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dummy_tensor = torch.randn(1, 4096, tokenizer_student_total_vocab_size, device=device, dtype=torch.bfloat16) + + # Transform this tensor using the projection map + projected_tensor = project_token_likelihoods( + dummy_tensor, + top_k_indices.to(device), + top_k_likelihoods.to(device), + tokenizer_teacher_total_vocab_size, + device + ) + print(f"Input tensor shape: {dummy_tensor.shape}") + print(f"Projected tensor shape: {projected_tensor.shape}") + print("Projection test successful - format is fully compatible!") + + # pdb.set_trace() + \ No newline at end of file diff --git a/nemo_rl/utils/x_token/reapply_exact_map.py b/nemo_rl/utils/x_token/reapply_exact_map.py new file mode 100644 index 0000000000..b37f994466 --- /dev/null +++ b/nemo_rl/utils/x_token/reapply_exact_map.py @@ -0,0 +1,246 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +import torch.nn as nn +from tokenalign import TokenAligner +import gc +from collections import defaultdict +from datasets import load_dataset, get_dataset_config_names +import random +import time +import numpy as np +import tqdm +import pdb +import difflib +import re +import argparse +import os + +def apply_canonicalization_if_enabled(token_str, use_canonicalization): + """Apply canonicalization to token string if enabled.""" + if use_canonicalization: + return TokenAligner._canonical_token(token_str) + return token_str + +def parse_arguments(): + """Parse command line arguments for the multi-token projection script.""" + parser = argparse.ArgumentParser( + description="Generate multi-token projection mappings between tokenizers", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + # Model selection arguments + parser.add_argument( + "--student-model", + type=str, + default="meta-llama/Llama-3.2-1B", + help="Student model name or path" + ) + parser.add_argument( + "--teacher-model", + type=str, + default="microsoft/phi-4", + help="Teacher model name or path" + ) + + # Boolean flags + parser.add_argument( + "--enable-scale-trick", + action="store_true", + default=True, + help="Enable scale trick (set last column likelihood to 0.2)" + ) + parser.add_argument( + "--disable-scale-trick", + action="store_false", + dest="enable_scale_trick", + help="Disable scale trick" + ) + parser.add_argument( + "--enable-reverse-pass", + action="store_true", + default=True, + help="Enable second pass: student tokens -> teacher tokens" + ) + parser.add_argument( + "--disable-reverse-pass", + action="store_false", + dest="enable_reverse_pass", + help="Disable reverse pass" + ) + parser.add_argument( + "--enable-exact-match", + action="store_true", + default=False, + help="Enable exact match enforcement for identical tokens" + ) + parser.add_argument( + "--use-raw-tokens", + action="store_true", + default=False, + help="Use convert_ids_to_tokens instead of decode, should be False" + ) + parser.add_argument( + "--enable-special-token-mapping", + action="store_true", + default=True, + help="Enable mapping of similar special tokens" + ) + parser.add_argument( + "--disable-special-token-mapping", + action="store_false", + dest="enable_special_token_mapping", + help="Disable special token mapping" + ) + parser.add_argument( + "--use-canonicalization", + action="store_true", + default=False, + help="Apply token canonicalization before processing to normalize different tokenizer representations (e.g., Ġ vs ▁ prefixes, Ċ vs \\n)" + ) + + # Numeric parameters + parser.add_argument( + "--tokens-to-cut", + type=int, + default=4, + help="Maximum number of tokens to consider for multi-token mappings" + ) + parser.add_argument( + "--top-k", + type=int, + default=32, + help="Number of top projections to keep for each token" + ) + parser.add_argument( + "--special-token-similarity-threshold", + type=float, + default=0.3, + help="Minimum similarity threshold for special token matching" + ) + parser.add_argument( + "--special-token-top-k", + type=int, + default=None, + help="Top K matches for each special token (defaults to --top-k value)" + ) + + # File paths + parser.add_argument( + "--initial-projection-path", + type=str, + default=None, + help="Path to initial projection map to load and extend" + ) + parser.add_argument( + "--output-dir", + type=str, + default="cross_tokenizer_data", + help="Output directory for saving projection maps" + ) + + return parser.parse_args() + + +if __name__ == "__main__": + # Parse command line arguments + args = parse_arguments() + # Model names from arguments + teacher_model_name = args.teacher_model + student_model_name = args.student_model + USE_CANONICALIZATION = args.use_canonicalization + + + tokenizer_student = AutoTokenizer.from_pretrained(student_model_name) + tokenizer_teacher = AutoTokenizer.from_pretrained(teacher_model_name) + + tokenizer_student_total_vocab_size = len(tokenizer_student) + tokenizer_teacher_total_vocab_size = len(tokenizer_teacher) + model_A_config = AutoConfig.from_pretrained(student_model_name) + model_B_config = AutoConfig.from_pretrained(teacher_model_name) + + tokens_student = [apply_canonicalization_if_enabled(tokenizer_student.convert_ids_to_tokens([i])[0], USE_CANONICALIZATION) for i in range(tokenizer_student_total_vocab_size)] + tokens_teacher = [apply_canonicalization_if_enabled(tokenizer_teacher.convert_ids_to_tokens([j])[0], USE_CANONICALIZATION) for j in range(tokenizer_teacher_total_vocab_size)] + + map_teacher_token_to_idx = {token: j for j, token in enumerate(tokens_teacher)} + + # Find indices in student and teacher where the tokens are identical + match_indices_student = [] + match_indices_teacher = [] + for i, token_student in enumerate(tokens_student): + if token_student in map_teacher_token_to_idx: + j = map_teacher_token_to_idx[token_student] + match_indices_student.append(i) + match_indices_teacher.append(j) + + if match_indices_student: + print(f"Found {len(match_indices_student)} exact matches. Setting perfect 1-to-1 mappings.") + + # load intial projection map + initial_projection_path = args.initial_projection_path + if initial_projection_path is not None: + initial_projection_map = torch.load(initial_projection_path) + else: + initial_projection_map = None + + # go through token in projection map. For each token present in match_indices_student, set it's likelihoods and incices to 1.0 and the exact match teacher token + non_exact_map_tokens = list(range(len(initial_projection_map["likelihoods"]))) + all_student_token_ids = list(range(len(initial_projection_map["likelihoods"]))) + + show_remapping = 5 + if show_remapping > 0: + print(f"Showing remapping for the last {show_remapping} exact matches.") + else: + print(f"Not showing remapping.") + + for i, exact_token_student in enumerate(match_indices_student): + exact_token_teacher = match_indices_teacher[i] + + index_ = all_student_token_ids.index(exact_token_student) + likelihoods = initial_projection_map["likelihoods"][index_] + indices = initial_projection_map["indices"][index_] + + if len(match_indices_student) - i <= show_remapping: + print(f"prior to remapping: likelihoods {likelihoods} indices {indices}") + + topk = indices.shape[0] + + remapped_indices = torch.ones_like(indices) * -1 + remapped_likelihoods = torch.zeros_like(likelihoods) + + remapped_likelihoods[0] = 1.0 + remapped_indices[0] = exact_token_teacher + + # if exact_token_student == 5159: + # import pdb + # pdb.set_trace() + + initial_projection_map["likelihoods"][index_] = remapped_likelihoods + initial_projection_map["indices"][index_] = remapped_indices + + + if len(match_indices_student) - i <= show_remapping: + print(f'after remapping {tokens_student[exact_token_student]}:{exact_token_student} -> {tokens_teacher[exact_token_teacher]}:{exact_token_teacher}: likelihoods {initial_projection_map["likelihoods"][index_]} indices {initial_projection_map["indices"][index_]}') + non_exact_map_tokens.remove(index_) + + + # import pdb + # pdb.set_trace() + # print(f"non exact map tokens: {non_exact_map_tokens}") + # pdb.set_trace() + save_path = args.initial_projection_path.split(".")[0] + "_exact_map_remapped.pt" + torch.save(initial_projection_map, save_path) + print(f"Saved remapped projection map to: {save_path}") + print(f"remapped {len(match_indices_student)} tokens. Retained remaining {len(non_exact_map_tokens)} tokens as is.") \ No newline at end of file diff --git a/nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py b/nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py new file mode 100644 index 0000000000..caba5b2738 --- /dev/null +++ b/nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py @@ -0,0 +1,451 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import os +import argparse +import tqdm +from transformers import AutoTokenizer, AutoConfig +import pdb; + +def sinkhorn_one_dim(A, n_iters=1): + """Apply Sinkhorn normalization to make each row sum to 1.""" + for _ in range(n_iters): + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + return A + +def clean_model_name_for_filename(name: str) -> str: + """Removes parameter counts and common suffixes from model names for cleaner filenames.""" + import re + # Removes patterns like -8B, -1.5B, -4b, -125m etc. + cleaned_name = re.sub(r'-?[0-9\.]+[bBmB]', '', name, flags=re.IGNORECASE) + # Remove common suffixes + cleaned_name = cleaned_name.replace('-Base', '').replace('-it', '').replace('-Instruct', '') + # Clean up any leading/trailing hyphens that might result + cleaned_name = cleaned_name.strip('-_') + return cleaned_name + +def sort_and_cut_projection_matrix(input_path, output_path, new_top_k, preserve_last=False, verbose=True): + """ + Load a projection matrix, sort each row by weight values, and save with new top_k cutoff. + + Args: + input_path: Path to input projection matrix file + output_path: Path to save the new projection matrix + new_top_k: New top_k value for cutoff + preserve_last: If True, always preserve the last column as the final element + verbose: Whether to print progress information + """ + if verbose: + print(f"Loading projection matrix from: {input_path}") + + # Load the projection matrix + projection_data = torch.load(input_path, map_location='cpu', weights_only=False) + + if not isinstance(projection_data, dict) or 'indices' not in projection_data or 'likelihoods' not in projection_data: + raise ValueError("Input file must contain a dictionary with 'indices' and 'likelihoods' keys") + + original_indices = projection_data['indices'] # Shape: [vocab_size, original_top_k] + original_likelihoods = projection_data['likelihoods'] # Shape: [vocab_size, original_top_k] + + vocab_size, original_top_k = original_indices.shape + + if verbose: + print(f"Original matrix shape: {original_indices.shape}") + print(f"Original top_k: {original_top_k}") + print(f"New top_k: {new_top_k}") + print(f"Preserve last column: {preserve_last}") + # pdb.set_trace() + + if new_top_k > original_top_k: + print(f"Warning: New top_k ({new_top_k}) is larger than original top_k ({original_top_k})") + print(f"Will pad with -1 indices and 0.0 likelihoods") + effective_top_k = original_top_k + else: + effective_top_k = new_top_k + + # Initialize new tensors + new_indices = torch.full((vocab_size, new_top_k), -1, dtype=original_indices.dtype) + new_likelihoods = torch.zeros((vocab_size, new_top_k), dtype=original_likelihoods.dtype) + + # Statistics tracking + rows_with_order_change = 0 + significant_components_count = [0] * min(new_top_k, 10) # Track up to 10 components + threshold_for_significance = 0.2 # Threshold for considering a component "significant" + # Track position of maximum element in original ordering + max_element_positions = {} # position -> count + # Track preserve_last statistics + rows_with_preserved_last = 0 + # Track specifically when max element is in the last column + rows_with_max_in_last_column = 0 + # Track position of maximum element in final sorted and trimmed matrix + final_max_element_positions = {} # position -> count + + # threshold_for_significance = 0.05 # Threshold for considering a component "significant" + # threshold_for_significance = 0.05 # Threshold for considering a component "significant" + + if verbose: + print("Sorting and cutting each row...") + + # Process each row (each source token) + last_element_trick_count = 0 + for row_idx in tqdm.tqdm(range(vocab_size), desc="Processing rows", disable=not verbose): + row_indices = original_indices[row_idx] # [original_top_k] + row_likelihoods = original_likelihoods[row_idx] # [original_top_k] + + # Filter out invalid indices (-1) and zero likelihoods + valid_mask = (row_indices != -1) & (row_likelihoods > 0) + + if valid_mask.any(): + valid_indices = row_indices[valid_mask] + valid_likelihoods = row_likelihoods[valid_mask] + + # Track position of maximum element in original ordering + max_pos = torch.argmax(valid_likelihoods).item() + if max_pos not in max_element_positions: + max_element_positions[max_pos] = 0 + max_element_positions[max_pos] += 1 + + # Check if max element is specifically in the last column + # Find the actual maximum value in the original row (including invalid entries) + original_max_pos = torch.argmax(row_likelihoods).item() + if original_max_pos == original_top_k - 1: + # Only count if the last position actually has valid data + last_index = row_indices[original_top_k - 1] + last_likelihood = row_likelihoods[original_top_k - 1] + if last_index != -1 and last_likelihood > 0: + rows_with_max_in_last_column += 1 + + if preserve_last and new_top_k >= 1: + # Handle preserve_last case + last_index = original_indices[row_idx, original_top_k - 1] + last_likelihood = original_likelihoods[row_idx, original_top_k - 1] + + if new_top_k == 1: + # Special case: only keep the last element + if last_index != -1 and last_likelihood > 0: + new_indices[row_idx, 0] = last_index + new_likelihoods[row_idx, 0] = last_likelihood + rows_with_preserved_last += 1 + + # Count significant components + if last_likelihood >= threshold_for_significance: + significant_components_count[0] += 1 + else: + # General case: sort first (original_top_k-1) elements, then add last element + elements_to_sort = min(len(valid_likelihoods), original_top_k - 1) + if elements_to_sort > 0: + # Get elements excluding the last position in original matrix + sort_mask = torch.arange(len(valid_likelihoods)) < elements_to_sort + if sort_mask.any(): + sortable_indices = valid_indices[sort_mask] + sortable_likelihoods = valid_likelihoods[sort_mask] + + # Sort the non-last elements + sorted_likelihoods, sort_order = torch.sort(sortable_likelihoods, descending=True) + sorted_indices = sortable_indices[sort_order] + + # Check if order changed in the sortable portion + original_order = torch.arange(len(sortable_likelihoods)) + if not torch.equal(sort_order, original_order): + rows_with_order_change += 1 + + # Take top (new_top_k - 1) elements from sorted portion + num_from_sorted = min(len(sorted_indices), new_top_k - 1) + + new_indices[row_idx, :num_from_sorted] = sorted_indices[:num_from_sorted] + new_likelihoods[row_idx, :num_from_sorted] = sorted_likelihoods[:num_from_sorted] + + # Count significant components from sorted portion + for comp_idx in range(min(num_from_sorted, len(significant_components_count) - 1)): + if sorted_likelihoods[comp_idx] >= threshold_for_significance: + significant_components_count[comp_idx] += 1 + + # Always put the last element at the end (if valid) + + if last_index != -1 and last_likelihood > 0: + last_element_trick_count += 1 + new_indices[row_idx, new_top_k - 1] = last_index + new_likelihoods[row_idx, new_top_k - 1] = last_likelihood + rows_with_preserved_last += 1 + + # Count significant component for the preserved last element + if new_top_k - 1 < len(significant_components_count): + if last_likelihood >= threshold_for_significance: + significant_components_count[new_top_k - 1] += 1 + + else: + # Original logic: sort all elements normally + # Check if order changed by comparing original vs sorted order + original_order = torch.arange(len(valid_likelihoods)) + sorted_likelihoods, sort_order = torch.sort(valid_likelihoods, descending=True) + + # Check if the order changed (not just sorted, but actually different) + if not torch.equal(sort_order, original_order): + rows_with_order_change += 1 + + sorted_indices = valid_indices[sort_order] + + # Take top effective_top_k elements + num_to_take = min(len(sorted_indices), effective_top_k) + + new_indices[row_idx, :num_to_take] = sorted_indices[:num_to_take] + new_likelihoods[row_idx, :num_to_take] = sorted_likelihoods[:num_to_take] + # pdb.set_trace() + # Count significant components (components above threshold) + for comp_idx in range(min(num_to_take, len(significant_components_count))): + if sorted_likelihoods[comp_idx] >= threshold_for_significance: + significant_components_count[comp_idx] += 1 + # if significant_components_count[1] > 0.0: + # pdb.set_trace() + + # If new_top_k > original_top_k, the tensors are already padded with -1 and 0.0 + + # Apply Sinkhorn normalization to the final matrix + print(f"last element trick count: {last_element_trick_count}") + if verbose: + print("Applying Sinkhorn normalization...") + + # Apply normalization only to non-zero values to preserve sparsity structure + normalized_likelihoods = sinkhorn_one_dim(new_likelihoods.clone(), n_iters=1) + + # Calculate final maximum element position statistics after sorting and normalization + if verbose: + print("Calculating final maximum element position statistics...") + + for row_idx in range(vocab_size): + row_likelihoods = normalized_likelihoods[row_idx] + # Filter out zero likelihoods + valid_mask = row_likelihoods > 0 + if valid_mask.any(): + valid_likelihoods = row_likelihoods[valid_mask] + # Find position of maximum element in the final matrix + max_pos_in_valid = torch.argmax(valid_likelihoods).item() + # Convert back to original position in the row + valid_positions = torch.nonzero(valid_mask).squeeze(-1) + actual_max_pos = valid_positions[max_pos_in_valid].item() + + if actual_max_pos not in final_max_element_positions: + final_max_element_positions[actual_max_pos] = 0 + final_max_element_positions[actual_max_pos] += 1 + + # Create output dictionary with same format as input + output_data = { + 'indices': new_indices, + 'likelihoods': normalized_likelihoods, + } + + # Copy over any additional metadata + for key in projection_data: + if key not in ['indices', 'likelihoods']: + output_data[key] = projection_data[key] + + # Save the new projection matrix + torch.save(output_data, output_path) + + if verbose: + print(f"Saved sorted and cut projection matrix to: {output_path}") + print(f"New matrix shape: {new_indices.shape}") + + # Show basic statistics + non_zero_counts = (new_likelihoods > 0).sum(dim=1) + avg_non_zero = non_zero_counts.float().mean().item() + print(f"Average non-zero entries per row: {avg_non_zero:.2f}") + print(f"Rows with max entries ({new_top_k}): {(non_zero_counts == new_top_k).sum().item()}") + + # Show ordering statistics + print(f"\n=== Ordering Statistics ===") + print(f"Rows with changed order after sorting: {rows_with_order_change:,} / {vocab_size:,} ({100*rows_with_order_change/vocab_size:.1f}%)") + if preserve_last: + print(f"Rows with preserved last element: {rows_with_preserved_last:,} / {vocab_size:,} ({100*rows_with_preserved_last/vocab_size:.1f}%)") + + # Show last column maximum element statistics + print(f"\n=== Last Column Maximum Element Statistics ===") + total_rows_with_data = sum(max_element_positions.values()) + if total_rows_with_data > 0: + percentage_last_max = 100 * rows_with_max_in_last_column / total_rows_with_data + print(f"Rows with maximum element in LAST column: {rows_with_max_in_last_column:,} / {total_rows_with_data:,} ({percentage_last_max:.1f}%)") + print(f"Rows with maximum element in NON-LAST columns: {total_rows_with_data - rows_with_max_in_last_column:,} / {total_rows_with_data:,} ({100 - percentage_last_max:.1f}%)") + else: + print(f"No valid data found to analyze last column statistics") + + # Show maximum element position distribution + print(f"\n=== Maximum Element Position Distribution (Original Ordering) ===") + total_rows_with_data = sum(max_element_positions.values()) + print(f"Total rows with valid data: {total_rows_with_data:,}") + + # Sort positions for ordered display + sorted_positions = sorted(max_element_positions.keys()) + for pos in sorted_positions[:20]: # Show up to first 20 positions + count = max_element_positions[pos] + percentage = 100 * count / total_rows_with_data if total_rows_with_data > 0 else 0 + ordinal = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"][pos] if pos < 10 else f"{pos+1}th" + print(f"Rows with max element in {ordinal} position: {count:,} / {total_rows_with_data:,} ({percentage:.1f}%)") + + if len(sorted_positions) > 20: + remaining_count = sum(max_element_positions[pos] for pos in sorted_positions[20:]) + remaining_percentage = 100 * remaining_count / total_rows_with_data if total_rows_with_data > 0 else 0 + print(f"Rows with max element in positions 21+: {remaining_count:,} / {total_rows_with_data:,} ({remaining_percentage:.1f}%)") + + # Show final maximum element position distribution (after sorting and normalization) + print(f"\n=== Maximum Element Position Distribution (Final Sorted & Normalized Matrix) ===") + total_final_rows_with_data = sum(final_max_element_positions.values()) + print(f"Total rows with valid data: {total_final_rows_with_data:,}") + + if total_final_rows_with_data > 0: + # Sort positions for ordered display + sorted_final_positions = sorted(final_max_element_positions.keys()) + for pos in sorted_final_positions[:min(new_top_k, 20)]: # Show up to new_top_k or 20 positions + count = final_max_element_positions[pos] + percentage = 100 * count / total_final_rows_with_data + ordinal = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"][pos] if pos < 10 else f"{pos+1}th" + print(f"Rows with max element in {ordinal} position: {count:,} / {total_final_rows_with_data:,} ({percentage:.1f}%)") + + if len(sorted_final_positions) > min(new_top_k, 20): + remaining_count = sum(final_max_element_positions[pos] for pos in sorted_final_positions[min(new_top_k, 20):]) + remaining_percentage = 100 * remaining_count / total_final_rows_with_data + print(f"Rows with max element in positions {min(new_top_k, 20)+1}+: {remaining_count:,} / {total_final_rows_with_data:,} ({remaining_percentage:.1f}%)") + + # Show significant components statistics + print(f"\n=== Significant Components Statistics (threshold >= {threshold_for_significance}) ===") + component_names = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"] + for i, count in enumerate(significant_components_count): + percentage = 100 * count / vocab_size if vocab_size > 0 else 0 + print(f"Rows with significant {component_names[i]} component: {count:,} / {vocab_size:,} ({percentage:.1f}%)") + + # Additional analysis: distribution of likelihood values (after normalization) + all_likelihoods = normalized_likelihoods[normalized_likelihoods > 0] + if len(all_likelihoods) > 0: + print(f"\n=== Likelihood Distribution ===") + print(f"Total non-zero likelihoods: {len(all_likelihoods):,}") + print(f"Mean likelihood: {all_likelihoods.mean().item():.4f}") + print(f"Median likelihood: {all_likelihoods.median().item():.4f}") + print(f"Min likelihood: {all_likelihoods.min().item():.4f}") + print(f"Max likelihood: {all_likelihoods.max().item():.4f}") + + # Show percentiles - convert to float for quantile calculation + percentiles = [90, 95, 99] + all_likelihoods_float = all_likelihoods.float() + for p in percentiles: + val = torch.quantile(all_likelihoods_float, p/100.0).item() + print(f"{p}th percentile: {val:.4f}") + + # Show how many rows have multiple significant components + print(f"\n=== Multi-Component Analysis ===") + rows_with_multiple_significant = 0 + for row_idx in range(vocab_size): + significant_in_row = (normalized_likelihoods[row_idx] >= threshold_for_significance).sum().item() + if significant_in_row >= 2: + rows_with_multiple_significant += 1 + + percentage_multi = 100 * rows_with_multiple_significant / vocab_size if vocab_size > 0 else 0 + print(f"Rows with 2+ significant components: {rows_with_multiple_significant:,} / {vocab_size:,} ({percentage_multi:.1f}%)") + + # Show normalization effect + print(f"\n=== Normalization Effect ===") + # Calculate row sums for ALL rows (including zero rows) + all_row_sums = normalized_likelihoods.sum(dim=1) + non_zero_rows = (normalized_likelihoods > 0).any(dim=1) + zero_rows = ~non_zero_rows + + print(f"Total rows: {vocab_size:,}") + print(f"Rows with non-zero entries: {non_zero_rows.sum().item():,}") + print(f"Rows with all zeros: {zero_rows.sum().item():,}") + + if non_zero_rows.any(): + row_sums_nonzero = all_row_sums[non_zero_rows] + print(f"\nNon-zero rows statistics:") + print(f" Mean sum: {row_sums_nonzero.mean().item():.6f}") + print(f" Std sum: {row_sums_nonzero.std().item():.6f}") + print(f" Min sum: {row_sums_nonzero.min().item():.6f}") + print(f" Max sum: {row_sums_nonzero.max().item():.6f}") + + # Check how many rows don't sum to 1 (with different tolerance levels) + tolerances = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2] + for tol in tolerances: + perfect_rows = (torch.abs(row_sums_nonzero - 1.0) < tol).sum().item() + imperfect_rows = len(row_sums_nonzero) - perfect_rows + percentage_imperfect = 100 * imperfect_rows / len(row_sums_nonzero) + print(f" Rows NOT summing to 1.0 (tol={tol}): {imperfect_rows:,}/{len(row_sums_nonzero):,} ({percentage_imperfect:.2f}%)") + + # Show distribution of row sums that deviate from 1.0 + if non_zero_rows.any(): + row_sums_nonzero = all_row_sums[non_zero_rows] + deviations = torch.abs(row_sums_nonzero - 1.0) + significant_deviations = deviations > 1e-3 + + if significant_deviations.any(): + print(f"\nRows with significant deviations from 1.0 (>0.001): {significant_deviations.sum().item():,}") + worst_deviations = deviations[significant_deviations] + print(f" Mean deviation: {worst_deviations.mean().item():.6f}") + print(f" Max deviation: {worst_deviations.max().item():.6f}") + + # Show some examples of problematic rows + worst_indices = torch.topk(deviations, k=min(5, len(deviations)))[1] + print(f" Worst {min(5, len(worst_indices))} row examples:") + for i, idx in enumerate(worst_indices): + actual_row_idx = torch.nonzero(non_zero_rows)[idx].item() + sum_val = row_sums_nonzero[idx].item() + deviation = deviations[idx].item() + non_zero_count = (normalized_likelihoods[actual_row_idx] > 0).sum().item() + print(f" Row {actual_row_idx}: sum={sum_val:.6f}, deviation={deviation:.6f}, non_zeros={non_zero_count}") + else: + print(f"\nAll non-zero rows sum very close to 1.0 (deviation < 0.001)") + +def main(): + parser = argparse.ArgumentParser(description="Sort and cut projection matrix by top_k") + parser.add_argument("input_path", help="Path to input projection matrix file") + parser.add_argument("--top_k", type=int, required=True, help="New top_k value for cutoff") + parser.add_argument("--output_path", help="Output path (auto-generated if not specified)") + parser.add_argument("--preserve_last", action="store_true", help="Always preserve the last column as the final element") + parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output") + # python sort_and_cut_projection_matrix.py /lustre/fsw/portfolios/nvr/projects/nvr_lpr_llm/users/pmolchanov/xtoken/models/runs/s4_l1q4b_lr0_kl1_ce0_k1_emb_top10_3_learn_qa2_transformation_matrices/learned_projection_map_latest.pt --top_k 8 --output_path cross_tokenizer_data/projection_matrix_learned_llama_qwen_top8.pt --preserve_last + #s4_l1q4b_lr0_kl1_ce0_k1_emb_top10_3_learn_qa2_transformation_matrices + args = parser.parse_args() + + # Auto-generate output path if not specified + if args.output_path is None: + input_dir = os.path.dirname(args.input_path) + input_filename = os.path.basename(args.input_path) + + # Extract base name and extension + base_name, ext = os.path.splitext(input_filename) + + # Remove old top_k info if present + import re + base_name = re.sub(r'_top_\d+', '', base_name) + + # Add new top_k info and preserve_last flag + suffix = "_sorted" + if args.preserve_last: + suffix += "_preservelast" + output_filename = f"{base_name}_top_{args.top_k}{suffix}{ext}" + args.output_path = os.path.join(input_dir, output_filename) + + # Ensure output directory exists + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + + # Process the matrix + sort_and_cut_projection_matrix( + args.input_path, + args.output_path, + args.top_k, + preserve_last=args.preserve_last, + verbose=not args.quiet + ) + +if __name__ == "__main__": + main() \ No newline at end of file From 59346997c46b9306bc63490b95ce635e4ea174aa Mon Sep 17 00:00:00 2001 From: Adithyakrishna Hanasoge Date: Sun, 26 Apr 2026 19:56:27 -0700 Subject: [PATCH 2/4] feat: cross-tokenizer collator, Arrow dataset, and eval datasets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Data-layer plumbing for cross-tokenizer off-policy distillation, plus in-training eval datasets. Builds on the TokenAligner package from the prior PR. - nemo_rl/data/cross_tokenizer_collate.py: CrossTokenizerCollator and TeacherCTSpec. Runs in StatefulDataLoader worker processes — does per-teacher tokenize + DP alignment up front so the train loop only consumes pre-built per_teacher_ct_data. Lazy-imports TokenAligner so workers that don't need cross-tokenizer never touch x_token. - nemo_rl/data/__init__.py: add NotRequired prefetch_factor to DataConfig. - nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py: ArrowTextDataset with lazy packing, registered as "arrow_text" in DATASET_REGISTRY. - nemo_rl/data/datasets/eval_datasets/{humaneval_plus,mbpp_plus,mmlu}.py and registry entries: in-training eval datasets. mmlu.py adds an optional num_few_shot argument with a static _build_few_shot_prefixes helper; default of 0 preserves existing behavior. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Adithyakrishna Hanasoge --- nemo_rl/data/__init__.py | 4 + nemo_rl/data/cross_tokenizer_collate.py | 352 ++++++++++++++++++ .../data/datasets/eval_datasets/__init__.py | 22 ++ .../datasets/eval_datasets/humaneval_plus.py | 74 ++++ .../data/datasets/eval_datasets/mbpp_plus.py | 96 +++++ nemo_rl/data/datasets/eval_datasets/mmlu.py | 39 ++ .../datasets/response_datasets/__init__.py | 3 + .../response_datasets/arrow_text_dataset.py | 308 +++++++++++++++ nemo_rl/data/processors.py | 67 ++++ nemo_rl/data/utils.py | 10 +- tests/unit/data/test_data_processor.py | 107 +++++- 11 files changed, 1079 insertions(+), 3 deletions(-) create mode 100644 nemo_rl/data/cross_tokenizer_collate.py create mode 100644 nemo_rl/data/datasets/eval_datasets/humaneval_plus.py create mode 100644 nemo_rl/data/datasets/eval_datasets/mbpp_plus.py create mode 100644 nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py diff --git a/nemo_rl/data/__init__.py b/nemo_rl/data/__init__.py index e77e97a2b0..fbced245f7 100644 --- a/nemo_rl/data/__init__.py +++ b/nemo_rl/data/__init__.py @@ -57,6 +57,10 @@ class DataConfig(TypedDict): # This saturates CPU threads without consuming too much memory # However, setting it too high might cause memory issues for long seqlens. num_workers: NotRequired[int] + # PyTorch DataLoader prefetch_factor: number of batches each worker pre-fetches. + # Combined with num_workers, controls how much cross-tokenizer alignment work + # is run ahead of the training step in dataloader processes. + prefetch_factor: NotRequired[int] # multiple dataloader configs # currently only supported for GRPO use_multiple_dataloader: NotRequired[bool] diff --git a/nemo_rl/data/cross_tokenizer_collate.py b/nemo_rl/data/cross_tokenizer_collate.py new file mode 100644 index 0000000000..a658712b5e --- /dev/null +++ b/nemo_rl/data/cross_tokenizer_collate.py @@ -0,0 +1,352 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cross-tokenizer collate function for off-policy distillation. + +Moves teacher tokenize + DP alignment off the training critical path and into +``StatefulDataLoader`` worker processes. With ``num_workers=N, prefetch_factor=P`` +there are up to ``N*P`` batches of CT work in flight, so the consumer pulls +already-processed batches and CT is hidden behind teacher forward. + +Mirrors the train_distillation_ddp / TokenizeAndAlignCollator shape in +tokenalign/src/pytorch_data_loader.py. +""" + +from __future__ import annotations + +import sys +from typing import Any, Optional + +if sys.version_info >= (3, 11): + from typing import NotRequired, TypedDict +else: + from typing import TypedDict + from typing_extensions import NotRequired + +import torch +from transformers import AutoTokenizer + +from nemo_rl.data.collate_fn import rl_collate_fn +from nemo_rl.data.interfaces import DatumSpec +from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message +from nemo_rl.distributed.batched_data_dict import BatchedDataDict + + +class TeacherCTSpec(TypedDict): + """Per-teacher spec passed to CrossTokenizerCollator. + + All fields are pickle-cheap primitives so the collator itself ships + cheaply to DataLoader workers. Tokenizers and aligners are built lazily + in each worker. + """ + + teacher_tokenizer_name: str + student_tokenizer_name: str + projection_matrix_path: str + use_sparse_format: bool + learnable: bool + max_comb_len: int + projection_matrix_multiplier: float + project_teacher_to_student: bool + max_teacher_len: int + dp_chunk_size: int + use_align_fast: bool + exact_token_match_only: bool + + +def _build_chunk_coo( + aligned_pairs: list, + student_seq_len: int, + teacher_seq_len: int, + exact_match_only: bool, +) -> tuple[torch.Tensor, torch.Tensor, int]: + """Pre-filter alignment pairs and emit per-sample chunk mask indices. + + Applies the exact_match_only, ``-1`` sentinel, and padded-length bounds + filters that the loss fn used to run per-microbatch, then flattens each + surviving chunk's student/teacher spans into COO rows ``[pos, chunk_id]`` + that the loss fn can ``index_put_`` into the dense ``proj_mask``/``tgt_mask``. + + Returns ``(student_chunk_coo, teacher_chunk_coo, num_chunks)``. COO tensors + are empty ``(0, 2)`` when a sample has no surviving chunks. + """ + student_rows: list[tuple[int, int]] = [] + teacher_rows: list[tuple[int, int]] = [] + chunk_id = 0 + for pair in aligned_pairs: + s1_start, s1_end, s2_start, s2_end = pair[2], pair[3], pair[4], pair[5] + if exact_match_only and ( + s1_end - s1_start != 1 or s2_end - s2_start != 1 + ): + continue + if s1_start == -1 or s2_start == -1: + continue + if s1_end > student_seq_len or s2_end > teacher_seq_len: + continue + for pos in range(s1_start, s1_end): + student_rows.append((pos, chunk_id)) + for pos in range(s2_start, s2_end): + teacher_rows.append((pos, chunk_id)) + chunk_id += 1 + + student_coo = ( + torch.tensor(student_rows, dtype=torch.int64) + if student_rows + else torch.empty((0, 2), dtype=torch.int64) + ) + teacher_coo = ( + torch.tensor(teacher_rows, dtype=torch.int64) + if teacher_rows + else torch.empty((0, 2), dtype=torch.int64) + ) + return student_coo, teacher_coo, chunk_id + + +class CrossTokenizerCollator: + """Collator that does base collate + message flatten + per-teacher CT. + + Designed to be pickled into ``torch.utils.data.DataLoader`` worker + processes. On first call inside a worker, it constructs its own + ``TokenAligner`` instances and teacher tokenizers from the specs, then + reuses them for every subsequent batch in that worker. + + DP-only: does not run the char-offset alignment fast path (consumers + that want char-offset should restore it here). + """ + + def __init__( + self, + pad_token_id: int, + make_sequence_length_divisible_by: int, + teacher_ct_specs: list[Optional[TeacherCTSpec]], + fallback_student_tokenizer_name: Optional[str] = None, + ) -> None: + self.pad_token_id = pad_token_id + self.make_seq_div_by = int(make_sequence_length_divisible_by) + self.teacher_ct_specs = list(teacher_ct_specs) + self.fallback_student_tokenizer_name = fallback_student_tokenizer_name + + # Lazy per-worker state — excluded from __getstate__ below. + self._initialized: bool = False + self._aligners: list[Optional[Any]] = [] + self._teacher_tokenizers: list[Optional[Any]] = [] + self._student_tokenizer: Optional[Any] = None + + def __getstate__(self) -> dict[str, Any]: + # Only ship pickle-cheap primitives across the fork/spawn boundary. + return { + "pad_token_id": self.pad_token_id, + "make_seq_div_by": self.make_seq_div_by, + "teacher_ct_specs": self.teacher_ct_specs, + "fallback_student_tokenizer_name": self.fallback_student_tokenizer_name, + } + + def __setstate__(self, state: dict[str, Any]) -> None: + self.pad_token_id = state["pad_token_id"] + self.make_seq_div_by = state["make_seq_div_by"] + self.teacher_ct_specs = state["teacher_ct_specs"] + self.fallback_student_tokenizer_name = state["fallback_student_tokenizer_name"] + self._initialized = False + self._aligners = [] + self._teacher_tokenizers = [] + self._student_tokenizer = None + + def _lazy_init(self) -> None: + if self._initialized: + return + + # Import TokenAligner lazily so module import stays cheap and so + # workers that don't need CT never touch x_token. + from nemo_rl.algorithms.x_token.tokenalign import TokenAligner + + for spec in self.teacher_ct_specs: + if spec is None: + self._aligners.append(None) + self._teacher_tokenizers.append(None) + continue + + teacher_tokenizer = AutoTokenizer.from_pretrained( + spec["teacher_tokenizer_name"] + ) + if teacher_tokenizer.pad_token is None: + teacher_tokenizer.pad_token = teacher_tokenizer.eos_token + + aligner = TokenAligner( + teacher_tokenizer_name=spec["teacher_tokenizer_name"], + student_tokenizer_name=spec["student_tokenizer_name"], + max_comb_len=int(spec["max_comb_len"]), + projection_matrix_multiplier=float( + spec["projection_matrix_multiplier"] + ), + ) + aligner._load_logits_projection_map( + file_path=spec["projection_matrix_path"], + use_sparse_format=bool(spec["use_sparse_format"]), + learnable=bool(spec["learnable"]), + device="cpu", + ) + if bool(spec["project_teacher_to_student"]): + aligner.create_reverse_projection_matrix(device="cpu") + aligner.precompute_canonical_maps() + + self._teacher_tokenizers.append(teacher_tokenizer) + self._aligners.append(aligner) + + self._initialized = True + + def _get_student_tokenizer(self) -> Any: + if self._student_tokenizer is not None: + return self._student_tokenizer + name = self.fallback_student_tokenizer_name + if name is None: + # Best-effort: reuse any CT spec's student name. + for spec in self.teacher_ct_specs: + if spec is not None: + name = spec["student_tokenizer_name"] + break + if name is None: + raise RuntimeError( + "CrossTokenizerCollator needs a student tokenizer for the decode " + "fallback, but no name was provided and no CT spec supplied one." + ) + self._student_tokenizer = AutoTokenizer.from_pretrained(name) + if self._student_tokenizer.pad_token is None: + self._student_tokenizer.pad_token = self._student_tokenizer.eos_token + return self._student_tokenizer + + def __call__(self, data_batch: list[DatumSpec]) -> BatchedDataDict[Any]: + self._lazy_init() + + base = rl_collate_fn(data_batch) + + # --- Message-flatten (ported from _prepare_train_batch_data) --- + for message_log in base["message_log"]: + for m in message_log: + if "token_loss_mask" not in m: + m["token_loss_mask"] = ( + torch.ones_like(m["token_ids"]) + if m["role"] == "assistant" + else torch.zeros_like(m["token_ids"]) + ) + flat_messages, input_lengths = batched_message_log_to_flat_message( + base["message_log"], + pad_value_dict={"token_ids": self.pad_token_id}, + make_sequence_length_divisible_by=self.make_seq_div_by, + ) + base["input_ids"] = flat_messages["token_ids"] + base["input_lengths"] = input_lengths + base["token_mask"] = flat_messages["token_loss_mask"] + base["sample_mask"] = base["loss_multiplier"] + base["flat_messages"] = flat_messages + mm_dict = flat_messages.get_multimodal_dict(as_tensors=False) + if mm_dict: + for k, v in mm_dict.items(): + base[k] = v + + # --- Per-teacher CT (DP-only) --- + student_ids = base["input_ids"] + extra_env = base.get("extra_env_info") + batch_size = student_ids.shape[0] + + has_raw_text = ( + extra_env is not None + and len(extra_env) == batch_size + and all( + isinstance(e, dict) and "raw_text" in e for e in extra_env + ) + ) + texts_cache: Optional[list[str]] = None + + per_teacher_ct_data: list[Optional[dict[str, Any]]] = [] + any_ct = any(spec is not None for spec in self.teacher_ct_specs) + if any_ct: + if has_raw_text: + texts_cache = [e["raw_text"] for e in extra_env] + else: + texts_cache = self._get_student_tokenizer().batch_decode( + student_ids.tolist(), skip_special_tokens=True + ) + + for t_idx, spec in enumerate(self.teacher_ct_specs): + if spec is None: + per_teacher_ct_data.append(None) + continue + + aligner = self._aligners[t_idx] + teacher_tokenizer = self._teacher_tokenizers[t_idx] + + enc = teacher_tokenizer( + texts_cache, + max_length=int(spec["max_teacher_len"]), + padding="max_length", + truncation=True, + return_tensors="pt", + ) + teacher_input_ids = enc["input_ids"] + teacher_attention_mask = enc["attention_mask"] + teacher_input_lengths = teacher_attention_mask.sum(dim=1) + teacher_token_mask = (teacher_attention_mask > 0).to(torch.float32) + + dp_chunk_size = int(spec["dp_chunk_size"]) + use_align_fast = bool(spec["use_align_fast"]) + exact_match_only = bool(spec.get("exact_token_match_only", False)) + student_seq_len = int(student_ids.shape[1]) + teacher_seq_len = int(teacher_input_ids.shape[1]) + + aligned_pairs: list[Any] = [] + student_chunk_coo: list[torch.Tensor] = [] + teacher_chunk_coo: list[torch.Tensor] = [] + num_chunks_per_sample: list[int] = [] + for b in range(batch_size): + s_t = student_ids[b : b + 1] + t_t = teacher_input_ids[b : b + 1] + if use_align_fast and aligner._student_canon_map is not None: + result = aligner.align_fast( + s_t, t_t, chunk_size=dp_chunk_size + ) + else: + result = aligner.align(s_t, t_t, chunk_size=dp_chunk_size) + pairs = result[0] + aligned_pairs.append(pairs) + + s_coo, t_coo, n_chunks = _build_chunk_coo( + pairs, student_seq_len, teacher_seq_len, exact_match_only, + ) + student_chunk_coo.append(s_coo) + teacher_chunk_coo.append(t_coo) + num_chunks_per_sample.append(n_chunks) + + teacher_data: BatchedDataDict[Any] = BatchedDataDict( + { + "input_ids": teacher_input_ids, + "input_lengths": teacher_input_lengths, + "token_mask": teacher_token_mask, + "sample_mask": base["loss_multiplier"], + } + ) + teacher_data.to("cpu") + + per_teacher_ct_data.append( + { + "teacher_input_ids": teacher_input_ids, + "aligned_pairs": aligned_pairs, + "teacher_data": teacher_data, + "student_chunk_coo": student_chunk_coo, + "teacher_chunk_coo": teacher_chunk_coo, + "num_chunks": num_chunks_per_sample, + } + ) + + base["per_teacher_ct_data"] = per_teacher_ct_data + return base diff --git a/nemo_rl/data/datasets/eval_datasets/__init__.py b/nemo_rl/data/datasets/eval_datasets/__init__.py index d813ed040c..f0b5d15a81 100644 --- a/nemo_rl/data/datasets/eval_datasets/__init__.py +++ b/nemo_rl/data/datasets/eval_datasets/__init__.py @@ -14,8 +14,10 @@ from nemo_rl.data.datasets.eval_datasets.aime import AIMEDataset from nemo_rl.data.datasets.eval_datasets.gpqa import GPQADataset +from nemo_rl.data.datasets.eval_datasets.humaneval_plus import HumanEvalPlusDataset from nemo_rl.data.datasets.eval_datasets.local_math_dataset import LocalMathDataset from nemo_rl.data.datasets.eval_datasets.math import MathDataset +from nemo_rl.data.datasets.eval_datasets.mbpp_plus import MBPPPlusDataset from nemo_rl.data.datasets.eval_datasets.mmau import MMAUDataset from nemo_rl.data.datasets.eval_datasets.mmlu import MMLUDataset from nemo_rl.data.datasets.eval_datasets.mmlu_pro import MMLUProDataset @@ -35,10 +37,12 @@ def load_eval_dataset(data_config): # mmlu if dataset_name.startswith("mmlu") and dataset_name != "mmlu_pro": + num_few_shot = data_config.get("num_few_shot", 0) if dataset_name == "mmlu": base_dataset = MMLUDataset( prompt_file=data_config["prompt_file"], system_prompt_file=data_config["system_prompt_file"], + num_few_shot=num_few_shot, ) else: language = dataset_name.split("_")[1] @@ -46,6 +50,7 @@ def load_eval_dataset(data_config): language=language, prompt_file=data_config["prompt_file"], system_prompt_file=data_config["system_prompt_file"], + num_few_shot=num_few_shot, ) elif dataset_name == "mmlu_pro": base_dataset = MMLUProDataset( @@ -98,6 +103,21 @@ def load_eval_dataset(data_config): dataset_name="TwinkStart/MMAU", split=split, ) + # code-execution benchmarks + elif dataset_name == "mbpp_plus": + base_dataset = MBPPPlusDataset( + prompt_file=data_config["prompt_file"], + system_prompt_file=data_config["system_prompt_file"], + dataset_path=data_config.get("dataset_path") or "evalplus/mbppplus", + split=data_config.get("split") or "test", + ) + elif dataset_name in ("humaneval_plus", "human_eval_plus"): + base_dataset = HumanEvalPlusDataset( + prompt_file=data_config["prompt_file"], + system_prompt_file=data_config["system_prompt_file"], + dataset_path=data_config.get("dataset_path") or "evalplus/humanevalplus", + split=data_config.get("split") or "test", + ) # fall back to local dataset else: print(f"Loading dataset from {dataset_name}...") @@ -117,8 +137,10 @@ def load_eval_dataset(data_config): __all__ = [ "AIMEDataset", "GPQADataset", + "HumanEvalPlusDataset", "LocalMathDataset", "MathDataset", + "MBPPPlusDataset", "MMAUDataset", "MMLUDataset", "MMLUProDataset", diff --git a/nemo_rl/data/datasets/eval_datasets/humaneval_plus.py b/nemo_rl/data/datasets/eval_datasets/humaneval_plus.py new file mode 100644 index 0000000000..8efa4f968b --- /dev/null +++ b/nemo_rl/data/datasets/eval_datasets/humaneval_plus.py @@ -0,0 +1,74 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""HumanEval+ dataset for code-execution eval. + +Loads the upstream `evalplus/humanevalplus` HuggingFace dataset and shapes each +example into the columns expected by +:func:`nemo_rl.data.processors.code_data_processor` and graded by +:class:`nemo_rl.environments.code_environment.CodeUnitTestEnvironment`. +""" + +from typing import Any, Optional + +from datasets import load_dataset + +from nemo_rl.data import processors +from nemo_rl.data.interfaces import TaskDataSpec + + +class HumanEvalPlusDataset: + """Code-execution eval over the HumanEval+ test split. + + The ``env_name`` attribute lets the runner pick the right environment + (``code_unit_test`` instead of ``math``/``multichoice``). + """ + + env_name: str = "code_unit_test" + + def __init__( + self, + prompt_file: Optional[str] = None, + system_prompt_file: Optional[str] = None, + dataset_path: str = "evalplus/humanevalplus", + split: str = "test", + ): + ds = load_dataset(dataset_path, split=split) + self.rekeyed_ds = ds.map(self._rekey, remove_columns=ds.column_names) + self.task_spec = TaskDataSpec( + task_name="humaneval_plus", + prompt_file=prompt_file, + system_prompt_file=system_prompt_file, + ) + self.processor = processors.code_data_processor + + @staticmethod + def _rekey(data: dict[str, Any]) -> dict[str, Any]: + prompt = data.get("prompt") or "" + test = data.get("test") or "" + entry_point = data.get("entry_point") or "" + return { + "problem": str(prompt), + # HumanEval+ ships the full check() harness in `test`; the grader + # will exec it and call check(). + "test_code": str(test), + "test_list": [], + "test_imports": [], + "entry_point": str(entry_point), + # HumanEval+ uses a continuation-style prompt: the model is + # expected to emit only the *body* of `entry_point`. The grader + # must prepend this stub before exec'ing so that `entry_point` + # actually gets defined in the namespace. + "code_prefix": str(prompt), + } diff --git a/nemo_rl/data/datasets/eval_datasets/mbpp_plus.py b/nemo_rl/data/datasets/eval_datasets/mbpp_plus.py new file mode 100644 index 0000000000..2d0492e4aa --- /dev/null +++ b/nemo_rl/data/datasets/eval_datasets/mbpp_plus.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MBPP+ (mostly-basic Python problems, plus) dataset for code-execution eval. + +Loads the upstream `evalplus/mbppplus` HuggingFace dataset and shapes each +example into the columns expected by +:func:`nemo_rl.data.processors.code_data_processor` and graded by +:class:`nemo_rl.environments.code_environment.CodeUnitTestEnvironment`. +""" + +import re +from typing import Any, Optional + +from datasets import load_dataset + +from nemo_rl.data import processors +from nemo_rl.data.interfaces import TaskDataSpec + + +_FUNCTION_DEF_RE = re.compile(r"def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(") + + +def _infer_entry_point(reference_code: str, test_list: list[str]) -> str: + """Best-effort inference of the candidate function name.""" + if reference_code: + match = _FUNCTION_DEF_RE.search(reference_code) + if match: + return match.group(1) + for assertion in test_list: + match = re.search(r"assert\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", str(assertion)) + if match: + return match.group(1) + return "" + + +class MBPPPlusDataset: + """Code-execution eval over the MBPP+ test split. + + The ``env_name`` attribute lets the runner pick the right environment + (``code_unit_test`` instead of ``math``/``multichoice``). + """ + + env_name: str = "code_unit_test" + + def __init__( + self, + prompt_file: Optional[str] = None, + system_prompt_file: Optional[str] = None, + dataset_path: str = "evalplus/mbppplus", + split: str = "test", + ): + ds = load_dataset(dataset_path, split=split) + self.rekeyed_ds = ds.map(self._rekey, remove_columns=ds.column_names) + self.task_spec = TaskDataSpec( + task_name="mbpp_plus", + prompt_file=prompt_file, + system_prompt_file=system_prompt_file, + ) + self.processor = processors.code_data_processor + + @staticmethod + def _rekey(data: dict[str, Any]) -> dict[str, Any]: + prompt = data.get("prompt") or data.get("text") or "" + test_list = data.get("test_list") or [] + if not isinstance(test_list, list): + test_list = [test_list] + test_imports = data.get("test_imports") or [] + if not isinstance(test_imports, list): + test_imports = [test_imports] + # MBPP+ does not ship an explicit entry_point; infer it so that + # HumanEval-style ``check(candidate)`` harnesses keep working when + # the same processor is reused. + entry_point = _infer_entry_point( + str(data.get("code", "")), [str(x) for x in test_list] + ) + return { + "problem": str(prompt), + # MBPP+'s ``test`` is the canonical-solution test script; the + # grader uses ``test_list`` (assertions) instead. + "test_code": "", + "test_list": [str(x) for x in test_list], + "test_imports": [str(x) for x in test_imports], + "entry_point": entry_point, + } diff --git a/nemo_rl/data/datasets/eval_datasets/mmlu.py b/nemo_rl/data/datasets/eval_datasets/mmlu.py index c9a373fc10..12953166ec 100644 --- a/nemo_rl/data/datasets/eval_datasets/mmlu.py +++ b/nemo_rl/data/datasets/eval_datasets/mmlu.py @@ -14,6 +14,7 @@ """MMLU dataset and its variants.""" +from collections import defaultdict from typing import Any, Literal, Optional from datasets import load_dataset @@ -21,6 +22,8 @@ from nemo_rl.data import processors from nemo_rl.data.interfaces import TaskDataSpec +ANSWER_INDEX_TO_LETTER = {0: "A", 1: "B", 2: "C", 3: "D"} + class MMLUDataset: def __init__( @@ -44,6 +47,7 @@ def __init__( ] = "EN-US", prompt_file: Optional[str] = None, system_prompt_file: Optional[str] = None, + num_few_shot: int = 0, ): if language != "EN-US": data_files = f"https://openaipublic.blob.core.windows.net/simple-evals/mmlu_{language}.csv" @@ -58,6 +62,14 @@ def __init__( ) self.rekeyed_ds = ds.map(self._rekey, remove_columns=ds.column_names) + if num_few_shot > 0: + few_shot_prefixes = self._build_few_shot_prefixes(num_few_shot) + self.rekeyed_ds = self.rekeyed_ds.map( + lambda ex: { + "few_shot_prefix": few_shot_prefixes.get(ex["subject"], "") + } + ) + self.task_spec = TaskDataSpec( task_name=f"MMLU_{language}", prompt_file=prompt_file, @@ -65,6 +77,33 @@ def __init__( ) self.processor = processors.multichoice_qa_processor + @staticmethod + def _build_few_shot_prefixes(num_few_shot: int) -> dict[str, str]: + """Build per-subject few-shot prefixes from MMLU's dev (validation) split.""" + dev_ds = load_dataset("cais/mmlu", "all", split="validation") + + dev_by_subject: dict[str, list[dict[str, Any]]] = defaultdict(list) + for ex in dev_ds: + dev_by_subject[ex["subject"]].append(ex) + + prefixes: dict[str, str] = {} + for subject, examples in dev_by_subject.items(): + parts = [] + for fs_ex in examples[:num_few_shot]: + choices = fs_ex["choices"] + options_str = "\n".join( + f"{letter}) {choices[i]}" + for i, letter in enumerate(["A", "B", "C", "D"]) + ) + answer_letter = ANSWER_INDEX_TO_LETTER[fs_ex["answer"]] + parts.append( + f"Question: {fs_ex['question']}\nOptions:\n{options_str}\n" + f"Answer: {answer_letter}" + ) + prefixes[subject] = "\n\n".join(parts) + + return prefixes + def _rekey(self, data: dict[str, Any]): return { "question": data["Question"], diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py index a8928a16a9..bfb98e350c 100644 --- a/nemo_rl/data/datasets/response_datasets/__init__.py +++ b/nemo_rl/data/datasets/response_datasets/__init__.py @@ -14,6 +14,7 @@ from nemo_rl.data import ResponseDatasetConfig from nemo_rl.data.datasets.response_datasets.aime24 import AIME2024Dataset +from nemo_rl.data.datasets.response_datasets.arrow_text_dataset import ArrowTextDataset from nemo_rl.data.datasets.response_datasets.avqa import AVQADataset from nemo_rl.data.datasets.response_datasets.clevr import CLEVRCoGenTDataset from nemo_rl.data.datasets.response_datasets.daily_omni import DailyOmniDataset @@ -48,6 +49,7 @@ # built-in datasets "avqa": AVQADataset, "AIME2024": AIME2024Dataset, + "arrow_text": ArrowTextDataset, "clevr-cogent": CLEVRCoGenTDataset, "daily-omni": DailyOmniDataset, "general-conversation-jsonl": GeneralConversationsJsonlDataset, @@ -98,6 +100,7 @@ def load_response_dataset(data_config: ResponseDatasetConfig): __all__ = [ "AVQADataset", "AIME2024Dataset", + "ArrowTextDataset", "CLEVRCoGenTDataset", "DailyOmniDataset", "GeneralConversationsJsonlDataset", diff --git a/nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py b/nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py new file mode 100644 index 0000000000..e7f7d89145 --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py @@ -0,0 +1,308 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Arrow Text Dataset for loading arrow files with 'text' column.""" + +import glob +import hashlib +import json +import os +import time +from typing import Any, Optional + +import numpy as np +from datasets import Dataset, load_dataset + +from nemo_rl.data.datasets.raw_dataset import RawDataset + + +class _LazyPackedDataset: + """Map-style dataset that packs text on demand via precomputed boundaries. + + Each ``__getitem__`` call loads only the raw rows needed for one packed + sample, joins them with newlines, and returns the messages dict. The + underlying arrow dataset stays memory-mapped — no bulk text loading. + """ + + def __init__( + self, + arrow_dataset: Dataset, + pack_ranges: list[tuple[int, int]], + text_key: str = "text", + ): + self.arrow_dataset = arrow_dataset + self.pack_ranges = pack_ranges + self.text_key = text_key + + def __len__(self) -> int: + return len(self.pack_ranges) + + def __getitem__(self, idx: int) -> dict[str, Any]: + start, end = self.pack_ranges[idx] + texts = self.arrow_dataset[start:end][self.text_key] + packed = "\n".join(t for t in texts if isinstance(t, str) and t) + return { + "messages": [{"role": "assistant", "content": packed}], + "task_name": "arrow_text_dataset", + } + + +class ArrowTextDataset(RawDataset): + """Dataset class for loading arrow files containing raw text. + + This class loads arrow files with a 'text' column and converts them to + the messages format expected by SFT training. + + The text is wrapped as an assistant message: + {"messages": [{"role": "assistant", "content": }]} + + This format allows training on all tokens (language modeling style). + + Optionally, multiple short samples can be concatenated (packed) into a + single sample up to ``characters_per_sample`` characters, separated by + newlines. This avoids padding waste when individual samples are short + relative to the model's context window. Works correctly with + cross-tokenizer distillation because packing happens at the raw-text + level before any tokenizer sees the data. + + Packing uses **lazy loading**: only character lengths are scanned at + init time (fast) to build pack boundaries; actual text is loaded + on-demand in ``__getitem__``. + + Args: + arrow_files: Path pattern (glob) or list of arrow file paths + val_split: Fraction of data to use for validation (default: 0.05) + seed: Random seed for train/val split + text_key: Key for text column in arrow files (default: "text") + characters_per_sample: If set, concatenate multiple texts (separated + by "\\n") until the accumulated character count reaches this + threshold before yielding one packed sample. To guarantee that + every packed sample tokenizes to at least ``max_input_seq_length`` + tokens (so truncation always fires and every training sequence is + exactly the context length), use ``max_input_seq_length * 8`` + (≈ 8 chars per token, matching tokenalign_upstream's default + ``characters_multiplier``). Using a smaller multiplier (e.g. 4) + may leave some samples shorter than the context window for dense + text such as code or CJK. Set to None or 0 to disable packing. + + Example config: + data: + dataset_name: "arrow_text" + arrow_files: "/path/to/data/*.arrow" + val_split: 0.05 + max_input_seq_length: 4096 + characters_per_sample: 32768 # 4096 tokens * 8 chars/token (guarantees full context) + """ + + def __init__( + self, + arrow_files: str | list[str], + val_split: float = 0.05, + seed: int = 42, + text_key: str = "text", + characters_per_sample: Optional[int] = None, + pack_cache_dir: Optional[str] = None, + **kwargs, + ): + # Don't call super().__init__() since RawDataset raises NotImplementedError + self.seed = seed + self.text_key = text_key + self.task_name = "arrow_text_dataset" + + # Resolve glob pattern if string + if isinstance(arrow_files, str): + file_list = glob.glob(arrow_files) + if not file_list: + raise ValueError(f"No arrow files found matching pattern: {arrow_files}") + else: + file_list = arrow_files + + print(f"Loading {len(file_list)} arrow files...") + dataset = load_dataset("arrow", data_files=file_list, split="train") + original_count = len(dataset) + print(f" ✓ Loaded {original_count} total samples") + + # Verify text column exists + if self.text_key not in dataset.column_names: + raise ValueError( + f"Column '{self.text_key}' not found in arrow files. " + f"Available columns: {dataset.column_names}" + ) + + if characters_per_sample is not None and characters_per_sample > 0: + # Lazy packing: scan character lengths to build pack boundaries, + # then load+concatenate text on demand in __getitem__. + # + # The scan is deterministic in (file set, text_key, chars/sample) + # and dominates dataset init time on large arrow corpora + # (~6 minutes for 70M rows). When ``pack_cache_dir`` is set we + # fingerprint the inputs and store the resulting boundaries as + # an ``int64[N, 2]`` ``.npy`` file so subsequent runs skip the + # scan entirely. + pack_ranges = _load_or_build_pack_ranges( + dataset=dataset, + file_list=file_list, + text_key=text_key, + characters_per_sample=characters_per_sample, + pack_cache_dir=pack_cache_dir, + ) + + # Split pack_ranges into train/val + if val_split > 0: + import random + rng = random.Random(seed) + indices = list(range(len(pack_ranges))) + rng.shuffle(indices) + val_count = max(1, int(len(pack_ranges) * val_split)) + val_indices = sorted(indices[:val_count]) + train_indices = sorted(indices[val_count:]) + train_ranges = [pack_ranges[i] for i in train_indices] + val_ranges = [pack_ranges[i] for i in val_indices] + else: + train_ranges = pack_ranges + val_ranges = pack_ranges[:min(100, len(pack_ranges))] + + train_dataset = _LazyPackedDataset(dataset, train_ranges, text_key) + val_dataset = _LazyPackedDataset(dataset, val_ranges, text_key) + else: + # No packing: convert text to messages format directly + def text_to_messages(example: dict[str, Any]) -> dict[str, Any]: + text = example[self.text_key] + return { + "messages": [{"role": "assistant", "content": text}], + "task_name": "arrow_text_dataset", + } + + formatted_dataset = dataset.map(text_to_messages, remove_columns=dataset.column_names) + + if val_split > 0: + split = formatted_dataset.train_test_split(test_size=val_split, seed=seed) + train_dataset = split["train"] + val_dataset = split["test"] + else: + train_dataset = formatted_dataset + val_dataset = formatted_dataset.select(range(min(100, len(formatted_dataset)))) + + print(f" ✓ Train: {len(train_dataset)}, Validation: {len(val_dataset)}") + + self.dataset = train_dataset + self.val_dataset = val_dataset + + +def _pack_cache_key( + file_list: list[str], text_key: str, characters_per_sample: int +) -> str: + """Fingerprint the inputs that determine the pack boundary list. + + Uses (sorted resolved path, size, mtime-as-int) per arrow file so the + cache invalidates automatically when any shard is rewritten or when the + file set changes. ``text_key`` and ``characters_per_sample`` are part of + the key because changing either yields different boundaries. + """ + parts = [] + for path in sorted(file_list): + st = os.stat(path) + parts.append((path, st.st_size, int(st.st_mtime))) + blob = json.dumps( + { + "files": parts, + "text_key": text_key, + "chars_per_sample": int(characters_per_sample), + }, + sort_keys=True, + ) + return hashlib.sha1(blob.encode()).hexdigest()[:16] + + +def _load_or_build_pack_ranges( + dataset: Dataset, + file_list: list[str], + text_key: str, + characters_per_sample: int, + pack_cache_dir: Optional[str], +) -> list[tuple[int, int]]: + """Scan + cache pack boundaries, or load from disk if a cache hit exists.""" + cache_path: Optional[str] = None + if pack_cache_dir: + os.makedirs(pack_cache_dir, exist_ok=True) + key = _pack_cache_key(file_list, text_key, characters_per_sample) + cache_path = os.path.join(pack_cache_dir, f"{key}.npy") + if os.path.exists(cache_path): + print(f" ↪ Loading cached pack boundaries from {cache_path}") + arr = np.load(cache_path) + pack_ranges = [(int(s), int(e)) for s, e in arr] + print(f" ✓ Loaded {len(pack_ranges)} pack boundaries (cache hit)") + return pack_ranges + + print( + f" Scanning character lengths for lazy packing (target ~{characters_per_sample} chars)..." + ) + t0 = time.time() + pack_ranges = _build_pack_ranges(dataset, text_key, characters_per_sample) + print( + f" ✓ Built {len(pack_ranges)} pack boundaries from {len(dataset)} samples in {time.time() - t0:.1f}s" + ) + + if cache_path is not None: + # Atomic write so concurrent jobs don't corrupt the cache file. + # Pass a file handle to np.save (not a path) — np.save auto-appends + # ".npy" when given a path that doesn't already end in .npy, which + # breaks the os.replace(tmp -> final) rename below. + tmp_path = f"{cache_path}.tmp.{os.getpid()}" + arr = np.asarray(pack_ranges, dtype=np.int64) + with open(tmp_path, "wb") as f: + np.save(f, arr) + os.replace(tmp_path, cache_path) + print(f" ↳ Cached pack boundaries to {cache_path}") + + return pack_ranges + + +def _build_pack_ranges( + dataset: Dataset, + text_key: str, + characters_per_sample: int, + scan_batch_size: int = 10000, +) -> list[tuple[int, int]]: + """Scan character lengths in batches and build pack boundary indices. + + Returns a list of (start_row, end_row) tuples. Each packed sample + covers rows [start_row, end_row) whose combined character length + (plus newline separators) meets or exceeds ``characters_per_sample``. + """ + n = len(dataset) + pack_ranges: list[tuple[int, int]] = [] + start = 0 + accum = 0 + + for batch_start in range(0, n, scan_batch_size): + batch_end = min(batch_start + scan_batch_size, n) + batch_texts = dataset[batch_start:batch_end][text_key] + for i, text in enumerate(batch_texts): + row_idx = batch_start + i + text_len = len(text) if isinstance(text, str) and text else 0 + if text_len == 0: + continue + # +1 for newline separator (except the first text in a pack) + accum += text_len + (1 if accum > 0 else 0) + if accum >= characters_per_sample: + pack_ranges.append((start, row_idx + 1)) + start = row_idx + 1 + accum = 0 + + # Flush remaining rows as a partial pack + if start < n and accum > 0: + pack_ranges.append((start, n)) + + return pack_ranges diff --git a/nemo_rl/data/processors.py b/nemo_rl/data/processors.py index 091c5ad1c5..ed3ac29465 100644 --- a/nemo_rl/data/processors.py +++ b/nemo_rl/data/processors.py @@ -187,6 +187,72 @@ def sft_processor( return output +def kd_data_processor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer: TokenizerType, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a datum for knowledge-distillation training on raw text. + + Cross-tokenizer KD requires both student and teacher to see the same + raw text (no chat template, no instruction formatting), with loss + applied to every token. The raw text is preserved in + ``extra_env_info`` so the teacher worker can retokenize it with its + own tokenizer. + + Args: + datum_dict: Sample dict with a ``messages`` list whose ``content`` + fields are concatenated with newlines into the raw text. + task_data_spec: Task spec (unused; present for protocol parity). + tokenizer: Student tokenizer used for the student-side token ids. + max_seq_length: Truncation cap; samples longer than this are + zero-weighted via ``loss_multiplier``. + idx: Index of the sample within the dataset. + + Returns: + DatumSpec with a single assistant-role ``message_log`` entry, + ``token_loss_mask`` set to all-ones (loss on every token), and + ``extra_env_info["raw_text"]`` carrying the joined text. + """ + raw_text = "\n".join( + msg["content"] + for msg in datum_dict["messages"] + if isinstance(msg.get("content"), str) + ) + + token_ids = tokenizer( + raw_text, + return_tensors="pt", + add_special_tokens=True, + max_length=max_seq_length, + truncation=True, + )["input_ids"][0] + + length = len(token_ids) + loss_multiplier = 1.0 + if length > max_seq_length: + loss_multiplier = 0.0 + + message_log = [ + { + "role": "assistant", + "content": raw_text, + "token_ids": token_ids, + "token_loss_mask": torch.ones_like(token_ids), + } + ] + + return { + "message_log": message_log, + "length": length, + "extra_env_info": {"raw_text": raw_text}, + "loss_multiplier": loss_multiplier, + "idx": idx, + } + + def preference_preprocessor( datum_dict: dict[str, Any], task_data_spec: TaskDataSpec, @@ -718,6 +784,7 @@ def nemo_gym_data_processor( { "default": math_hf_data_processor, "helpsteer3_data_processor": helpsteer3_data_processor, + "kd_data_processor": kd_data_processor, "math_data_processor": math_data_processor, "math_hf_data_processor": math_hf_data_processor, "multichoice_qa_processor": multichoice_qa_processor, diff --git a/nemo_rl/data/utils.py b/nemo_rl/data/utils.py index 2819e27582..b53adc6522 100644 --- a/nemo_rl/data/utils.py +++ b/nemo_rl/data/utils.py @@ -134,7 +134,10 @@ def setup_response_data( } else: # merge datasets into a single dataset - merged_data = concatenate_datasets([data.dataset for data in data_list]) + if len(data_list) == 1: + merged_data = data_list[0].dataset + else: + merged_data = concatenate_datasets([data.dataset for data in data_list]) dataset = AllTaskProcessedDataset( merged_data, tokenizer, @@ -199,7 +202,10 @@ def setup_response_data( # merge datasets val_dataset = None if len(val_data_list) > 0: - merged_val_data = concatenate_datasets(val_data_list) + if len(val_data_list) == 1: + merged_val_data = val_data_list[0] + else: + merged_val_data = concatenate_datasets(val_data_list) val_dataset = AllTaskProcessedDataset( merged_val_data, tokenizer, diff --git a/tests/unit/data/test_data_processor.py b/tests/unit/data/test_data_processor.py index 4e20788aaa..fe77d7926d 100644 --- a/tests/unit/data/test_data_processor.py +++ b/tests/unit/data/test_data_processor.py @@ -39,6 +39,7 @@ from nemo_rl.data.interfaces import TaskDataProcessFnCallable, TaskDataSpec from nemo_rl.data.processors import ( helpsteer3_data_processor, + kd_data_processor, math_data_processor, math_hf_data_processor, ) @@ -60,10 +61,19 @@ def apply_chat_template( content += "assistant:" return content - def __call__(self, text, return_tensors=None, add_special_tokens=False): + def __call__( + self, + text, + return_tensors=None, + add_special_tokens=False, + max_length=None, + truncation=False, + ): if isinstance(text, list): text = "".join(text) encoded = list(range(len(text))) + if truncation and max_length is not None: + encoded = encoded[:max_length] if return_tensors == "pt": return {"input_ids": torch.tensor([encoded], dtype=torch.long)} return {"input_ids": encoded} @@ -348,3 +358,98 @@ def test_helpsteer3_data_processor(): # Length equals sum of token lengths assert out["length"] == sum(int(m["token_ids"].numel()) for m in msg_log) + + +def test_kd_data_processor(): + datum_dict = { + "messages": [ + {"role": "assistant", "content": "Hello world"}, + {"role": "assistant", "content": "Goodbye"}, + ], + } + tokenizer = DummyTokenizer() + task_spec = TaskDataSpec( + task_name="arrow_text_dataset", + prompt_file=None, + system_prompt_file=None, + ) + + out = kd_data_processor( + datum_dict=datum_dict, + task_data_spec=task_spec, + tokenizer=tokenizer, + max_seq_length=128, + idx=7, + ) + + expected_raw_text = "Hello world\nGoodbye" + assert out["extra_env_info"]["raw_text"] == expected_raw_text + assert out["loss_multiplier"] == 1.0 + assert out["idx"] == 7 + + # Single assistant-role entry; raw text is mirrored back into the message. + assert len(out["message_log"]) == 1 + msg = out["message_log"][0] + assert msg["role"] == "assistant" + assert msg["content"] == expected_raw_text + + # KD invariant: loss applies to every token. + assert isinstance(msg["token_ids"], torch.Tensor) + assert isinstance(msg["token_loss_mask"], torch.Tensor) + assert msg["token_loss_mask"].shape == msg["token_ids"].shape + assert torch.all(msg["token_loss_mask"] == 1) + + # Reported length matches the tokenized output. + assert out["length"] == int(msg["token_ids"].numel()) + + +def test_kd_data_processor_truncates_to_max_seq_length(): + long_text = "x" * 200 + datum_dict = {"messages": [{"role": "assistant", "content": long_text}]} + tokenizer = DummyTokenizer() + task_spec = TaskDataSpec( + task_name="arrow_text_dataset", + prompt_file=None, + system_prompt_file=None, + ) + + out = kd_data_processor( + datum_dict=datum_dict, + task_data_spec=task_spec, + tokenizer=tokenizer, + max_seq_length=64, + idx=0, + ) + + msg = out["message_log"][0] + assert int(msg["token_ids"].numel()) == 64 + assert out["length"] == 64 + # Truncated length equals the cap (not strictly greater) — loss stays on. + assert out["loss_multiplier"] == 1.0 + + +def test_kd_data_processor_skips_non_string_content(): + datum_dict = { + "messages": [ + {"role": "assistant", "content": "kept"}, + {"role": "assistant", "content": None}, + {"role": "assistant"}, # no content key + {"role": "assistant", "content": "also kept"}, + ], + } + tokenizer = DummyTokenizer() + task_spec = TaskDataSpec( + task_name="arrow_text_dataset", + prompt_file=None, + system_prompt_file=None, + ) + + out = kd_data_processor( + datum_dict=datum_dict, + task_data_spec=task_spec, + tokenizer=tokenizer, + max_seq_length=128, + idx=0, + ) + + assert out["extra_env_info"]["raw_text"] == "kept\nalso kept" From 70a14eb36bcee4aba655fd86969c84695bb0f717 Mon Sep 17 00:00:00 2001 From: Adithyakrishna Hanasoge Date: Sun, 26 Apr 2026 19:57:48 -0700 Subject: [PATCH 3/4] feat: cross-tokenizer distillation loss and multi-teacher aggregator Adds the loss-fn layer for cross-tokenizer distillation. Builds on the TokenAligner package (PR 1). - CrossTokenizerDistillationLossFn: per-token KL/CE loss over 1:1 aligned positions, with optional gold-loss path. Holds a reference to a TokenAligner; teacher data (input_ids, aligned_pairs, optional chunked COO masks) is set per-step via set_cross_tokenizer_data. - CrossTokenizerDistillationLossConfig and CrossTokenizerDistillationLossDataDict TypedDicts. - MultiTeacherLossAggregator: wraps a list of optional CrossTokenizerDistillationLossFn instances with per-teacher weights. N=1 is a degenerate case used by the unified single-/multi-teacher worker path; the algorithm-layer multi-teacher orchestration comes in a later PR. - _scatter_chunk_mask_from_coo helper for the chunked-CE path. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Adithyakrishna Hanasoge --- nemo_rl/algorithms/loss/loss_functions.py | 970 ++++++++++++++++++++++ 1 file changed, 970 insertions(+) diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index df6ff6bc54..3a1cbdd671 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -1035,3 +1035,973 @@ def __call__( } return kl_loss, metrics + + +class CrossTokenizerDistillationLossConfig(TypedDict): + """Configuration for cross-tokenizer distillation loss.""" + loss_type: str # 'KL', 'cross_entropy', or 'chunked_ce' + temperature: float # Softmax temperature + vocab_topk: int # Reduce teacher vocab to top-k (0 = all) + exact_token_match_only: bool # Only use 1:1 aligned positions + reverse_kl: bool # Reverse KL direction + project_teacher_to_student: NotRequired[bool] + gold_loss: NotRequired[bool] # Use gold loss (common KL + uncommon L1, no projection) + xtoken_loss: NotRequired[bool] # Relaxed exact-map threshold (>=0.6 instead of ==1.0) + ce_loss_scale: NotRequired[float] # Scale for additional CE (next-token) loss (0.0 = disabled) + dynamic_loss_scaling: NotRequired[bool] # Scale KL loss to match CE magnitude + + +class CrossTokenizerDistillationLossDataDict(TypedDict): + """Data dict for cross-tokenizer distillation. + + Only contains student-side tensors (same sequence dimension). + Teacher-side data (teacher_input_ids, aligned_pairs) is stored on the + loss function instance via set_cross_tokenizer_data() to avoid + sequence-length mismatches in the worker's shape validation. + """ + input_ids: torch.Tensor # Student token IDs (B, S_student) + input_lengths: torch.Tensor + token_mask: torch.Tensor # (B, S_student) + sample_mask: torch.Tensor # (B,) + + +def _scatter_chunk_mask_from_coo( + coo_list: list[torch.Tensor], + batch_size: int, + seq_len: int, + total_chunks: int, + device: torch.device, +) -> torch.Tensor: + """Materialize a dense ``(batch_size, seq_len, total_chunks)`` bool mask. + + ``coo_list[b]`` is a CPU ``LongTensor (N_b, 2)`` with rows ``[pos, chunk_id]`` + precomputed by ``CrossTokenizerCollator._build_chunk_coo`` (already + filtered for ``exact_match_only`` / ``-1`` sentinels / padded-length bounds). + + ``chunk_id`` in each sample's COO is in ``[0, num_chunks_per_sample)``; + samples with fewer chunks than ``total_chunks`` leave the padded columns + all-False, which the downstream ``chunk_valid = (chunk_sizes > 0)`` gate + already drops. + """ + parts: list[torch.Tensor] = [] + for b, coo in enumerate(coo_list): + if coo.shape[0] == 0: + continue + bcol = torch.full((coo.shape[0], 1), b, dtype=torch.int64) + parts.append(torch.cat([bcol, coo], dim=1)) + mask = torch.zeros( + batch_size, seq_len, total_chunks, dtype=torch.bool, device=device, + ) + if parts: + idx = torch.cat(parts, dim=0).to(device) + mask[idx[:, 0], idx[:, 1], idx[:, 2]] = True + return mask + + +class CrossTokenizerDistillationLossFn(LossFunction): + """Cross-tokenizer distillation loss using TokenAligner's projection matrix. + + Computes per-token KL divergence between projected student probabilities + (in teacher vocab space) and teacher probabilities, only at positions where + the two tokenizations have 1:1 aligned tokens. Uses NeMo RL's standard + masked_mean normalization so loss magnitude is comparable to same-tokenizer + distillation. + + Teacher-specific data (teacher_input_ids, aligned_pairs) is stored on + this object via set_cross_tokenizer_data() before each training step, + rather than in the data dict, because teacher and student sequences + have different lengths and the worker validates that all tensors in + the data dict share the same sequence dimension. + """ + + def __init__(self, cfg: CrossTokenizerDistillationLossConfig, token_aligner): + from nemo_rl.algorithms.x_token.tokenalign import TokenAligner + assert isinstance(token_aligner, TokenAligner) + self.token_aligner = token_aligner + self.cfg = cfg + self.loss_type = LossType.TOKEN_LEVEL + self._teacher_input_ids = None + self._aligned_pairs = None + self._chunk_indices: Optional[dict[str, list]] = None + + def set_cross_tokenizer_data( + self, + teacher_input_ids: torch.Tensor, + aligned_pairs: list, + chunk_indices: Optional[dict[str, list]] = None, + ): + """Store teacher-side data before each training step. + + Called from the training loop before student_policy.train(). + The worker never sees these tensors in shape validation. + + ``chunk_indices`` carries the per-sample COO chunk-mask indices that + used to be rebuilt inside ``__call__`` every microbatch. When set, it + is a dict with keys ``student_chunk_coo``, ``teacher_chunk_coo``, + ``num_chunks``, each a DP-sharded list of length ``batch_size``. + """ + self._teacher_input_ids = teacher_input_ids + self._aligned_pairs = aligned_pairs + self._chunk_indices = chunk_indices + + def _project_student_to_teacher( + self, + student_logits: torch.Tensor, + teacher_vocab_size: int, + temperature: float, + global_top_indices: torch.Tensor, + device: torch.device, + precomputed_student_probs: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Project student logits into the reduced teacher vocabulary space. + + Returns projected student probabilities of shape (B, S_student, K) + where K = len(global_top_indices). + + If `precomputed_student_probs` is provided, it is used directly instead + of recomputing softmax(student_logits / temperature). The caller is + responsible for ensuring it was computed with the same temperature. + """ + if precomputed_student_probs is not None: + student_probs = precomputed_student_probs + else: + student_probs = torch.softmax(student_logits / temperature, dim=-1) + + has_sparse = ( + hasattr(self.token_aligner, 'sparse_transformation_matrix') + and self.token_aligner.sparse_transformation_matrix is not None + ) + if has_sparse: + sparse_mat = self.token_aligner.sparse_transformation_matrix + reduced_sparse = sparse_mat.index_select(1, global_top_indices).coalesce() + projected = self.token_aligner.project_token_likelihoods_instance( + student_probs, None, None, None, device, + use_sparse_format=True, + sparse_matrix=reduced_sparse, + ) + return projected + + proj_values = self.token_aligner.likelihood_projection_matrix + if getattr(self.token_aligner, 'learnable', False): + proj_values = self.token_aligner.transform_learned_matrix_instance(proj_values) + projected_full = self.token_aligner.project_token_likelihoods_instance( + student_probs, self.token_aligner.likelihood_projection_indices, + proj_values, teacher_vocab_size, device, + use_sparse_format=False, + ) + return projected_full[:, :, global_top_indices] + + def _compute_gold_loss( + self, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + student_chunk_mask: torch.Tensor, + teacher_chunk_mask: torch.Tensor, + batch_size: int, + student_seq_len: int, + teacher_seq_len: int, + teacher_vocab_size: int, + temperature: float, + reverse_kl: bool, + xtoken_loss: bool, + device: torch.device, + precomputed_student_log_probs: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, float]: + """Gold loss: common-vocab KL + uncommon-vocab sorted L1. + + Splits the vocabulary into tokens with exact 1:1 projection mappings + ("common") and the rest ("uncommon"). Common tokens are compared + directly via KL on their native log-probs (no projection needed). + Uncommon tokens are compared via L1 on sorted probability vectors + (Universal Likelihood Distillation). + + Matches tokenalign.py compute_KL_loss_optimized gold_loss branch. + """ + partition = self.token_aligner.build_vocab_partition( + xtoken_loss=xtoken_loss, + teacher_vocab_size=teacher_vocab_size, + ) + common_student_indices = partition.common_student_indices.to(device) + common_teacher_indices = partition.common_teacher_indices.to(device) + uncommon_student_indices = partition.uncommon_student_indices.to(device) + uncommon_teacher_indices = partition.uncommon_teacher_indices.to(device) + + # student_chunk_mask / teacher_chunk_mask are precomputed by the + # caller (from collator-emitted per-sample COO) and shared with the + # non-gold path — shape (B, seq_len, total_chunks), dtype bool. + total_chunks = student_chunk_mask.shape[-1] + + # log_softmax on full original logits BEFORE chunk averaging + if precomputed_student_log_probs is not None: + student_log_probs = precomputed_student_log_probs + else: + student_log_probs = torch.log_softmax(student_logits / temperature, dim=-1) + teacher_log_probs = torch.log_softmax(teacher_logits / temperature, dim=-1) + + # Chunk-average log-probs over full vocabularies + student_chunk_lp = torch.bmm( + student_chunk_mask.transpose(1, 2).to(student_log_probs.dtype), student_log_probs, + ) + teacher_chunk_lp = torch.bmm( + teacher_chunk_mask.transpose(1, 2).to(teacher_log_probs.dtype), teacher_log_probs, + ) + del student_log_probs, teacher_log_probs + + student_chunk_sizes = student_chunk_mask.sum(dim=1, keepdim=True).float().transpose(1, 2) + teacher_chunk_sizes = teacher_chunk_mask.sum(dim=1, keepdim=True).float().transpose(1, 2) + + student_chunk_lp = student_chunk_lp / (student_chunk_sizes + 1e-10) + teacher_chunk_lp = teacher_chunk_lp / (teacher_chunk_sizes + 1e-10) + + chunk_valid = (student_chunk_sizes.squeeze(-1) > 0) & (teacher_chunk_sizes.squeeze(-1) > 0) + + if not chunk_valid.any(): + return torch.tensor(0.0, device=device, requires_grad=True), 0.0 + + # --- Part 1: KL on common (exactly-mapped) vocab --- + loss_kl_common = torch.tensor(0.0, device=device, requires_grad=True) + if common_student_indices.numel() > 0: + s_common = student_chunk_lp[:, :, common_student_indices] + t_common = teacher_chunk_lp[:, :, common_teacher_indices] + + if not reverse_kl: + kl_elem = torch.nn.functional.kl_div( + s_common, t_common, reduction="none", log_target=True, + ) + else: + kl_elem = torch.nn.functional.kl_div( + t_common, s_common, reduction="none", log_target=True, + ) + kl_per_chunk = kl_elem.sum(dim=-1) * chunk_valid + if chunk_valid.sum() > 0: + loss_kl_common = kl_per_chunk.sum() / chunk_valid.sum() + + # --- Part 2: L1 on uncommon (unaligned) vocab --- + loss_l1_uncommon = torch.tensor(0.0, device=device, requires_grad=True) + if uncommon_student_indices.numel() > 0 or uncommon_teacher_indices.numel() > 0: + s_uncommon = ( + student_chunk_lp[:, :, uncommon_student_indices] + if uncommon_student_indices.numel() > 0 + else torch.empty(batch_size, total_chunks, 0, device=device) + ) + t_uncommon = ( + teacher_chunk_lp[:, :, uncommon_teacher_indices] + if uncommon_teacher_indices.numel() > 0 + else torch.empty(batch_size, total_chunks, 0, device=device) + ) + + s_valid = s_uncommon[chunk_valid] + t_valid = t_uncommon[chunk_valid] + + if s_valid.shape[0] > 0: + with torch.no_grad(): + max_uncommon_vocab = min(s_valid.shape[-1], t_valid.shape[-1], 8192) + + if max_uncommon_vocab > 0: + s_probs = torch.exp(s_valid) + t_probs = torch.exp(t_valid) + + if s_probs.shape[-1] > max_uncommon_vocab: + s_sorted, _ = torch.topk(s_probs, k=max_uncommon_vocab, dim=-1, largest=True) + else: + s_sorted = torch.sort(s_probs, dim=-1, descending=True)[0] + + if t_probs.shape[-1] > max_uncommon_vocab: + t_sorted, _ = torch.topk(t_probs, k=max_uncommon_vocab, dim=-1, largest=True) + else: + t_sorted = torch.sort(t_probs, dim=-1, descending=True)[0] + + del s_probs, t_probs + min_len = min(s_sorted.shape[-1], t_sorted.shape[-1]) + if min_len > 0: + loss_l1_per_chunk = torch.nn.functional.l1_loss( + s_sorted[:, :min_len], t_sorted[:, :min_len], reduction='none', + ).sum(dim=-1) + loss_l1_uncommon = loss_l1_per_chunk.mean() + del loss_l1_per_chunk + del s_sorted, t_sorted + + loss_total = (loss_kl_common + loss_l1_uncommon) * (temperature ** 2) + + # Top-1 accuracy on common vocab. Computed as a 0-d GPU tensor so the + # caller can include it in the batched .cpu().tolist() sync at the end + # of the step (instead of forcing two CPU<->GPU stalls per teacher per + # microbatch here). + top1_accuracy: Union[float, torch.Tensor] = 0.0 + with torch.no_grad(): + if common_student_indices.numel() > 0 and chunk_valid.any(): + s_valid_lp = student_chunk_lp[chunk_valid][:, common_student_indices] + t_valid_lp = teacher_chunk_lp[chunk_valid][:, common_teacher_indices] + matches = ( + s_valid_lp.argmax(dim=-1) == t_valid_lp.argmax(dim=-1) + ).sum() + denom = chunk_valid.sum().clamp(min=1) + top1_accuracy = matches.to(torch.float32) / denom.to(torch.float32) + + del student_chunk_lp, teacher_chunk_lp + return loss_total, top1_accuracy + + def __call__( + self, + next_token_logits: torch.Tensor, + data: CrossTokenizerDistillationLossDataDict, + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + teacher_logits: Optional[torch.Tensor] = None, + mb_idx: Optional[int] = None, + mbs: Optional[int] = None, + teacher_topk_indices_ipc: Optional[torch.Tensor] = None, + _return_raw_kl: bool = False, + precomputed_student_logits_f32: Optional[torch.Tensor] = None, + precomputed_student_probs: Optional[torch.Tensor] = None, + precomputed_student_log_probs: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, dict[str, Any]]: + """Compute cross-tokenizer distillation loss via chunk-averaged KL. + + For each alignment chunk (1:1, 1:many, many:1, or many:many), the + projected student and teacher distributions are averaged over their + respective spans, renormalized, and compared via KL divergence. + The per-chunk KL is then distributed back to student positions + and normalized with the standard NeMo RL masked_mean. + + The three ``precomputed_*`` kwargs let an outer aggregator hoist + student-side work that does not depend on the teacher (fp32 cast, + softmax, log_softmax) out of the per-teacher loop and share it + across multiple teachers with the same temperature. + """ + input_ids_student = data["input_ids"] + batch_size = input_ids_student.shape[0] + + # Keep logits in their native dtype (typically bf16). The downstream + # log_softmax / softmax / bmm ops on CUDA upcast to fp32 internally for + # numerics while storing activations in bf16, which roughly halves the + # working-set memory of the gold-loss / projection paths. + if precomputed_student_logits_f32 is not None: + student_logits = precomputed_student_logits_f32 + elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): + student_logits = next_token_logits.full_tensor() + else: + student_logits = next_token_logits + + if teacher_logits is None: + raise ValueError( + "CrossTokenizerDistillationLossFn requires teacher_logits via IPC. " + "Set use_ipc=True in the distillation config." + ) + if self._aligned_pairs is None or self._teacher_input_ids is None: + raise ValueError( + "Cross-tokenizer data not set. " + "Call loss_fn.set_cross_tokenizer_data() before training." + ) + + # IPC contract (XTokenTeacherIPCExportPostProcessor, full-vocab branch): + # raw logits in their native dtype (bf16 in practice). The log_softmax + # calls below in _compute_gold_loss / the projection path are the only + # ones in the teacher pipeline. Variable name kept for diff stability. + if isinstance(teacher_logits, torch.distributed.tensor.DTensor): + teacher_logits_f32 = teacher_logits.full_tensor() + else: + teacher_logits_f32 = teacher_logits + + if teacher_logits_f32.shape[-1] == 0: + raise ValueError( + f"Teacher logits have vocab dimension 0 (shape={teacher_logits_f32.shape}). " + "This typically means topk_logits=0 was passed instead of None " + "for the teacher forward pass. Cross-tokenizer distillation " + "requires full teacher logits (topk_logits=None)." + ) + + if mb_idx is not None and mbs is not None: + mb_start = mb_idx * mbs + mb_end = mb_start + batch_size + else: + mb_start = 0 + mb_end = batch_size + + self.token_aligner = self.token_aligner.to(student_logits.device) + device = student_logits.device + + temperature = self.cfg.get("temperature", 1.0) + vocab_topk = self.cfg.get("vocab_topk", 8192) + reverse_kl = self.cfg.get("reverse_kl", False) + use_gold_loss = self.cfg.get("gold_loss", False) + use_xtoken_loss = self.cfg.get("xtoken_loss", False) + student_seq_len = student_logits.shape[1] + teacher_seq_len = teacher_logits_f32.shape[1] + teacher_vocab_size = teacher_logits_f32.shape[-1] + + # The collator pre-applies exact_match_only, sentinel (-1), and + # padded-length bounds filters, and emits the chunk mask in COO form + # per sample. The loss fn just slices the MB range and scatters. + if self._chunk_indices is None: + raise ValueError( + "CrossTokenizerDistillationLossFn requires chunk_indices. " + "CrossTokenizerCollator should have precomputed them; verify " + "the training loop forwards them through update_cross_tokenizer_data()." + ) + student_coo_list = self._chunk_indices["student_chunk_coo"][mb_start:mb_end] + teacher_coo_list = self._chunk_indices["teacher_chunk_coo"][mb_start:mb_end] + num_chunks_list = self._chunk_indices["num_chunks"][mb_start:mb_end] + total_chunks = max(num_chunks_list) if num_chunks_list else 0 + + if total_chunks == 0: + loss = torch.tensor(0.0, device=device, requires_grad=True) + return loss, {"loss": 0.0, "topk_accuracy": 0.0, "num_chunks": 0} + + proj_mask = _scatter_chunk_mask_from_coo( + student_coo_list, batch_size, student_seq_len, total_chunks, device, + ) + tgt_mask = _scatter_chunk_mask_from_coo( + teacher_coo_list, batch_size, teacher_seq_len, total_chunks, device, + ) + num_valid_chunks_total = int(sum(num_chunks_list)) + + # ================================================================ + # Gold loss path: common-vocab KL + uncommon-vocab sorted L1. + # Bypasses the projection matrix for tokens with exact 1:1 mappings. + # Matches tokenalign.py compute_KL_loss_optimized gold_loss branch. + # ================================================================ + if use_gold_loss: + loss, top1_accuracy = self._compute_gold_loss( + student_logits, teacher_logits_f32, proj_mask, tgt_mask, + batch_size, student_seq_len, teacher_seq_len, + teacher_vocab_size, + temperature, reverse_kl, use_xtoken_loss, device, + precomputed_student_log_probs=precomputed_student_log_probs, + ) + else: + # ================================================================ + # Standard projection-based path + # ================================================================ + + # -- 3. Global vocabulary filtering (top-k teacher tokens) -- + with torch.no_grad(): + if vocab_topk == 0 or vocab_topk >= teacher_vocab_size: + global_top_indices = torch.arange(teacher_vocab_size, device=device) + else: + teacher_flat = teacher_logits_f32.view(-1, teacher_vocab_size) + importance = teacher_flat.max(dim=0)[0] + _, global_top_indices = torch.topk( + importance, k=min(vocab_topk, teacher_vocab_size), dim=-1, + ) + global_top_indices = global_top_indices.sort()[0] + + # -- 4. Project student probs to teacher vocab -- + projected_student = self._project_student_to_teacher( + student_logits, teacher_vocab_size, temperature, global_top_indices, device, + precomputed_student_probs=precomputed_student_probs, + ) + + # -- 5. Teacher log-probs in reduced vocab -- + teacher_logits_reduced = teacher_logits_f32[:, :, global_top_indices] + teacher_log_probs = torch.log_softmax(teacher_logits_reduced / temperature, dim=-1) + del teacher_logits_reduced + + # -- 6. Chunk-averaged distributions -- + proj_chunks = torch.bmm( + proj_mask.transpose(1, 2).to(projected_student.dtype), projected_student, + ) + tgt_log_chunks = torch.bmm( + tgt_mask.transpose(1, 2).to(teacher_log_probs.dtype), teacher_log_probs, + ) + del projected_student, teacher_log_probs + + proj_sizes = proj_mask.sum(dim=1).unsqueeze(-1).to(proj_chunks.dtype) + tgt_sizes = tgt_mask.sum(dim=1).unsqueeze(-1).to(tgt_log_chunks.dtype) + + proj_chunks = proj_chunks / (proj_sizes + 1e-10) + tgt_log_chunks = tgt_log_chunks / (tgt_sizes + 1e-10) + + proj_chunks = proj_chunks / (proj_chunks.sum(dim=-1, keepdim=True) + 1e-10) + proj_log_chunks = torch.log(proj_chunks + 1e-10) + + chunk_valid = (proj_sizes.squeeze(-1) > 0) & (tgt_sizes.squeeze(-1) > 0) + + # -- 7. KL divergence per chunk -- + if reverse_kl: + kl_per_elem = torch.nn.functional.kl_div( + tgt_log_chunks, proj_log_chunks, reduction="none", log_target=True, + ) + else: + kl_per_elem = torch.nn.functional.kl_div( + proj_log_chunks, tgt_log_chunks, reduction="none", log_target=True, + ) + kl_per_chunk = kl_per_elem.sum(dim=-1) * (temperature ** 2) + kl_per_chunk = kl_per_chunk * chunk_valid + del proj_chunks, tgt_log_chunks, proj_log_chunks, kl_per_elem + + # -- 8. Scalar loss -- + num_valid_chunks = chunk_valid.sum() + if num_valid_chunks > 0: + loss = kl_per_chunk.sum() / num_valid_chunks + else: + loss = torch.tensor(0.0, device=device, requires_grad=True) + top1_accuracy = 0.0 + + # Keep a stable alias for raw-KL return path and optional CE fusion path. + kl_loss = loss + if _return_raw_kl: + # Return scalar metrics as GPU tensors so the outer aggregator can + # batch the CPU sync for all teachers / metrics into a single + # .cpu().tolist() call. The aggregator unwraps tensor entries. + raw_metrics: dict[str, Any] = { + "loss": kl_loss if isinstance(kl_loss, torch.Tensor) else kl_loss, + "kl_loss": kl_loss if isinstance(kl_loss, torch.Tensor) else kl_loss, + "topk_accuracy": top1_accuracy, + "num_valid_samples": int(batch_size), + "num_chunks": num_valid_chunks_total, + } + return kl_loss, raw_metrics + + # ================================================================ + # Optional CE (next-token prediction) loss, matching the DDP + # train_distillation_ddp.py logic: + # without dynamic scaling: loss = kl * kl_weight + ce * ce_scale + # with dynamic scaling: loss = kl * (ce/kl) + ce + # ================================================================ + ce_loss_scale = self.cfg.get("ce_loss_scale", 0.0) + dynamic_loss_scaling = self.cfg.get("dynamic_loss_scaling", False) + ce_loss_value = 0.0 + + if ce_loss_scale > 0.0 or dynamic_loss_scaling: + # Mask padding positions so CE loss only covers real tokens. + # token_mask[:, 1:] marks valid next-token targets (shifted by 1). + token_mask = data["token_mask"] + ce_mask = token_mask[:, 1 : student_seq_len].to(torch.bool) + ce_targets = input_ids_student[:, 1:student_seq_len].clone() + ce_targets[~ce_mask] = -100 + ce_loss = torch.nn.functional.cross_entropy( + student_logits[:, :student_seq_len - 1].reshape(-1, student_logits.shape[-1]), + ce_targets.reshape(-1), + ignore_index=-100, + ) + ce_loss_value = float(ce_loss.item()) + + if dynamic_loss_scaling and kl_loss.item() > 0: + dls_scale = ce_loss.item() / kl_loss.item() + loss = kl_loss * dls_scale + ce_loss + else: + loss = kl_loss + ce_loss * ce_loss_scale + + # Scale for NeMo RL distributed training + token_mask = data["token_mask"] + sample_mask = data["sample_mask"] + max_len = min(token_mask.shape[1] - 1, student_seq_len) + local_mask = token_mask[:, 1 : max_len + 1] * sample_mask.unsqueeze(-1) + local_valid_toks = local_mask.sum() + + if local_valid_toks > 0 and global_valid_toks > 0: + tok_scale = float(local_valid_toks / global_valid_toks) + loss = loss * tok_scale + else: + tok_scale = 0.0 + loss = loss * 0.0 + + num_valid = num_valid_chunks_total + metrics = { + "loss": float(loss.item()) if loss.ndim == 0 else loss, + "kl_loss": float(kl_loss.item()) * tok_scale, + "ce_loss": ce_loss_value * tok_scale, + "topk_accuracy": top1_accuracy, + "num_valid_samples": int(batch_size), + "num_chunks": num_valid, + "alignment_density": num_valid / max(1, batch_size * student_seq_len), + } + + return loss, metrics + + +class MultiTeacherLossAggregator(LossFunction): + """Aggregate weighted losses from multiple teachers in a unified path.""" + + def __init__( + self, + loss_fns: list[Optional[CrossTokenizerDistillationLossFn]], + weights: list[float], + normalize_by_vocab: bool = False, + cfg: Optional[dict[str, Any]] = None, + ): + assert len(loss_fns) == len(weights), ( + f"loss_fns ({len(loss_fns)}) and weights ({len(weights)}) length mismatch" + ) + self.loss_fns = loss_fns + self.weights = weights + self.normalize_by_vocab = normalize_by_vocab + self.cfg = cfg or {} + self.teacher_aggregation_mode = self.cfg.get("teacher_aggregation_mode", "weighted") + if self.teacher_aggregation_mode not in {"weighted", "routing", "average"}: + raise ValueError( + "teacher_aggregation_mode must be one of {'weighted', 'routing', 'average'}, " + f"got '{self.teacher_aggregation_mode}'" + ) + self.loss_type = LossType.TOKEN_LEVEL + # Marks this loss fn as expecting student logits as its primary input. + # Required by `prepare_loss_input` so the unified single-/multi-teacher + # cross-tokenizer path can flow through this aggregator. + self.input_type = LossInputType.LOGIT + + def set_cross_tokenizer_data( + self, + teacher_input_ids: torch.Tensor, + aligned_pairs: list, + teacher_idx: Optional[int] = None, + chunk_indices: Optional[dict[str, list]] = None, + ) -> None: + # When called from the single-teacher dispatch (no explicit teacher_idx), + # default to the only teacher slot so the unified worker path keeps + # working without per-call branching upstream. + if teacher_idx is None: + if len(self.loss_fns) != 1: + raise ValueError( + "set_cross_tokenizer_data requires teacher_idx when " + f"len(loss_fns) > 1 (got {len(self.loss_fns)})" + ) + teacher_idx = 0 + fn = self.loss_fns[teacher_idx] + if fn is not None: + fn.set_cross_tokenizer_data( + teacher_input_ids, aligned_pairs, chunk_indices=chunk_indices, + ) + + def _compute_same_tokenizer_kl( + self, + next_token_logits: torch.Tensor, + teacher_logits: torch.Tensor, + data: CrossTokenizerDistillationLossDataDict, + teacher_topk_indices_ipc: Optional[torch.Tensor], + ) -> torch.Tensor: + student_logits = next_token_logits.to(torch.float32) + t_logits = teacher_logits.to(student_logits.device, dtype=torch.float32) + + seq_len = student_logits.shape[1] - 1 + student_shifted = student_logits[:, :-1] + + if teacher_topk_indices_ipc is None: + teacher_logprobs = torch.nn.functional.log_softmax(t_logits[:, :seq_len], dim=-1) + student_logprobs = torch.nn.functional.log_softmax(student_shifted, dim=-1) + per_token_kl = ( + teacher_logprobs.exp() * (teacher_logprobs - student_logprobs) + ).sum(dim=-1) + else: + topk_idx = teacher_topk_indices_ipc[:, :seq_len].to(student_shifted.device) + teacher_topk = t_logits[:, :seq_len] + student_logprobs = torch.nn.functional.log_softmax(student_shifted, dim=-1) + student_topk = torch.gather(student_logprobs, dim=-1, index=topk_idx) + teacher_topk_probs = teacher_topk.exp() + teacher_rest = (1.0 - teacher_topk_probs.sum(dim=-1, keepdim=True)).clamp(min=1e-10) + teacher_probs_full = torch.cat([teacher_topk_probs, teacher_rest], dim=-1) + teacher_logprobs_full = torch.cat([teacher_topk, teacher_rest.log()], dim=-1) + student_topk_probs = student_topk.exp() + student_rest = (1.0 - student_topk_probs.sum(dim=-1, keepdim=True)).clamp(min=1e-10) + student_logprobs_full = torch.cat([student_topk, student_rest.log()], dim=-1) + per_token_kl = ( + teacher_probs_full * (teacher_logprobs_full - student_logprobs_full) + ).sum(dim=-1) + + token_mask = data["token_mask"][:, 1 : seq_len + 1] + sample_mask = data["sample_mask"] + mask = token_mask * sample_mask.unsqueeze(-1) + valid_toks = mask.sum().clamp(min=1.0) + return (per_token_kl * mask).sum() / valid_toks + + def __call__( + self, + next_token_logits: Optional[torch.Tensor] = None, + data: Optional[CrossTokenizerDistillationLossDataDict] = None, + global_valid_seqs: Optional[torch.Tensor] = None, + global_valid_toks: Optional[torch.Tensor] = None, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + teacher_logits: Optional[torch.Tensor] = None, + mb_idx: Optional[int] = None, + mbs: Optional[int] = None, + teacher_topk_indices_ipc: Optional[torch.Tensor] = None, + teacher_logits_list: Optional[list[torch.Tensor]] = None, + teacher_topk_indices_list: Optional[list[Optional[torch.Tensor]]] = None, + teacher_routing_indices: Optional[torch.Tensor] = None, + logits: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, dict[str, Any]]: + # Accept `logits=` as an alias for `next_token_logits=` so this aggregator + # is a drop-in for `prepare_loss_input`, which emits {"logits": ...} for + # LossInputType.LOGIT. This removes the need for a separate Compat shim. + if next_token_logits is None: + next_token_logits = logits + if next_token_logits is None: + raise ValueError( + "MultiTeacherLossAggregator requires either `next_token_logits` " + "or `logits` to be provided." + ) + + if teacher_logits_list is None: + teacher_logits_list = [teacher_logits] if teacher_logits is not None else [] + if teacher_topk_indices_list is None: + teacher_topk_indices_list = [teacher_topk_indices_ipc] * len(teacher_logits_list) + + if len(teacher_logits_list) == 0: + zero = torch.tensor(0.0, device=next_token_logits.device, requires_grad=True) + return zero, {"loss": 0.0, "num_valid_samples": 0} + + if len(teacher_logits_list) != len(self.loss_fns): + raise ValueError( + "teacher_logits_list length must match number of configured teachers: " + f"{len(teacher_logits_list)} != {len(self.loss_fns)}" + ) + if len(teacher_topk_indices_list) != len(teacher_logits_list): + raise ValueError( + "teacher_topk_indices_list length must match teacher_logits_list length: " + f"{len(teacher_topk_indices_list)} != {len(teacher_logits_list)}" + ) + + vocab_sizes = [int(t.shape[-1]) for t in teacher_logits_list] + min_log_vocab = math.log(max(2, min(vocab_sizes))) if self.normalize_by_vocab else 1.0 + + total_kl = torch.tensor(0.0, device=next_token_logits.device, requires_grad=True) + metrics: dict[str, Any] = {} + # GPU-side scalar tensors accumulated during the loop, synced once + # (as a single batched D2H copy) at the end of the call. This keeps the + # CPU from stalling between teachers and between micro-steps, which is + # critical for letting the GPU pipeline kernels back-to-back. + gpu_scalar_metrics: dict[str, torch.Tensor] = {} + routing_indices = teacher_routing_indices + if routing_indices is None and isinstance(data, dict): + routing_indices = data.get("teacher_routing_indices", None) + if self.teacher_aggregation_mode == "routing": + if routing_indices is None: + raise ValueError( + "teacher_aggregation_mode='routing' requires teacher_routing_indices " + "either as an argument or in data['teacher_routing_indices']" + ) + if routing_indices.ndim != 1: + raise ValueError( + "teacher_routing_indices must be a rank-1 tensor with one teacher index per sample" + ) + if routing_indices.shape[0] != data["sample_mask"].shape[0]: + raise ValueError( + "teacher_routing_indices length must match batch size: " + f"{routing_indices.shape[0]} != {data['sample_mask'].shape[0]}" + ) + + original_sample_mask = data["sample_mask"] + active_teachers = sum(1 for t_logits in teacher_logits_list if t_logits is not None) + average_weight = 1.0 / max(1, active_teachers) + + # ===== Hoist student-only work out of the per-teacher loop ===== + # The fp32 cast and softmax/log_softmax of student_logits depend only + # on student_logits and (optionally) the per-teacher temperature. + # When two or more teachers share the same temperature and code path + # (gold_loss vs. projection), we can compute the corresponding + # student tensor exactly once and reuse it across those teachers. + per_teacher_share_keys: list[Optional[tuple[float, str]]] = [] + share_key_counts: dict[tuple[float, str], int] = {} + for teacher_idx, (loss_fn, t_logits) in enumerate( + zip(self.loss_fns, teacher_logits_list) + ): + if t_logits is None or loss_fn is None: + per_teacher_share_keys.append(None) + continue + cfg = getattr(loss_fn, "cfg", {}) or {} + temp = float(cfg.get("temperature", 1.0)) + kind = "log_probs" if cfg.get("gold_loss", False) else "probs" + key = (temp, kind) + per_teacher_share_keys.append(key) + share_key_counts[key] = share_key_counts.get(key, 0) + 1 + + has_sharing = any(c >= 2 for c in share_key_counts.values()) + shared_softmax_cache: dict[tuple[float, str], torch.Tensor] = {} + if has_sharing: + if isinstance(next_token_logits, torch.distributed.tensor.DTensor): + full_logits = next_token_logits.full_tensor() + else: + full_logits = next_token_logits + # Keep the cached softmax / log_softmax in the input dtype + # (typically bf16). The CUDA softmax / log_softmax kernels already + # accumulate in fp32 internally regardless of input dtype, so + # numerics match the previous explicit fp32 cast while halving + # the cache memory and avoiding two transient full-vocab fp32 + # buffers (the .to(fp32) result and the / temp intermediate) + # that were causing OOM at mbs > 1. + for key, count in share_key_counts.items(): + if count < 2: + continue + temp, kind = key + # Skip the divide-by-temperature kernel (and its transient + # buffer) when temperature is the no-op default. + scaled_logits = full_logits if temp == 1.0 else full_logits / temp + if kind == "probs": + shared_softmax_cache[key] = torch.softmax(scaled_logits, dim=-1) + else: + shared_softmax_cache[key] = torch.log_softmax(scaled_logits, dim=-1) + if scaled_logits is not full_logits: + del scaled_logits + del full_logits + # =============================================================== + + for teacher_idx, (loss_fn, weight, t_logits, t_topk_idx) in enumerate( + zip(self.loss_fns, self.weights, teacher_logits_list, teacher_topk_indices_list) + ): + if t_logits is None: + continue + if self.teacher_aggregation_mode == "routing": + routed_sample_mask = original_sample_mask * ( + routing_indices.to(original_sample_mask.device) == teacher_idx + ).to(original_sample_mask.dtype) + routed_samples = int((routed_sample_mask > 0).sum().item()) + metrics[f"teacher_{teacher_idx}/routed_samples"] = routed_samples + if routed_samples == 0: + continue + data["sample_mask"] = routed_sample_mask + teacher_compute_start = time.perf_counter() + if loss_fn is not None: + share_key = per_teacher_share_keys[teacher_idx] + shared_softmax = ( + shared_softmax_cache.get(share_key) + if share_key is not None + else None + ) + shared_kwargs: dict[str, Any] = {} + if shared_softmax is not None and share_key is not None: + if share_key[1] == "probs": + shared_kwargs["precomputed_student_probs"] = shared_softmax + else: + shared_kwargs["precomputed_student_log_probs"] = shared_softmax + teacher_kl, teacher_metrics = loss_fn( + next_token_logits=next_token_logits, + data=data, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, + teacher_logits=t_logits, + mb_idx=mb_idx, + mbs=mbs, + teacher_topk_indices_ipc=t_topk_idx, + _return_raw_kl=True, + **shared_kwargs, + ) + else: + teacher_kl = self._compute_same_tokenizer_kl( + next_token_logits, t_logits, data, t_topk_idx + ) + # Defer this scalar's sync to the batched .cpu().tolist() below. + teacher_metrics = {"kl_loss": teacher_kl} + data["sample_mask"] = original_sample_mask + teacher_compute_elapsed = time.perf_counter() - teacher_compute_start + + vocab_scale = 1.0 + if self.normalize_by_vocab: + vocab_scale = math.log(max(2, int(t_logits.shape[-1]))) / min_log_vocab + + if self.teacher_aggregation_mode == "weighted": + effective_weight = weight + elif self.teacher_aggregation_mode == "average": + effective_weight = average_weight + else: + effective_weight = 1.0 + weighted_teacher_kl = teacher_kl * effective_weight * vocab_scale + total_kl = total_kl + weighted_teacher_kl + + # Stash GPU scalars; sync once at the end. Non-tensor metadata + # (weights, vocab_scale, elapsed time) goes straight into metrics. + gpu_scalar_metrics[f"teacher_{teacher_idx}/raw_kl"] = teacher_kl.detach() + metrics[f"teacher_{teacher_idx}/weight"] = float(effective_weight) + metrics[f"teacher_{teacher_idx}/vocab_scale"] = float(vocab_scale) + gpu_scalar_metrics[f"teacher_{teacher_idx}/weighted_kl"] = ( + weighted_teacher_kl.detach() + ) + metrics[f"teacher_{teacher_idx}/loss_compute"] = float(teacher_compute_elapsed) + for key, value in teacher_metrics.items(): + metric_key = f"teacher_{teacher_idx}/{key}" + if isinstance(value, torch.Tensor) and value.ndim == 0: + gpu_scalar_metrics[metric_key] = value.detach() + else: + metrics[metric_key] = value + + # Release shared student tensors before CE loss / masking work to keep + # peak memory close to the original per-teacher implementation. + shared_softmax_cache.clear() + del shared_softmax_cache + + ce_loss_scale = self.cfg.get("ce_loss_scale", 0.0) + dynamic_loss_scaling = self.cfg.get("dynamic_loss_scaling", False) + loss = total_kl + ce_loss_tensor: Optional[torch.Tensor] = None + if ce_loss_scale > 0.0 or dynamic_loss_scaling: + # Pass logits in their native dtype; cross_entropy internally + # promotes to fp32 for the log_softmax/NLL reduction. Avoids + # materializing a second full-vocab fp32 tensor here. + ce_logits = next_token_logits + student_seq_len = ce_logits.shape[1] + # Mask padding positions so CE loss only covers real tokens. + token_mask_ce = data["token_mask"][:, 1:student_seq_len].to(torch.bool) + ce_targets = data["input_ids"][:, 1:student_seq_len].clone() + ce_targets[~token_mask_ce] = -100 + ce_loss = torch.nn.functional.cross_entropy( + ce_logits[:, :student_seq_len - 1].reshape(-1, ce_logits.shape[-1]), + ce_targets.reshape(-1), + ignore_index=-100, + ) + ce_loss_tensor = ce_loss + if dynamic_loss_scaling: + # Avoid the per-microbatch CPU<->GPU sync by computing the + # scaling factor entirely on-device. The scale is detached so + # gradient flow matches the original (where dls_scale was a + # Python scalar treated as a constant). The clamp guards + # against the degenerate total_kl==0 case; in practice KL is + # strictly positive, so this matches the original branch. + dls_scale = ( + ce_loss.detach() / total_kl.detach().clamp(min=1e-10) + ) + loss = total_kl * dls_scale + ce_loss + else: + loss = total_kl + ce_loss * ce_loss_scale + + token_mask = data["token_mask"] + sample_mask = data["sample_mask"] + student_seq_len = next_token_logits.shape[1] + max_len = min(token_mask.shape[1] - 1, student_seq_len) + local_mask = token_mask[:, 1 : max_len + 1] * sample_mask.unsqueeze(-1) + local_valid_toks = local_mask.sum() + if local_valid_toks > 0 and global_valid_toks > 0: + tok_scale = float(local_valid_toks / global_valid_toks) + loss = loss * tok_scale + else: + tok_scale = 0.0 + loss = loss * 0.0 + + # Defer the loss / kl_loss / ce_loss sync to the single batched + # transfer below. Apply tok_scale to kl_loss and ce_loss so wandb + # reports the correct global mean (not N_ranks × mean). + if isinstance(loss, torch.Tensor) and loss.ndim == 0: + gpu_scalar_metrics["loss"] = loss.detach() + else: + metrics["loss"] = loss + if isinstance(total_kl, torch.Tensor) and total_kl.ndim == 0: + gpu_scalar_metrics["kl_loss"] = total_kl.detach() * tok_scale + else: + metrics["kl_loss"] = total_kl * tok_scale if total_kl else 0.0 + if ce_loss_tensor is not None: + gpu_scalar_metrics["ce_loss"] = ce_loss_tensor.detach() * tok_scale + else: + metrics["ce_loss"] = 0.0 + metrics["num_valid_samples"] = int(data["input_ids"].shape[0]) + + # Single batched D2H copy for every scalar metric we collected above. + # This replaces the 8-14 individual .item() syncs that were previously + # interspersed throughout the per-teacher loop and the CE / dynamic + # loss scaling branches, which forced the CPU to wait between teachers + # and prevented the GPU from pipelining their kernels. + if gpu_scalar_metrics: + keys = list(gpu_scalar_metrics.keys()) + stacked = torch.stack( + [t.reshape(()) for t in gpu_scalar_metrics.values()] + ) + values = stacked.cpu().tolist() + for k, v in zip(keys, values): + metrics[k] = float(v) + + return loss, metrics From a14695360acfcf67432f2f300dc5b430891b1835 Mon Sep 17 00:00:00 2001 From: Adithyakrishna Hanasoge Date: Sun, 26 Apr 2026 20:01:56 -0700 Subject: [PATCH 4/4] feat: CUDA IPC for teacher logits transfer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the IPC plumbing that lets a teacher policy worker hand its logits to the student worker without going through Ray's serialization path — required for cross-tokenizer distillation where teacher full-vocab logits are too big to pickle per step. - nemo_rl/distributed/ipc_utils.py: get_handle_from_tensor and rebuild_cuda_tensor_from_ipc helpers wrapping CUDA IPC handles. - nemo_rl/models/automodel/train.py: two new post-processors — XTokenTeacherIPCExportPostProcessor (teacher side, allocates a pre-sized CUDA buffer and exports the IPC handle per microbatch) and XTokenTeacherIPCLossPostProcessor (student side, rebuilds the tensor from the handle and feeds it to the loss fn). Existing post-processors are untouched. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Adithyakrishna Hanasoge --- nemo_rl/distributed/ipc_utils.py | 42 ++++ nemo_rl/models/automodel/train.py | 308 ++++++++++++++++++++++++++++++ 2 files changed, 350 insertions(+) create mode 100644 nemo_rl/distributed/ipc_utils.py diff --git a/nemo_rl/distributed/ipc_utils.py b/nemo_rl/distributed/ipc_utils.py new file mode 100644 index 0000000000..46e9b206c3 --- /dev/null +++ b/nemo_rl/distributed/ipc_utils.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""IPC helpers for sharing CUDA tensors across processes. + +Used by cross-tokenizer off-policy distillation to ship teacher logits +from the teacher policy worker to the student worker without an extra +host roundtrip. +""" + +from typing import Any + +import torch +from torch.multiprocessing.reductions import rebuild_cuda_tensor + + +def get_handle_from_tensor(tensor: torch.Tensor) -> tuple[Any]: + """Get IPC handle from a tensor.""" + from torch.multiprocessing.reductions import reduce_tensor + + # Skip serializing the function for better refit performance. + return reduce_tensor(tensor.detach())[1:] + + +def rebuild_cuda_tensor_from_ipc( + cuda_ipc_handle: tuple, device_id: int +) -> torch.Tensor: + """Rebuild a CUDA tensor from an IPC handle on ``device_id``.""" + args = cuda_ipc_handle[0] + list_args = list(args) + list_args[6] = device_id + return rebuild_cuda_tensor(*list_args) diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index d2b5979400..f885a8611a 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -45,7 +45,12 @@ from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.algorithms.utils import mask_out_neg_inf_logprobs from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.ipc_utils import ( + get_handle_from_tensor, + rebuild_cuda_tensor_from_ipc, +) from nemo_rl.distributed.model_utils import ( + _compute_distributed_log_softmax, allgather_cp_sharded_tensor, distributed_vocab_topk, get_logprobs_from_vocab_parallel_logits, @@ -59,6 +64,8 @@ "LogprobsPostProcessor", "TopkLogitsPostProcessor", "ScorePostProcessor", + "XTokenTeacherIPCLossPostProcessor", + "XTokenTeacherIPCExportPostProcessor", ] @@ -588,6 +595,307 @@ def __call__( return loss, loss_metrics + + +class XTokenTeacherIPCLossPostProcessor(LossPostProcessor): + """Loss post-processor that injects teacher logits via CUDA IPC handles.""" + + def __init__( + self, *args: Any, teacher_result: Optional[dict[str, Any]] = None, **kwargs: Any + ): + super().__init__(*args, **kwargs) + self._teacher_result = teacher_result + self._microbatch_idx = 0 + + def set_microbatch_index(self, mb_idx: int) -> None: + self._microbatch_idx = mb_idx + + def __call__( + self, + logits: torch.Tensor, + data_dict: BatchedDataDict[Any], + processed_inputs: ProcessedInputs, + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + sequence_dim: int = 1, + ) -> tuple[torch.Tensor, dict[str, Any]]: + if self.cp_size > 1: + _, data_dict = prepare_data_for_cp( + data_dict, processed_inputs, self.cp_mesh, sequence_dim + ) + logits = redistribute_logits_for_cp( + logits, self.device_mesh, self.cp_mesh, sequence_dim + ) + + if self.enable_seq_packing: + loss_fn_ = SequencePackingLossWrapper( + loss_fn=self.loss_fn, + cu_seqlens_q=processed_inputs.flash_attn_kwargs.cu_seqlens_q, + cu_seqlens_q_padded=processed_inputs.flash_attn_kwargs.cu_seqlens_q, + ) + else: + loss_fn_ = self.loss_fn + + def _extract_teacher_ipc_payload(teacher_result_obj: dict[str, Any]) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + if "microbatch_handles" not in teacher_result_obj: + return None, None + handles = teacher_result_obj["microbatch_handles"] + if self._microbatch_idx >= len(handles): + return None, None + rank = torch.distributed.get_rank() + current_device_id = torch.cuda.current_device() + handle = handles[self._microbatch_idx] + aB, aS, aK = handle["actual_shape"] + + # View into the teacher's pre-allocated IPC buffer — no clone. + # The teacher's per-microbatch buffer slot is only reused on the + # next step's teacher forward, which is strictly ordered AFTER + # this student forward+backward by the driver-side ray.get + # sequencing. Teacher tensor is detached and doesn't require + # grad, so autograd never captures it past the loss op. Avoids a + # full-vocab fp32-or-bf16 memcpy per microbatch (~3.3 GiB at + # mbs=2, V≈200K, bf16 — was the dominant peak after fix A). + teacher_logits_tensor = rebuild_cuda_tensor_from_ipc( + handle[rank], current_device_id + ).detach()[:aB, :aS, :aK] + + teacher_topk_indices_tensor = None + is_topk = teacher_result_obj.get("is_topk", False) + if is_topk and "topk_indices_ipc" in handle: + teacher_topk_indices_tensor = rebuild_cuda_tensor_from_ipc( + handle["topk_indices_ipc"], current_device_id + ).detach()[:aB, :aS, :aK] + return teacher_logits_tensor, teacher_topk_indices_tensor + + loss_kwargs = {} + if self._teacher_result is not None and not self.enable_seq_packing: + if isinstance(self._teacher_result, list): + teacher_logits_list: list[torch.Tensor] = [] + teacher_topk_indices_list: list[Optional[torch.Tensor]] = [] + for teacher_result_obj in self._teacher_result: + t_logits, t_topk = _extract_teacher_ipc_payload(teacher_result_obj) + if t_logits is None: + continue + teacher_logits_list.append(t_logits) + teacher_topk_indices_list.append(t_topk) + if teacher_logits_list: + loss_kwargs["teacher_logits_list"] = teacher_logits_list + loss_kwargs["teacher_topk_indices_list"] = teacher_topk_indices_list + else: + t_logits, t_topk = _extract_teacher_ipc_payload(self._teacher_result) + if t_logits is not None: + loss_kwargs["teacher_logits"] = t_logits + if t_topk is not None: + loss_kwargs["teacher_topk_indices_ipc"] = t_topk + + loss, loss_metrics = loss_fn_( + logits, + data_dict, + global_valid_seqs, + global_valid_toks, + **loss_kwargs, + ) + return loss, loss_metrics + + +class XTokenTeacherIPCExportPostProcessor(LossPostProcessor): + """Teacher-side post-processor that exports per-microbatch logits via CUDA IPC.""" + + def __init__( + self, + *args: Any, + tp_mesh: Any, + topk_logits: Optional[int], + is_mdlm: bool = False, + **kwargs: Any, + ): + # Keep explicit tp_mesh for local use and also forward it to base init. + super().__init__(*args, tp_mesh=tp_mesh, **kwargs) + self.tp_mesh = tp_mesh + self.topk_logits = topk_logits + self.is_mdlm = is_mdlm + self.microbatch_handles: list[dict[str, Any]] = [] + self._microbatch_idx = 0 + self._mb_vals_buffers: list[torch.Tensor] = [] + self._mb_vals_ipcs: list[tuple[Any]] = [] + self._mb_idx_buffers: list[torch.Tensor] = [] + self._mb_idx_ipcs: list[tuple[Any]] = [] + self._mb_logits_buffers: list[torch.Tensor] = [] + self._mb_logits_ipcs: list[tuple[Any]] = [] + + def set_microbatch_index(self, mb_idx: int) -> None: + self._microbatch_idx = mb_idx + + def _ensure_topk_buffer( + self, + buf_idx: int, + B: int, + S: int, + K: int, + vals_dtype: torch.dtype, + idx_dtype: torch.dtype, + device: torch.device, + ) -> None: + while len(self._mb_vals_buffers) <= buf_idx: + vals_buf = torch.empty((B, S, K), dtype=vals_dtype, device=device) + idx_buf = torch.empty((B, S, K), dtype=idx_dtype, device=device) + self._mb_vals_buffers.append(vals_buf) + self._mb_vals_ipcs.append(get_handle_from_tensor(vals_buf)) + self._mb_idx_buffers.append(idx_buf) + self._mb_idx_ipcs.append(get_handle_from_tensor(idx_buf)) + vals_buf = self._mb_vals_buffers[buf_idx] + idx_buf = self._mb_idx_buffers[buf_idx] + needs_realloc = ( + vals_buf.shape[0] < B + or vals_buf.shape[1] < S + or vals_buf.shape[2] < K + or vals_buf.dtype != vals_dtype + or vals_buf.device != device + or idx_buf.shape[0] < B + or idx_buf.shape[1] < S + or idx_buf.shape[2] < K + or idx_buf.dtype != idx_dtype + or idx_buf.device != device + ) + if needs_realloc: + vals_buf = torch.empty((B, S, K), dtype=vals_dtype, device=device) + idx_buf = torch.empty((B, S, K), dtype=idx_dtype, device=device) + self._mb_vals_buffers[buf_idx] = vals_buf + self._mb_vals_ipcs[buf_idx] = get_handle_from_tensor(vals_buf) + self._mb_idx_buffers[buf_idx] = idx_buf + self._mb_idx_ipcs[buf_idx] = get_handle_from_tensor(idx_buf) + + def _ensure_logits_buffer( + self, + buf_idx: int, + B: int, + S: int, + V: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + while len(self._mb_logits_buffers) <= buf_idx: + buf = torch.empty((B, S, V), dtype=dtype, device=device) + self._mb_logits_buffers.append(buf) + self._mb_logits_ipcs.append(get_handle_from_tensor(buf)) + buf = self._mb_logits_buffers[buf_idx] + needs_realloc = ( + buf.shape[0] < B + or buf.shape[1] < S + or buf.shape[2] < V + or buf.dtype != dtype + or buf.device != device + ) + if needs_realloc: + buf = torch.empty((B, S, V), dtype=dtype, device=device) + self._mb_logits_buffers[buf_idx] = buf + self._mb_logits_ipcs[buf_idx] = get_handle_from_tensor(buf) + + def __call__( + self, + logits: torch.Tensor, + data_dict: BatchedDataDict[Any], # noqa: ARG002 + processed_inputs: ProcessedInputs, # noqa: ARG002 + global_valid_seqs: torch.Tensor, # noqa: ARG002 + global_valid_toks: torch.Tensor, # noqa: ARG002 + sequence_dim: int = 1, # noqa: ARG002 + ) -> tuple[torch.Tensor, dict[str, Any]]: + if isinstance(logits, DTensor): + mb_logits_local = logits.to_local() + else: + mb_logits_local = logits + + tp_group = self.tp_mesh.get_group() + tp_rank = torch.distributed.get_rank(tp_group) + V_local = int(mb_logits_local.shape[-1]) + vocab_start_index = tp_rank * V_local + vocab_end_index = (tp_rank + 1) * V_local + + if self.topk_logits is not None: + # Top-k path (same-tokenizer distillation): the student loss + # consumes IPC payload entries as log-probs (it `.exp()`s them + # to recover probs), so do the fp32 cast + distributed + # log-softmax here before extracting the top-k. + mb_logits_fp32 = mb_logits_local.to(torch.float32) + mb_log_prob = _compute_distributed_log_softmax( + mb_logits_fp32, group=tp_group + ) + del mb_logits_fp32, mb_logits_local + + if isinstance(mb_log_prob, DTensor): + mb_log_prob = mb_log_prob.to_local() + + if self.is_mdlm: + shared_seq_len = int(mb_log_prob.shape[1] / 2) + mb_log_prob = mb_log_prob[:, shared_seq_len:, :] + + mb_topk_vals, mb_topk_idx = distributed_vocab_topk( + mb_log_prob, + k=self.topk_logits, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_end_index, + ) + del mb_log_prob + + B_mb, S_mb, K_mb = mb_topk_vals.shape + buf_idx = self._microbatch_idx + self._ensure_topk_buffer( + buf_idx, + B_mb, + S_mb, + K_mb, + mb_topk_vals.dtype, + mb_topk_idx.dtype, + mb_topk_vals.device, + ) + self._mb_vals_buffers[buf_idx][:B_mb, :S_mb, :K_mb].copy_(mb_topk_vals) + self._mb_idx_buffers[buf_idx][:B_mb, :S_mb, :K_mb].copy_(mb_topk_idx) + del mb_topk_vals, mb_topk_idx + + rank = torch.distributed.get_rank() + handle = { + rank: self._mb_vals_ipcs[buf_idx], + "actual_shape": (B_mb, S_mb, K_mb), + "topk_indices_ipc": self._mb_idx_ipcs[buf_idx], + } + else: + # Full-vocab path (cross-tokenizer distillation): ship raw logits + # in their native dtype (typically bf16). The student-side loss + # applies log_softmax exactly once, on raw logits, matching the + # reference train_distillation_ddp.py. This avoids the teacher- + # side fp32 cast + full-vocab log_softmax kernel and removes the + # redundant student-side log_softmax of already-log-softmaxed + # values (idempotent at T=1, mathematically wrong at T!=1). + if self.is_mdlm: + shared_seq_len = int(mb_logits_local.shape[1] / 2) + mb_logits_local = mb_logits_local[:, shared_seq_len:, :] + + B_mb, S_mb, V_mb = mb_logits_local.shape + buf_idx = self._microbatch_idx + self._ensure_logits_buffer( + buf_idx, + B_mb, + S_mb, + V_mb, + mb_logits_local.dtype, + mb_logits_local.device, + ) + self._mb_logits_buffers[buf_idx][:B_mb, :S_mb, :V_mb].copy_(mb_logits_local) + del mb_logits_local + + rank = torch.distributed.get_rank() + handle = { + rank: self._mb_logits_ipcs[buf_idx], + "actual_shape": (B_mb, S_mb, V_mb), + } + + self.microbatch_handles.append(handle) + + dummy_loss = torch.zeros((), device="cuda", dtype=torch.float32) + return dummy_loss, {"num_valid_samples": 1.0} + + class LogprobsPostProcessor: """Post-processor for computing log probabilities from model outputs."""