diff --git a/examples/configs/off_policy_distillation.yaml b/examples/configs/off_policy_distillation.yaml new file mode 100644 index 0000000000..96a475d9d7 --- /dev/null +++ b/examples/configs/off_policy_distillation.yaml @@ -0,0 +1,310 @@ +token_aligner: + enabled: false + +teachers: + # Teacher 0: Phi-4-mini-instruct (weight 0.5) + # Mirrors tokenalign/teacher_configs/multi_teacher_config_phi-4B_llama-3.1-4b_best-proj_not_learned.json + - weight: 0.5 + teacher: + model_name: "microsoft/Phi-4-mini-instruct" + tokenizer: + name: "microsoft/Phi-4-mini-instruct" + chat_template: null + precision: "bfloat16" + train_global_batch_size: 768 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 4096 + max_grad_norm: 1.0 + logprob_chunk_size: null + offload_optimizer_for_logprob: false + dtensor_cfg: + enabled: true + _v2: true + cpu_offload: false + sequence_parallel: false + activation_checkpointing: true + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + dynamic_batching: + enabled: false + train_mb_tokens: 4096 + logprob_mb_tokens: 4096 + sequence_length_round: 64 + sequence_packing: + enabled: false + train_mb_tokens: 4096 + logprob_mb_tokens: 4096 + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-5 + weight_decay: 0.1 + betas: [0.9, 0.98] + eps: 1e-5 + foreach: false + fused: false + generation: null + token_aligner: + enabled: true + projection_matrix_path: "cross_tokenizer_data/llama_phi-mini_best_special_exact_map_remapped.pt" + use_sparse_format: false + loss_type: "KL" + exact_token_match_only: false + temperature: 1.0 + vocab_topk: 8192 + reverse_kl: false + projection_matrix_multiplier: 1.0 + max_comb_len: 4 + learnable: false + project_teacher_to_student: false + use_char_offset: false + use_cuda_dp: false + dp_chunk_size: 128 + + # Teacher 1: Llama-3.2-3B (weight 0.5) + # Same tokenizer family as the Llama-3.2-1B student; uses the identity/exact-map + # remapped projection so the cross-tokenizer path reduces to exact-token alignment. + - weight: 0.5 + teacher: + model_name: "meta-llama/Llama-3.2-3B" + tokenizer: + name: "meta-llama/Llama-3.2-3B" + chat_template: null + precision: "bfloat16" + train_global_batch_size: 768 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 4096 + max_grad_norm: 1.0 + logprob_chunk_size: null + offload_optimizer_for_logprob: false + dtensor_cfg: + enabled: true + _v2: true + cpu_offload: false + sequence_parallel: false + activation_checkpointing: true + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + dynamic_batching: + enabled: false + train_mb_tokens: 4096 + logprob_mb_tokens: 4096 + sequence_length_round: 64 + sequence_packing: + enabled: false + train_mb_tokens: 4096 + logprob_mb_tokens: 4096 + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-5 + weight_decay: 0.1 + betas: [0.9, 0.98] + eps: 1e-5 + foreach: false + fused: false + generation: null + # No token_aligner: same tokenizer as student → _compute_same_tokenizer_kl path + +distillation: + num_prompts_per_step: 768 + max_num_steps: 80000 + max_num_epochs: 1 + val_period: 1000 + val_at_start: false + val_micro_batch_size: 1 + topk_logits_k: 8192 + use_ipc: true + seed: 42 + # Skip per-step model/optimizer offloads in the off-policy distillation loop. + # Requires student + all teachers + student optimizer state to fit resident on + # each GPU. With Llama-3.2-1B student + Phi-4-mini + Llama-3.2-3B teachers on + # 80GB cards this fits comfortably (~35-45GB peak). + keep_models_resident: false + +loss_fn: + loss_type: "KL" + temperature: 1.0 + vocab_topk: 8192 + exact_token_match_only: false + reverse_kl: false + project_teacher_to_student: false + gold_loss: true + xtoken_loss: true + ce_loss_scale: 0.1 + dynamic_loss_scaling: true + normalize_by_vocab: true + teacher_aggregation_mode: "weighted" + +checkpointing: + enabled: true + checkpoint_dir: "checkpoints/multi-teacher-distillation-llama1b" + metric_name: "train:loss" + higher_is_better: false + keep_top_k: 3 + save_period: 10 + save_optimizer: true + model_save_format: "safetensors" + save_consolidated: false + +policy: + model_name: "meta-llama/Llama-3.2-1B" + tokenizer: + name: "meta-llama/Llama-3.2-1B" + chat_template: null + train_global_batch_size: 768 + train_micro_batch_size: 1 + max_total_sequence_length: 4096 + precision: "bfloat16" + offload_optimizer_for_logprob: false + dtensor_cfg: + enabled: true + _v2: true + cpu_offload: false + sequence_parallel: false + activation_checkpointing: true + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + max_grad_norm: 1.0 + dynamic_batching: + enabled: false + train_mb_tokens: 4096 + logprob_mb_tokens: 4096 + sequence_length_round: 64 + sequence_packing: + enabled: false + train_mb_tokens: 4096 + logprob_mb_tokens: 4096 + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-5 + weight_decay: 0.1 + betas: [0.9, 0.98] + eps: 1e-5 + foreach: false + fused: false + # Matches PyTorch reference: 5% linear warmup + cosine decay to min_lr=0. + # Tuned for a 1000-step production run (warmup=50, cosine_T_max=950). + # For different total steps, scale both counts ∝ distillation.max_num_steps. + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + # Recent PyTorch enforces 0 < start_factor <= 1 (older versions allowed 0.0). + # 1e-8 is effectively zero-warmup but satisfies the constraint. + start_factor: 1.0e-8 + end_factor: 1.0 + total_iters: 50 + - name: "torch.optim.lr_scheduler.CosineAnnealingLR" + kwargs: + T_max: 950 + eta_min: 0.0 + - milestones: [50] + sequence_packing: + enabled: false + train_mb_tokens: 4096 + logprob_mb_tokens: 4096 + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + dynamic_batching: + enabled: false + train_mb_tokens: 4096 + logprob_mb_tokens: 4096 + sequence_length_round: 64 + generation: null + +teacher: + model_name: "microsoft/Phi-4-mini-instruct" + +data: + max_input_seq_length: 4096 + shuffle: true + # DataLoader workers that run teacher tokenize + DP alignment in parallel + # via CrossTokenizerCollator. 8 workers × 4 prefetch = up to 32 batches in + # flight, fully hiding CT behind teacher forward. + num_workers: 8 + prefetch_factor: 4 + train: + dataset_name: "arrow_text" + processor: "kd_data_processor" + arrow_files: null # Override via Hydra CLI: data.train.arrow_files=/path/to/file.arrow + prompt_file: null + characters_per_sample: 32768 # 4096 tokens × 8 chars/token (lazy packing) + default: + dataset_path: "allenai/c4" + hf_dataset_name: "allenai/c4" + hf_dataset_subset: "en" + hf_split: "train" + text_key: "text" + +eval: + val_period: 50 + val_at_start: false + max_val_samples: 512 + val_batch_size: 64 + max_rollout_turns: 1 + benchmarks: + math: + dataset_name: "math" + prompt_file: "examples/prompts/cot.txt" + env: + num_workers: 8 + mmlu: + dataset_name: "mmlu" + prompt_file: "examples/prompts/mmlu.txt" + env: + num_workers: 8 + verifier_type: "multilingual_multichoice" + mmlu_5shot: + dataset_name: "mmlu" + prompt_file: "examples/prompts/mmlu.txt" + num_few_shot: 5 + env: + num_workers: 8 + verifier_type: "multilingual_multichoice" + mbpp_plus: + dataset_name: "mbpp_plus" + # Optional override: + # dataset_path: "evalplus/mbppplus" + split: "test" + env: + num_workers: 8 + timeout_seconds: 5 + humaneval_plus: + dataset_name: "humaneval_plus" + # Optional override: + # dataset_path: "evalplus/humanevalplus" + split: "test" + env: + num_workers: 8 + timeout_seconds: 5 + +logger: + log_dir: "logs/multi-teacher-distillation-llama1b" + num_val_samples_to_print: 5 + wandb_enabled: true + swanlab_enabled: false + mlflow_enabled: false + tensorboard_enabled: false + monitor_gpus: true + wandb: + project: "nemo-multi-teacher-distillation" + name: "multi-teacher-llama1b" + gpu_monitoring: + collection_interval: 10 + flush_interval: 10 + +cluster: + gpus_per_node: 8 + num_nodes: 16 diff --git a/examples/run_off_policy_distillation.py b/examples/run_off_policy_distillation.py new file mode 100644 index 0000000000..62dc707261 --- /dev/null +++ b/examples/run_off_policy_distillation.py @@ -0,0 +1,121 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Run cross-tokenizer off-policy distillation. + +Mirrors examples/run_distillation.py: load a YAML config, build dataset and +policies, then call off_policy_distillation_train. Off-policy means student +training consumes fixed (prompt, response) pairs from the dataset rather +than newly generated rollouts; teacher logits are computed once per batch +and shipped to the student via CUDA IPC. +""" + +import argparse +import os + +from omegaconf import OmegaConf + +from nemo_rl.algorithms.off_policy_distillation import ( + OffPolicyMasterConfig, + off_policy_distillation_train, + setup, +) +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.data.utils import setup_response_data +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.utils.config import ( + load_config, + parse_hydra_overrides, + register_omegaconf_resolvers, +) +from nemo_rl.utils.logger import get_next_experiment_dir + + +def parse_args() -> tuple[argparse.Namespace, list[str]]: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Run off-policy distillation training with configuration" + ) + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + args, overrides = parser.parse_known_args() + return args, overrides + + +def main() -> None: + """Main entry point.""" + register_omegaconf_resolvers() + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "off_policy_distillation.yaml" + ) + + config = load_config(args.config) + if overrides: + config = parse_hydra_overrides(config, overrides) + + config: OffPolicyMasterConfig = OmegaConf.to_container(config, resolve=True) + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + + init_ray() + + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + + env_configs = config.get("env") or None + if env_configs is None: + dataset, val_dataset = setup_response_data( + tokenizer, config["data"], None + ) + else: + ( + dataset, + val_dataset, + _task_to_env, + _val_task_to_env, + ) = setup_response_data(tokenizer, config["data"], env_configs) + + ( + student_policy, + teacher_policies, + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + save_state, + master_config, + token_aligners, + teacher_tokenizers, + ) = setup(config, tokenizer, dataset, val_dataset) + + off_policy_distillation_train( + student_policy, + teacher_policies, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + logger, + checkpointer, + save_state, + master_config, + token_aligners=token_aligners, + teacher_tokenizers=teacher_tokenizers, + ) + + +if __name__ == "__main__": + main() diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index df6ff6bc54..3a1cbdd671 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -1035,3 +1035,973 @@ def __call__( } return kl_loss, metrics + + +class CrossTokenizerDistillationLossConfig(TypedDict): + """Configuration for cross-tokenizer distillation loss.""" + loss_type: str # 'KL', 'cross_entropy', or 'chunked_ce' + temperature: float # Softmax temperature + vocab_topk: int # Reduce teacher vocab to top-k (0 = all) + exact_token_match_only: bool # Only use 1:1 aligned positions + reverse_kl: bool # Reverse KL direction + project_teacher_to_student: NotRequired[bool] + gold_loss: NotRequired[bool] # Use gold loss (common KL + uncommon L1, no projection) + xtoken_loss: NotRequired[bool] # Relaxed exact-map threshold (>=0.6 instead of ==1.0) + ce_loss_scale: NotRequired[float] # Scale for additional CE (next-token) loss (0.0 = disabled) + dynamic_loss_scaling: NotRequired[bool] # Scale KL loss to match CE magnitude + + +class CrossTokenizerDistillationLossDataDict(TypedDict): + """Data dict for cross-tokenizer distillation. + + Only contains student-side tensors (same sequence dimension). + Teacher-side data (teacher_input_ids, aligned_pairs) is stored on the + loss function instance via set_cross_tokenizer_data() to avoid + sequence-length mismatches in the worker's shape validation. + """ + input_ids: torch.Tensor # Student token IDs (B, S_student) + input_lengths: torch.Tensor + token_mask: torch.Tensor # (B, S_student) + sample_mask: torch.Tensor # (B,) + + +def _scatter_chunk_mask_from_coo( + coo_list: list[torch.Tensor], + batch_size: int, + seq_len: int, + total_chunks: int, + device: torch.device, +) -> torch.Tensor: + """Materialize a dense ``(batch_size, seq_len, total_chunks)`` bool mask. + + ``coo_list[b]`` is a CPU ``LongTensor (N_b, 2)`` with rows ``[pos, chunk_id]`` + precomputed by ``CrossTokenizerCollator._build_chunk_coo`` (already + filtered for ``exact_match_only`` / ``-1`` sentinels / padded-length bounds). + + ``chunk_id`` in each sample's COO is in ``[0, num_chunks_per_sample)``; + samples with fewer chunks than ``total_chunks`` leave the padded columns + all-False, which the downstream ``chunk_valid = (chunk_sizes > 0)`` gate + already drops. + """ + parts: list[torch.Tensor] = [] + for b, coo in enumerate(coo_list): + if coo.shape[0] == 0: + continue + bcol = torch.full((coo.shape[0], 1), b, dtype=torch.int64) + parts.append(torch.cat([bcol, coo], dim=1)) + mask = torch.zeros( + batch_size, seq_len, total_chunks, dtype=torch.bool, device=device, + ) + if parts: + idx = torch.cat(parts, dim=0).to(device) + mask[idx[:, 0], idx[:, 1], idx[:, 2]] = True + return mask + + +class CrossTokenizerDistillationLossFn(LossFunction): + """Cross-tokenizer distillation loss using TokenAligner's projection matrix. + + Computes per-token KL divergence between projected student probabilities + (in teacher vocab space) and teacher probabilities, only at positions where + the two tokenizations have 1:1 aligned tokens. Uses NeMo RL's standard + masked_mean normalization so loss magnitude is comparable to same-tokenizer + distillation. + + Teacher-specific data (teacher_input_ids, aligned_pairs) is stored on + this object via set_cross_tokenizer_data() before each training step, + rather than in the data dict, because teacher and student sequences + have different lengths and the worker validates that all tensors in + the data dict share the same sequence dimension. + """ + + def __init__(self, cfg: CrossTokenizerDistillationLossConfig, token_aligner): + from nemo_rl.algorithms.x_token.tokenalign import TokenAligner + assert isinstance(token_aligner, TokenAligner) + self.token_aligner = token_aligner + self.cfg = cfg + self.loss_type = LossType.TOKEN_LEVEL + self._teacher_input_ids = None + self._aligned_pairs = None + self._chunk_indices: Optional[dict[str, list]] = None + + def set_cross_tokenizer_data( + self, + teacher_input_ids: torch.Tensor, + aligned_pairs: list, + chunk_indices: Optional[dict[str, list]] = None, + ): + """Store teacher-side data before each training step. + + Called from the training loop before student_policy.train(). + The worker never sees these tensors in shape validation. + + ``chunk_indices`` carries the per-sample COO chunk-mask indices that + used to be rebuilt inside ``__call__`` every microbatch. When set, it + is a dict with keys ``student_chunk_coo``, ``teacher_chunk_coo``, + ``num_chunks``, each a DP-sharded list of length ``batch_size``. + """ + self._teacher_input_ids = teacher_input_ids + self._aligned_pairs = aligned_pairs + self._chunk_indices = chunk_indices + + def _project_student_to_teacher( + self, + student_logits: torch.Tensor, + teacher_vocab_size: int, + temperature: float, + global_top_indices: torch.Tensor, + device: torch.device, + precomputed_student_probs: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Project student logits into the reduced teacher vocabulary space. + + Returns projected student probabilities of shape (B, S_student, K) + where K = len(global_top_indices). + + If `precomputed_student_probs` is provided, it is used directly instead + of recomputing softmax(student_logits / temperature). The caller is + responsible for ensuring it was computed with the same temperature. + """ + if precomputed_student_probs is not None: + student_probs = precomputed_student_probs + else: + student_probs = torch.softmax(student_logits / temperature, dim=-1) + + has_sparse = ( + hasattr(self.token_aligner, 'sparse_transformation_matrix') + and self.token_aligner.sparse_transformation_matrix is not None + ) + if has_sparse: + sparse_mat = self.token_aligner.sparse_transformation_matrix + reduced_sparse = sparse_mat.index_select(1, global_top_indices).coalesce() + projected = self.token_aligner.project_token_likelihoods_instance( + student_probs, None, None, None, device, + use_sparse_format=True, + sparse_matrix=reduced_sparse, + ) + return projected + + proj_values = self.token_aligner.likelihood_projection_matrix + if getattr(self.token_aligner, 'learnable', False): + proj_values = self.token_aligner.transform_learned_matrix_instance(proj_values) + projected_full = self.token_aligner.project_token_likelihoods_instance( + student_probs, self.token_aligner.likelihood_projection_indices, + proj_values, teacher_vocab_size, device, + use_sparse_format=False, + ) + return projected_full[:, :, global_top_indices] + + def _compute_gold_loss( + self, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + student_chunk_mask: torch.Tensor, + teacher_chunk_mask: torch.Tensor, + batch_size: int, + student_seq_len: int, + teacher_seq_len: int, + teacher_vocab_size: int, + temperature: float, + reverse_kl: bool, + xtoken_loss: bool, + device: torch.device, + precomputed_student_log_probs: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, float]: + """Gold loss: common-vocab KL + uncommon-vocab sorted L1. + + Splits the vocabulary into tokens with exact 1:1 projection mappings + ("common") and the rest ("uncommon"). Common tokens are compared + directly via KL on their native log-probs (no projection needed). + Uncommon tokens are compared via L1 on sorted probability vectors + (Universal Likelihood Distillation). + + Matches tokenalign.py compute_KL_loss_optimized gold_loss branch. + """ + partition = self.token_aligner.build_vocab_partition( + xtoken_loss=xtoken_loss, + teacher_vocab_size=teacher_vocab_size, + ) + common_student_indices = partition.common_student_indices.to(device) + common_teacher_indices = partition.common_teacher_indices.to(device) + uncommon_student_indices = partition.uncommon_student_indices.to(device) + uncommon_teacher_indices = partition.uncommon_teacher_indices.to(device) + + # student_chunk_mask / teacher_chunk_mask are precomputed by the + # caller (from collator-emitted per-sample COO) and shared with the + # non-gold path — shape (B, seq_len, total_chunks), dtype bool. + total_chunks = student_chunk_mask.shape[-1] + + # log_softmax on full original logits BEFORE chunk averaging + if precomputed_student_log_probs is not None: + student_log_probs = precomputed_student_log_probs + else: + student_log_probs = torch.log_softmax(student_logits / temperature, dim=-1) + teacher_log_probs = torch.log_softmax(teacher_logits / temperature, dim=-1) + + # Chunk-average log-probs over full vocabularies + student_chunk_lp = torch.bmm( + student_chunk_mask.transpose(1, 2).to(student_log_probs.dtype), student_log_probs, + ) + teacher_chunk_lp = torch.bmm( + teacher_chunk_mask.transpose(1, 2).to(teacher_log_probs.dtype), teacher_log_probs, + ) + del student_log_probs, teacher_log_probs + + student_chunk_sizes = student_chunk_mask.sum(dim=1, keepdim=True).float().transpose(1, 2) + teacher_chunk_sizes = teacher_chunk_mask.sum(dim=1, keepdim=True).float().transpose(1, 2) + + student_chunk_lp = student_chunk_lp / (student_chunk_sizes + 1e-10) + teacher_chunk_lp = teacher_chunk_lp / (teacher_chunk_sizes + 1e-10) + + chunk_valid = (student_chunk_sizes.squeeze(-1) > 0) & (teacher_chunk_sizes.squeeze(-1) > 0) + + if not chunk_valid.any(): + return torch.tensor(0.0, device=device, requires_grad=True), 0.0 + + # --- Part 1: KL on common (exactly-mapped) vocab --- + loss_kl_common = torch.tensor(0.0, device=device, requires_grad=True) + if common_student_indices.numel() > 0: + s_common = student_chunk_lp[:, :, common_student_indices] + t_common = teacher_chunk_lp[:, :, common_teacher_indices] + + if not reverse_kl: + kl_elem = torch.nn.functional.kl_div( + s_common, t_common, reduction="none", log_target=True, + ) + else: + kl_elem = torch.nn.functional.kl_div( + t_common, s_common, reduction="none", log_target=True, + ) + kl_per_chunk = kl_elem.sum(dim=-1) * chunk_valid + if chunk_valid.sum() > 0: + loss_kl_common = kl_per_chunk.sum() / chunk_valid.sum() + + # --- Part 2: L1 on uncommon (unaligned) vocab --- + loss_l1_uncommon = torch.tensor(0.0, device=device, requires_grad=True) + if uncommon_student_indices.numel() > 0 or uncommon_teacher_indices.numel() > 0: + s_uncommon = ( + student_chunk_lp[:, :, uncommon_student_indices] + if uncommon_student_indices.numel() > 0 + else torch.empty(batch_size, total_chunks, 0, device=device) + ) + t_uncommon = ( + teacher_chunk_lp[:, :, uncommon_teacher_indices] + if uncommon_teacher_indices.numel() > 0 + else torch.empty(batch_size, total_chunks, 0, device=device) + ) + + s_valid = s_uncommon[chunk_valid] + t_valid = t_uncommon[chunk_valid] + + if s_valid.shape[0] > 0: + with torch.no_grad(): + max_uncommon_vocab = min(s_valid.shape[-1], t_valid.shape[-1], 8192) + + if max_uncommon_vocab > 0: + s_probs = torch.exp(s_valid) + t_probs = torch.exp(t_valid) + + if s_probs.shape[-1] > max_uncommon_vocab: + s_sorted, _ = torch.topk(s_probs, k=max_uncommon_vocab, dim=-1, largest=True) + else: + s_sorted = torch.sort(s_probs, dim=-1, descending=True)[0] + + if t_probs.shape[-1] > max_uncommon_vocab: + t_sorted, _ = torch.topk(t_probs, k=max_uncommon_vocab, dim=-1, largest=True) + else: + t_sorted = torch.sort(t_probs, dim=-1, descending=True)[0] + + del s_probs, t_probs + min_len = min(s_sorted.shape[-1], t_sorted.shape[-1]) + if min_len > 0: + loss_l1_per_chunk = torch.nn.functional.l1_loss( + s_sorted[:, :min_len], t_sorted[:, :min_len], reduction='none', + ).sum(dim=-1) + loss_l1_uncommon = loss_l1_per_chunk.mean() + del loss_l1_per_chunk + del s_sorted, t_sorted + + loss_total = (loss_kl_common + loss_l1_uncommon) * (temperature ** 2) + + # Top-1 accuracy on common vocab. Computed as a 0-d GPU tensor so the + # caller can include it in the batched .cpu().tolist() sync at the end + # of the step (instead of forcing two CPU<->GPU stalls per teacher per + # microbatch here). + top1_accuracy: Union[float, torch.Tensor] = 0.0 + with torch.no_grad(): + if common_student_indices.numel() > 0 and chunk_valid.any(): + s_valid_lp = student_chunk_lp[chunk_valid][:, common_student_indices] + t_valid_lp = teacher_chunk_lp[chunk_valid][:, common_teacher_indices] + matches = ( + s_valid_lp.argmax(dim=-1) == t_valid_lp.argmax(dim=-1) + ).sum() + denom = chunk_valid.sum().clamp(min=1) + top1_accuracy = matches.to(torch.float32) / denom.to(torch.float32) + + del student_chunk_lp, teacher_chunk_lp + return loss_total, top1_accuracy + + def __call__( + self, + next_token_logits: torch.Tensor, + data: CrossTokenizerDistillationLossDataDict, + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + teacher_logits: Optional[torch.Tensor] = None, + mb_idx: Optional[int] = None, + mbs: Optional[int] = None, + teacher_topk_indices_ipc: Optional[torch.Tensor] = None, + _return_raw_kl: bool = False, + precomputed_student_logits_f32: Optional[torch.Tensor] = None, + precomputed_student_probs: Optional[torch.Tensor] = None, + precomputed_student_log_probs: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, dict[str, Any]]: + """Compute cross-tokenizer distillation loss via chunk-averaged KL. + + For each alignment chunk (1:1, 1:many, many:1, or many:many), the + projected student and teacher distributions are averaged over their + respective spans, renormalized, and compared via KL divergence. + The per-chunk KL is then distributed back to student positions + and normalized with the standard NeMo RL masked_mean. + + The three ``precomputed_*`` kwargs let an outer aggregator hoist + student-side work that does not depend on the teacher (fp32 cast, + softmax, log_softmax) out of the per-teacher loop and share it + across multiple teachers with the same temperature. + """ + input_ids_student = data["input_ids"] + batch_size = input_ids_student.shape[0] + + # Keep logits in their native dtype (typically bf16). The downstream + # log_softmax / softmax / bmm ops on CUDA upcast to fp32 internally for + # numerics while storing activations in bf16, which roughly halves the + # working-set memory of the gold-loss / projection paths. + if precomputed_student_logits_f32 is not None: + student_logits = precomputed_student_logits_f32 + elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): + student_logits = next_token_logits.full_tensor() + else: + student_logits = next_token_logits + + if teacher_logits is None: + raise ValueError( + "CrossTokenizerDistillationLossFn requires teacher_logits via IPC. " + "Set use_ipc=True in the distillation config." + ) + if self._aligned_pairs is None or self._teacher_input_ids is None: + raise ValueError( + "Cross-tokenizer data not set. " + "Call loss_fn.set_cross_tokenizer_data() before training." + ) + + # IPC contract (XTokenTeacherIPCExportPostProcessor, full-vocab branch): + # raw logits in their native dtype (bf16 in practice). The log_softmax + # calls below in _compute_gold_loss / the projection path are the only + # ones in the teacher pipeline. Variable name kept for diff stability. + if isinstance(teacher_logits, torch.distributed.tensor.DTensor): + teacher_logits_f32 = teacher_logits.full_tensor() + else: + teacher_logits_f32 = teacher_logits + + if teacher_logits_f32.shape[-1] == 0: + raise ValueError( + f"Teacher logits have vocab dimension 0 (shape={teacher_logits_f32.shape}). " + "This typically means topk_logits=0 was passed instead of None " + "for the teacher forward pass. Cross-tokenizer distillation " + "requires full teacher logits (topk_logits=None)." + ) + + if mb_idx is not None and mbs is not None: + mb_start = mb_idx * mbs + mb_end = mb_start + batch_size + else: + mb_start = 0 + mb_end = batch_size + + self.token_aligner = self.token_aligner.to(student_logits.device) + device = student_logits.device + + temperature = self.cfg.get("temperature", 1.0) + vocab_topk = self.cfg.get("vocab_topk", 8192) + reverse_kl = self.cfg.get("reverse_kl", False) + use_gold_loss = self.cfg.get("gold_loss", False) + use_xtoken_loss = self.cfg.get("xtoken_loss", False) + student_seq_len = student_logits.shape[1] + teacher_seq_len = teacher_logits_f32.shape[1] + teacher_vocab_size = teacher_logits_f32.shape[-1] + + # The collator pre-applies exact_match_only, sentinel (-1), and + # padded-length bounds filters, and emits the chunk mask in COO form + # per sample. The loss fn just slices the MB range and scatters. + if self._chunk_indices is None: + raise ValueError( + "CrossTokenizerDistillationLossFn requires chunk_indices. " + "CrossTokenizerCollator should have precomputed them; verify " + "the training loop forwards them through update_cross_tokenizer_data()." + ) + student_coo_list = self._chunk_indices["student_chunk_coo"][mb_start:mb_end] + teacher_coo_list = self._chunk_indices["teacher_chunk_coo"][mb_start:mb_end] + num_chunks_list = self._chunk_indices["num_chunks"][mb_start:mb_end] + total_chunks = max(num_chunks_list) if num_chunks_list else 0 + + if total_chunks == 0: + loss = torch.tensor(0.0, device=device, requires_grad=True) + return loss, {"loss": 0.0, "topk_accuracy": 0.0, "num_chunks": 0} + + proj_mask = _scatter_chunk_mask_from_coo( + student_coo_list, batch_size, student_seq_len, total_chunks, device, + ) + tgt_mask = _scatter_chunk_mask_from_coo( + teacher_coo_list, batch_size, teacher_seq_len, total_chunks, device, + ) + num_valid_chunks_total = int(sum(num_chunks_list)) + + # ================================================================ + # Gold loss path: common-vocab KL + uncommon-vocab sorted L1. + # Bypasses the projection matrix for tokens with exact 1:1 mappings. + # Matches tokenalign.py compute_KL_loss_optimized gold_loss branch. + # ================================================================ + if use_gold_loss: + loss, top1_accuracy = self._compute_gold_loss( + student_logits, teacher_logits_f32, proj_mask, tgt_mask, + batch_size, student_seq_len, teacher_seq_len, + teacher_vocab_size, + temperature, reverse_kl, use_xtoken_loss, device, + precomputed_student_log_probs=precomputed_student_log_probs, + ) + else: + # ================================================================ + # Standard projection-based path + # ================================================================ + + # -- 3. Global vocabulary filtering (top-k teacher tokens) -- + with torch.no_grad(): + if vocab_topk == 0 or vocab_topk >= teacher_vocab_size: + global_top_indices = torch.arange(teacher_vocab_size, device=device) + else: + teacher_flat = teacher_logits_f32.view(-1, teacher_vocab_size) + importance = teacher_flat.max(dim=0)[0] + _, global_top_indices = torch.topk( + importance, k=min(vocab_topk, teacher_vocab_size), dim=-1, + ) + global_top_indices = global_top_indices.sort()[0] + + # -- 4. Project student probs to teacher vocab -- + projected_student = self._project_student_to_teacher( + student_logits, teacher_vocab_size, temperature, global_top_indices, device, + precomputed_student_probs=precomputed_student_probs, + ) + + # -- 5. Teacher log-probs in reduced vocab -- + teacher_logits_reduced = teacher_logits_f32[:, :, global_top_indices] + teacher_log_probs = torch.log_softmax(teacher_logits_reduced / temperature, dim=-1) + del teacher_logits_reduced + + # -- 6. Chunk-averaged distributions -- + proj_chunks = torch.bmm( + proj_mask.transpose(1, 2).to(projected_student.dtype), projected_student, + ) + tgt_log_chunks = torch.bmm( + tgt_mask.transpose(1, 2).to(teacher_log_probs.dtype), teacher_log_probs, + ) + del projected_student, teacher_log_probs + + proj_sizes = proj_mask.sum(dim=1).unsqueeze(-1).to(proj_chunks.dtype) + tgt_sizes = tgt_mask.sum(dim=1).unsqueeze(-1).to(tgt_log_chunks.dtype) + + proj_chunks = proj_chunks / (proj_sizes + 1e-10) + tgt_log_chunks = tgt_log_chunks / (tgt_sizes + 1e-10) + + proj_chunks = proj_chunks / (proj_chunks.sum(dim=-1, keepdim=True) + 1e-10) + proj_log_chunks = torch.log(proj_chunks + 1e-10) + + chunk_valid = (proj_sizes.squeeze(-1) > 0) & (tgt_sizes.squeeze(-1) > 0) + + # -- 7. KL divergence per chunk -- + if reverse_kl: + kl_per_elem = torch.nn.functional.kl_div( + tgt_log_chunks, proj_log_chunks, reduction="none", log_target=True, + ) + else: + kl_per_elem = torch.nn.functional.kl_div( + proj_log_chunks, tgt_log_chunks, reduction="none", log_target=True, + ) + kl_per_chunk = kl_per_elem.sum(dim=-1) * (temperature ** 2) + kl_per_chunk = kl_per_chunk * chunk_valid + del proj_chunks, tgt_log_chunks, proj_log_chunks, kl_per_elem + + # -- 8. Scalar loss -- + num_valid_chunks = chunk_valid.sum() + if num_valid_chunks > 0: + loss = kl_per_chunk.sum() / num_valid_chunks + else: + loss = torch.tensor(0.0, device=device, requires_grad=True) + top1_accuracy = 0.0 + + # Keep a stable alias for raw-KL return path and optional CE fusion path. + kl_loss = loss + if _return_raw_kl: + # Return scalar metrics as GPU tensors so the outer aggregator can + # batch the CPU sync for all teachers / metrics into a single + # .cpu().tolist() call. The aggregator unwraps tensor entries. + raw_metrics: dict[str, Any] = { + "loss": kl_loss if isinstance(kl_loss, torch.Tensor) else kl_loss, + "kl_loss": kl_loss if isinstance(kl_loss, torch.Tensor) else kl_loss, + "topk_accuracy": top1_accuracy, + "num_valid_samples": int(batch_size), + "num_chunks": num_valid_chunks_total, + } + return kl_loss, raw_metrics + + # ================================================================ + # Optional CE (next-token prediction) loss, matching the DDP + # train_distillation_ddp.py logic: + # without dynamic scaling: loss = kl * kl_weight + ce * ce_scale + # with dynamic scaling: loss = kl * (ce/kl) + ce + # ================================================================ + ce_loss_scale = self.cfg.get("ce_loss_scale", 0.0) + dynamic_loss_scaling = self.cfg.get("dynamic_loss_scaling", False) + ce_loss_value = 0.0 + + if ce_loss_scale > 0.0 or dynamic_loss_scaling: + # Mask padding positions so CE loss only covers real tokens. + # token_mask[:, 1:] marks valid next-token targets (shifted by 1). + token_mask = data["token_mask"] + ce_mask = token_mask[:, 1 : student_seq_len].to(torch.bool) + ce_targets = input_ids_student[:, 1:student_seq_len].clone() + ce_targets[~ce_mask] = -100 + ce_loss = torch.nn.functional.cross_entropy( + student_logits[:, :student_seq_len - 1].reshape(-1, student_logits.shape[-1]), + ce_targets.reshape(-1), + ignore_index=-100, + ) + ce_loss_value = float(ce_loss.item()) + + if dynamic_loss_scaling and kl_loss.item() > 0: + dls_scale = ce_loss.item() / kl_loss.item() + loss = kl_loss * dls_scale + ce_loss + else: + loss = kl_loss + ce_loss * ce_loss_scale + + # Scale for NeMo RL distributed training + token_mask = data["token_mask"] + sample_mask = data["sample_mask"] + max_len = min(token_mask.shape[1] - 1, student_seq_len) + local_mask = token_mask[:, 1 : max_len + 1] * sample_mask.unsqueeze(-1) + local_valid_toks = local_mask.sum() + + if local_valid_toks > 0 and global_valid_toks > 0: + tok_scale = float(local_valid_toks / global_valid_toks) + loss = loss * tok_scale + else: + tok_scale = 0.0 + loss = loss * 0.0 + + num_valid = num_valid_chunks_total + metrics = { + "loss": float(loss.item()) if loss.ndim == 0 else loss, + "kl_loss": float(kl_loss.item()) * tok_scale, + "ce_loss": ce_loss_value * tok_scale, + "topk_accuracy": top1_accuracy, + "num_valid_samples": int(batch_size), + "num_chunks": num_valid, + "alignment_density": num_valid / max(1, batch_size * student_seq_len), + } + + return loss, metrics + + +class MultiTeacherLossAggregator(LossFunction): + """Aggregate weighted losses from multiple teachers in a unified path.""" + + def __init__( + self, + loss_fns: list[Optional[CrossTokenizerDistillationLossFn]], + weights: list[float], + normalize_by_vocab: bool = False, + cfg: Optional[dict[str, Any]] = None, + ): + assert len(loss_fns) == len(weights), ( + f"loss_fns ({len(loss_fns)}) and weights ({len(weights)}) length mismatch" + ) + self.loss_fns = loss_fns + self.weights = weights + self.normalize_by_vocab = normalize_by_vocab + self.cfg = cfg or {} + self.teacher_aggregation_mode = self.cfg.get("teacher_aggregation_mode", "weighted") + if self.teacher_aggregation_mode not in {"weighted", "routing", "average"}: + raise ValueError( + "teacher_aggregation_mode must be one of {'weighted', 'routing', 'average'}, " + f"got '{self.teacher_aggregation_mode}'" + ) + self.loss_type = LossType.TOKEN_LEVEL + # Marks this loss fn as expecting student logits as its primary input. + # Required by `prepare_loss_input` so the unified single-/multi-teacher + # cross-tokenizer path can flow through this aggregator. + self.input_type = LossInputType.LOGIT + + def set_cross_tokenizer_data( + self, + teacher_input_ids: torch.Tensor, + aligned_pairs: list, + teacher_idx: Optional[int] = None, + chunk_indices: Optional[dict[str, list]] = None, + ) -> None: + # When called from the single-teacher dispatch (no explicit teacher_idx), + # default to the only teacher slot so the unified worker path keeps + # working without per-call branching upstream. + if teacher_idx is None: + if len(self.loss_fns) != 1: + raise ValueError( + "set_cross_tokenizer_data requires teacher_idx when " + f"len(loss_fns) > 1 (got {len(self.loss_fns)})" + ) + teacher_idx = 0 + fn = self.loss_fns[teacher_idx] + if fn is not None: + fn.set_cross_tokenizer_data( + teacher_input_ids, aligned_pairs, chunk_indices=chunk_indices, + ) + + def _compute_same_tokenizer_kl( + self, + next_token_logits: torch.Tensor, + teacher_logits: torch.Tensor, + data: CrossTokenizerDistillationLossDataDict, + teacher_topk_indices_ipc: Optional[torch.Tensor], + ) -> torch.Tensor: + student_logits = next_token_logits.to(torch.float32) + t_logits = teacher_logits.to(student_logits.device, dtype=torch.float32) + + seq_len = student_logits.shape[1] - 1 + student_shifted = student_logits[:, :-1] + + if teacher_topk_indices_ipc is None: + teacher_logprobs = torch.nn.functional.log_softmax(t_logits[:, :seq_len], dim=-1) + student_logprobs = torch.nn.functional.log_softmax(student_shifted, dim=-1) + per_token_kl = ( + teacher_logprobs.exp() * (teacher_logprobs - student_logprobs) + ).sum(dim=-1) + else: + topk_idx = teacher_topk_indices_ipc[:, :seq_len].to(student_shifted.device) + teacher_topk = t_logits[:, :seq_len] + student_logprobs = torch.nn.functional.log_softmax(student_shifted, dim=-1) + student_topk = torch.gather(student_logprobs, dim=-1, index=topk_idx) + teacher_topk_probs = teacher_topk.exp() + teacher_rest = (1.0 - teacher_topk_probs.sum(dim=-1, keepdim=True)).clamp(min=1e-10) + teacher_probs_full = torch.cat([teacher_topk_probs, teacher_rest], dim=-1) + teacher_logprobs_full = torch.cat([teacher_topk, teacher_rest.log()], dim=-1) + student_topk_probs = student_topk.exp() + student_rest = (1.0 - student_topk_probs.sum(dim=-1, keepdim=True)).clamp(min=1e-10) + student_logprobs_full = torch.cat([student_topk, student_rest.log()], dim=-1) + per_token_kl = ( + teacher_probs_full * (teacher_logprobs_full - student_logprobs_full) + ).sum(dim=-1) + + token_mask = data["token_mask"][:, 1 : seq_len + 1] + sample_mask = data["sample_mask"] + mask = token_mask * sample_mask.unsqueeze(-1) + valid_toks = mask.sum().clamp(min=1.0) + return (per_token_kl * mask).sum() / valid_toks + + def __call__( + self, + next_token_logits: Optional[torch.Tensor] = None, + data: Optional[CrossTokenizerDistillationLossDataDict] = None, + global_valid_seqs: Optional[torch.Tensor] = None, + global_valid_toks: Optional[torch.Tensor] = None, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + teacher_logits: Optional[torch.Tensor] = None, + mb_idx: Optional[int] = None, + mbs: Optional[int] = None, + teacher_topk_indices_ipc: Optional[torch.Tensor] = None, + teacher_logits_list: Optional[list[torch.Tensor]] = None, + teacher_topk_indices_list: Optional[list[Optional[torch.Tensor]]] = None, + teacher_routing_indices: Optional[torch.Tensor] = None, + logits: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, dict[str, Any]]: + # Accept `logits=` as an alias for `next_token_logits=` so this aggregator + # is a drop-in for `prepare_loss_input`, which emits {"logits": ...} for + # LossInputType.LOGIT. This removes the need for a separate Compat shim. + if next_token_logits is None: + next_token_logits = logits + if next_token_logits is None: + raise ValueError( + "MultiTeacherLossAggregator requires either `next_token_logits` " + "or `logits` to be provided." + ) + + if teacher_logits_list is None: + teacher_logits_list = [teacher_logits] if teacher_logits is not None else [] + if teacher_topk_indices_list is None: + teacher_topk_indices_list = [teacher_topk_indices_ipc] * len(teacher_logits_list) + + if len(teacher_logits_list) == 0: + zero = torch.tensor(0.0, device=next_token_logits.device, requires_grad=True) + return zero, {"loss": 0.0, "num_valid_samples": 0} + + if len(teacher_logits_list) != len(self.loss_fns): + raise ValueError( + "teacher_logits_list length must match number of configured teachers: " + f"{len(teacher_logits_list)} != {len(self.loss_fns)}" + ) + if len(teacher_topk_indices_list) != len(teacher_logits_list): + raise ValueError( + "teacher_topk_indices_list length must match teacher_logits_list length: " + f"{len(teacher_topk_indices_list)} != {len(teacher_logits_list)}" + ) + + vocab_sizes = [int(t.shape[-1]) for t in teacher_logits_list] + min_log_vocab = math.log(max(2, min(vocab_sizes))) if self.normalize_by_vocab else 1.0 + + total_kl = torch.tensor(0.0, device=next_token_logits.device, requires_grad=True) + metrics: dict[str, Any] = {} + # GPU-side scalar tensors accumulated during the loop, synced once + # (as a single batched D2H copy) at the end of the call. This keeps the + # CPU from stalling between teachers and between micro-steps, which is + # critical for letting the GPU pipeline kernels back-to-back. + gpu_scalar_metrics: dict[str, torch.Tensor] = {} + routing_indices = teacher_routing_indices + if routing_indices is None and isinstance(data, dict): + routing_indices = data.get("teacher_routing_indices", None) + if self.teacher_aggregation_mode == "routing": + if routing_indices is None: + raise ValueError( + "teacher_aggregation_mode='routing' requires teacher_routing_indices " + "either as an argument or in data['teacher_routing_indices']" + ) + if routing_indices.ndim != 1: + raise ValueError( + "teacher_routing_indices must be a rank-1 tensor with one teacher index per sample" + ) + if routing_indices.shape[0] != data["sample_mask"].shape[0]: + raise ValueError( + "teacher_routing_indices length must match batch size: " + f"{routing_indices.shape[0]} != {data['sample_mask'].shape[0]}" + ) + + original_sample_mask = data["sample_mask"] + active_teachers = sum(1 for t_logits in teacher_logits_list if t_logits is not None) + average_weight = 1.0 / max(1, active_teachers) + + # ===== Hoist student-only work out of the per-teacher loop ===== + # The fp32 cast and softmax/log_softmax of student_logits depend only + # on student_logits and (optionally) the per-teacher temperature. + # When two or more teachers share the same temperature and code path + # (gold_loss vs. projection), we can compute the corresponding + # student tensor exactly once and reuse it across those teachers. + per_teacher_share_keys: list[Optional[tuple[float, str]]] = [] + share_key_counts: dict[tuple[float, str], int] = {} + for teacher_idx, (loss_fn, t_logits) in enumerate( + zip(self.loss_fns, teacher_logits_list) + ): + if t_logits is None or loss_fn is None: + per_teacher_share_keys.append(None) + continue + cfg = getattr(loss_fn, "cfg", {}) or {} + temp = float(cfg.get("temperature", 1.0)) + kind = "log_probs" if cfg.get("gold_loss", False) else "probs" + key = (temp, kind) + per_teacher_share_keys.append(key) + share_key_counts[key] = share_key_counts.get(key, 0) + 1 + + has_sharing = any(c >= 2 for c in share_key_counts.values()) + shared_softmax_cache: dict[tuple[float, str], torch.Tensor] = {} + if has_sharing: + if isinstance(next_token_logits, torch.distributed.tensor.DTensor): + full_logits = next_token_logits.full_tensor() + else: + full_logits = next_token_logits + # Keep the cached softmax / log_softmax in the input dtype + # (typically bf16). The CUDA softmax / log_softmax kernels already + # accumulate in fp32 internally regardless of input dtype, so + # numerics match the previous explicit fp32 cast while halving + # the cache memory and avoiding two transient full-vocab fp32 + # buffers (the .to(fp32) result and the / temp intermediate) + # that were causing OOM at mbs > 1. + for key, count in share_key_counts.items(): + if count < 2: + continue + temp, kind = key + # Skip the divide-by-temperature kernel (and its transient + # buffer) when temperature is the no-op default. + scaled_logits = full_logits if temp == 1.0 else full_logits / temp + if kind == "probs": + shared_softmax_cache[key] = torch.softmax(scaled_logits, dim=-1) + else: + shared_softmax_cache[key] = torch.log_softmax(scaled_logits, dim=-1) + if scaled_logits is not full_logits: + del scaled_logits + del full_logits + # =============================================================== + + for teacher_idx, (loss_fn, weight, t_logits, t_topk_idx) in enumerate( + zip(self.loss_fns, self.weights, teacher_logits_list, teacher_topk_indices_list) + ): + if t_logits is None: + continue + if self.teacher_aggregation_mode == "routing": + routed_sample_mask = original_sample_mask * ( + routing_indices.to(original_sample_mask.device) == teacher_idx + ).to(original_sample_mask.dtype) + routed_samples = int((routed_sample_mask > 0).sum().item()) + metrics[f"teacher_{teacher_idx}/routed_samples"] = routed_samples + if routed_samples == 0: + continue + data["sample_mask"] = routed_sample_mask + teacher_compute_start = time.perf_counter() + if loss_fn is not None: + share_key = per_teacher_share_keys[teacher_idx] + shared_softmax = ( + shared_softmax_cache.get(share_key) + if share_key is not None + else None + ) + shared_kwargs: dict[str, Any] = {} + if shared_softmax is not None and share_key is not None: + if share_key[1] == "probs": + shared_kwargs["precomputed_student_probs"] = shared_softmax + else: + shared_kwargs["precomputed_student_log_probs"] = shared_softmax + teacher_kl, teacher_metrics = loss_fn( + next_token_logits=next_token_logits, + data=data, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, + teacher_logits=t_logits, + mb_idx=mb_idx, + mbs=mbs, + teacher_topk_indices_ipc=t_topk_idx, + _return_raw_kl=True, + **shared_kwargs, + ) + else: + teacher_kl = self._compute_same_tokenizer_kl( + next_token_logits, t_logits, data, t_topk_idx + ) + # Defer this scalar's sync to the batched .cpu().tolist() below. + teacher_metrics = {"kl_loss": teacher_kl} + data["sample_mask"] = original_sample_mask + teacher_compute_elapsed = time.perf_counter() - teacher_compute_start + + vocab_scale = 1.0 + if self.normalize_by_vocab: + vocab_scale = math.log(max(2, int(t_logits.shape[-1]))) / min_log_vocab + + if self.teacher_aggregation_mode == "weighted": + effective_weight = weight + elif self.teacher_aggregation_mode == "average": + effective_weight = average_weight + else: + effective_weight = 1.0 + weighted_teacher_kl = teacher_kl * effective_weight * vocab_scale + total_kl = total_kl + weighted_teacher_kl + + # Stash GPU scalars; sync once at the end. Non-tensor metadata + # (weights, vocab_scale, elapsed time) goes straight into metrics. + gpu_scalar_metrics[f"teacher_{teacher_idx}/raw_kl"] = teacher_kl.detach() + metrics[f"teacher_{teacher_idx}/weight"] = float(effective_weight) + metrics[f"teacher_{teacher_idx}/vocab_scale"] = float(vocab_scale) + gpu_scalar_metrics[f"teacher_{teacher_idx}/weighted_kl"] = ( + weighted_teacher_kl.detach() + ) + metrics[f"teacher_{teacher_idx}/loss_compute"] = float(teacher_compute_elapsed) + for key, value in teacher_metrics.items(): + metric_key = f"teacher_{teacher_idx}/{key}" + if isinstance(value, torch.Tensor) and value.ndim == 0: + gpu_scalar_metrics[metric_key] = value.detach() + else: + metrics[metric_key] = value + + # Release shared student tensors before CE loss / masking work to keep + # peak memory close to the original per-teacher implementation. + shared_softmax_cache.clear() + del shared_softmax_cache + + ce_loss_scale = self.cfg.get("ce_loss_scale", 0.0) + dynamic_loss_scaling = self.cfg.get("dynamic_loss_scaling", False) + loss = total_kl + ce_loss_tensor: Optional[torch.Tensor] = None + if ce_loss_scale > 0.0 or dynamic_loss_scaling: + # Pass logits in their native dtype; cross_entropy internally + # promotes to fp32 for the log_softmax/NLL reduction. Avoids + # materializing a second full-vocab fp32 tensor here. + ce_logits = next_token_logits + student_seq_len = ce_logits.shape[1] + # Mask padding positions so CE loss only covers real tokens. + token_mask_ce = data["token_mask"][:, 1:student_seq_len].to(torch.bool) + ce_targets = data["input_ids"][:, 1:student_seq_len].clone() + ce_targets[~token_mask_ce] = -100 + ce_loss = torch.nn.functional.cross_entropy( + ce_logits[:, :student_seq_len - 1].reshape(-1, ce_logits.shape[-1]), + ce_targets.reshape(-1), + ignore_index=-100, + ) + ce_loss_tensor = ce_loss + if dynamic_loss_scaling: + # Avoid the per-microbatch CPU<->GPU sync by computing the + # scaling factor entirely on-device. The scale is detached so + # gradient flow matches the original (where dls_scale was a + # Python scalar treated as a constant). The clamp guards + # against the degenerate total_kl==0 case; in practice KL is + # strictly positive, so this matches the original branch. + dls_scale = ( + ce_loss.detach() / total_kl.detach().clamp(min=1e-10) + ) + loss = total_kl * dls_scale + ce_loss + else: + loss = total_kl + ce_loss * ce_loss_scale + + token_mask = data["token_mask"] + sample_mask = data["sample_mask"] + student_seq_len = next_token_logits.shape[1] + max_len = min(token_mask.shape[1] - 1, student_seq_len) + local_mask = token_mask[:, 1 : max_len + 1] * sample_mask.unsqueeze(-1) + local_valid_toks = local_mask.sum() + if local_valid_toks > 0 and global_valid_toks > 0: + tok_scale = float(local_valid_toks / global_valid_toks) + loss = loss * tok_scale + else: + tok_scale = 0.0 + loss = loss * 0.0 + + # Defer the loss / kl_loss / ce_loss sync to the single batched + # transfer below. Apply tok_scale to kl_loss and ce_loss so wandb + # reports the correct global mean (not N_ranks × mean). + if isinstance(loss, torch.Tensor) and loss.ndim == 0: + gpu_scalar_metrics["loss"] = loss.detach() + else: + metrics["loss"] = loss + if isinstance(total_kl, torch.Tensor) and total_kl.ndim == 0: + gpu_scalar_metrics["kl_loss"] = total_kl.detach() * tok_scale + else: + metrics["kl_loss"] = total_kl * tok_scale if total_kl else 0.0 + if ce_loss_tensor is not None: + gpu_scalar_metrics["ce_loss"] = ce_loss_tensor.detach() * tok_scale + else: + metrics["ce_loss"] = 0.0 + metrics["num_valid_samples"] = int(data["input_ids"].shape[0]) + + # Single batched D2H copy for every scalar metric we collected above. + # This replaces the 8-14 individual .item() syncs that were previously + # interspersed throughout the per-teacher loop and the CE / dynamic + # loss scaling branches, which forced the CPU to wait between teachers + # and prevented the GPU from pipelining their kernels. + if gpu_scalar_metrics: + keys = list(gpu_scalar_metrics.keys()) + stacked = torch.stack( + [t.reshape(()) for t in gpu_scalar_metrics.values()] + ) + values = stacked.cpu().tolist() + for k, v in zip(keys, values): + metrics[k] = float(v) + + return loss, metrics diff --git a/nemo_rl/algorithms/off_policy_distillation.py b/nemo_rl/algorithms/off_policy_distillation.py new file mode 100644 index 0000000000..6b5556141c --- /dev/null +++ b/nemo_rl/algorithms/off_policy_distillation.py @@ -0,0 +1,1412 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and limitations. +# limitations under the License. + +""" +Off-Policy Distillation Algorithm + +This module implements off-policy distillation where: +- A fixed dataset of prompt-response pairs is used (no student generation) +- Teacher provides logits for the fixed responses +- Student aligns with teacher using KL divergence loss + +Key difference from on-policy distillation (in distillation.py): +- No student generation step - uses pre-existing responses from dataset +- No environment needed for reward computation +- Simpler training loop without rollout generation +""" + +import importlib.util +import os +import warnings +from pathlib import Path +import sys +if sys.version_info >= (3, 11): + from typing import Any, Callable, NotRequired, Optional, TypedDict, TypeVar, cast +else: + from typing import Any, Callable, Optional, TypedDict, TypeVar, cast + from typing_extensions import NotRequired + +import numpy as np +import torch +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoConfig, AutoTokenizer +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from nemo_rl.algorithms.loss.loss_functions import ( + CrossTokenizerDistillationLossFn, + DistillationLossConfig, + DistillationLossDataDict, + DistillationLossFn, + MultiTeacherLossAggregator, +) +from nemo_rl.algorithms.utils import maybe_pad_last_batch, set_seed +from nemo_rl.data import DataConfig +from nemo_rl.data.collate_fn import rl_collate_fn +from nemo_rl.data.cross_tokenizer_collate import ( + CrossTokenizerCollator, + TeacherCTSpec, +) +from nemo_rl.data.datasets import AllTaskProcessedDataset +from nemo_rl.data.interfaces import DatumSpec +from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import ( + ClusterConfig, + RayVirtualCluster, +) +from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy.interfaces import ColocatablePolicyInterface +from nemo_rl.models.policy.lm_policy import Policy +from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager +from nemo_rl.utils.logger import Logger, LoggerConfig +from nemo_rl.utils.nsys import maybe_gpu_profile_step +from nemo_rl.utils.timer import TimeoutChecker, Timer + +# =============================================================================== +# Configuration +# =============================================================================== +TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) + + +class TokenAlignerConfig(TypedDict, total=False): + """Configuration for cross-tokenizer distillation via TokenAligner. + + When enabled, teacher and student may use different tokenizers/vocabularies. + A precomputed projection matrix maps between the two vocabulary spaces. + """ + enabled: bool # Master switch for cross-tokenizer mode + projection_matrix_path: str # Path to .pt projection matrix file + use_sparse_format: bool # True = sparse COO format, False = dense indices/values + loss_type: str # 'KL', 'cross_entropy', or 'chunked_ce' + exact_token_match_only: bool # Only use 1:1 aligned token positions for loss + temperature: float # Softmax temperature for KL computation + vocab_topk: int # Reduce teacher vocab to top-k for speed (0 = all) + reverse_kl: bool # If True, use reverse KL direction + projection_matrix_multiplier: float # Scaling factor for projection matrix + max_comb_len: int # Max combination length for token alignment DP + learnable: bool # If True, projection matrix is trainable + project_teacher_to_student: bool # If True, project teacher->student instead of student->teacher + use_char_offset: bool # If True, try char-offset alignment before DP fallback + force_dp_only: bool # If True, disable char-offset path and run DP for all samples + use_cuda_dp: bool # If True, patch TokenAligner chunked DP base case with CUDA kernel + dp_chunk_size: int # Chunk size used by DP chunked solver + + +class OffPolicyDistillationConfig(TypedDict): + """Configuration for off-policy distillation training. + + Simplified compared to on-policy: + - No num_generations_per_prompt (we use fixed responses) + - No max_rollout_turns (no generation) + """ + num_prompts_per_step: int # Batch size + max_num_steps: int # Maximum number of steps to train for + max_num_epochs: int # Maximum number of epochs to train for + topk_logits_k: int # Top-k logits for sparse KL loss + seed: int + # Validation settings + val_period: NotRequired[int] # Run validation every N steps (0 = disabled) + val_batches: NotRequired[int] # Number of validation batches (0 = all) + val_global_batch_size: NotRequired[int] # Validation batch size + val_micro_batch_size: NotRequired[int] # Validation micro batch size + val_at_start: NotRequired[bool] # Run validation before training starts + + +class OffPolicyDistillationSaveState(TypedDict): + """State to save for checkpointing.""" + total_steps: int # Track total number of steps across all epochs + current_epoch: int # Track current epoch + current_step: int # Track step within current epoch + consumed_samples: int + total_valid_tokens: int # Track total number of non-padding tokens during training + + +def _default_distillation_save_state() -> OffPolicyDistillationSaveState: + return { + "current_epoch": 0, + "current_step": 0, + "total_steps": 0, + "consumed_samples": 0, + "total_valid_tokens": 0, + } + + +class OffPolicyMasterConfig(TypedDict): + """Main configuration structure for off-policy distillation. + + Key difference from on-policy MasterConfig: + - No 'env' config (no environment needed) + """ + policy: PolicyConfig # Student model configuration + teacher: PolicyConfig # Teacher model configuration (single-teacher compatibility) + loss_fn: DistillationLossConfig # Loss function configuration + data: DataConfig # Data configuration + distillation: OffPolicyDistillationConfig # Distillation configuration + logger: LoggerConfig # Logger configuration + cluster: ClusterConfig # Cluster configuration + checkpointing: CheckpointingConfig # Checkpointing configuration + token_aligner: NotRequired[TokenAlignerConfig] # Cross-tokenizer config (single-teacher compatibility) + teachers: NotRequired[list["TeacherSpec"]] # Multi-teacher configuration + + +class TeacherSpec(TypedDict, total=False): + """Per-teacher configuration for multi-teacher distillation.""" + + teacher: PolicyConfig + token_aligner: TokenAlignerConfig + loss_fn: DistillationLossConfig + weight: float + + +# =============================================================================== +# Cross-Tokenizer Processing +# =============================================================================== +# CT (teacher tokenize + DP alignment) runs inside the StatefulDataLoader's +# worker processes via ``CrossTokenizerCollator`` (nemo_rl/data/cross_tokenizer_collate.py). +# The training loop just consumes pre-processed batches. + + +# =============================================================================== +# Setup & Initialization +# =============================================================================== +def check_vocab_equality( + tokenizer: TokenizerType, student_model_name: str, teacher_model_name: str +) -> None: + """Check if the vocab of the tokenizer (student) and the teacher tokenizer are equal.""" + teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name) + + skip_hint = "Set NRL_SKIP_DISTILLATION_TOKENIZER_CHECK=true to skip this check." + + # 1) Exact token->id mapping equality + vocab_a = tokenizer.get_vocab() + vocab_b = teacher_tokenizer.get_vocab() + assert vocab_a == vocab_b, ( + f"Token->ID mapping differs between student and teacher. {skip_hint}" + ) + + # 2) Size consistency (sanity checks) + assert len(tokenizer) == len(teacher_tokenizer), ( + f"Effective vocab sizes differ between student and teacher. {skip_hint}" + ) + + # 3) Check model.config.vocab_size to guarantee the last dimension of the logits is the same + student_config = AutoConfig.from_pretrained(student_model_name) + teacher_config = AutoConfig.from_pretrained(teacher_model_name) + assert student_config.vocab_size == teacher_config.vocab_size, ( + f"Model config vocab sizes differ between student and teacher. {skip_hint}" + ) + + +def _ensure_topk_logprobs_for_non_ipc( + teacher_topk_logits: torch.Tensor, +) -> tuple[torch.Tensor, bool]: + """Normalize teacher top-k values to log-probs for non-IPC distillation. + + Depending on worker/backend path, `get_topk_logits` may return either: + - top-k log-probabilities, or + - raw top-k logits. + Distillation loss expects log-probs in this non-IPC data-dict path. + """ + teacher_topk_logits = teacher_topk_logits.to(torch.float32) + topk_mass = teacher_topk_logits.exp().sum(dim=-1) + looks_like_logprobs = bool( + (teacher_topk_logits.max() <= 1e-6).item() + and (topk_mass.max() <= 1.0001).item() + ) + if looks_like_logprobs: + return teacher_topk_logits, False + return torch.nn.functional.log_softmax(teacher_topk_logits, dim=-1), True + + +def _normalize_teacher_specs(master_config: OffPolicyMasterConfig) -> list[TeacherSpec]: + """Return a normalized teacher spec list for unified single/multi path.""" + teachers_cfg = master_config.get("teachers", []) + if teachers_cfg: + return teachers_cfg + + single_spec: TeacherSpec = { + "teacher": master_config["teacher"], + "weight": 1.0, + } + token_aligner_cfg = master_config.get("token_aligner", {}) + if token_aligner_cfg.get("enabled", False): + single_spec["token_aligner"] = token_aligner_cfg + return [single_spec] + + +def _group_teacher_logits_by_rank(all_teacher_logits: list[Any]) -> dict[int, list[Any]]: + """Repack ``[teacher][rank]`` payloads into ``{rank: [teacher_payloads...]}``.""" + teacher_logits_by_rank: dict[int, list[Any]] = {} + for teacher_result in all_teacher_logits: + for rank, payload in enumerate(teacher_result): + teacher_logits_by_rank.setdefault(rank, []).append(payload) + return teacher_logits_by_rank + + +def setup( + master_config: OffPolicyMasterConfig, + tokenizer: TokenizerType, + train_dataset: AllTaskProcessedDataset, + val_dataset: Optional[AllTaskProcessedDataset] = None, +) -> tuple[ + ColocatablePolicyInterface, # student_policy + list[ColocatablePolicyInterface], # teacher_policies + StatefulDataLoader, # train_dataloader + Optional[StatefulDataLoader], # val_dataloader + DistillationLossFn, + Logger, + CheckpointManager, + OffPolicyDistillationSaveState, + OffPolicyMasterConfig, + list[Any], # token_aligners (per-teacher, with None for same-tokenizer teachers) + list[Optional[PreTrainedTokenizerBase]], # teacher_tokenizers +]: + """Setup for off-policy distillation algorithm. + + Key differences from on-policy setup(): + - No student_generation interface (we don't generate responses) + - Simpler cluster setup (training only, no inference cluster needed) + + Returns: + tuple of student_policy, teacher_policy, train_dataloader, val_dataloader, + loss_fn, logger, checkpointer, distillation_save_state, master_config + """ + # Extract configuration + policy_config = master_config["policy"] + loss_config = master_config["loss_fn"] + distillation_config = master_config["distillation"] + data_config = master_config["data"] + logger_config = master_config["logger"] + cluster_config = master_config["cluster"] + teacher_specs = _normalize_teacher_specs(master_config) + + # Disallow SP + packing for dtensor path + for cfg, who in ((policy_config, "student"),): + dtensor_enabled = cfg["dtensor_cfg"]["enabled"] + sequence_packing_enabled = ( + "sequence_packing" in cfg and cfg["sequence_packing"]["enabled"] + ) + sequence_parallel_enabled = ( + "sequence_parallel" in cfg["dtensor_cfg"] + and cfg["dtensor_cfg"]["sequence_parallel"] + ) + + if dtensor_enabled and sequence_packing_enabled and sequence_parallel_enabled: + raise AssertionError( + f"Distillation does not support DTensor sequence parallel + sequence packing ({who} policy). " + "Please refer to https://github.com/NVIDIA-NeMo/RL/issues/1178 for more details." + ) + for teacher_idx, spec in enumerate(teacher_specs): + teacher_cfg = spec["teacher"] + dtensor_enabled = teacher_cfg["dtensor_cfg"]["enabled"] + sequence_packing_enabled = ( + "sequence_packing" in teacher_cfg + and teacher_cfg["sequence_packing"]["enabled"] + ) + sequence_parallel_enabled = ( + "sequence_parallel" in teacher_cfg["dtensor_cfg"] + and teacher_cfg["dtensor_cfg"]["sequence_parallel"] + ) + if dtensor_enabled and sequence_packing_enabled and sequence_parallel_enabled: + raise AssertionError( + "Distillation does not support DTensor sequence parallel + sequence packing " + f"(teacher_{teacher_idx} policy). " + "Please refer to https://github.com/NVIDIA-NeMo/RL/issues/1178 for more details." + ) + + # Set random seed + set_seed(distillation_config["seed"]) + + # ========================== + # Logger + # ========================== + logger = Logger(logger_config) + logger.log_hyperparams(master_config) + + # ========================== + # Checkpointing + # ========================== + checkpointer = CheckpointManager(master_config["checkpointing"]) + last_checkpoint_path = checkpointer.get_latest_checkpoint_path() + distillation_save_state: Optional[OffPolicyDistillationSaveState] = cast( + Optional[OffPolicyDistillationSaveState], + checkpointer.load_training_info(last_checkpoint_path), + ) + if distillation_save_state is None: + distillation_save_state = _default_distillation_save_state() + + # ========================== + # Data + # ========================== + teacher_ct_specs: list[Optional[TeacherCTSpec]] = [] + for spec_cfg in teacher_specs: + ta_cfg = spec_cfg.get("token_aligner", {}) + if not ta_cfg.get("enabled", False): + teacher_ct_specs.append(None) + continue + teacher_model_name = spec_cfg["teacher"]["model_name"] + per_teacher_loss_cfg = spec_cfg.get("loss_fn", loss_config) + teacher_ct_specs.append( + TeacherCTSpec( + teacher_tokenizer_name=teacher_model_name, + student_tokenizer_name=policy_config["model_name"], + projection_matrix_path=ta_cfg["projection_matrix_path"], + use_sparse_format=bool(ta_cfg.get("use_sparse_format", False)), + learnable=bool(ta_cfg.get("learnable", False)), + max_comb_len=int(ta_cfg.get("max_comb_len", 4)), + projection_matrix_multiplier=float( + ta_cfg.get("projection_matrix_multiplier", 1.0) + ), + project_teacher_to_student=bool( + ta_cfg.get("project_teacher_to_student", False) + ), + max_teacher_len=int( + spec_cfg["teacher"].get( + "max_total_sequence_length", + policy_config["max_total_sequence_length"], + ) + ), + dp_chunk_size=int(ta_cfg.get("dp_chunk_size", 128)), + exact_token_match_only=bool( + per_teacher_loss_cfg.get("exact_token_match_only", False) + ), + ) + ) + + train_collator = CrossTokenizerCollator( + pad_token_id=tokenizer.pad_token_id, + make_sequence_length_divisible_by=policy_config.get( + "make_sequence_length_divisible_by", 1 + ), + teacher_ct_specs=teacher_ct_specs, + fallback_student_tokenizer_name=policy_config["model_name"], + ) + + nw = int(data_config.get("num_workers", 8)) + pf = int(data_config.get("prefetch_factor", 4)) + dataloader_kwargs: dict[str, Any] = dict( + batch_size=distillation_config["num_prompts_per_step"], + shuffle=data_config.get("shuffle", True), + collate_fn=train_collator, + drop_last=True, + num_workers=nw, + persistent_workers=nw > 0, + ) + if nw > 0: + dataloader_kwargs["prefetch_factor"] = pf + dataloader = StatefulDataLoader(train_dataset, **dataloader_kwargs) + + if last_checkpoint_path: + dataloader_state_dict = torch.load( + os.path.join(last_checkpoint_path, "train_dataloader.pt") + ) + dataloader.load_state_dict(dataloader_state_dict) + + print( + f" ✓ Training dataloader loaded with {len(train_dataset)} samples", flush=True + ) + + # Load validation dataloader if provided + val_dataloader: Optional[StatefulDataLoader] = None + val_period = distillation_config.get("val_period", 0) + val_at_start = distillation_config.get("val_at_start", False) + if val_period > 0 or val_at_start: + assert val_dataset is not None, ( + "Validation dataset is required if validation is enabled " + "(val_period > 0 or val_at_start = True)" + ) + val_dataloader = StatefulDataLoader( + val_dataset, + batch_size=distillation_config.get( + "val_global_batch_size", distillation_config["num_prompts_per_step"] + ), + shuffle=False, + collate_fn=rl_collate_fn, + drop_last=False, + ) + print( + f" ✓ Validation dataloader loaded with {len(val_dataset)} samples", + flush=True, + ) + + # ========================== + # Cluster + # ========================== + # For off-policy distillation, we only need a training cluster + # No inference cluster needed since we don't generate responses + print("\n▶ Setting up compute cluster...", flush=True) + + # Need one colocated worker-group slot per policy (all teachers + student). + # Keep historical minimum of 3 for existing two-teacher setups. + required_worker_groups = max(3, len(teacher_specs) + 1) + cluster = RayVirtualCluster( + name="off_policy_distillation_cluster", + bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] + * cluster_config["num_nodes"], + use_gpus=True, + num_gpus_per_node=cluster_config["gpus_per_node"], + max_colocated_worker_groups=required_worker_groups, + ) + print( + f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes", + flush=True, + ) + + # ========================== + # Teacher Policies + # ========================== + token_aligners: list[Any] = [] + teacher_tokenizers: list[Optional[TokenizerType]] = [] + teacher_policies: list[ColocatablePolicyInterface] = [] + for teacher_idx, spec_cfg in enumerate(teacher_specs): + teacher_cfg = spec_cfg["teacher"] + token_aligner_cfg = spec_cfg.get("token_aligner", {}) + cross_tokenizer_enabled = token_aligner_cfg.get("enabled", False) + token_aligner = None + teacher_tokenizer = None + + if cross_tokenizer_enabled: + from nemo_rl.algorithms.x_token.tokenalign import TokenAligner + + print( + f"\n▶ Setting up cross-tokenizer distillation for teacher {teacher_idx}...", + flush=True, + ) + teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_cfg["model_name"]) + if teacher_tokenizer.pad_token is None: + teacher_tokenizer.pad_token = teacher_tokenizer.eos_token + + token_aligner = TokenAligner( + teacher_tokenizer_name=teacher_cfg["model_name"], + student_tokenizer_name=policy_config["model_name"], + max_comb_len=token_aligner_cfg.get("max_comb_len", 4), + projection_matrix_multiplier=token_aligner_cfg.get( + "projection_matrix_multiplier", 1.0 + ), + ) + token_aligner._load_logits_projection_map( + file_path=token_aligner_cfg["projection_matrix_path"], + use_sparse_format=token_aligner_cfg.get("use_sparse_format", True), + learnable=token_aligner_cfg.get("learnable", False), + device="cpu", + ) + if token_aligner_cfg.get("project_teacher_to_student", False): + token_aligner.create_reverse_projection_matrix(device="cpu") + if token_aligner_cfg.get("use_cuda_dp", False): + cuda_dp_path = ( + Path(__file__).resolve().parents[2] + / "x_token" + / "cuda_tokenalign_dp.py" + ) + if not cuda_dp_path.exists(): + raise FileNotFoundError( + "Requested token_aligner.use_cuda_dp=true but file not found: " + f"{cuda_dp_path}" + ) + spec_obj = importlib.util.spec_from_file_location( + "x_token_cuda_dp", str(cuda_dp_path) + ) + if spec_obj is None or spec_obj.loader is None: + raise ImportError( + f"Failed to load CUDA DP module from: {cuda_dp_path}" + ) + mod = importlib.util.module_from_spec(spec_obj) + spec_obj.loader.exec_module(mod) + mod.monkeypatch_tokenaligner_cuda_basecase() + token_aligner._use_cuda_dp = True + token_aligner._cuda_dp_module_path = str(cuda_dp_path) + print( + f" ✓ Teacher {teacher_idx} CUDA DP monkeypatch enabled", + flush=True, + ) + else: + if not bool(os.getenv("NRL_SKIP_DISTILLATION_TOKENIZER_CHECK", False)): + check_vocab_equality( + tokenizer, policy_config["model_name"], teacher_cfg["model_name"] + ) + + if "megatron_cfg" in teacher_cfg and teacher_cfg["megatron_cfg"]["enabled"]: + total_train_iters = min( + distillation_config["max_num_steps"], + distillation_config["max_num_epochs"] * len(dataloader), + ) + teacher_cfg["megatron_cfg"]["train_iters"] = total_train_iters + + print( + f"\n▶ Setting up teacher policy {teacher_idx} ({teacher_cfg['model_name']})...", + flush=True, + ) + teacher_policy = Policy( + name_prefix=f"teacher_{teacher_idx}" + if len(teacher_specs) > 1 + else "teacher", + cluster=cluster, + config=teacher_cfg, + tokenizer=teacher_tokenizer if cross_tokenizer_enabled else tokenizer, + weights_path=None, + optimizer_path=None, + init_optimizer=False, + init_reference_model=False, + ) + if not bool(distillation_config.get("keep_models_resident", False)): + teacher_policy.offload_after_refit() + + token_aligners.append(token_aligner) + teacher_tokenizers.append(teacher_tokenizer) + teacher_policies.append(teacher_policy) + + # ========================== + # Student Policy + # ========================== + # Note: No student_generation interface for off-policy distillation + print("\n▶ Setting up student policy...", flush=True) + + # Checkpoint paths + weights_path = None + optimizer_path = None + if last_checkpoint_path: + weights_path = Path(last_checkpoint_path) / "policy" / "weights" + optimizer_path = Path(last_checkpoint_path) / "policy" / "optimizer" + + if "megatron_cfg" in policy_config and policy_config["megatron_cfg"]["enabled"]: + ## NOTE: this is equal to the total number of scheduler steps + total_train_iters = min( + distillation_config["max_num_steps"], + distillation_config["max_num_epochs"] * len(dataloader), + ) + policy_config["megatron_cfg"]["train_iters"] = total_train_iters + + student_policy = Policy( + name_prefix="student", + cluster=cluster, + config=policy_config, + tokenizer=tokenizer, + weights_path=weights_path, + optimizer_path=optimizer_path, + init_optimizer=True, + init_reference_model=False, + ) + + if any(ta is not None for ta in token_aligners): + # Unified single-/multi-teacher path: always go through the aggregator + # so per-microbatch metrics are batched + .cpu()-synced once and the + # downstream code never sees stray GPU tensors in `all_mb_metrics`. + per_teacher_loss_fns: list[Optional[CrossTokenizerDistillationLossFn]] = [] + per_teacher_weights: list[float] = [] + for t_idx, spec_cfg in enumerate(teacher_specs): + teacher_loss_cfg = spec_cfg.get("loss_fn", loss_config) + if token_aligners[t_idx] is None: + per_teacher_loss_fns.append(None) + else: + per_teacher_loss_fns.append( + CrossTokenizerDistillationLossFn( + teacher_loss_cfg, token_aligners[t_idx] + ) + ) + per_teacher_weights.append(spec_cfg.get("weight", 1.0)) + loss_fn = MultiTeacherLossAggregator( + per_teacher_loss_fns, + per_teacher_weights, + normalize_by_vocab=loss_config.get("normalize_by_vocab", False), + cfg=loss_config, + ) + else: + loss_fn = DistillationLossFn(loss_config) + + print("\n" + "=" * 60) + print(" " * 12 + "OFF-POLICY DISTILLATION SETUP COMPLETE") + print("=" * 60 + "\n", flush=True) + + return ( + student_policy, + teacher_policies, + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + distillation_save_state, + master_config, + token_aligners, + teacher_tokenizers, + ) + + +# =============================================================================== +# Training +# =============================================================================== + + +def validate( + student_policy: ColocatablePolicyInterface, + teacher_policies: list[ColocatablePolicyInterface], + val_dataloader: Optional[StatefulDataLoader], + tokenizer: TokenizerType, + loss_fn: DistillationLossFn, + step: int, + master_config: OffPolicyMasterConfig, +) -> tuple[dict[str, Any], dict[str, Any]]: + """Run validation on the validation dataset for off-policy distillation. + + Computes teacher top-k logits and student distillation loss on validation data + in eval mode (no gradient updates). + + Args: + student_policy: The student policy to evaluate. + teacher_policies: Teacher policy list; first teacher used for validation. + val_dataloader: Validation dataloader. + tokenizer: Tokenizer for processing text. + loss_fn: Distillation loss function. + step: Current training step (for logging). + master_config: Master configuration dictionary. + + Returns: + Tuple of (val_metrics, timing_metrics). + """ + if val_dataloader is None: + print(" ⚠️ No validation dataloader provided, skipping validation", flush=True) + return {}, {} + + timer = Timer() + + with timer.time("total_validation_time"): + print(f"▶ Starting validation at step {step}...", flush=True) + + val_metrics: dict[str, Any] = {"val_loss": 0.0} + sum_num_valid_tokens = 0 + + val_batches = master_config["distillation"].get("val_batches", 0) + val_batch_size = master_config["distillation"].get( + "val_global_batch_size", + master_config["distillation"]["num_prompts_per_step"], + ) + val_mbs = master_config["distillation"].get( + "val_micro_batch_size", val_batch_size + ) + + for batch_idx, val_batch in enumerate(val_dataloader): + # Add loss masks for assistant tokens + for message_log in val_batch["message_log"]: + for message in message_log: + if "token_loss_mask" not in message: + if message["role"] == "assistant": + message["token_loss_mask"] = torch.ones_like( + message["token_ids"] + ) + else: + message["token_loss_mask"] = torch.zeros_like( + message["token_ids"] + ) + + # Flatten messages + flat_messages, input_lengths = batched_message_log_to_flat_message( + val_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + make_sequence_length_divisible_by=master_config["policy"].get( + "make_sequence_length_divisible_by", 1 + ), + ) + + val_data = BatchedDataDict[DistillationLossDataDict]( + { + "input_ids": flat_messages["token_ids"], + "input_lengths": input_lengths, + "token_mask": flat_messages["token_loss_mask"], + "sample_mask": val_batch["loss_multiplier"], + } + ) + val_data.update(flat_messages.get_multimodal_dict(as_tensors=False)) + val_data.to("cpu") + + # Pad partial batch if needed (drop_last=False for val) + # Must pad BEFORE teacher logits to avoid size mismatch: + # teacher.get_topk_logits internally pads for its own DP sharding + # and returns padded-size outputs, so all inputs must be + # uniformly padded first. + if val_data.size < val_batch_size: + dp_size = student_policy.sharding_annotations.get_axis_size( + "data_parallel" + ) + val_data = maybe_pad_last_batch(val_data, dp_size, val_mbs) + + # Get teacher top-k logits + use_ipc = master_config["distillation"].get("use_ipc", True) + topk_k = master_config["distillation"]["topk_logits_k"] + + teacher_policy = teacher_policies[0] + teacher_policy.prepare_for_lp_inference() + if use_ipc: + teacher_logits = teacher_policy.compute_teacher_logits_ipc( + val_data, + topk_logits=topk_k, + gbs=val_data.size, + mbs=master_config["distillation"].get( + "val_micro_batch_size", + master_config["distillation"].get( + "val_global_batch_size", + master_config["distillation"]["num_prompts_per_step"], + ), + ), + ) + else: + teacher_topk = teacher_policy.get_topk_logits(val_data, k=topk_k) + teacher_topk_logprobs, _ = _ensure_topk_logprobs_for_non_ipc( + teacher_topk["topk_logits"] + ) + val_data["teacher_topk_logits"] = teacher_topk_logprobs + val_data["teacher_topk_indices"] = teacher_topk["topk_indices"] + del teacher_topk + if not bool(master_config["distillation"].get("keep_models_resident", False)): + teacher_policy.offload_after_refit() + + # Compute student validation loss (eval mode, no gradient updates). + # When the run uses cross-tokenizer KD (loss_fn is a + # MultiTeacherLossAggregator), the loss state lives on each + # worker as ``_cached_loss_fn`` and was populated during the + # preceding training step's update_cross_tokenizer_data fan-out. + # Pass loss_fn=None so workers reuse that cached fn instead of + # the driver-side instance (which was never given CT data). + student_policy.prepare_for_training() + val_loss_fn = ( + None + if isinstance(loss_fn, MultiTeacherLossAggregator) + else loss_fn + ) + if use_ipc: + val_results = student_policy.train_off_policy_distillation( + val_data, + teacher_logits=teacher_logits, + loss_fn=val_loss_fn, + eval_mode=True, + gbs=val_data.size, + mbs=val_mbs, + ) + del teacher_logits + else: + val_results = student_policy.train( + val_data, + loss_fn, + eval_mode=True, + gbs=val_data.size, + mbs=val_mbs, + ) + + if len(val_results["all_mb_metrics"]) == 0: + warnings.warn( + "No validation metrics were collected for this batch." + " This is likely because there were no valid samples." + ) + else: + num_valid_tokens = ( + val_data["sample_mask"].unsqueeze(-1) * val_data["token_mask"] + ).sum() + val_metrics["val_loss"] += float(val_results["loss"]) * num_valid_tokens + sum_num_valid_tokens += num_valid_tokens + + if val_batches > 0 and batch_idx >= val_batches - 1: + break + + if sum_num_valid_tokens > 0: + val_metrics["val_loss"] /= sum_num_valid_tokens + else: + warnings.warn( + "No validation metrics were collected." + " This is likely because there were no valid samples in the validation set." + ) + + student_policy.prepare_for_training() + + # Get timing metrics + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + validation_time = timing_metrics.get("total_validation_time", 0) + + if sum_num_valid_tokens > 0: + # Print summary of validation results + print("\n📊 Validation Results:") + print(f" • Validation loss: {val_metrics['val_loss']:.4f}") + + # Print timing information + print("\n ⏱️ Validation Timing:") + print(f" • Total validation time: {validation_time:.2f}s") + + # Make sure to reset the timer after validation + timer.reset() + + return val_metrics, timing_metrics + + +def off_policy_distillation_train( + student_policy: ColocatablePolicyInterface, + teacher_policies: list[ColocatablePolicyInterface], + dataloader: StatefulDataLoader, + val_dataloader: Optional[StatefulDataLoader], + tokenizer: TokenizerType, + loss_fn: DistillationLossFn, + logger: Logger, + checkpointer: CheckpointManager, + distillation_save_state: OffPolicyDistillationSaveState, + master_config: OffPolicyMasterConfig, + eval_hook: Optional[Callable] = None, + eval_hook_period: int = 0, + eval_hook_at_start: bool = False, + token_aligners: Optional[list[Any]] = None, + teacher_tokenizers: Optional[list[Optional[PreTrainedTokenizerBase]]] = None, +) -> None: + """Run off-policy distillation training algorithm. + + Key differences from on-policy distillation train(): + - No student_generation parameter (we don't generate responses) + - No task_to_env / val_task_to_env (no environment scoring) + - No rollout generation step - uses fixed responses from dataset directly + + Training loop: + 1. Load batch with prompt-response pairs (responses already in dataset) + 2. Add loss masks (train on assistant tokens only) + 3. Get teacher top-k logits for the fixed responses + 4. Train student with KL divergence loss + + Args: + eval_hook: Optional callback ``(step, student_policy, teacher_policy, logger) -> dict`` + called every *eval_hook_period* steps. Return value (if dict) is + logged under ``prefix="eval_hook"`` and used for checkpoint metric lookup. + eval_hook_period: How often (in steps) to call *eval_hook*. 0 = disabled. + eval_hook_at_start: If True, call eval_hook before the first training step. + """ + timer = Timer() + timeout = TimeoutChecker( + timeout=master_config["checkpointing"].get("checkpoint_must_save_by", None), + fit_last_save_time=True, + ) + timeout.start_iterations() + + # common config/state items + current_epoch = distillation_save_state["current_epoch"] # current epoch + current_step = distillation_save_state[ + "current_step" + ] # current step within current epoch + total_steps = distillation_save_state[ + "total_steps" + ] # total number of steps across all epochs + consumed_samples = distillation_save_state["consumed_samples"] + total_valid_tokens = distillation_save_state["total_valid_tokens"] + max_epochs = master_config["distillation"][ + "max_num_epochs" + ] # max number of epochs to train for + max_steps = master_config["distillation"][ + "max_num_steps" + ] # max number of steps to train for + + # Validation configuration + val_period = master_config["distillation"].get("val_period", 0) + val_at_start = master_config["distillation"].get("val_at_start", False) + + # Per-step model/optimizer offload control (off-policy distillation only). + # When True, skip the `offload_after_refit` calls between teacher and student + # phases. Requires that student + all teachers + student optimizer state fit + # resident on each GPU. Default False preserves the original eviction behavior. + keep_models_resident = bool( + master_config["distillation"].get("keep_models_resident", False) + ) + if keep_models_resident: + print( + "▶ keep_models_resident=True — skipping per-step model/optimizer " + "offloads in off-policy distillation loop", + flush=True, + ) + + # Run validation at the start if configured + if val_at_start and total_steps == 0: + print("\n🔍 Running initial validation...", flush=True) + val_metrics, validation_timings = validate( + student_policy, + teacher_policies, + val_dataloader, + tokenizer, + loss_fn, + step=0, + master_config=master_config, + ) + logger.log_metrics(val_metrics, total_steps, prefix="validation") + logger.log_metrics(validation_timings, total_steps, prefix="timing/validation") + + # Run eval hook at start if configured + eval_hook_metrics = None + if eval_hook and eval_hook_at_start and total_steps == 0: + print("\n🔍 Running initial eval hook...", flush=True) + eval_hook_metrics = eval_hook( + step=0, + student_policy=student_policy, + teacher_policy=teacher_policies[0], + logger=logger, + ) + if isinstance(eval_hook_metrics, dict): + logger.log_metrics(eval_hook_metrics, 0, prefix="eval_hook") + + # Run off-policy distillation training + batch: BatchedDataDict[DatumSpec] + + teacher_specs = _normalize_teacher_specs(master_config) + num_teachers = len(teacher_specs) + token_aligners = token_aligners or [None] * num_teachers + teacher_tokenizers = teacher_tokenizers or [None] * num_teachers + cross_tokenizer_enabled = any(a is not None for a in token_aligners) + + while total_steps < max_steps and current_epoch < max_epochs: + print( + f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_epochs} {'=' * 25}", + flush=True, + ) + + dataloader_iter = iter(dataloader) + while total_steps < max_steps: + try: + batch = next(dataloader_iter) + except StopIteration: + break + + print( + f"\n{'=' * 25} Step {current_step + 1}/{min(len(dataloader), max_steps)} {'=' * 25}", + flush=True, + ) + maybe_gpu_profile_step(student_policy, total_steps + 1) + val_metrics, validation_timings = None, None + + with timer.time("total_step_time"): + # ==== Data Processing ==== + # CrossTokenizerCollator in the StatefulDataLoader's worker + # processes already did the message flatten + per-teacher CT + # (teacher tokenize + DP alignment). The batch carries + # input_ids / input_lengths / token_mask / sample_mask / + # flat_messages / per_teacher_ct_data. + print("▶ Processing batch data (off-policy - using fixed responses)...", flush=True) + with timer.time("data_processing"): + flat_messages = batch["flat_messages"] + input_lengths = batch["input_lengths"] + train_data = BatchedDataDict[DistillationLossDataDict]( + { + "input_ids": batch["input_ids"], + "input_lengths": input_lengths, + "token_mask": batch["token_mask"], + "sample_mask": batch["sample_mask"], + } + ) + mm_dict = flat_messages.get_multimodal_dict(as_tensors=False) + if mm_dict: + train_data.update(mm_dict) + train_data.to("cpu") + + # ==== Teacher Logprob Inference ==== + use_ipc = bool(master_config["distillation"].get("use_ipc", True)) + topk_k = master_config["distillation"]["topk_logits_k"] + if num_teachers > 1 and not use_ipc: + raise NotImplementedError( + "Multi-teacher distillation currently requires use_ipc=True." + ) + all_teacher_logits: list[Any] = [] + per_teacher_ct_data: list[tuple[torch.Tensor, list[Any], Optional[dict[str, list]]]] = [] + + print( + f"▶ Preparing for teacher logprob inference ({num_teachers} teacher(s))...", + flush=True, + ) + if not keep_models_resident: + student_policy.offload_after_refit() + + batch_ct_data = batch.get("per_teacher_ct_data", [None] * num_teachers) + for teacher_idx, teacher_policy in enumerate(teacher_policies): + ct = batch_ct_data[teacher_idx] + if ct is not None: + teacher_data = ct["teacher_data"] + chunk_indices = None + if "student_chunk_coo" in ct: + chunk_indices = { + "student_chunk_coo": ct["student_chunk_coo"], + "teacher_chunk_coo": ct["teacher_chunk_coo"], + "num_chunks": ct["num_chunks"], + } + per_teacher_ct_data.append( + (ct["teacher_input_ids"], ct["aligned_pairs"], chunk_indices) + ) + else: + teacher_data = None + per_teacher_ct_data.append((torch.empty(0), [], None)) + + teacher_fwd_data = teacher_data if teacher_data is not None else train_data + # Always send full logits: cross-tokenizer teachers need them + # for projection, and same-tokenizer teachers need them for + # exact full-vocab KL (topk approximation inflates KL by ~30%). + teacher_topk_k = None + + teacher_policy.prepare_for_lp_inference() + if use_ipc: + with timer.time(f"teacher_{teacher_idx}_logprob_inference"): + teacher_logits = teacher_policy.compute_teacher_logits_ipc( + teacher_fwd_data, + topk_logits=teacher_topk_k, + gbs=master_config["policy"]["train_global_batch_size"], + mbs=master_config["policy"]["train_micro_batch_size"], + ) + all_teacher_logits.append(teacher_logits) + else: + if token_aligners[teacher_idx] is not None: + raise NotImplementedError( + "Cross-tokenizer distillation requires use_ipc=True. " + "Set distillation.use_ipc: true in the config." + ) + with timer.time(f"teacher_{teacher_idx}_logprob_inference"): + teacher_topk = teacher_policy.get_topk_logits(train_data, k=topk_k) + teacher_topk_logprobs, converted_to_logprobs = _ensure_topk_logprobs_for_non_ipc( + teacher_topk["topk_logits"] + ) + train_data["teacher_topk_logits"] = teacher_topk_logprobs + train_data["teacher_topk_indices"] = teacher_topk["topk_indices"] + if converted_to_logprobs and total_steps == 0 and current_step == 0: + print( + "⚠️ teacher.get_topk_logits returned raw logits in non-IPC mode; " + "normalizing with log_softmax before distillation loss.", + flush=True, + ) + del teacher_topk + all_teacher_logits.append(None) + if not keep_models_resident: + teacher_policy.offload_after_refit() + + # ==== Student Training ==== + print("▶ Preparing for training...", flush=True) + with timer.time("training_prep"): + student_policy.prepare_for_training() + if not keep_models_resident: + # offload_after_refit above moved the student optimizer + # state to CPU; prepare_for_training only re-onloads it + # for the logprob/colocated-generation gate. Restore it + # explicitly so the next optimizer.step finds tensors + # on the same device as the gradients. + student_policy.move_optimizer_to_cuda() + + if cross_tokenizer_enabled: + if not getattr(student_policy, "_loss_fn_initialized", False): + student_policy._loss_fn_initialized = True + # Unified single-/multi-teacher path: always send the + # list-shape worker spec so each worker builds a + # MultiTeacherLossAggregator (with N=1 for single + # teacher). This eliminates the diverging code path + # that previously left single-teacher metrics on GPU. + teacher_worker_specs: list[tuple[DistillationLossConfig, Optional[dict[str, Any]], float]] = [] + for teacher_idx, spec_cfg in enumerate(teacher_specs): + aligner_cfg = spec_cfg.get("token_aligner", {}) + teacher_loss_cfg = spec_cfg.get("loss_fn", master_config["loss_fn"]) + if token_aligners[teacher_idx] is None: + teacher_worker_specs.append((teacher_loss_cfg, None, spec_cfg.get("weight", 1.0))) + else: + teacher_worker_specs.append( + ( + teacher_loss_cfg, + { + "teacher_model": spec_cfg["teacher"]["model_name"], + "student_model": master_config["policy"]["model_name"], + "projection_matrix_path": aligner_cfg["projection_matrix_path"], + "use_sparse_format": aligner_cfg.get("use_sparse_format", True), + "learnable": aligner_cfg.get("learnable", False), + "max_comb_len": aligner_cfg.get("max_comb_len", 4), + "projection_matrix_multiplier": aligner_cfg.get( + "projection_matrix_multiplier", 1.0 + ), + "project_teacher_to_student": aligner_cfg.get( + "project_teacher_to_student", False + ), + }, + spec_cfg.get("weight", 1.0), + ) + ) + student_policy.init_cross_tokenizer_loss_fn( + loss_config=teacher_worker_specs, + token_aligner_config=None, + ) + for teacher_idx, (teacher_input_ids, aligned_pairs, chunk_indices) in enumerate(per_teacher_ct_data): + if teacher_input_ids.numel() == 0: + continue + # Always pass the teacher index now that every cached + # loss fn is a MultiTeacherLossAggregator (N>=1). + student_policy.update_cross_tokenizer_data( + teacher_input_ids=teacher_input_ids, + aligned_pairs=aligned_pairs, + teacher_idx=teacher_idx, + chunk_indices=chunk_indices, + ) + + student_loss_fn = None if cross_tokenizer_enabled else loss_fn + print("▶ Training policy...", flush=True) + with timer.time("policy_training"): + if use_ipc: + if num_teachers > 1: + teacher_logits_arg = _group_teacher_logits_by_rank( + all_teacher_logits + ) + else: + teacher_logits_arg = all_teacher_logits[0] + train_results = student_policy.train_off_policy_distillation( + train_data, + teacher_logits=teacher_logits_arg, + loss_fn=student_loss_fn, + ) + del all_teacher_logits + else: + train_results = student_policy.train( + train_data, + student_loss_fn, + ) + + is_last_step = (total_steps + 1 >= max_steps) or ( + (current_epoch + 1 == max_epochs) + and (current_step + 1 == len(dataloader)) + ) + + # ==== Validation ==== + if val_period > 0 and (total_steps + 1) % val_period == 0: + val_metrics, validation_timings = validate( + student_policy, + teacher_policies, + val_dataloader, + tokenizer, + loss_fn, + step=total_steps + 1, + master_config=master_config, + ) + logger.log_metrics( + validation_timings, total_steps + 1, prefix="timing/validation" + ) + logger.log_metrics( + val_metrics, total_steps + 1, prefix="validation" + ) + + # ==== Eval Hook (e.g., generation-based MATH/MMLU eval) ==== + if eval_hook and eval_hook_period > 0 and (total_steps + 1) % eval_hook_period == 0: + print(f"\n🔍 Running eval hook at step {total_steps + 1}...", flush=True) + with timer.time("eval_hook"): + eval_hook_metrics = eval_hook( + step=total_steps + 1, + student_policy=student_policy, + teacher_policy=teacher_policies[0], + logger=logger, + ) + if isinstance(eval_hook_metrics, dict): + logger.log_metrics(eval_hook_metrics, total_steps + 1, prefix="eval_hook") + student_policy.prepare_for_training() + + # ==== Metrics ==== + metrics = { + "loss": train_results["loss"].numpy(), + "grad_norm": train_results["grad_norm"].numpy(), + "mean_seq_length": batch["length"].numpy().mean(), + "total_num_tokens": input_lengths.numpy().sum(), + } + metrics.update(train_results["all_mb_metrics"]) + for k, v in metrics.items(): + if k in { + "lr", + "wd", + "global_valid_seqs", + "global_valid_toks", + "mean_seq_length", + }: + metrics[k] = np.mean(v).item() + else: + metrics[k] = np.sum(v).item() + total_valid_tokens += metrics["global_valid_toks"] + + ## Checkpointing + consumed_samples += master_config["distillation"][ + "num_prompts_per_step" + ] + timeout.mark_iteration() + + should_save_by_step = ( + is_last_step + or (total_steps + 1) % master_config["checkpointing"]["save_period"] + == 0 + ) + # Check if timeout-based checkpointing is enabled in config. + should_save_by_timeout = timeout.check_save() + + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout + ): + student_policy.prepare_for_training() + + distillation_save_state["current_epoch"] = current_epoch + distillation_save_state["current_step"] = current_step + 1 + distillation_save_state["total_steps"] = total_steps + 1 + distillation_save_state["total_valid_tokens"] = total_valid_tokens + distillation_save_state["consumed_samples"] = consumed_samples + + full_metric_name = master_config["checkpointing"]["metric_name"] + if full_metric_name is not None: + assert full_metric_name.startswith( + "train:" + ) or full_metric_name.startswith("val:"), ( + f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n" + f'followed by the corresponding name in the "val" or "train" metrics dictionary. ' + f"Example: 'train:loss' or 'val:val_loss'" + ) + prefix, metric_name = full_metric_name.split(":", 1) + metrics_source = metrics if prefix == "train" else val_metrics + if not metrics_source: + warnings.warn( + f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. " + "This checkpoint will not be saved as top-k.", + stacklevel=2, + ) + if full_metric_name in distillation_save_state: + del distillation_save_state[full_metric_name] + elif metric_name not in metrics_source: + raise ValueError( + f"Metric {metric_name} not found in {prefix} metrics" + ) + else: + distillation_save_state[full_metric_name] = metrics_source[ + metric_name + ] + + with timer.time("checkpointing"): + print( + f"Saving checkpoint for step {total_steps + 1}...", + flush=True, + ) + checkpoint_path = checkpointer.init_tmp_checkpoint( + total_steps + 1, distillation_save_state, master_config + ) + student_policy.save_checkpoint( + weights_path=os.path.join( + checkpoint_path, "policy", "weights" + ), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ), + tokenizer_path=os.path.join( + checkpoint_path, "policy", "tokenizer" + ), + checkpointing_cfg=master_config["checkpointing"], + ) + torch.save( + dataloader.state_dict(), + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(checkpoint_path) + + # Logging + # Log training data + log_data = {"content": flat_messages["content"]} + log_data["input_lengths"] = input_lengths.tolist() + logger.log_batched_dict_as_jsonl( + log_data, f"train_data_step{total_steps + 1}.jsonl" + ) + + timing_metrics: dict[str, float] = timer.get_timing_metrics( + reduction_op="sum" + ) # type: ignore + for teacher_idx in range(num_teachers): + ct_key = f"teacher_{teacher_idx}_ct_processing" + lp_key = f"teacher_{teacher_idx}_logprob_inference" + if ct_key in timing_metrics: + timing_metrics[f"teacher_{teacher_idx}/ct_processing"] = timing_metrics[ct_key] + if lp_key in timing_metrics: + timing_metrics[f"teacher_{teacher_idx}/logprob_inference"] = timing_metrics[lp_key] + loss_compute_key = f"teacher_{teacher_idx}/loss_compute" + if loss_compute_key in metrics: + timing_metrics[loss_compute_key] = float(metrics[loss_compute_key]) + + teacher_total = 0.0 + for teacher_idx in range(num_teachers): + teacher_total += timing_metrics.get(f"teacher_{teacher_idx}/ct_processing", 0.0) + teacher_total += timing_metrics.get(f"teacher_{teacher_idx}/logprob_inference", 0.0) + teacher_total += timing_metrics.get(f"teacher_{teacher_idx}/loss_compute", 0.0) + timing_metrics["multi_teacher_total"] = teacher_total + # policy_training is the only worker-side timing exposed at this layer. + timing_metrics["student_forward"] = timing_metrics.get("policy_training", 0.0) + timing_metrics["student_backward"] = timing_metrics.get("policy_training", 0.0) + + print("\n📊 Training Results:") + + print(f" • Loss: {metrics['loss']:.4f}") + print(f" • Grad Norm: {metrics['grad_norm']:.4f}") + print(f" • Mean Sequence Length: {metrics['mean_seq_length']:.1f}") + + if "total_flops" in train_results: + total_time = timing_metrics.get("total_step_time", 0) + total_tflops = ( + train_results["total_flops"] + / timing_metrics["policy_training"] + / 1e12 + ) + num_ranks = train_results["num_ranks"] + print( + f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)", + flush=True, + ) + if "theoretical_tflops" in train_results: + theoretical_tflops = train_results["theoretical_tflops"] + print( + f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%", + flush=True, + ) + metrics["train_fp_utilization"] = total_tflops / theoretical_tflops + + print("\n⏱️ Timing:", flush=True) + # Display total time first, separately + total_time = timing_metrics.get("total_step_time", 0) + + total_num_gpus = ( + master_config["cluster"]["num_nodes"] + * master_config["cluster"]["gpus_per_node"] + ) + metrics.update( + { + "tokens_per_sec_per_gpu": metrics["total_num_tokens"] + / total_time + / total_num_gpus + } + ) + + print(f" • Total step time: {total_time:.2f}s", flush=True) + + # Display all other timing metrics + for k, v in sorted( + timing_metrics.items(), key=lambda item: item[1], reverse=True + ): + if k != "total_step_time": + percent = (v / total_time * 100) if total_time > 0 else 0 + print(f" • {k}: {v:.2f}s ({percent:.1f}%)", flush=True) + + timing_metrics["valid_tokens_per_sec_per_gpu"] = ( + metrics["global_valid_toks"] / total_time / total_num_gpus + ) + logger.log_metrics(metrics, total_steps + 1, prefix="train") + logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train") + + timer.reset() + current_step += 1 + total_steps += 1 + if should_save_by_timeout: + print("Timeout has been reached, stopping training early", flush=True) + return + if total_steps >= max_steps: + print( + "Max number of steps has been reached, stopping training early", + flush=True, + ) + return + + # End of epoch + current_epoch += 1 + current_step = 0 # Reset step counter for new epoch diff --git a/nemo_rl/algorithms/x_token/__init__.py b/nemo_rl/algorithms/x_token/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/nemo_rl/algorithms/x_token/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_rl/algorithms/x_token/tokenalign.py b/nemo_rl/algorithms/x_token/tokenalign.py new file mode 100644 index 0000000000..1d55f858f7 --- /dev/null +++ b/nemo_rl/algorithms/x_token/tokenalign.py @@ -0,0 +1,2194 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoTokenizer + +try: + from numba import njit + _NUMBA_AVAILABLE = True +except ImportError: + _NUMBA_AVAILABLE = False + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +if _NUMBA_AVAILABLE: + @njit(cache=True) + def _dp_core_numba(ids1, ids2, joined1, joined2, n1, n2, + exact_match_score, gap_penalty, comb_mul, max_comb_len): + """Numba-accelerated DP core for token alignment. + + Uses the same algorithm as align_tokens_with_combinations_numpy but + with integer ID comparisons instead of Python string operations. + + Trace codes: 0=start, 1=diag, 2=up, 3=left, + 10+k = comb_s1_over_s2_k, 20+k = comb_s2_over_s1_k + """ + INVALID = np.int64(-1) + dp = np.zeros((n1 + 1, n2 + 1), dtype=np.float32) + trace = np.zeros((n1 + 1, n2 + 1), dtype=np.int32) + + for i in range(1, n1 + 1): + dp[i, 0] = dp[i - 1, 0] + gap_penalty + trace[i, 0] = 2 + for j in range(1, n2 + 1): + dp[0, j] = dp[0, j - 1] + gap_penalty + trace[0, j] = 3 + + for i in range(1, n1 + 1): + id_i = ids1[i - 1] + for j in range(1, n2 + 1): + id_j = ids2[j - 1] + + if id_i == id_j: + best = dp[i - 1, j - 1] + exact_match_score + else: + best = dp[i - 1, j - 1] - exact_match_score + best_m = np.int32(1) + + s = dp[i - 1, j] + gap_penalty + if s > best: + best = s + best_m = np.int32(2) + + s = dp[i, j - 1] + gap_penalty + if s > best: + best = s + best_m = np.int32(3) + + k_max_s2 = min(j, max_comb_len) + for k in range(2, k_max_s2 + 1): + jid = joined2[j, k] + if jid != INVALID and id_i == jid: + s = dp[i - 1, j - k] + comb_mul * np.float32(k) + if s > best: + best = s + best_m = np.int32(10 + k) + + k_max_s1 = min(i, max_comb_len) + for k in range(2, k_max_s1 + 1): + jid = joined1[i, k] + if jid != INVALID and id_j == jid: + s = dp[i - k, j - 1] + comb_mul * np.float32(k) + if s > best: + best = s + best_m = np.int32(20 + k) + + dp[i, j] = best + trace[i, j] = best_m + + return dp, trace +else: + _dp_core_numba = None + + +@dataclass(frozen=True) +class VocabPartition: + """Projection-matrix-derived vocab partition for the gold-loss path. + + Built once per (xtoken_loss, teacher_vocab_size) by + `TokenAligner.build_vocab_partition`. All tensors are long, 1-D, and + live on the aligner's projection-matrix device. + """ + + common_student_indices: torch.Tensor + common_teacher_indices: torch.Tensor + uncommon_student_indices: torch.Tensor + uncommon_teacher_indices: torch.Tensor + + +class TokenAligner(nn.Module): + def __init__(self, max_comb_len=4, teacher_tokenizer_name=None, student_tokenizer_name=None, init_hf_tokenizers=True, projection_matrix_multiplier=1.0, enable_scale_trick=None): + super().__init__() + self.teacher_tokenizer_name = teacher_tokenizer_name + self.student_tokenizer_name = student_tokenizer_name + self.projection_matrix_multiplier = projection_matrix_multiplier + self.enable_scale_trick = enable_scale_trick + + if init_hf_tokenizers: + self.teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_tokenizer_name) + self.student_tokenizer = AutoTokenizer.from_pretrained(student_tokenizer_name) + if self.teacher_tokenizer.pad_token is None: + self.teacher_tokenizer.pad_token = self.teacher_tokenizer.eos_token + if self.student_tokenizer.pad_token is None: + self.student_tokenizer.pad_token = self.student_tokenizer.eos_token + else: + self.teacher_tokenizer = None + self.student_tokenizer = None + + self.max_combination_len = max_comb_len + self.sparse_transformation_matrix = None + # Cached CSR for dense top-k projection (built from indices/values) to avoid scatter path + self._dense_proj_csr = None + self._dense_proj_csr_device = None + + # Precomputed canonical ID maps (built by precompute_canonical_maps) + self._student_canon_map = None + self._teacher_canon_map = None + self._canon_id_to_str = None + + # Cached gold-loss vocab partitions keyed by (xtoken_loss, teacher_vocab_size). + # Populated lazily by build_vocab_partition(); bypassed when learnable=True. + self._vocab_partition_cache: dict[tuple[bool, int], VocabPartition] = {} + + @torch.no_grad() + def build_vocab_partition( + self, xtoken_loss: bool, teacher_vocab_size: int + ) -> VocabPartition: + """Derive (and cache) the gold-loss vocab partition. + + Computes exactly the state that CrossTokenizerDistillationLossFn._compute_gold_loss + previously rebuilt every microbatch: the set of student/teacher tokens that have an + exact 1:1 projection (common) and the complement (uncommon). Cached by + (xtoken_loss, teacher_vocab_size) so repeat calls cost nothing. Caching is skipped + when the projection matrix is learnable, since the partition then depends on + gradient-updated values. + """ + if ( + not hasattr(self, "likelihood_projection_indices") + or self.likelihood_projection_indices is None + ): + raise ValueError( + "build_vocab_partition requires likelihood_projection_indices to be loaded" + ) + + learnable = bool(getattr(self, "learnable", False)) + key = (bool(xtoken_loss), int(teacher_vocab_size)) + if not learnable: + cached = self._vocab_partition_cache.get(key) + if cached is not None: + return cached + + projection_indices = self.likelihood_projection_indices + projection_matrix = ( + self.transform_learned_matrix_instance(self.likelihood_projection_matrix) + if learnable + else self.likelihood_projection_matrix + ) + device = projection_matrix.device + student_vocab_size = int(projection_matrix.shape[0]) + + sorted_values, sorted_indices_in_topk = torch.sort( + projection_matrix, dim=-1, descending=True + ) + + if xtoken_loss: + has_exact_map = sorted_values[:, 0] >= 0.6 + else: + has_exact_map = (sorted_values[:, 0] == 1.0) & (projection_indices[:, 1] == -1) + + student_indices_with_exact_map = torch.where(has_exact_map)[0] + teacher_indices_for_exact_map = projection_indices[ + student_indices_with_exact_map, + sorted_indices_in_topk[student_indices_with_exact_map, 0], + ] + + student_to_teacher_exact_map: dict[int, int] = {} + teacher_to_student_exact_map: dict[int, int] = {} + for s_idx, t_idx in zip( + student_indices_with_exact_map.tolist(), + teacher_indices_for_exact_map.tolist(), + ): + if 0 <= t_idx < teacher_vocab_size: + if t_idx not in teacher_to_student_exact_map or xtoken_loss: + if t_idx in teacher_to_student_exact_map: + prev_student_token = teacher_to_student_exact_map[t_idx] + if sorted_values[prev_student_token, 0] >= sorted_values[s_idx, 0]: + continue + del student_to_teacher_exact_map[prev_student_token] + student_to_teacher_exact_map[s_idx] = t_idx + teacher_to_student_exact_map[t_idx] = s_idx + + common_student = sorted(student_to_teacher_exact_map.keys()) + common_teacher = [student_to_teacher_exact_map[s] for s in common_student] + uncommon_student = sorted( + set(range(student_vocab_size)) - set(common_student) + ) + uncommon_teacher = sorted( + set(range(teacher_vocab_size)) - set(common_teacher) + ) + + partition = VocabPartition( + common_student_indices=torch.tensor( + common_student, dtype=torch.long, device=device + ), + common_teacher_indices=torch.tensor( + common_teacher, dtype=torch.long, device=device + ), + uncommon_student_indices=torch.tensor( + uncommon_student, dtype=torch.long, device=device + ), + uncommon_teacher_indices=torch.tensor( + uncommon_teacher, dtype=torch.long, device=device + ), + ) + + if not learnable: + self._vocab_partition_cache[key] = partition + return partition + + def precompute_canonical_maps(self): + """Build token_id → canonical_string lookup tables for both tokenizers. + + Call once at startup. After this, align_fast() can skip + convert_ids_to_tokens and _canonicalize_sequence entirely. + """ + import time as _time + _t0 = _time.time() + + canon_str_to_id: dict[str, int] = {} + next_id = [0] + + def _get_canon_id(s: str) -> int: + cid = canon_str_to_id.get(s) + if cid is None: + cid = next_id[0] + canon_str_to_id[s] = cid + next_id[0] += 1 + return cid + + student_vocab_size = len(self.student_tokenizer) + teacher_vocab_size = len(self.teacher_tokenizer) + + student_map = np.zeros(student_vocab_size, dtype=np.int64) + for tid in range(student_vocab_size): + tok = self.student_tokenizer.convert_ids_to_tokens(tid) + canon = self._canonical_token(tok) + student_map[tid] = _get_canon_id(canon) + + teacher_map = np.zeros(teacher_vocab_size, dtype=np.int64) + for tid in range(teacher_vocab_size): + tok = self.teacher_tokenizer.convert_ids_to_tokens(tid) + canon = self._canonical_token(tok) + teacher_map[tid] = _get_canon_id(canon) + + self._student_canon_map = student_map + self._teacher_canon_map = teacher_map + self._canon_id_to_str = {v: k for k, v in canon_str_to_id.items()} + + _t1 = _time.time() + print(f" [TokenAligner] Precomputed canonical maps in {_t1-_t0:.2f}s " + f"(student_vocab={student_vocab_size}, teacher_vocab={teacher_vocab_size}, " + f"unique_canonical={len(canon_str_to_id)})", flush=True) + + def align_fast(self, student_ids, teacher_ids, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + chunk_size=128, + post_process=True, + anchor_lengths=[3,], + ignore_leading_char_diff=False): + """Fast alignment using precomputed canonical ID maps. + + Skips convert_ids_to_tokens and _canonicalize_sequence by looking up + canonical strings directly from token IDs via precomputed numpy arrays. + Falls back to regular align() if precomputed maps are not available. + """ + if self._student_canon_map is None: + return self.align(student_ids, teacher_ids, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + chunk_size=chunk_size, + post_process=post_process, + anchor_lengths=anchor_lengths, + ignore_leading_char_diff=ignore_leading_char_diff) + + if isinstance(student_ids, torch.Tensor): + student_ids = student_ids.cpu().numpy() + if isinstance(teacher_ids, torch.Tensor): + teacher_ids = teacher_ids.cpu().numpy() + + if student_ids.ndim == 1: + student_ids = student_ids[np.newaxis, :] + teacher_ids = teacher_ids[np.newaxis, :] + + import time as _time + _t_lookup_total = 0.0 + _t_anchors_dp_total = 0.0 + _t_postprocess_total = 0.0 + _t_mask_total = 0.0 + + all_aligned_pairs = [] + for i in range(student_ids.shape[0]): + s_ids = student_ids[i] + t_ids = teacher_ids[i] + + _tl0 = _time.time() + s_canon_strs = [self._canon_id_to_str[self._student_canon_map[tid]] for tid in s_ids] + t_canon_strs = [self._canon_id_to_str[self._teacher_canon_map[tid]] for tid in t_ids] + _tl1 = _time.time() + _t_lookup_total += _tl1 - _tl0 + + align_kwargs = { + 'exact_match_score': exact_match_score, + 'combination_score_multiplier': combination_score_multiplier, + 'gap_penalty': gap_penalty, + 'max_combination_len': self.max_combination_len, + 'ignore_leading_char_diff': False, + 'chunk_size': chunk_size, + 'anchor_lengths': anchor_lengths, + } + + aligned_pairs, _ = self._align_with_anchors(s_canon_strs, t_canon_strs, **align_kwargs) + _tl2 = _time.time() + _t_anchors_dp_total += _tl2 - _tl1 + + if post_process: + aligned_pairs = self.post_process_alignment_optimized( + aligned_pairs, + ignore_leading_char_diff=ignore_leading_char_diff, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=self.max_combination_len + ) + _tl3 = _time.time() + _t_postprocess_total += _tl3 - _tl2 + + mask = self.get_alignment_mask(aligned_pairs, use_canonicalization=True, + ignore_leading_char_diff=ignore_leading_char_diff) + aligned_pairs = [ + (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, mask_value) + for (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end), mask_value + in zip(aligned_pairs, mask) + ] + _tl4 = _time.time() + _t_mask_total += _tl4 - _tl3 + + all_aligned_pairs.append(aligned_pairs) + + n = student_ids.shape[0] + _t_total = _t_lookup_total + _t_anchors_dp_total + _t_postprocess_total + _t_mask_total + if _t_total > 0.5 or n > 1: + print(f" [align_fast timing] lookup={_t_lookup_total:.3f}s, " + f"anchors+DP={_t_anchors_dp_total:.3f}s, " + f"postprocess={_t_postprocess_total:.3f}s, " + f"mask={_t_mask_total:.3f}s, " + f"total={_t_total:.3f}s (n={n})", flush=True) + + return all_aligned_pairs + + def _load_logits_projection_map( + self, + folder_location: str = "cross_tokenizer_data", + file_path: str = None, + top_k: int = 100, + device: str = "cuda", + use_sparse_format: bool = False, + learnable: bool = False, + ): + """ + Load projection map for cross-tokenizer likelihood projection. + Always creates student→teacher mapping. + + Args: + folder_location: Directory containing the projection files + file_path: Specific file path (overrides folder_location) + top_k: Number of top entries per row (only used for old format) + device: Device to load tensors on + use_sparse_format: If True, load sparse transformation matrix format (from multi-token mapping) + If False, load old dense indices/values format + learnable: If True, make the transformation matrix learnable + """ + self.learnable = learnable + if use_sparse_format: + # Load sparse transformation matrix format + if file_path is None: + file_path = f"{folder_location}/transformation_counts_via_multitoken.pt" + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Sparse transformation matrix file not found: {file_path}. Please generate it first.") + + # Load transformation counts dictionary + transformation_counts = torch.load(file_path, map_location='cpu', weights_only=False) + + # Get tokenizer vocab sizes + teacher_vocab_size = len(self.teacher_tokenizer) if self.teacher_tokenizer else 151669 # fallback + student_vocab_size = len(self.student_tokenizer) if self.student_tokenizer else 128256 # fallback + if 1: + # get vocab sizes from autoconfig + if "gemma" not in self.teacher_tokenizer_name.lower() and "qwen3.5" not in self.teacher_tokenizer_name.lower(): + teacher_vocab_size = AutoConfig.from_pretrained(self.teacher_tokenizer_name).vocab_size + else: + teacher_vocab_size = AutoConfig.from_pretrained(self.teacher_tokenizer_name).text_config.vocab_size + if "gemma" not in self.student_tokenizer_name.lower() and "qwen3.5" not in self.student_tokenizer_name.lower(): + student_vocab_size = AutoConfig.from_pretrained(self.student_tokenizer_name).vocab_size + else: + student_vocab_size = AutoConfig.from_pretrained(self.student_tokenizer_name).text_config.vocab_size + # teacher_vocab_size = AutoConfig.from_pretrained(self.teacher_tokenizer_name).vocab_size + # student_vocab_size = AutoConfig.from_pretrained(self.student_tokenizer_name).vocab_size + + + # Debug vocab sizes + print(f"Teacher vocab size: {teacher_vocab_size}, Student vocab size: {student_vocab_size}") + + # Convert dictionary to sparse tensor + if transformation_counts: + + + indices = list(transformation_counts.keys()) + values = list(transformation_counts.values()) + + student_indices = [idx[0] for idx in indices] + teacher_indices = [idx[1] for idx in indices] + + # Always create student→teacher mapping: rows = student vocab, cols = teacher vocab + indices_tensor = torch.LongTensor([student_indices, teacher_indices]) + values_tensor = torch.FloatTensor(values)/self.projection_matrix_multiplier + matrix_shape = (student_vocab_size, teacher_vocab_size) + + print(f"Creating sparse matrix: student→teacher ({student_vocab_size} x {teacher_vocab_size})") + + sparse_transformation_matrix = torch.sparse_coo_tensor( + indices_tensor, + values_tensor, + (student_vocab_size, teacher_vocab_size), # student_vocab × teacher_vocab + device=device, + dtype=torch.float32 + ) + + # Optionally make the sparse matrix learnable (values only) + if learnable: + self.sparse_transformation_matrix = nn.Parameter( + sparse_transformation_matrix.coalesce(), requires_grad=True + ) + else: + # Register as buffer for non-learnable parameters (ensures proper device handling) + self.register_buffer('sparse_transformation_matrix', + sparse_transformation_matrix.coalesce(), + persistent=True) + + # Store a flag for downstream code + self.is_sparse_learnable = learnable + print(f"Loaded sparse transformation matrix with {len(transformation_counts)} entries") + else: + # Empty transformation matrix (student→teacher) + matrix_shape = (student_vocab_size, teacher_vocab_size) + + empty_sparse = torch.sparse_coo_tensor( + torch.zeros(2, 0, dtype=torch.long), + torch.zeros(0, dtype=torch.float32), + matrix_shape, + device=device, + ) + + if learnable: + self.sparse_transformation_matrix = nn.Parameter(empty_sparse, requires_grad=True) + else: + # Register as buffer for non-learnable parameters + self.register_buffer('sparse_transformation_matrix', empty_sparse, persistent=True) + + self.is_sparse_learnable = learnable + print("Warning: Empty transformation matrix loaded") + else: + # Load old dense indices/values format + if file_path is None: + file_path = f"{folder_location}/projection_map_Llama-3.1_to_Qwen3_bidirectional_top_10.pt" + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Projection map file not found: {file_path}. Please generate it first.") + + projection_data = torch.load(file_path, map_location='cpu', weights_only=False) + # Always use B_to_A direction for student->teacher projection + # projection_data = projection_data["B_to_A"] + # projection_data = projection_data["A_to_B"] + + indices = projection_data["indices"] + likelihoods = projection_data["likelihoods"]/self.projection_matrix_multiplier + + # Register indices as buffer (always non-learnable) + self.register_buffer('likelihood_projection_indices', indices.to(device), persistent=True) + if learnable: + if 1: + likelihoods = (likelihoods+1e-10).log() + + # Use instance variable if set, otherwise use default (False) + # scale_trick_enabled = self.enable_scale_trick if self.enable_scale_trick is not None else False + + # if scale_trick_enabled: + # #trick with last column being multiplier - set to -4.0 + # likelihoods[:,-1] = likelihoods[:,-1]*0.0 - 4.0 + #lets introduce some noise to encourage training. will remove later. + if 0: + likelihoods = likelihoods + torch.randn_like(likelihoods) * 1e-1 + likelihoods = likelihoods/2.0 + + self.likelihood_projection_matrix = nn.Parameter(likelihoods.to(device), requires_grad=True) + # print(self.likelihood_projection_matrix[0]) + # print(self.likelihood_projection_matrix[:,-1]) + # exit() + #add small gaussian noise to the projection matrix + #use log form + else: + # Register as buffer for non-learnable parameters + self.register_buffer('likelihood_projection_matrix', likelihoods.to(device), persistent=True) + + + print(f"Loaded dense projection map with shape {indices.shape}") + # Invalidate cached CSR; will rebuild on first use + self._dense_proj_csr = None + self._dense_proj_csr_device = None + + def create_reverse_projection_matrix(self, device="cuda"): + """ + Create a reverse (transposed) projection matrix for teacher→student projection. + + For sparse format: Transposes the sparse_transformation_matrix from [student_vocab, teacher_vocab] + to [teacher_vocab, student_vocab] + For dense format: Builds a reverse index mapping from teacher tokens to student tokens + + This enables projecting teacher logits into student vocabulary space. + """ + if hasattr(self, 'sparse_transformation_matrix') and self.sparse_transformation_matrix is not None: + # Transpose sparse matrix + print("Creating reverse projection matrix (sparse format): teacher→student") + sparse_matrix = self.sparse_transformation_matrix.coalesce() + indices = sparse_matrix.indices() + values = sparse_matrix.values() + + # Swap student and teacher indices (transpose) + transposed_indices = torch.stack([indices[1], indices[0]], dim=0) # Swap rows: [teacher, student] + teacher_vocab_size, student_vocab_size = sparse_matrix.shape[1], sparse_matrix.shape[0] + + reverse_sparse = torch.sparse_coo_tensor( + transposed_indices, + values, + (teacher_vocab_size, student_vocab_size), + device=device, + dtype=torch.float32 + ).coalesce() + + # Store as buffer or parameter based on learnability + if self.is_sparse_learnable: + self.reverse_sparse_transformation_matrix = nn.Parameter(reverse_sparse, requires_grad=True) + else: + self.register_buffer('reverse_sparse_transformation_matrix', reverse_sparse, persistent=True) + + print(f"Created reverse sparse matrix: teacher→student ({teacher_vocab_size} x {student_vocab_size})") + print(f"Reverse matrix has {len(values)} non-zero entries") + + elif hasattr(self, 'likelihood_projection_indices') and self.likelihood_projection_indices is not None: + # Build reverse index for dense format + print("Creating reverse projection matrix (dense format): teacher→student") + + # Current: likelihood_projection_indices is [student_vocab, topk] + # We need to build: [teacher_vocab, variable_k] where variable_k depends on how many students map to each teacher token + + student_vocab_size = self.likelihood_projection_indices.shape[0] + topk = self.likelihood_projection_indices.shape[1] + + # Infer teacher vocab size from the max index + teacher_vocab_size = self.likelihood_projection_indices.max().item() + 1 + + # Build reverse mapping: for each teacher token, collect all (student_token, value) pairs + from collections import defaultdict + teacher_to_students = defaultdict(list) + + for student_idx in range(student_vocab_size): + for k in range(topk): + teacher_idx = self.likelihood_projection_indices[student_idx, k].item() + if hasattr(self, 'likelihood_projection_matrix'): + value = self.likelihood_projection_matrix[student_idx, k].item() + else: + value = 1.0 # Default value if no matrix + + # Check for valid entries: teacher_idx must be valid, and value must be finite (not -inf) + # If matrix is in log-space, valid log-probs are finite negative values + # Threshold at -20 to filter out padding values like -22.3197 + if teacher_idx >= 0 and value > -20.0: # Skip invalid or padding entries + teacher_to_students[teacher_idx].append((student_idx, value)) + + # Find max number of students mapping to any teacher token + raw_max_students = max([len(v) for v in teacher_to_students.values()]) if teacher_to_students else 1 + print(f"Max students mapping to any teacher token (before filtering): {raw_max_students}") + + # Limit to top-K students per teacher token to avoid explosion + # Keep only the top-K highest probability mappings per teacher + max_students_per_teacher = min(topk, raw_max_students) # Use same topk as forward direction + print(f"Limiting to top-{max_students_per_teacher} students per teacher token") + + # Sort each teacher's student list by value (descending) and keep only top-K + for teacher_idx in teacher_to_students: + student_list = teacher_to_students[teacher_idx] + # Sort by value (descending - higher log-prob = less negative) + student_list_sorted = sorted(student_list, key=lambda x: x[1], reverse=True) + teacher_to_students[teacher_idx] = student_list_sorted[:max_students_per_teacher] + + # Create dense reverse index [teacher_vocab, max_students_per_teacher] + # Use 0 instead of -1 for padding (valid index), with very negative values to nullify contribution + reverse_indices = torch.zeros((teacher_vocab_size, max_students_per_teacher), + dtype=torch.long, device=device) + # Initialize with very negative values (padding sentinel, similar to forward direction) + reverse_values = torch.full((teacher_vocab_size, max_students_per_teacher), -22.3197, + dtype=torch.float32, device=device) + + for teacher_idx, student_list in teacher_to_students.items(): + for k, (student_idx, value) in enumerate(student_list): + reverse_indices[teacher_idx, k] = student_idx + reverse_values[teacher_idx, k] = value + + print(f"Created reverse dense projection: teacher→student ({teacher_vocab_size} x {max_students_per_teacher})") + + # Store as buffer or parameter + self.register_buffer('reverse_likelihood_projection_indices', reverse_indices, persistent=True) + if self.learnable: + self.reverse_likelihood_projection_matrix = nn.Parameter(reverse_values, requires_grad=True) + else: + self.register_buffer('reverse_likelihood_projection_matrix', reverse_values, persistent=True) + + print(f"Created reverse dense projection: teacher→student ({teacher_vocab_size} x {max_students_per_teacher})") + else: + raise ValueError("No projection matrix loaded. Cannot create reverse projection.") + + def project_token_likelihoods_instance(self, input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device, use_sparse_format=False, sparse_matrix=None, use_vectorized=True, gpu_optimized_scatter=True, global_top_indices=None): + """ + Instance method wrapper for project_token_likelihoods that can access instance variables. + + Args: + global_top_indices: Optional tensor of shape (K,) containing indices of tokens to project to. + If provided, only projects to these K tokens instead of full target_vocab_size. + Results in (batch, seq, K) output instead of (batch, seq, target_vocab_size). + """ + if use_sparse_format: + if sparse_matrix is None: + raise ValueError("sparse_matrix must be provided when use_sparse_format=True") + + if global_top_indices is not None: + # For sparse format with global_top_indices, project to full vocab then slice + full_projection = TokenAligner.project_token_likelihoods_sparse(input_likelihoods, sparse_matrix*self.projection_matrix_multiplier, device) + return full_projection[:, :, global_top_indices] + else: + return TokenAligner.project_token_likelihoods_sparse(input_likelihoods, sparse_matrix*self.projection_matrix_multiplier, device) + else: + # If projection map is learnable, fall back to dense scatter path to preserve gradients + if getattr(projection_map_values, "requires_grad", False): + scale_trick_enabled = self.enable_scale_trick if self.enable_scale_trick is not None else False + return TokenAligner.project_token_likelihoods_dense( + input_likelihoods, + projection_map_indices, + projection_map_values * self.projection_matrix_multiplier, + target_vocab_size, + device, + use_vectorized=True, + gpu_optimized_scatter=gpu_optimized_scatter, + enable_scale_trick=scale_trick_enabled, + global_top_indices=global_top_indices, + ) + + # Otherwise, use stateless CSR matmul (no caching) for memory efficiency + vs = projection_map_indices.shape[0] + top_k = projection_map_indices.shape[1] + # Ensure device/dtype for indices/values + idx = projection_map_indices.to(device) + val = (projection_map_values * self.projection_matrix_multiplier).to(device) + if val.dtype != input_likelihoods.dtype: + val = val.to(input_likelihoods.dtype) + # Build CSR once per call outside autograd to keep checkpoint recomputation identical + with torch.no_grad(): + crow_indices = torch.arange(0, (vs + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = idx.reshape(-1) + values = val.reshape(-1) + # _exact_map_remapped matrices use -1 as a padding sentinel for missing + # entries. A -1 column index is illegal in a CSR tensor and causes a + # CUDA illegal-memory-access during the subsequent matmul. Clamp those + # entries to column 0 and zero their value so they contribute nothing. + pad_mask = col_indices < 0 + if pad_mask.any(): + col_indices = col_indices.clone() + col_indices[pad_mask] = 0 + values = values.clone() + values[pad_mask] = 0.0 + proj_csr = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(vs, target_vocab_size), device=device + ) + # Matmul: [B, S, Vs] -> [B*S, Vs] @ [Vs, Vt] -> [B*S, Vt] -> [B, S, Vt] + bsz, seqlen, vs_in = input_likelihoods.shape + if vs_in != vs: + # In case logits have extra vocab tail, slice to match + x = input_likelihoods[:, :, :vs] + else: + x = input_likelihoods + x2d = x.reshape(bsz * seqlen, vs) + out2d = torch.matmul(x2d.to(torch.float32), proj_csr.to(torch.float32)) + out = out2d.reshape(bsz, seqlen, target_vocab_size).to(input_likelihoods.dtype) + return out + + @staticmethod + def project_token_likelihoods_dense(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device, use_vectorized=True, gpu_optimized_scatter=True, enable_scale_trick=None, global_top_indices=None): + """ + Projects token likelihoods from a source to a target vocabulary using dense indices/values format. + + Args: + global_top_indices: Optional tensor of shape (K,) containing indices of target tokens to project to. + If provided, only projects to these K tokens instead of full target_vocab_size. + Results in (batch, seq, K) output instead of (batch, seq, target_vocab_size). + MAJOR SPEEDUP: Reduces both memory and compute significantly. + """ + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + if abs(source_vocab_size - projection_map_indices.shape[0]) > 1000: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches projection map size ({projection_map_indices.shape[0]})") + + top_k = projection_map_indices.shape[1] + input_likelihoods = input_likelihoods.to(device) + if projection_map_indices.device != device: + projection_map_indices = projection_map_indices.to(device) + if projection_map_values.device != device: + projection_map_values = projection_map_values.to(device) + #do for dtype + if projection_map_values.dtype != input_likelihoods.dtype: + projection_map_values = projection_map_values.to(input_likelihoods.dtype) + + # else: + # projection_map_values = projection_map_values.to(device) + + if use_vectorized: + # Solution 1: Efficient dense implementation using vectorized operations for small top_k + source_vocab_size_fixed = projection_map_indices.shape[0] + input_likelihoods_fixed = input_likelihoods[:, :, :source_vocab_size_fixed] + + # OPTIMIZATION: Use reduced vocabulary if global_top_indices provided + if global_top_indices is not None: + k_indices = len(global_top_indices) + global_top_indices = global_top_indices.to(device) + + # Create mapping from full target indices to reduced indices [0, 1, 2, ..., k-1] + full_to_reduced_map = torch.full((target_vocab_size,), -1, device=device, dtype=torch.long) + full_to_reduced_map[global_top_indices] = torch.arange(k_indices, device=device) + + # Initialize smaller output tensor - MAJOR MEMORY SAVINGS + projected_likelihoods = torch.zeros(batch_size, seq_len, k_indices, + device=device, dtype=input_likelihoods.dtype) + effective_vocab_size = k_indices + + # Filter projection matrices to only include mappings to global_top_indices + # This will be used in the scatter operations below + use_reduced_projection = True + else: + # Initialize full output tensor + projected_likelihoods = torch.zeros(batch_size, seq_len, target_vocab_size, + device=device, dtype=input_likelihoods.dtype) + effective_vocab_size = target_vocab_size + use_reduced_projection = False + + # Optimized chunked processing with multiple speedup techniques + # Use larger chunks for better amortization of fixed costs + max_memory_mb = 200 # Increased for better performance + # max_memory_mb = 500 # Increased for better performance + elements_per_chunk = max_memory_mb * 1024 * 1024 // 4 # 4 bytes per float32 + chunk_size = max(512, min(source_vocab_size_fixed, elements_per_chunk // (batch_size * seq_len))) + + + use_masking = False + # Process vocabulary in optimized chunks + for chunk_start in range(0, source_vocab_size_fixed, chunk_size): + chunk_end = min(chunk_start + chunk_size, source_vocab_size_fixed) + chunk_len = chunk_end - chunk_start + + + input_chunk = input_likelihoods_fixed[:, :, chunk_start:chunk_end] # (B, S, chunk_len) + indices_chunk = projection_map_indices[chunk_start:chunk_end, :] # (chunk_len, top_k) + values_chunk = projection_map_values[chunk_start:chunk_end, :] # (chunk_len, top_k) + + # Extract input chunk once per chunk (not per k) - major speedup + # Determine effective top_k (exclude last column if scale trick is enabled) + scale_trick_enabled = enable_scale_trick if enable_scale_trick is not None else False + effective_top_k = top_k - 1 if scale_trick_enabled else top_k + # effective_top_k = 1 + + if gpu_optimized_scatter: + if use_masking: + # Process one k at a time to reduce peak memory usage + for k in range(effective_top_k): + values_k = values_chunk[:, k] + valid_mask_k = values_k > 1e-4 + if not valid_mask_k.any(): + continue + + source_indices_k = torch.nonzero(valid_mask_k, as_tuple=True)[0] + + input_subset_k = input_chunk[:, :, source_indices_k] + values_subset_k = values_k[source_indices_k] + + indices_k = indices_chunk[:, k] + target_indices_subset_k = indices_k[source_indices_k] + + weighted_inputs_k = input_subset_k * values_subset_k.view(1, 1, -1) + expanded_target_indices_k = target_indices_subset_k.view(1, 1, -1).expand(batch_size, seq_len, -1) + + projected_likelihoods.scatter_add_(2, expanded_target_indices_k, weighted_inputs_k) + else: + # Compact, un-masked implementation + # Process only effective columns without creating intermediate tensors + input_expanded = input_chunk.unsqueeze(-1) # (B, S, chunk_len, 1) + + for k in range(effective_top_k): + values_k = values_chunk[:, k:k+1] # (chunk_len, 1) - view, no copy + indices_k = indices_chunk[:, k] # (chunk_len,) + + if use_reduced_projection: + # OPTIMIZATION: Only project to indices in global_top_indices + # Map full indices to reduced indices and filter out invalid ones + reduced_indices_k = full_to_reduced_map[indices_k] # (chunk_len,) + valid_mask = reduced_indices_k != -1 # Only keep indices in global_top_indices + + if not valid_mask.any(): + continue # Skip if no valid indices in this chunk + + # Filter to only valid entries - MAJOR COMPUTE SAVINGS + valid_indices = torch.nonzero(valid_mask, as_tuple=True)[0] + reduced_indices_filtered = reduced_indices_k[valid_indices] + values_filtered = values_k.squeeze(-1)[valid_indices] # (valid_count,) + input_filtered = input_chunk[:, :, valid_indices] # (B, S, valid_count) + + weighted_k = input_filtered * values_filtered.unsqueeze(0).unsqueeze(0) + indices_expanded = reduced_indices_filtered.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, -1) + projected_likelihoods.scatter_add_(2, indices_expanded, weighted_k) + else: + # Standard full projection + weighted_k = input_expanded * values_k.unsqueeze(0).unsqueeze(0) # (B, S, chunk_len, 1) + weighted_k = weighted_k.squeeze(-1) # (B, S, chunk_len) + + indices_expanded = indices_k.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, -1) + projected_likelihoods.scatter_add_(2, indices_expanded, weighted_k) + else: + # Original implementation with a loop over top_k + if True: # For small top_k, process all k together + # Broadcast input: (B, S, chunk_len, 1) * (1, 1, chunk_len, top_k) -> (B, S, chunk_len, top_k) + weighted_inputs = input_chunk.unsqueeze(-1) * values_chunk.unsqueeze(0).unsqueeze(0) + + # Process all k simultaneously using advanced indexing + for k in range(effective_top_k): + target_indices_k = indices_chunk[:, k] # (chunk_len,) + weighted_k = weighted_inputs[:, :, :, k] # (B, S, chunk_len) + + if use_reduced_projection: + # OPTIMIZATION: Only project to indices in global_top_indices + reduced_indices_k = full_to_reduced_map[target_indices_k] # (chunk_len,) + valid_mask = reduced_indices_k != -1 + + if not valid_mask.any(): + continue # Skip if no valid indices + + # Filter to only valid entries - MAJOR COMPUTE SAVINGS + valid_indices = torch.nonzero(valid_mask, as_tuple=True)[0] + reduced_indices_filtered = reduced_indices_k[valid_indices] + weighted_filtered = weighted_k[:, :, valid_indices] # (B, S, valid_count) + + target_expanded = reduced_indices_filtered.view(1, 1, -1).expand(batch_size, seq_len, len(valid_indices)) + projected_likelihoods.scatter_add_(2, target_expanded, weighted_filtered) + else: + # Use optimized scatter with pre-expanded indices (avoid .expand() in loop) + target_expanded = target_indices_k.view(1, 1, -1).expand(batch_size, seq_len, chunk_len) + projected_likelihoods.scatter_add_(2, target_expanded, weighted_k) + + # else: # For larger top_k, use optimized sequential processing + # for k in range(top_k): + # target_indices_k = indices_chunk[:, k] # (chunk_len,) + # target_values_k = values_chunk[:, k] # (chunk_len,) + + # # Skip projections marked with -1 + # valid_mask = target_values_k > -0.00001 + # if not valid_mask.any(): + # continue + + # # Only process valid projections + # valid_target_indices = target_indices_k[valid_mask] + # valid_target_values = target_values_k[valid_mask] + # valid_input = input_chunk[valid_mask] + + # weighted_input = valid_input * valid_target_values.view(-1, 1, 1) + + # # Direct scatter (simpler and often faster than index caching) + # target_expanded = valid_target_indices.view(1, 1, -1).expand(batch_size, seq_len, valid_target_indices.size(0)) + # projected_likelihoods.scatter_add_(2, target_expanded, weighted_input) + + return projected_likelihoods + else: + # Solution 2: Sparse matrix approach (original implementation) + source_vocab_size_fixed = projection_map_indices.shape[0] + + # Create sparse CSR matrix + crow_indices = torch.arange(0, (source_vocab_size_fixed + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = projection_map_indices.flatten() + values = projection_map_values.flatten() + + sparse_projection_matrix = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(source_vocab_size_fixed, target_vocab_size), device=device + ) + + # Apply sparse matrix multiplication + input_likelihoods_fixed = input_likelihoods[:, :, :source_vocab_size_fixed] + reshaped_input = input_likelihoods_fixed.reshape(batch_size * seq_len, source_vocab_size) + + projected_likelihoods_reshaped = torch.matmul(reshaped_input.to(torch.float32), sparse_projection_matrix.to(torch.float32)) + + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, target_vocab_size).to(input_likelihoods.dtype) + + @staticmethod + def project_token_likelihoods_sparse(input_likelihoods, sparse_matrix, device): + """Projects token likelihoods using a sparse transformation matrix.""" + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + + # Get dimensions from sparse matrix + matrix_input_size, matrix_output_size = sparse_matrix.shape + + if abs(source_vocab_size - matrix_input_size) > 1000: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches sparse matrix input size ({matrix_input_size})") + + # Move to correct device and dtype + # input_likelihoods = input_likelihoods.to(device) + # sparse_matrix = sparse_matrix.to(device) + + # Adjust input size to match matrix dimensions + # next 2 lines required when we used vocab length from tokenizer, now we use the size of logits + # source_vocab_size_fixed = min(source_vocab_size, matrix_input_size) + # input_likelihoods_fixed = input_likelihoods[:, :, :source_vocab_size_fixed] + input_likelihoods_fixed = input_likelihoods + + # Reshape for matrix multiplication + reshaped_input = input_likelihoods_fixed.reshape(batch_size * seq_len, source_vocab_size) + + # Project using sparse matrix multiplication + projected_likelihoods_reshaped = torch.matmul(reshaped_input.to(torch.float32), sparse_matrix.to(torch.float32)) + + # Reshape back to original format + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, matrix_output_size).to(input_likelihoods.dtype) + + def align(self, student_seq: Union[List[str], List[List[str]], List[int], List[List[int]]], + teacher_seq: Union[List[str], List[List[str]], List[int], List[List[int]]], + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + ignore_leading_char_diff=False, + chunk_size=128, + post_process=True, + convert_ids_to_tokens=True, + anchor_lengths=[3,], + _debug_timing=False): + """Align two sequences of tokens (or batches of sequences).""" + import time as _time + + seq1 = student_seq + seq2 = teacher_seq + + _t_convert = 0.0 + if isinstance(seq1, torch.Tensor): + seq1 = seq1.cpu().tolist() + seq2 = seq2.cpu().tolist() + if convert_ids_to_tokens: + _tc0 = _time.time() + seq1 = [self.student_tokenizer.convert_ids_to_tokens(seq1_single) for seq1_single in seq1] + seq2 = [self.teacher_tokenizer.convert_ids_to_tokens(seq2_single) for seq2_single in seq2] + _t_convert = _time.time() - _tc0 + + is_batched = isinstance(seq1, list) and len(seq1) > 0 and isinstance(seq1[0], list) + + _t_canon_total = 0.0 + _t_anchors_dp_total = 0.0 + _t_postprocess_total = 0.0 + _t_mask_total = 0.0 + + if is_batched: + if not (isinstance(seq2, list) and len(seq2) == len(seq1) and (len(seq2) == 0 or isinstance(seq2[0], list))): + raise ValueError("For batched input, seq1 and seq2 must be lists of lists with the same length.") + + all_aligned_pairs = [] + for s1, s2 in zip(seq1, seq2): + aligned_pairs, timings = self._align_single(s1, s2, exact_match_score, combination_score_multiplier, gap_penalty, ignore_leading_char_diff, chunk_size, post_process, anchor_lengths, _return_timings=True) + all_aligned_pairs.append(aligned_pairs) + _t_canon_total += timings.get("canon", 0) + _t_anchors_dp_total += timings.get("anchors_dp", 0) + _t_postprocess_total += timings.get("postprocess", 0) + _t_mask_total += timings.get("mask", 0) + else: + aligned_pairs, timings = self._align_single(seq1, seq2, exact_match_score, combination_score_multiplier, gap_penalty, ignore_leading_char_diff, chunk_size, post_process, anchor_lengths, _return_timings=True) + all_aligned_pairs = [aligned_pairs] + _t_canon_total += timings.get("canon", 0) + _t_anchors_dp_total += timings.get("anchors_dp", 0) + _t_postprocess_total += timings.get("postprocess", 0) + _t_mask_total += timings.get("mask", 0) + + if _debug_timing: + n = len(all_aligned_pairs) + print(f" [align timing] convert_ids={_t_convert:.3f}s, " + f"canonicalize={_t_canon_total:.3f}s, " + f"anchors+DP={_t_anchors_dp_total:.3f}s, " + f"postprocess={_t_postprocess_total:.3f}s, " + f"mask={_t_mask_total:.3f}s " + f"(n={n})", flush=True) + + return all_aligned_pairs + + def _align_single(self, seq1, seq2, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + ignore_leading_char_diff=True, + chunk_size=0, + post_process=True, + anchor_lengths=None, + _return_timings=False): + """Align two sequences of tokens.""" + import time as _time + + _tc0 = _time.time() + seq1_canon = TokenAligner._canonicalize_sequence(seq1) + seq2_canon = TokenAligner._canonicalize_sequence(seq2) + _tc1 = _time.time() + + align_kwargs = { + 'exact_match_score': exact_match_score, + 'combination_score_multiplier': combination_score_multiplier, + 'gap_penalty': gap_penalty, + 'max_combination_len': self.max_combination_len, + 'ignore_leading_char_diff': False, + 'chunk_size': chunk_size, + 'anchor_lengths': anchor_lengths, + } + + aligned_pairs, _ = self._align_with_anchors(seq1_canon, seq2_canon, **align_kwargs) + _tc2 = _time.time() + + if post_process: + aligned_pairs = self.post_process_alignment_optimized( + aligned_pairs, + ignore_leading_char_diff=ignore_leading_char_diff, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=self.max_combination_len + ) + _tc3 = _time.time() + + mask = self.get_alignment_mask(aligned_pairs, use_canonicalization=True, ignore_leading_char_diff=ignore_leading_char_diff) + aligned_pairs = [ + (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, mask_value) + for (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end), mask_value in zip(aligned_pairs, mask) + ] + _tc4 = _time.time() + + timings = { + "canon": _tc1 - _tc0, + "anchors_dp": _tc2 - _tc1, + "postprocess": _tc3 - _tc2, + "mask": _tc4 - _tc3, + } + + if _return_timings: + return aligned_pairs, timings + return aligned_pairs + + + def _align_with_anchors(self, seq1, seq2, anchor_lengths=[3,], **kwargs): + """ + Optimized alignment using unique 1-to-1 matches as anchors. + """ + # CRITICAL FIX: If anchor_lengths is empty, disable anchor optimization completely + if not anchor_lengths: + return self._perform_dp_alignment(seq1, seq2, **kwargs) + + if anchor_lengths is None: + anchor_lengths = [3, 2] # Default: check 3-token, then 2-token sequences + + # Debug output + debug = kwargs.get('debug', False) + + # 1. Find high-confidence anchor points using unique token matches. + s1_counts = {} + for i, t in enumerate(seq1): + if t not in s1_counts: s1_counts[t] = [] + s1_counts[t].append(i) + + s2_counts = {} + for i, t in enumerate(seq2): + if t not in s2_counts: s2_counts[t] = [] + s2_counts[t].append(i) + + # Find potential anchors using consecutive token sequences + potential_anchors = [] + + # FIXED: Don't break early - collect anchors from all lengths and then choose the best + all_potential_anchors = [] + + # Check for anchors of different lengths + for anchor_len in anchor_lengths: + anchors_for_this_len = [] + + if anchor_len == 1: + # Handle single token anchors + common_tokens = s1_counts.keys() & s2_counts.keys() + for token in common_tokens: + if len(s1_counts[token]) == 1 and len(s2_counts[token]) == 1: + i = s1_counts[token][0] + j = s2_counts[token][0] + anchors_for_this_len.append((i, j, anchor_len)) + else: + # Handle multi-token anchors + s1_ngram_counts = {} + for i in range(len(seq1) - anchor_len + 1): + ngram = tuple(seq1[i:i + anchor_len]) + if ngram not in s1_ngram_counts: + s1_ngram_counts[ngram] = [] + s1_ngram_counts[ngram].append(i) + + s2_ngram_counts = {} + for i in range(len(seq2) - anchor_len + 1): + ngram = tuple(seq2[i:i + anchor_len]) + if ngram not in s2_ngram_counts: + s2_ngram_counts[ngram] = [] + s2_ngram_counts[ngram].append(i) + + # Find n-grams that appear exactly once in both sequences + common_ngrams = s1_ngram_counts.keys() & s2_ngram_counts.keys() + for ngram in common_ngrams: + if len(s1_ngram_counts[ngram]) == 1 and len(s2_ngram_counts[ngram]) == 1: + i = s1_ngram_counts[ngram][0] + j = s2_ngram_counts[ngram][0] + # ADDED: Verify the anchor is actually correct + if (i + anchor_len <= len(seq1) and j + anchor_len <= len(seq2) and + seq1[i:i + anchor_len] == seq2[j:j + anchor_len]): + anchors_for_this_len.append((i, j, anchor_len)) + + all_potential_anchors.extend(anchors_for_this_len) + + # IMPROVED: Choose the best set of anchors + # Prefer longer anchors, but if shorter anchors give better coverage, use them + + # Sort by position and filter for monotonic ordering + all_potential_anchors.sort() + + # IMPROVED: Better anchor selection - use greedy approach to maximize coverage + selected_anchors = [] + used_positions_seq1 = set() + used_positions_seq2 = set() + + # Sort by anchor length (descending) then by position + all_potential_anchors.sort(key=lambda x: (-x[2], x[0], x[1])) + + for i, j, anchor_len in all_potential_anchors: + # Check if this anchor conflicts with already selected ones + seq1_range = set(range(i, i + anchor_len)) + seq2_range = set(range(j, j + anchor_len)) + + if not (seq1_range & used_positions_seq1) and not (seq2_range & used_positions_seq2): + # This anchor doesn't conflict - we can use it + selected_anchors.append((i, j, anchor_len)) + used_positions_seq1.update(seq1_range) + used_positions_seq2.update(seq2_range) + + # Re-sort selected anchors by position for processing + selected_anchors.sort() + + # IMPROVED: Additional validation of selected anchors + validated_anchors = [] + last_j = -1 + for i, j, anchor_len in selected_anchors: + # Ensure monotonic ordering and no overlaps + if j > last_j: + # Double-check the anchor is valid + if (i + anchor_len <= len(seq1) and j + anchor_len <= len(seq2) and + seq1[i:i + anchor_len] == seq2[j:j + anchor_len]): + validated_anchors.append((i, j, anchor_len)) + last_j = j + anchor_len - 1 + + anchors = validated_anchors + + if not anchors: + # If no anchors are found, fall back to the standard alignment. + return self._perform_dp_alignment(seq1, seq2, **kwargs) + + # 2. Align segments between anchors. + full_alignment = [] + last_i, last_j = 0, 0 + + for anchor_idx, (i, j, anchor_len) in enumerate(anchors): + + # Align segment before the current anchor. + seg1, seg2 = seq1[last_i:i], seq2[last_j:j] + + if seg1 or seg2: + aligned_segment, _ = self._perform_dp_alignment(seg1, seg2, **kwargs) + + # Adjust indices to be relative to the full sequence and split exact matches. + for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end in aligned_segment: + new_s1_start = s1_start + last_i if s1_start != -1 else -1 + new_s1_end = s1_end + last_i if s1_end != -1 else -1 + new_s2_start = s2_start + last_j if s2_start != -1 else -1 + new_s2_end = s2_end + last_j if s2_end != -1 else -1 + + # Split if both sides have the same tokens + if (len(s1_toks) > 1 and len(s2_toks) > 1 and + len(s1_toks) == len(s2_toks) and s1_toks == s2_toks): + # Split into individual 1-to-1 matches + for k in range(len(s1_toks)): + full_alignment.append(( + [s1_toks[k]], [s2_toks[k]], + new_s1_start + k, new_s1_start + k + 1, + new_s2_start + k, new_s2_start + k + 1 + )) + else: + full_alignment.append((s1_toks, s2_toks, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + + # Add the anchor itself (consecutive tokens), also split if needed. + anchor_seq1 = seq1[i:i + anchor_len] + anchor_seq2 = seq2[j:j + anchor_len] + + # Split anchor into individual matches since they should be identical + for k in range(anchor_len): + full_alignment.append(( + [anchor_seq1[k]], [anchor_seq2[k]], + i + k, i + k + 1, + j + k, j + k + 1 + )) + + last_i, last_j = i + anchor_len, j + anchor_len + + # 3. Align the final segment after the last anchor. + seg1, seg2 = seq1[last_i:], seq2[last_j:] + + if seg1 or seg2: + aligned_segment, _ = self._perform_dp_alignment(seg1, seg2, **kwargs) + + for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end in aligned_segment: + new_s1_start = s1_start + last_i if s1_start != -1 else -1 + new_s1_end = s1_end + last_i if s1_end != -1 else -1 + new_s2_start = s2_start + last_j if s2_start != -1 else -1 + new_s2_end = s2_end + last_j if s2_end != -1 else -1 + + # Split if both sides have the same tokens + if (len(s1_toks) > 1 and len(s2_toks) > 1 and + len(s1_toks) == len(s2_toks) and s1_toks == s2_toks): + # Split into individual 1-to-1 matches + for k in range(len(s1_toks)): + full_alignment.append(( + [s1_toks[k]], [s2_toks[k]], + new_s1_start + k, new_s1_start + k + 1, + new_s2_start + k, new_s2_start + k + 1 + )) + else: + full_alignment.append((s1_toks, s2_toks, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + + return full_alignment, 0 # Return 0 for score as it's not well-defined here + + def _perform_dp_alignment(self, seq1, seq2, **kwargs): + """ + Helper function to run the core DP-based alignment. + """ + chunk_size = kwargs.get('chunk_size', 0) + kwargs.pop('chunk_size', None) + kwargs.pop('anchor_lengths', None) + + if chunk_size > 0: + return self.align_tokens_combinations_chunked(seq1, seq2, chunk_size=chunk_size, **kwargs) + else: + return self.align_tokens_with_combinations_numpy_jit(seq1, seq2, **kwargs) + + @staticmethod + def _canonical_token(token: str) -> str: + """Return a canonical representation of a tokenizer token.""" + if not token: + return token + + # 1. Normalize space prefixes first + if token.startswith(' '): + token = 'Ġ' + token[1:] + elif token.startswith('_'): + token = 'Ġ' + token[1:] + elif token.startswith('▁'): # SentencePiece-style space prefix + token = 'Ġ' + token[1:] + + # 1.5. Normalize newline and whitespace representations + if token == 'Ċ': # GPT-style newline (used by Llama) + token = '\n' + elif token == '\\n': # Escaped newline representation + token = '\n' + elif token == 'ĉ': # Alternative newline representation + token = '\n' + elif token == 'Ġ\n': # Space + newline combination + token = '\n' + elif 'Ċ' in token: # Handle Ċ embedded in other tokens + token = token.replace('Ċ', '\n') + elif '\\n' in token: # Handle escaped newlines in compound tokens + token = token.replace('\\n', '\n') + + # 1.6. Handle space-separated punctuation normalization + if token == 'Ġ,': # Space + comma + token = ',' + elif token == 'Ġ.': # Space + period + token = '.' + elif token == 'Ġ;': # Space + semicolon + token = ';' + elif token == 'Ġ:': # Space + colon + token = ':' + + # 2. Handle SentencePiece byte fallback tokens like <0x20> + if token.startswith('<0x') and token.endswith('>') and len(token) == 6: + try: + byte_val = int(token[3:5], 16) + if 0 <= byte_val <= 255: + return chr(byte_val) + except ValueError: + pass + + # 3. Normalize common Unicode encoding issues + unicode_fixes = { + # Spanish + 'ñ': 'ñ', 'á': 'á', 'é': 'é', 'í': 'í', 'ó': 'ó', 'ú': 'ú', + 'Ã': 'À', 'â': 'â', 'ç': 'ç', + # French + 'ç': 'ç', 'è': 'è', 'é': 'é', 'ë': 'ë', 'î': 'î', 'ô': 'ô', + 'ù': 'ù', 'û': 'û', 'ÿ': 'ÿ', + # Chinese (common encoding artifacts) + 'ä¸Ń': '中', 'æĸĩ': '文', 'æĹ¥æľ¬': '日本', 'èªŀ': '語', + # Russian + 'ÐłÑĥÑģ': 'Рус', 'Ñģкий': 'ский', + # Arabic + 'اÙĦعربÙĬØ©': 'العربية', + # Hindi + 'ह': 'ह', 'िà¤Ĥ': 'हिं', 'दà¥Ģ': 'दी', + # Mathematical symbols (common artifacts) + 'âĪij': '∑', 'âĪı': '∏', 'âĪĤ': '∂', 'âĪĩ': '∇', + 'âĪŀ': '∞', 'âĪļ': '√', 'âĪ«': '∫', 'âīĪ': '≈', + 'âīł': '≠', 'âī¤': '≤', 'âī¥': '≥', + } + + # Apply Unicode fixes + for broken, fixed in unicode_fixes.items(): + if broken in token: + token = token.replace(broken, fixed) + + # 4. Normalize special tokens + special_token_map = { + '<|begin_of_text|>': '', # Llama-style BOS token + '': '', # Standard BOS token + '': '', # Padding tokens → empty (will be handled by alignment) + '': ' ', # End tokens + '': ' ', # End tokens + } + + if token in special_token_map: + return special_token_map[token] + + return token + + @staticmethod + def _canonicalize_sequence(seq: List[str]) -> List[str]: + """Canonicalize every token in a sequence (list of str).""" + # First, handle multi-token encoding artifacts (before individual canonicalization) + merged_artifacts = TokenAligner._merge_encoding_artifacts(seq) + + # Then, canonicalize individual tokens + canon_tokens = [TokenAligner._canonical_token(tok) for tok in merged_artifacts] + + # Finally, merge consecutive byte tokens into proper Unicode characters + return TokenAligner._merge_consecutive_bytes(canon_tokens) + + @staticmethod + def _merge_encoding_artifacts(tokens: List[str]) -> List[str]: + """Merge consecutive tokens that represent multi-token encoding artifacts.""" + if not tokens: + return tokens + + # Common multi-token encoding artifacts that should be merged + multi_token_fixes = [ + # Mathematical symbols split across tokens + (['ĠâĪ', 'ij'], ['Ġ∑']), # Sum symbol + (['âĪ', 'ij'], ['∑']), # Sum symbol (no space) + (['ĠâĪ', 'ı'], ['Ġ∏']), # Product symbol + (['âĪ', 'ı'], ['∏']), # Product symbol (no space) + (['ĠâĪ', 'Ĥ'], ['Ġ∂']), # Partial derivative + (['âĪ', 'Ĥ'], ['∂']), # Partial derivative (no space) + (['ĠâĪ', 'ĩ'], ['Ġ∇']), # Nabla/gradient + (['âĪ', 'ĩ'], ['∇']), # Nabla/gradient (no space) + (['ĠâĪ', 'ŀ'], ['Ġ∞']), # Infinity + (['âĪ', 'ŀ'], ['∞']), # Infinity (no space) + (['ĠâĪ', 'ļ'], ['Ġ√']), # Square root + (['âĪ', 'ļ'], ['√']), # Square root (no space) + (['ĠâĪ', '«'], ['Ġ∫']), # Integral + (['âĪ', '«'], ['∫']), # Integral (no space) + (['Ġâī', 'ł'], ['Ġ≠']), # Not equal + (['âī', 'ł'], ['≠']), # Not equal (no space) + # Other common multi-token artifacts + (['Ġä¸', 'Ń'], ['Ġ中']), # Chinese character + (['ä¸', 'Ń'], ['中']), # Chinese character (no space) + (['æĸ', 'ĩ'], ['文']), # Chinese character + (['Ġæĸ', 'ĩ'], ['Ġ文']), # Chinese character (with space) + ] + + result = [] + i = 0 + + while i < len(tokens): + # Check if current position matches any multi-token pattern + matched = False + + for pattern, replacement in multi_token_fixes: + pattern_len = len(pattern) + if i + pattern_len <= len(tokens): + # Check if the tokens match the pattern + if tokens[i:i+pattern_len] == pattern: + # Replace with the fixed version + result.extend(replacement) + i += pattern_len + matched = True + break + + if not matched: + # No pattern matched, keep the original token + result.append(tokens[i]) + i += 1 + + return result + + @staticmethod + def _merge_consecutive_bytes(tokens: List[str]) -> List[str]: + """Merge consecutive tokens that represent UTF-8 byte sequences.""" + if not tokens: + return tokens + + result = [] + byte_buffer = [] + + for token in tokens: + # Check if this token represents byte(s) + clean_token = token.lstrip('Ġ') + + # Check if all characters in the token are visual bytes + all_chars_are_bytes = True + if len(clean_token) == 0: + all_chars_are_bytes = False + else: + for char in clean_token: + if TokenAligner._get_byte_value(char) is None: + all_chars_are_bytes = False + break + + if all_chars_are_bytes: + byte_buffer.append(token) + else: + # Not a byte token, flush buffer first + if byte_buffer: + merged = TokenAligner._try_merge_byte_buffer(byte_buffer) + result.extend(merged) + byte_buffer = [] + result.append(token) + + # Flush any remaining bytes + if byte_buffer: + merged = TokenAligner._try_merge_byte_buffer(byte_buffer) + result.extend(merged) + + return result + + @staticmethod + def _try_merge_byte_buffer(byte_tokens: List[str]) -> List[str]: + """Try to merge a buffer of potential byte tokens into a Unicode character.""" + if not byte_tokens: + return [] + + # If only one token, just return it unless it's a multi-character byte token + if len(byte_tokens) == 1: + token = byte_tokens[0] + clean_token = token.lstrip('Ġ') + if len(clean_token) <= 1: + return byte_tokens + # Continue processing multi-character token + + # Extract space prefix from first token + first_token = byte_tokens[0] + space_prefix = 'Ġ' if first_token.startswith('Ġ') else '' + + # Extract raw bytes from all characters in all tokens + raw_bytes = [] + for token in byte_tokens: + clean_token = token.lstrip('Ġ') + for char in clean_token: + byte_value = TokenAligner._get_byte_value(char) + if byte_value is not None: + raw_bytes.append(byte_value) + else: + # If any character is not a byte, return original tokens + return byte_tokens + + # Only try to merge if we have 2-4 bytes (typical for emoji/multi-byte chars) + if len(raw_bytes) < 2 or len(raw_bytes) > 4: + return byte_tokens + + # Try to decode as UTF-8 + try: + decoded_text = bytes(raw_bytes).decode('utf-8') + # Only merge if the result is a single Unicode character (like an emoji) + if len(decoded_text) == 1 and ord(decoded_text) > 127: + return [space_prefix + decoded_text] + else: + # If it's not a single special character, keep original tokens + return byte_tokens + except UnicodeDecodeError: + # If decoding fails, return original tokens + return byte_tokens + + # Common visual byte representations used by some tokenizers (especially for emojis) + VISUAL_BYTE_MAP = { + # Common emoji byte range (240-255) + 'ð': 240, 'Ɩ': 241, 'Ɨ': 242, 'Ƙ': 243, 'ƙ': 244, 'ƚ': 245, 'ƛ': 246, 'Ɯ': 247, + 'Ɲ': 248, 'ƞ': 249, 'Ɵ': 250, 'Ơ': 251, 'ơ': 252, 'Ƣ': 253, 'ƣ': 254, 'Ƥ': 255, + # Other common byte representations (0-255 only) + 'Ł': 156, 'ł': 157, 'Ń': 158, 'ń': 159, 'ĺ': 149, 'Ļ': 150, 'ļ': 151, 'Ľ': 152, + 'ľ': 153, 'Ŀ': 154, 'ŀ': 155, 'Ĭ': 135, 'ĭ': 136, 'Į': 137, 'į': 138, 'İ': 139, + 'ı': 140, 'IJ': 141, 'ij': 142, 'Ĵ': 143, 'ĵ': 144, 'Ķ': 145, 'ķ': 146, 'ĸ': 147, + 'Ĺ': 148, 'ĥ': 128, 'Ħ': 129, 'ħ': 130, 'Ĩ': 131, 'ĩ': 132, 'Ī': 133, 'ī': 134, + 'Ģ': 162, 'ģ': 163, 'Ĝ': 28, 'ĝ': 29, 'Ğ': 30, 'ğ': 31, + } + + @staticmethod + def _get_byte_value(token_char: str) -> int: + """Get the byte value for a character, handling both direct bytes and visual representations.""" + if len(token_char) != 1: + return None + + char_ord = ord(token_char) + + # Direct byte (0-255) + if char_ord < 256: + return char_ord + + # Visual byte representation + if token_char in TokenAligner.VISUAL_BYTE_MAP: + return TokenAligner.VISUAL_BYTE_MAP[token_char] + + return None + + @staticmethod + def _strings_equal_flexible(s1, s2, ignore_leading_char_diff): + if not ignore_leading_char_diff: + return s1 == s2 + + # Use our comprehensive canonicalization for robust comparison + s1_canonical = TokenAligner._canonical_token(s1) + s2_canonical = TokenAligner._canonical_token(s2) + + return s1_canonical == s2_canonical + + def align_tokens_with_combinations_numpy(seq1, seq2, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + max_combination_len=4, + ignore_leading_char_diff=False): + n1, n2 = len(seq1), len(seq2) + dp = np.zeros((n1 + 1, n2 + 1), dtype=np.float32) + trace = np.full((n1 + 1, n2 + 1), '', dtype=object) + + # Initialize DP edges with gap penalties + for i in range(1, n1 + 1): + dp[i, 0] = dp[i - 1, 0] + gap_penalty + trace[i, 0] = 'up' + for j in range(1, n2 + 1): + dp[0, j] = dp[0, j - 1] + gap_penalty + trace[0, j] = 'left' + + # Precompute joined substrings for all valid k-length spans + joined_seq1 = {(i - k, i): ''.join(seq1[i - k:i]) + for i in range(n1 + 1) + for k in range(1, min(i, max_combination_len) + 1)} + joined_seq2 = {(j - k, j): ''.join(seq2[j - k:j]) + for j in range(n2 + 1) + for k in range(1, min(j, max_combination_len) + 1)} + + # Fill DP table + for i in range(1, n1 + 1): + for j in range(1, n2 + 1): + s1_val, s2_val = seq1[i - 1], seq2[j - 1] + match_score = exact_match_score if TokenAligner._strings_equal_flexible(s1_val, s2_val, ignore_leading_char_diff) else -exact_match_score + score_diag = dp[i - 1, j - 1] + match_score + score_up = dp[i - 1, j] + gap_penalty + score_left = dp[i, j - 1] + gap_penalty + + max_score = score_diag + best_move = 'diag' + if score_up > max_score: + max_score = score_up + best_move = 'up' + if score_left > max_score: + max_score = score_left + best_move = 'left' + + # Check for seq1[i-1] == join(seq2[j-k:j]) + for k in range(2, min(j + 1, max_combination_len + 1)): + if (j - k, j) in joined_seq2 and TokenAligner._strings_equal_flexible(s1_val, joined_seq2[(j - k, j)], ignore_leading_char_diff): + comb_score = dp[i - 1, j - k] + combination_score_multiplier * k + if comb_score > max_score: + max_score = comb_score + best_move = f'comb_s1_over_s2_{k}' + + # Check for seq2[j-1] vs seq1[i-k:i] + for k in range(2, min(i + 1, max_combination_len + 1)): + if (i - k, i) in joined_seq1 and TokenAligner._strings_equal_flexible(s2_val, joined_seq1[(i - k, i)], ignore_leading_char_diff): + comb_score = dp[i - k, j - 1] + combination_score_multiplier * k + if comb_score > max_score: + max_score = comb_score + best_move = f'comb_s2_over_s1_{k}' + + dp[i, j] = max_score + trace[i, j] = best_move + + # Backtrack to extract alignment + aligned = [] + i, j = n1, n2 + while i > 0 or j > 0: + move = trace[i, j] + if move == 'diag': + aligned.append(([seq1[i - 1]], [seq2[j - 1]], i - 1, i, j - 1, j)) + i -= 1 + j -= 1 + elif move == 'up': + aligned.append(([seq1[i - 1]], [], i - 1, i, -1, -1)) + i -= 1 + elif move == 'left': + aligned.append(([], [seq2[j - 1]], -1, -1, j - 1, j)) + j -= 1 + elif move.startswith('comb_s1_over_s2_'): + k = int(move.rsplit('_', 1)[-1]) + aligned.append(([seq1[i - 1]], seq2[j - k:j], i - 1, i, j - k, j)) + i -= 1 + j -= k + elif move.startswith('comb_s2_over_s1_'): + k = int(move.rsplit('_', 1)[-1]) + aligned.append((seq1[i - k:i], [seq2[j - 1]], i - k, i, j - 1, j)) + i -= k + j -= 1 + else: + break + + aligned.reverse() + return aligned, dp[n1, n2] + + @staticmethod + def align_tokens_with_combinations_numpy_jit( + seq1, seq2, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + max_combination_len=4, + ignore_leading_char_diff=False, + ): + """Numba-accelerated version of align_tokens_with_combinations_numpy. + + Pre-converts string tokens to integer IDs, runs the DP in a Numba + @njit kernel, then backtracks using the original string tokens. + Falls back to the pure-Python original when Numba is unavailable or + when ignore_leading_char_diff is True (requires Python string logic). + """ + if not _NUMBA_AVAILABLE or ignore_leading_char_diff: + return TokenAligner.align_tokens_with_combinations_numpy( + seq1, seq2, exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff, + ) + + n1, n2 = len(seq1), len(seq2) + if n1 == 0 and n2 == 0: + return [], 0.0 + if n1 == 0: + return [([], [seq2[j]], -1, -1, j, j + 1) for j in range(n2)], n2 * gap_penalty + if n2 == 0: + return [([seq1[i]], [], i, i + 1, -1, -1) for i in range(n1)], n1 * gap_penalty + + token_to_id: dict[str, int] = {} + _next_id = [0] + + def _get_id(s: str) -> int: + tid = token_to_id.get(s) + if tid is None: + tid = _next_id[0] + token_to_id[s] = tid + _next_id[0] += 1 + return tid + + ids1 = np.array([_get_id(t) for t in seq1], dtype=np.int64) + ids2 = np.array([_get_id(t) for t in seq2], dtype=np.int64) + + INVALID = np.int64(-1) + joined1 = np.full((n1 + 1, max_combination_len + 1), INVALID, dtype=np.int64) + for i in range(n1 + 1): + for k in range(2, min(i, max_combination_len) + 1): + joined1[i, k] = _get_id(''.join(seq1[i - k:i])) + + joined2 = np.full((n2 + 1, max_combination_len + 1), INVALID, dtype=np.int64) + for j in range(n2 + 1): + for k in range(2, min(j, max_combination_len) + 1): + joined2[j, k] = _get_id(''.join(seq2[j - k:j])) + + dp, trace = _dp_core_numba( + ids1, ids2, joined1, joined2, n1, n2, + np.float32(exact_match_score), + np.float32(gap_penalty), + np.float32(combination_score_multiplier), + max_combination_len, + ) + + aligned = [] + i, j = n1, n2 + while i > 0 or j > 0: + m = trace[i, j] + if m == 1: + aligned.append(([seq1[i - 1]], [seq2[j - 1]], i - 1, i, j - 1, j)) + i -= 1 + j -= 1 + elif m == 2: + aligned.append(([seq1[i - 1]], [], i - 1, i, -1, -1)) + i -= 1 + elif m == 3: + aligned.append(([], [seq2[j - 1]], -1, -1, j - 1, j)) + j -= 1 + elif 10 <= m < 20: + k = m - 10 + aligned.append(([seq1[i - 1]], seq2[j - k:j], i - 1, i, j - k, j)) + i -= 1 + j -= k + elif 20 <= m < 30: + k = m - 20 + aligned.append((seq1[i - k:i], [seq2[j - 1]], i - k, i, j - 1, j)) + i -= k + j -= 1 + else: + break + + aligned.reverse() + return aligned, float(dp[n1, n2]) + + @staticmethod + def align_tokens_combinations_chunked( + seq1: List[str], + seq2: List[str], + exact_match_score: float = 3.0, + combination_score_multiplier: float = 1.5, + gap_penalty: float = -1.5, + max_combination_len: int = 4, + ignore_leading_char_diff: bool = False, + chunk_size: int = 256, + ): + """ + Chunked processing for very large sequences. + """ + n1, n2 = len(seq1), len(seq2) + + # If sequences are small enough, use regular algorithm + if n1 <= chunk_size and n2 <= chunk_size: + return TokenAligner.align_tokens_with_combinations_numpy_jit( + seq1, seq2, exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff + ) + + # For very large sequences, use divide-and-conquer approach + if n1 > chunk_size or n2 > chunk_size: + # Find approximate midpoint alignment using simplified algorithm + mid1, mid2 = n1 // 2, n2 // 2 + + # Recursively align left and right parts + left_aligned, left_score = TokenAligner.align_tokens_combinations_chunked( + seq1[:mid1], seq2[:mid2], exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff, + chunk_size=chunk_size + ) + + right_aligned, right_score = TokenAligner.align_tokens_combinations_chunked( + seq1[mid1:], seq2[mid2:], exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff, + chunk_size=chunk_size + ) + + # Adjust indices for right part + adjusted_right = [] + for s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end in right_aligned: + new_s1_start = s1_start + mid1 if s1_start >= 0 else -1 + new_s1_end = s1_end + mid1 if s1_end >= 0 else -1 + new_s2_start = s2_start + mid2 if s2_start >= 0 else -1 + new_s2_end = s2_end + mid2 if s2_end >= 0 else -1 + adjusted_right.append((s1_tokens, s2_tokens, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + + # Combine results + combined_aligned = left_aligned + adjusted_right + combined_score = left_score + right_score + + return combined_aligned, combined_score + + # Fallback to regular algorithm + return TokenAligner.align_tokens_with_combinations_numpy_jit( + seq1, seq2, exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff + ) + + @staticmethod + def _combine_consecutive_misaligned_tokens( + aligned_pairs: List, + pair_strings: List, + end_mismatch_threshold: float = 0.2 + ) -> List: + """ + Combine consecutive misaligned tokens into single chunks to improve alignment. + + This addresses cases where multiple tokens are individually misaligned but + collectively represent the same content. Avoids combining tokens near the + end of sequences that might be misaligned due to length differences. + + Args: + aligned_pairs: List of alignment pairs + pair_strings: Precomputed string representations and match status + end_mismatch_threshold: Fraction of sequence from end to avoid chunking + + Returns: + Modified aligned_pairs with consecutive misaligned tokens combined + """ + if not aligned_pairs or len(aligned_pairs) < 2: + return aligned_pairs + + # Calculate the boundary for avoiding end mismatches + sequence_length = len(aligned_pairs) + end_boundary = int(sequence_length * (1 - end_mismatch_threshold)) + + processed_pairs = [] + i = 0 + + while i < len(aligned_pairs): + # Check if current pair is misaligned and not near the end + if (i < end_boundary and + not pair_strings[i][2] and # Current pair is misaligned + i + 1 < len(aligned_pairs)): # Not the last pair + + # Find consecutive misaligned pairs + consecutive_misaligned = [i] + j = i + 1 + + # Look ahead for more consecutive misaligned pairs (up to end boundary) + while (j < end_boundary and + j < len(aligned_pairs) and + not pair_strings[j][2]): # Next pair is also misaligned + consecutive_misaligned.append(j) + j += 1 + + # Only combine if we have multiple consecutive misaligned pairs + if len(consecutive_misaligned) >= 2: + # Combine all consecutive misaligned pairs into one chunk + combined_s1_tokens = [] + combined_s2_tokens = [] + s1_indices = [] + s2_indices = [] + + for idx in consecutive_misaligned: + s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, *rest = aligned_pairs[idx] + combined_s1_tokens.extend(s1_tokens) + combined_s2_tokens.extend(s2_tokens) + + if s1_tokens and s1_start != -1: + s1_indices.extend([s1_start, s1_end]) + if s2_tokens and s2_start != -1: + s2_indices.extend([s2_start, s2_end]) + + # Calculate combined indices + combined_s1_start = min(s1_indices[::2]) if s1_indices else -1 + combined_s1_end = max(s1_indices[1::2]) if s1_indices else -1 + combined_s2_start = min(s2_indices[::2]) if s2_indices else -1 + combined_s2_end = max(s2_indices[1::2]) if s2_indices else -1 + + # Create combined pair + combined_pair = ( + combined_s1_tokens, + combined_s2_tokens, + combined_s1_start, + combined_s1_end, + combined_s2_start, + combined_s2_end + ) + + processed_pairs.append(combined_pair) + i = j # Skip to after the combined region + else: + # Only one misaligned pair, keep as is + processed_pairs.append(aligned_pairs[i]) + i += 1 + else: + # Current pair is aligned or near the end, keep as is + processed_pairs.append(aligned_pairs[i]) + i += 1 + + return processed_pairs + + + @staticmethod + def post_process_alignment_optimized( + aligned_pairs: List, + ignore_leading_char_diff: bool = False, + exact_match_score: float = 3.0, + combination_score_multiplier: float = 1.5, + gap_penalty: float = -1.5, + max_combination_len: int = 4, + combine_misaligned_chunks: bool = True, + end_mismatch_threshold: float = 0.2 + ) -> List: + """ + Optimized version of post_process_alignment with better performance. + + Key optimizations: + 1. Precompute string concatenations to avoid repeated joins + 2. Early termination when no bad regions are found + 3. Cache alignment results for repeated chunk patterns + 4. Vectorized index calculations + 5. Reduced nested loop complexity + 6. Combine multiple consecutive misaligned tokens into single chunks + + Args: + combine_misaligned_chunks: If True, combine consecutive misaligned tokens into chunks + end_mismatch_threshold: Fraction of sequence length from end to avoid chunking (0.2 = last 20%) + """ + if not aligned_pairs: + return [] + + # Precompute joined strings for all pairs to avoid repeated concatenation + # Use canonicalization for robust comparison + pair_strings = [] + for i, (s1_tokens, s2_tokens, *rest) in enumerate(aligned_pairs): + # Canonicalize individual tokens before joining for better matching + s1_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s1_tokens] if s1_tokens else [] + s2_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s2_tokens] if s2_tokens else [] + s1_str = "".join(s1_canonical_tokens) + s2_str = "".join(s2_canonical_tokens) + is_match = TokenAligner._strings_equal_flexible(s1_str, s2_str, ignore_leading_char_diff) + pair_strings.append((s1_str, s2_str, is_match)) + + # Step 1: Handle consecutive misaligned chunks if enabled + if combine_misaligned_chunks: + aligned_pairs = TokenAligner._combine_consecutive_misaligned_tokens( + aligned_pairs, pair_strings, end_mismatch_threshold + ) + + # Recompute pair_strings after combining misaligned chunks + pair_strings = [] + for i, (s1_tokens, s2_tokens, *rest) in enumerate(aligned_pairs): + s1_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s1_tokens] if s1_tokens else [] + s2_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s2_tokens] if s2_tokens else [] + s1_str = "".join(s1_canonical_tokens) + s2_str = "".join(s2_canonical_tokens) + is_match = TokenAligner._strings_equal_flexible(s1_str, s2_str, ignore_leading_char_diff) + pair_strings.append((s1_str, s2_str, is_match)) + + processed_pairs = [] + alignment_cache = {} # Cache for repeated alignment patterns + i = 0 + + while i < len(aligned_pairs): + s1_tokens, s2_tokens, *_ = aligned_pairs[i] + + # Handle coarse alignments that can be split (optimized) + if len(s1_tokens) > 1 and len(s1_tokens) == len(s2_tokens) and s1_tokens == s2_tokens: + s1_start, s1_end, s2_start, s2_end = aligned_pairs[i][2:6] + # Vectorized creation of split pairs + for k in range(len(s1_tokens)): + processed_pairs.append( + ([s1_tokens[k]], [s2_tokens[k]], + s1_start + k, s1_start + k + 1, + s2_start + k, s2_start + k + 1) + ) + i += 1 + continue + + # Find bad regions more efficiently using precomputed strings + start_bad_region = -1 + for j in range(i, len(aligned_pairs)): + if not pair_strings[j][2]: # is_match is False + start_bad_region = j + break + + if start_bad_region == -1: + # No more bad regions - add remaining pairs and exit + processed_pairs.extend(aligned_pairs[i:]) + break + + # Add good pairs before bad region + processed_pairs.extend(aligned_pairs[i:start_bad_region]) + + # Optimized chunk processing with early termination + found_fix = False + max_chunk_size = min(10, len(aligned_pairs) - start_bad_region) # Limit search space + + for chunk_size in range(2, max_chunk_size + 1): + chunk = aligned_pairs[start_bad_region : start_bad_region + chunk_size] + + # Efficient token extraction using list comprehension + chunk_s1_tokens = [] + chunk_s2_tokens = [] + s1_indices = [] + s2_indices = [] + + for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end, *rest in chunk: + chunk_s1_tokens.extend(s1_toks) + chunk_s2_tokens.extend(s2_toks) + if s1_toks: + s1_indices.extend([s1_start, s1_end]) + if s2_toks: + s2_indices.extend([s2_start, s2_end]) + + # Quick string comparison using canonicalization + chunk_s1_canonical = [TokenAligner._canonical_token(tok) for tok in chunk_s1_tokens] + chunk_s2_canonical = [TokenAligner._canonical_token(tok) for tok in chunk_s2_tokens] + chunk_s1_str = "".join(chunk_s1_canonical) + chunk_s2_str = "".join(chunk_s2_canonical) + + if not TokenAligner._strings_equal_flexible(chunk_s1_str, chunk_s2_str, ignore_leading_char_diff): + continue + + # Create cache key for alignment + cache_key = (tuple(chunk_s1_tokens), tuple(chunk_s2_tokens)) + + if cache_key in alignment_cache: + sub_aligned_pairs, realign_is_perfect = alignment_cache[cache_key] + else: + # Perform alignment + sub_aligned_pairs, _ = TokenAligner.align_tokens_with_combinations_numpy( + chunk_s1_tokens, + chunk_s2_tokens, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=max_combination_len, + ignore_leading_char_diff=ignore_leading_char_diff + ) + + # Check if re-alignment was successful using canonicalization + realign_is_perfect = all( + TokenAligner._strings_equal_flexible( + "".join([TokenAligner._canonical_token(tok) for tok in p[0]]), + "".join([TokenAligner._canonical_token(tok) for tok in p[1]]), + ignore_leading_char_diff + ) + for p in sub_aligned_pairs + ) + + # Cache the result + alignment_cache[cache_key] = (sub_aligned_pairs, realign_is_perfect) + + # Vectorized index calculations + s1_chunk_start = min(s1_indices[::2]) if s1_indices else -1 + s2_chunk_start = min(s2_indices[::2]) if s2_indices else -1 + + if realign_is_perfect: + # Add granular aligned pairs + for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end, *_ in sub_aligned_pairs: + new_s1_start = s1_chunk_start + s1_start if s1_start != -1 else -1 + new_s1_end = s1_chunk_start + s1_end if s1_end != -1 else -1 + new_s2_start = s2_chunk_start + s2_start if s2_start != -1 else -1 + new_s2_end = s2_chunk_start + s2_end if s2_end != -1 else -1 + processed_pairs.append((s1_toks, s2_toks, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + else: + # Create merged pair + s1_chunk_end = max(s1_indices[1::2]) if s1_indices else -1 + s2_chunk_end = max(s2_indices[1::2]) if s2_indices else -1 + merged_pair = (chunk_s1_tokens, chunk_s2_tokens, s1_chunk_start, s1_chunk_end, s2_chunk_start, s2_chunk_end) + processed_pairs.append(merged_pair) + + i = start_bad_region + chunk_size + found_fix = True + break + + if not found_fix: + processed_pairs.append(aligned_pairs[start_bad_region]) + i = start_bad_region + 1 + + return processed_pairs + + @staticmethod + def get_alignment_mask(aligned_pairs: List, use_canonicalization: bool = True, + ignore_leading_char_diff: bool = False) -> List[bool]: + """ + Get a boolean mask indicating which alignments are correct. + """ + if not aligned_pairs: + return [] + + # Handle batch case - take first batch + if isinstance(aligned_pairs, list) and aligned_pairs and isinstance(aligned_pairs[0], list): + pairs_to_verify = aligned_pairs[0] + else: + pairs_to_verify = aligned_pairs + + mask = [] + for s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, *rest in pairs_to_verify: + # Concatenate tokens into strings + s1_str = "".join(s1_tokens) if s1_tokens else "" + s2_str = "".join(s2_tokens) if s2_tokens else "" + + # Apply canonicalization if requested + if use_canonicalization: + s1_canonical = "".join([TokenAligner._canonical_token(tok) for tok in s1_tokens]) if s1_tokens else "" + s2_canonical = "".join([TokenAligner._canonical_token(tok) for tok in s2_tokens]) if s2_tokens else "" + is_correct = TokenAligner._strings_equal_flexible(s1_canonical, s2_canonical, ignore_leading_char_diff) + else: + if ignore_leading_char_diff: + is_correct = TokenAligner._strings_equal_flexible(s1_str, s2_str, ignore_leading_char_diff) + else: + is_correct = s1_str == s2_str + + mask.append(is_correct) + + return mask + + + def transform_learned_matrix_instance(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Instance method version that uses instance variables. + """ + scale_trick_enabled = self.enable_scale_trick if self.enable_scale_trick is not None else False + return TokenAligner.transform_learned_matrix(x, dim, enable_scale_trick=scale_trick_enabled) + + @staticmethod + def transform_learned_matrix(x: torch.Tensor, dim: int = -1, enable_scale_trick=None) -> torch.Tensor: + """ + Compute Quite Attention over tensor x along specified dimension. + + Args: + x: Input tensor. + dim: Dimension to apply attention over (default: -1). + + Returns: + Tensor of same shape with quite attention applied. + """ + if 0: + exp_x = torch.exp(x) + denom = 1 + torch.sum(exp_x, dim=dim, keepdim=True) + return exp_x / denom + # write as a single lambda function + # return lambda x: torch.exp(x) / (1 + torch.sum(torch.exp(x), dim=dim, keepdim=True)) + else: + scale_trick_enabled = enable_scale_trick if enable_scale_trick is not None else False + if scale_trick_enabled: + #trick with last column being multiplier of 0..1, or try with c instead of 1 in qa. + scores = torch.nn.functional.softmax(x, dim=dim) + # Create a mask to zero out the last column while preserving gradients + # mask = torch.ones_like(scores) + # mask[:, -1] = 0.0 + # scores = scores * mask + # Alternative approach using sigmoid (commented out): + # scores = scores * torch.sigmoid(x[:, -1].unsqueeze(-1)) + return scores + else: + #normal softmax + return torch.nn.functional.softmax(x, dim=dim) + return torch.nn.functional.softmax(x, dim=dim) diff --git a/nemo_rl/data/__init__.py b/nemo_rl/data/__init__.py index e77e97a2b0..fbced245f7 100644 --- a/nemo_rl/data/__init__.py +++ b/nemo_rl/data/__init__.py @@ -57,6 +57,10 @@ class DataConfig(TypedDict): # This saturates CPU threads without consuming too much memory # However, setting it too high might cause memory issues for long seqlens. num_workers: NotRequired[int] + # PyTorch DataLoader prefetch_factor: number of batches each worker pre-fetches. + # Combined with num_workers, controls how much cross-tokenizer alignment work + # is run ahead of the training step in dataloader processes. + prefetch_factor: NotRequired[int] # multiple dataloader configs # currently only supported for GRPO use_multiple_dataloader: NotRequired[bool] diff --git a/nemo_rl/data/cross_tokenizer_collate.py b/nemo_rl/data/cross_tokenizer_collate.py new file mode 100644 index 0000000000..a658712b5e --- /dev/null +++ b/nemo_rl/data/cross_tokenizer_collate.py @@ -0,0 +1,352 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cross-tokenizer collate function for off-policy distillation. + +Moves teacher tokenize + DP alignment off the training critical path and into +``StatefulDataLoader`` worker processes. With ``num_workers=N, prefetch_factor=P`` +there are up to ``N*P`` batches of CT work in flight, so the consumer pulls +already-processed batches and CT is hidden behind teacher forward. + +Mirrors the train_distillation_ddp / TokenizeAndAlignCollator shape in +tokenalign/src/pytorch_data_loader.py. +""" + +from __future__ import annotations + +import sys +from typing import Any, Optional + +if sys.version_info >= (3, 11): + from typing import NotRequired, TypedDict +else: + from typing import TypedDict + from typing_extensions import NotRequired + +import torch +from transformers import AutoTokenizer + +from nemo_rl.data.collate_fn import rl_collate_fn +from nemo_rl.data.interfaces import DatumSpec +from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message +from nemo_rl.distributed.batched_data_dict import BatchedDataDict + + +class TeacherCTSpec(TypedDict): + """Per-teacher spec passed to CrossTokenizerCollator. + + All fields are pickle-cheap primitives so the collator itself ships + cheaply to DataLoader workers. Tokenizers and aligners are built lazily + in each worker. + """ + + teacher_tokenizer_name: str + student_tokenizer_name: str + projection_matrix_path: str + use_sparse_format: bool + learnable: bool + max_comb_len: int + projection_matrix_multiplier: float + project_teacher_to_student: bool + max_teacher_len: int + dp_chunk_size: int + use_align_fast: bool + exact_token_match_only: bool + + +def _build_chunk_coo( + aligned_pairs: list, + student_seq_len: int, + teacher_seq_len: int, + exact_match_only: bool, +) -> tuple[torch.Tensor, torch.Tensor, int]: + """Pre-filter alignment pairs and emit per-sample chunk mask indices. + + Applies the exact_match_only, ``-1`` sentinel, and padded-length bounds + filters that the loss fn used to run per-microbatch, then flattens each + surviving chunk's student/teacher spans into COO rows ``[pos, chunk_id]`` + that the loss fn can ``index_put_`` into the dense ``proj_mask``/``tgt_mask``. + + Returns ``(student_chunk_coo, teacher_chunk_coo, num_chunks)``. COO tensors + are empty ``(0, 2)`` when a sample has no surviving chunks. + """ + student_rows: list[tuple[int, int]] = [] + teacher_rows: list[tuple[int, int]] = [] + chunk_id = 0 + for pair in aligned_pairs: + s1_start, s1_end, s2_start, s2_end = pair[2], pair[3], pair[4], pair[5] + if exact_match_only and ( + s1_end - s1_start != 1 or s2_end - s2_start != 1 + ): + continue + if s1_start == -1 or s2_start == -1: + continue + if s1_end > student_seq_len or s2_end > teacher_seq_len: + continue + for pos in range(s1_start, s1_end): + student_rows.append((pos, chunk_id)) + for pos in range(s2_start, s2_end): + teacher_rows.append((pos, chunk_id)) + chunk_id += 1 + + student_coo = ( + torch.tensor(student_rows, dtype=torch.int64) + if student_rows + else torch.empty((0, 2), dtype=torch.int64) + ) + teacher_coo = ( + torch.tensor(teacher_rows, dtype=torch.int64) + if teacher_rows + else torch.empty((0, 2), dtype=torch.int64) + ) + return student_coo, teacher_coo, chunk_id + + +class CrossTokenizerCollator: + """Collator that does base collate + message flatten + per-teacher CT. + + Designed to be pickled into ``torch.utils.data.DataLoader`` worker + processes. On first call inside a worker, it constructs its own + ``TokenAligner`` instances and teacher tokenizers from the specs, then + reuses them for every subsequent batch in that worker. + + DP-only: does not run the char-offset alignment fast path (consumers + that want char-offset should restore it here). + """ + + def __init__( + self, + pad_token_id: int, + make_sequence_length_divisible_by: int, + teacher_ct_specs: list[Optional[TeacherCTSpec]], + fallback_student_tokenizer_name: Optional[str] = None, + ) -> None: + self.pad_token_id = pad_token_id + self.make_seq_div_by = int(make_sequence_length_divisible_by) + self.teacher_ct_specs = list(teacher_ct_specs) + self.fallback_student_tokenizer_name = fallback_student_tokenizer_name + + # Lazy per-worker state — excluded from __getstate__ below. + self._initialized: bool = False + self._aligners: list[Optional[Any]] = [] + self._teacher_tokenizers: list[Optional[Any]] = [] + self._student_tokenizer: Optional[Any] = None + + def __getstate__(self) -> dict[str, Any]: + # Only ship pickle-cheap primitives across the fork/spawn boundary. + return { + "pad_token_id": self.pad_token_id, + "make_seq_div_by": self.make_seq_div_by, + "teacher_ct_specs": self.teacher_ct_specs, + "fallback_student_tokenizer_name": self.fallback_student_tokenizer_name, + } + + def __setstate__(self, state: dict[str, Any]) -> None: + self.pad_token_id = state["pad_token_id"] + self.make_seq_div_by = state["make_seq_div_by"] + self.teacher_ct_specs = state["teacher_ct_specs"] + self.fallback_student_tokenizer_name = state["fallback_student_tokenizer_name"] + self._initialized = False + self._aligners = [] + self._teacher_tokenizers = [] + self._student_tokenizer = None + + def _lazy_init(self) -> None: + if self._initialized: + return + + # Import TokenAligner lazily so module import stays cheap and so + # workers that don't need CT never touch x_token. + from nemo_rl.algorithms.x_token.tokenalign import TokenAligner + + for spec in self.teacher_ct_specs: + if spec is None: + self._aligners.append(None) + self._teacher_tokenizers.append(None) + continue + + teacher_tokenizer = AutoTokenizer.from_pretrained( + spec["teacher_tokenizer_name"] + ) + if teacher_tokenizer.pad_token is None: + teacher_tokenizer.pad_token = teacher_tokenizer.eos_token + + aligner = TokenAligner( + teacher_tokenizer_name=spec["teacher_tokenizer_name"], + student_tokenizer_name=spec["student_tokenizer_name"], + max_comb_len=int(spec["max_comb_len"]), + projection_matrix_multiplier=float( + spec["projection_matrix_multiplier"] + ), + ) + aligner._load_logits_projection_map( + file_path=spec["projection_matrix_path"], + use_sparse_format=bool(spec["use_sparse_format"]), + learnable=bool(spec["learnable"]), + device="cpu", + ) + if bool(spec["project_teacher_to_student"]): + aligner.create_reverse_projection_matrix(device="cpu") + aligner.precompute_canonical_maps() + + self._teacher_tokenizers.append(teacher_tokenizer) + self._aligners.append(aligner) + + self._initialized = True + + def _get_student_tokenizer(self) -> Any: + if self._student_tokenizer is not None: + return self._student_tokenizer + name = self.fallback_student_tokenizer_name + if name is None: + # Best-effort: reuse any CT spec's student name. + for spec in self.teacher_ct_specs: + if spec is not None: + name = spec["student_tokenizer_name"] + break + if name is None: + raise RuntimeError( + "CrossTokenizerCollator needs a student tokenizer for the decode " + "fallback, but no name was provided and no CT spec supplied one." + ) + self._student_tokenizer = AutoTokenizer.from_pretrained(name) + if self._student_tokenizer.pad_token is None: + self._student_tokenizer.pad_token = self._student_tokenizer.eos_token + return self._student_tokenizer + + def __call__(self, data_batch: list[DatumSpec]) -> BatchedDataDict[Any]: + self._lazy_init() + + base = rl_collate_fn(data_batch) + + # --- Message-flatten (ported from _prepare_train_batch_data) --- + for message_log in base["message_log"]: + for m in message_log: + if "token_loss_mask" not in m: + m["token_loss_mask"] = ( + torch.ones_like(m["token_ids"]) + if m["role"] == "assistant" + else torch.zeros_like(m["token_ids"]) + ) + flat_messages, input_lengths = batched_message_log_to_flat_message( + base["message_log"], + pad_value_dict={"token_ids": self.pad_token_id}, + make_sequence_length_divisible_by=self.make_seq_div_by, + ) + base["input_ids"] = flat_messages["token_ids"] + base["input_lengths"] = input_lengths + base["token_mask"] = flat_messages["token_loss_mask"] + base["sample_mask"] = base["loss_multiplier"] + base["flat_messages"] = flat_messages + mm_dict = flat_messages.get_multimodal_dict(as_tensors=False) + if mm_dict: + for k, v in mm_dict.items(): + base[k] = v + + # --- Per-teacher CT (DP-only) --- + student_ids = base["input_ids"] + extra_env = base.get("extra_env_info") + batch_size = student_ids.shape[0] + + has_raw_text = ( + extra_env is not None + and len(extra_env) == batch_size + and all( + isinstance(e, dict) and "raw_text" in e for e in extra_env + ) + ) + texts_cache: Optional[list[str]] = None + + per_teacher_ct_data: list[Optional[dict[str, Any]]] = [] + any_ct = any(spec is not None for spec in self.teacher_ct_specs) + if any_ct: + if has_raw_text: + texts_cache = [e["raw_text"] for e in extra_env] + else: + texts_cache = self._get_student_tokenizer().batch_decode( + student_ids.tolist(), skip_special_tokens=True + ) + + for t_idx, spec in enumerate(self.teacher_ct_specs): + if spec is None: + per_teacher_ct_data.append(None) + continue + + aligner = self._aligners[t_idx] + teacher_tokenizer = self._teacher_tokenizers[t_idx] + + enc = teacher_tokenizer( + texts_cache, + max_length=int(spec["max_teacher_len"]), + padding="max_length", + truncation=True, + return_tensors="pt", + ) + teacher_input_ids = enc["input_ids"] + teacher_attention_mask = enc["attention_mask"] + teacher_input_lengths = teacher_attention_mask.sum(dim=1) + teacher_token_mask = (teacher_attention_mask > 0).to(torch.float32) + + dp_chunk_size = int(spec["dp_chunk_size"]) + use_align_fast = bool(spec["use_align_fast"]) + exact_match_only = bool(spec.get("exact_token_match_only", False)) + student_seq_len = int(student_ids.shape[1]) + teacher_seq_len = int(teacher_input_ids.shape[1]) + + aligned_pairs: list[Any] = [] + student_chunk_coo: list[torch.Tensor] = [] + teacher_chunk_coo: list[torch.Tensor] = [] + num_chunks_per_sample: list[int] = [] + for b in range(batch_size): + s_t = student_ids[b : b + 1] + t_t = teacher_input_ids[b : b + 1] + if use_align_fast and aligner._student_canon_map is not None: + result = aligner.align_fast( + s_t, t_t, chunk_size=dp_chunk_size + ) + else: + result = aligner.align(s_t, t_t, chunk_size=dp_chunk_size) + pairs = result[0] + aligned_pairs.append(pairs) + + s_coo, t_coo, n_chunks = _build_chunk_coo( + pairs, student_seq_len, teacher_seq_len, exact_match_only, + ) + student_chunk_coo.append(s_coo) + teacher_chunk_coo.append(t_coo) + num_chunks_per_sample.append(n_chunks) + + teacher_data: BatchedDataDict[Any] = BatchedDataDict( + { + "input_ids": teacher_input_ids, + "input_lengths": teacher_input_lengths, + "token_mask": teacher_token_mask, + "sample_mask": base["loss_multiplier"], + } + ) + teacher_data.to("cpu") + + per_teacher_ct_data.append( + { + "teacher_input_ids": teacher_input_ids, + "aligned_pairs": aligned_pairs, + "teacher_data": teacher_data, + "student_chunk_coo": student_chunk_coo, + "teacher_chunk_coo": teacher_chunk_coo, + "num_chunks": num_chunks_per_sample, + } + ) + + base["per_teacher_ct_data"] = per_teacher_ct_data + return base diff --git a/nemo_rl/data/datasets/eval_datasets/__init__.py b/nemo_rl/data/datasets/eval_datasets/__init__.py index d813ed040c..f0b5d15a81 100644 --- a/nemo_rl/data/datasets/eval_datasets/__init__.py +++ b/nemo_rl/data/datasets/eval_datasets/__init__.py @@ -14,8 +14,10 @@ from nemo_rl.data.datasets.eval_datasets.aime import AIMEDataset from nemo_rl.data.datasets.eval_datasets.gpqa import GPQADataset +from nemo_rl.data.datasets.eval_datasets.humaneval_plus import HumanEvalPlusDataset from nemo_rl.data.datasets.eval_datasets.local_math_dataset import LocalMathDataset from nemo_rl.data.datasets.eval_datasets.math import MathDataset +from nemo_rl.data.datasets.eval_datasets.mbpp_plus import MBPPPlusDataset from nemo_rl.data.datasets.eval_datasets.mmau import MMAUDataset from nemo_rl.data.datasets.eval_datasets.mmlu import MMLUDataset from nemo_rl.data.datasets.eval_datasets.mmlu_pro import MMLUProDataset @@ -35,10 +37,12 @@ def load_eval_dataset(data_config): # mmlu if dataset_name.startswith("mmlu") and dataset_name != "mmlu_pro": + num_few_shot = data_config.get("num_few_shot", 0) if dataset_name == "mmlu": base_dataset = MMLUDataset( prompt_file=data_config["prompt_file"], system_prompt_file=data_config["system_prompt_file"], + num_few_shot=num_few_shot, ) else: language = dataset_name.split("_")[1] @@ -46,6 +50,7 @@ def load_eval_dataset(data_config): language=language, prompt_file=data_config["prompt_file"], system_prompt_file=data_config["system_prompt_file"], + num_few_shot=num_few_shot, ) elif dataset_name == "mmlu_pro": base_dataset = MMLUProDataset( @@ -98,6 +103,21 @@ def load_eval_dataset(data_config): dataset_name="TwinkStart/MMAU", split=split, ) + # code-execution benchmarks + elif dataset_name == "mbpp_plus": + base_dataset = MBPPPlusDataset( + prompt_file=data_config["prompt_file"], + system_prompt_file=data_config["system_prompt_file"], + dataset_path=data_config.get("dataset_path") or "evalplus/mbppplus", + split=data_config.get("split") or "test", + ) + elif dataset_name in ("humaneval_plus", "human_eval_plus"): + base_dataset = HumanEvalPlusDataset( + prompt_file=data_config["prompt_file"], + system_prompt_file=data_config["system_prompt_file"], + dataset_path=data_config.get("dataset_path") or "evalplus/humanevalplus", + split=data_config.get("split") or "test", + ) # fall back to local dataset else: print(f"Loading dataset from {dataset_name}...") @@ -117,8 +137,10 @@ def load_eval_dataset(data_config): __all__ = [ "AIMEDataset", "GPQADataset", + "HumanEvalPlusDataset", "LocalMathDataset", "MathDataset", + "MBPPPlusDataset", "MMAUDataset", "MMLUDataset", "MMLUProDataset", diff --git a/nemo_rl/data/datasets/eval_datasets/humaneval_plus.py b/nemo_rl/data/datasets/eval_datasets/humaneval_plus.py new file mode 100644 index 0000000000..8efa4f968b --- /dev/null +++ b/nemo_rl/data/datasets/eval_datasets/humaneval_plus.py @@ -0,0 +1,74 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""HumanEval+ dataset for code-execution eval. + +Loads the upstream `evalplus/humanevalplus` HuggingFace dataset and shapes each +example into the columns expected by +:func:`nemo_rl.data.processors.code_data_processor` and graded by +:class:`nemo_rl.environments.code_environment.CodeUnitTestEnvironment`. +""" + +from typing import Any, Optional + +from datasets import load_dataset + +from nemo_rl.data import processors +from nemo_rl.data.interfaces import TaskDataSpec + + +class HumanEvalPlusDataset: + """Code-execution eval over the HumanEval+ test split. + + The ``env_name`` attribute lets the runner pick the right environment + (``code_unit_test`` instead of ``math``/``multichoice``). + """ + + env_name: str = "code_unit_test" + + def __init__( + self, + prompt_file: Optional[str] = None, + system_prompt_file: Optional[str] = None, + dataset_path: str = "evalplus/humanevalplus", + split: str = "test", + ): + ds = load_dataset(dataset_path, split=split) + self.rekeyed_ds = ds.map(self._rekey, remove_columns=ds.column_names) + self.task_spec = TaskDataSpec( + task_name="humaneval_plus", + prompt_file=prompt_file, + system_prompt_file=system_prompt_file, + ) + self.processor = processors.code_data_processor + + @staticmethod + def _rekey(data: dict[str, Any]) -> dict[str, Any]: + prompt = data.get("prompt") or "" + test = data.get("test") or "" + entry_point = data.get("entry_point") or "" + return { + "problem": str(prompt), + # HumanEval+ ships the full check() harness in `test`; the grader + # will exec it and call check(). + "test_code": str(test), + "test_list": [], + "test_imports": [], + "entry_point": str(entry_point), + # HumanEval+ uses a continuation-style prompt: the model is + # expected to emit only the *body* of `entry_point`. The grader + # must prepend this stub before exec'ing so that `entry_point` + # actually gets defined in the namespace. + "code_prefix": str(prompt), + } diff --git a/nemo_rl/data/datasets/eval_datasets/mbpp_plus.py b/nemo_rl/data/datasets/eval_datasets/mbpp_plus.py new file mode 100644 index 0000000000..2d0492e4aa --- /dev/null +++ b/nemo_rl/data/datasets/eval_datasets/mbpp_plus.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MBPP+ (mostly-basic Python problems, plus) dataset for code-execution eval. + +Loads the upstream `evalplus/mbppplus` HuggingFace dataset and shapes each +example into the columns expected by +:func:`nemo_rl.data.processors.code_data_processor` and graded by +:class:`nemo_rl.environments.code_environment.CodeUnitTestEnvironment`. +""" + +import re +from typing import Any, Optional + +from datasets import load_dataset + +from nemo_rl.data import processors +from nemo_rl.data.interfaces import TaskDataSpec + + +_FUNCTION_DEF_RE = re.compile(r"def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(") + + +def _infer_entry_point(reference_code: str, test_list: list[str]) -> str: + """Best-effort inference of the candidate function name.""" + if reference_code: + match = _FUNCTION_DEF_RE.search(reference_code) + if match: + return match.group(1) + for assertion in test_list: + match = re.search(r"assert\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", str(assertion)) + if match: + return match.group(1) + return "" + + +class MBPPPlusDataset: + """Code-execution eval over the MBPP+ test split. + + The ``env_name`` attribute lets the runner pick the right environment + (``code_unit_test`` instead of ``math``/``multichoice``). + """ + + env_name: str = "code_unit_test" + + def __init__( + self, + prompt_file: Optional[str] = None, + system_prompt_file: Optional[str] = None, + dataset_path: str = "evalplus/mbppplus", + split: str = "test", + ): + ds = load_dataset(dataset_path, split=split) + self.rekeyed_ds = ds.map(self._rekey, remove_columns=ds.column_names) + self.task_spec = TaskDataSpec( + task_name="mbpp_plus", + prompt_file=prompt_file, + system_prompt_file=system_prompt_file, + ) + self.processor = processors.code_data_processor + + @staticmethod + def _rekey(data: dict[str, Any]) -> dict[str, Any]: + prompt = data.get("prompt") or data.get("text") or "" + test_list = data.get("test_list") or [] + if not isinstance(test_list, list): + test_list = [test_list] + test_imports = data.get("test_imports") or [] + if not isinstance(test_imports, list): + test_imports = [test_imports] + # MBPP+ does not ship an explicit entry_point; infer it so that + # HumanEval-style ``check(candidate)`` harnesses keep working when + # the same processor is reused. + entry_point = _infer_entry_point( + str(data.get("code", "")), [str(x) for x in test_list] + ) + return { + "problem": str(prompt), + # MBPP+'s ``test`` is the canonical-solution test script; the + # grader uses ``test_list`` (assertions) instead. + "test_code": "", + "test_list": [str(x) for x in test_list], + "test_imports": [str(x) for x in test_imports], + "entry_point": entry_point, + } diff --git a/nemo_rl/data/datasets/eval_datasets/mmlu.py b/nemo_rl/data/datasets/eval_datasets/mmlu.py index c9a373fc10..12953166ec 100644 --- a/nemo_rl/data/datasets/eval_datasets/mmlu.py +++ b/nemo_rl/data/datasets/eval_datasets/mmlu.py @@ -14,6 +14,7 @@ """MMLU dataset and its variants.""" +from collections import defaultdict from typing import Any, Literal, Optional from datasets import load_dataset @@ -21,6 +22,8 @@ from nemo_rl.data import processors from nemo_rl.data.interfaces import TaskDataSpec +ANSWER_INDEX_TO_LETTER = {0: "A", 1: "B", 2: "C", 3: "D"} + class MMLUDataset: def __init__( @@ -44,6 +47,7 @@ def __init__( ] = "EN-US", prompt_file: Optional[str] = None, system_prompt_file: Optional[str] = None, + num_few_shot: int = 0, ): if language != "EN-US": data_files = f"https://openaipublic.blob.core.windows.net/simple-evals/mmlu_{language}.csv" @@ -58,6 +62,14 @@ def __init__( ) self.rekeyed_ds = ds.map(self._rekey, remove_columns=ds.column_names) + if num_few_shot > 0: + few_shot_prefixes = self._build_few_shot_prefixes(num_few_shot) + self.rekeyed_ds = self.rekeyed_ds.map( + lambda ex: { + "few_shot_prefix": few_shot_prefixes.get(ex["subject"], "") + } + ) + self.task_spec = TaskDataSpec( task_name=f"MMLU_{language}", prompt_file=prompt_file, @@ -65,6 +77,33 @@ def __init__( ) self.processor = processors.multichoice_qa_processor + @staticmethod + def _build_few_shot_prefixes(num_few_shot: int) -> dict[str, str]: + """Build per-subject few-shot prefixes from MMLU's dev (validation) split.""" + dev_ds = load_dataset("cais/mmlu", "all", split="validation") + + dev_by_subject: dict[str, list[dict[str, Any]]] = defaultdict(list) + for ex in dev_ds: + dev_by_subject[ex["subject"]].append(ex) + + prefixes: dict[str, str] = {} + for subject, examples in dev_by_subject.items(): + parts = [] + for fs_ex in examples[:num_few_shot]: + choices = fs_ex["choices"] + options_str = "\n".join( + f"{letter}) {choices[i]}" + for i, letter in enumerate(["A", "B", "C", "D"]) + ) + answer_letter = ANSWER_INDEX_TO_LETTER[fs_ex["answer"]] + parts.append( + f"Question: {fs_ex['question']}\nOptions:\n{options_str}\n" + f"Answer: {answer_letter}" + ) + prefixes[subject] = "\n\n".join(parts) + + return prefixes + def _rekey(self, data: dict[str, Any]): return { "question": data["Question"], diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py index a8928a16a9..bfb98e350c 100644 --- a/nemo_rl/data/datasets/response_datasets/__init__.py +++ b/nemo_rl/data/datasets/response_datasets/__init__.py @@ -14,6 +14,7 @@ from nemo_rl.data import ResponseDatasetConfig from nemo_rl.data.datasets.response_datasets.aime24 import AIME2024Dataset +from nemo_rl.data.datasets.response_datasets.arrow_text_dataset import ArrowTextDataset from nemo_rl.data.datasets.response_datasets.avqa import AVQADataset from nemo_rl.data.datasets.response_datasets.clevr import CLEVRCoGenTDataset from nemo_rl.data.datasets.response_datasets.daily_omni import DailyOmniDataset @@ -48,6 +49,7 @@ # built-in datasets "avqa": AVQADataset, "AIME2024": AIME2024Dataset, + "arrow_text": ArrowTextDataset, "clevr-cogent": CLEVRCoGenTDataset, "daily-omni": DailyOmniDataset, "general-conversation-jsonl": GeneralConversationsJsonlDataset, @@ -98,6 +100,7 @@ def load_response_dataset(data_config: ResponseDatasetConfig): __all__ = [ "AVQADataset", "AIME2024Dataset", + "ArrowTextDataset", "CLEVRCoGenTDataset", "DailyOmniDataset", "GeneralConversationsJsonlDataset", diff --git a/nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py b/nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py new file mode 100644 index 0000000000..e7f7d89145 --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py @@ -0,0 +1,308 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Arrow Text Dataset for loading arrow files with 'text' column.""" + +import glob +import hashlib +import json +import os +import time +from typing import Any, Optional + +import numpy as np +from datasets import Dataset, load_dataset + +from nemo_rl.data.datasets.raw_dataset import RawDataset + + +class _LazyPackedDataset: + """Map-style dataset that packs text on demand via precomputed boundaries. + + Each ``__getitem__`` call loads only the raw rows needed for one packed + sample, joins them with newlines, and returns the messages dict. The + underlying arrow dataset stays memory-mapped — no bulk text loading. + """ + + def __init__( + self, + arrow_dataset: Dataset, + pack_ranges: list[tuple[int, int]], + text_key: str = "text", + ): + self.arrow_dataset = arrow_dataset + self.pack_ranges = pack_ranges + self.text_key = text_key + + def __len__(self) -> int: + return len(self.pack_ranges) + + def __getitem__(self, idx: int) -> dict[str, Any]: + start, end = self.pack_ranges[idx] + texts = self.arrow_dataset[start:end][self.text_key] + packed = "\n".join(t for t in texts if isinstance(t, str) and t) + return { + "messages": [{"role": "assistant", "content": packed}], + "task_name": "arrow_text_dataset", + } + + +class ArrowTextDataset(RawDataset): + """Dataset class for loading arrow files containing raw text. + + This class loads arrow files with a 'text' column and converts them to + the messages format expected by SFT training. + + The text is wrapped as an assistant message: + {"messages": [{"role": "assistant", "content": }]} + + This format allows training on all tokens (language modeling style). + + Optionally, multiple short samples can be concatenated (packed) into a + single sample up to ``characters_per_sample`` characters, separated by + newlines. This avoids padding waste when individual samples are short + relative to the model's context window. Works correctly with + cross-tokenizer distillation because packing happens at the raw-text + level before any tokenizer sees the data. + + Packing uses **lazy loading**: only character lengths are scanned at + init time (fast) to build pack boundaries; actual text is loaded + on-demand in ``__getitem__``. + + Args: + arrow_files: Path pattern (glob) or list of arrow file paths + val_split: Fraction of data to use for validation (default: 0.05) + seed: Random seed for train/val split + text_key: Key for text column in arrow files (default: "text") + characters_per_sample: If set, concatenate multiple texts (separated + by "\\n") until the accumulated character count reaches this + threshold before yielding one packed sample. To guarantee that + every packed sample tokenizes to at least ``max_input_seq_length`` + tokens (so truncation always fires and every training sequence is + exactly the context length), use ``max_input_seq_length * 8`` + (≈ 8 chars per token, matching tokenalign_upstream's default + ``characters_multiplier``). Using a smaller multiplier (e.g. 4) + may leave some samples shorter than the context window for dense + text such as code or CJK. Set to None or 0 to disable packing. + + Example config: + data: + dataset_name: "arrow_text" + arrow_files: "/path/to/data/*.arrow" + val_split: 0.05 + max_input_seq_length: 4096 + characters_per_sample: 32768 # 4096 tokens * 8 chars/token (guarantees full context) + """ + + def __init__( + self, + arrow_files: str | list[str], + val_split: float = 0.05, + seed: int = 42, + text_key: str = "text", + characters_per_sample: Optional[int] = None, + pack_cache_dir: Optional[str] = None, + **kwargs, + ): + # Don't call super().__init__() since RawDataset raises NotImplementedError + self.seed = seed + self.text_key = text_key + self.task_name = "arrow_text_dataset" + + # Resolve glob pattern if string + if isinstance(arrow_files, str): + file_list = glob.glob(arrow_files) + if not file_list: + raise ValueError(f"No arrow files found matching pattern: {arrow_files}") + else: + file_list = arrow_files + + print(f"Loading {len(file_list)} arrow files...") + dataset = load_dataset("arrow", data_files=file_list, split="train") + original_count = len(dataset) + print(f" ✓ Loaded {original_count} total samples") + + # Verify text column exists + if self.text_key not in dataset.column_names: + raise ValueError( + f"Column '{self.text_key}' not found in arrow files. " + f"Available columns: {dataset.column_names}" + ) + + if characters_per_sample is not None and characters_per_sample > 0: + # Lazy packing: scan character lengths to build pack boundaries, + # then load+concatenate text on demand in __getitem__. + # + # The scan is deterministic in (file set, text_key, chars/sample) + # and dominates dataset init time on large arrow corpora + # (~6 minutes for 70M rows). When ``pack_cache_dir`` is set we + # fingerprint the inputs and store the resulting boundaries as + # an ``int64[N, 2]`` ``.npy`` file so subsequent runs skip the + # scan entirely. + pack_ranges = _load_or_build_pack_ranges( + dataset=dataset, + file_list=file_list, + text_key=text_key, + characters_per_sample=characters_per_sample, + pack_cache_dir=pack_cache_dir, + ) + + # Split pack_ranges into train/val + if val_split > 0: + import random + rng = random.Random(seed) + indices = list(range(len(pack_ranges))) + rng.shuffle(indices) + val_count = max(1, int(len(pack_ranges) * val_split)) + val_indices = sorted(indices[:val_count]) + train_indices = sorted(indices[val_count:]) + train_ranges = [pack_ranges[i] for i in train_indices] + val_ranges = [pack_ranges[i] for i in val_indices] + else: + train_ranges = pack_ranges + val_ranges = pack_ranges[:min(100, len(pack_ranges))] + + train_dataset = _LazyPackedDataset(dataset, train_ranges, text_key) + val_dataset = _LazyPackedDataset(dataset, val_ranges, text_key) + else: + # No packing: convert text to messages format directly + def text_to_messages(example: dict[str, Any]) -> dict[str, Any]: + text = example[self.text_key] + return { + "messages": [{"role": "assistant", "content": text}], + "task_name": "arrow_text_dataset", + } + + formatted_dataset = dataset.map(text_to_messages, remove_columns=dataset.column_names) + + if val_split > 0: + split = formatted_dataset.train_test_split(test_size=val_split, seed=seed) + train_dataset = split["train"] + val_dataset = split["test"] + else: + train_dataset = formatted_dataset + val_dataset = formatted_dataset.select(range(min(100, len(formatted_dataset)))) + + print(f" ✓ Train: {len(train_dataset)}, Validation: {len(val_dataset)}") + + self.dataset = train_dataset + self.val_dataset = val_dataset + + +def _pack_cache_key( + file_list: list[str], text_key: str, characters_per_sample: int +) -> str: + """Fingerprint the inputs that determine the pack boundary list. + + Uses (sorted resolved path, size, mtime-as-int) per arrow file so the + cache invalidates automatically when any shard is rewritten or when the + file set changes. ``text_key`` and ``characters_per_sample`` are part of + the key because changing either yields different boundaries. + """ + parts = [] + for path in sorted(file_list): + st = os.stat(path) + parts.append((path, st.st_size, int(st.st_mtime))) + blob = json.dumps( + { + "files": parts, + "text_key": text_key, + "chars_per_sample": int(characters_per_sample), + }, + sort_keys=True, + ) + return hashlib.sha1(blob.encode()).hexdigest()[:16] + + +def _load_or_build_pack_ranges( + dataset: Dataset, + file_list: list[str], + text_key: str, + characters_per_sample: int, + pack_cache_dir: Optional[str], +) -> list[tuple[int, int]]: + """Scan + cache pack boundaries, or load from disk if a cache hit exists.""" + cache_path: Optional[str] = None + if pack_cache_dir: + os.makedirs(pack_cache_dir, exist_ok=True) + key = _pack_cache_key(file_list, text_key, characters_per_sample) + cache_path = os.path.join(pack_cache_dir, f"{key}.npy") + if os.path.exists(cache_path): + print(f" ↪ Loading cached pack boundaries from {cache_path}") + arr = np.load(cache_path) + pack_ranges = [(int(s), int(e)) for s, e in arr] + print(f" ✓ Loaded {len(pack_ranges)} pack boundaries (cache hit)") + return pack_ranges + + print( + f" Scanning character lengths for lazy packing (target ~{characters_per_sample} chars)..." + ) + t0 = time.time() + pack_ranges = _build_pack_ranges(dataset, text_key, characters_per_sample) + print( + f" ✓ Built {len(pack_ranges)} pack boundaries from {len(dataset)} samples in {time.time() - t0:.1f}s" + ) + + if cache_path is not None: + # Atomic write so concurrent jobs don't corrupt the cache file. + # Pass a file handle to np.save (not a path) — np.save auto-appends + # ".npy" when given a path that doesn't already end in .npy, which + # breaks the os.replace(tmp -> final) rename below. + tmp_path = f"{cache_path}.tmp.{os.getpid()}" + arr = np.asarray(pack_ranges, dtype=np.int64) + with open(tmp_path, "wb") as f: + np.save(f, arr) + os.replace(tmp_path, cache_path) + print(f" ↳ Cached pack boundaries to {cache_path}") + + return pack_ranges + + +def _build_pack_ranges( + dataset: Dataset, + text_key: str, + characters_per_sample: int, + scan_batch_size: int = 10000, +) -> list[tuple[int, int]]: + """Scan character lengths in batches and build pack boundary indices. + + Returns a list of (start_row, end_row) tuples. Each packed sample + covers rows [start_row, end_row) whose combined character length + (plus newline separators) meets or exceeds ``characters_per_sample``. + """ + n = len(dataset) + pack_ranges: list[tuple[int, int]] = [] + start = 0 + accum = 0 + + for batch_start in range(0, n, scan_batch_size): + batch_end = min(batch_start + scan_batch_size, n) + batch_texts = dataset[batch_start:batch_end][text_key] + for i, text in enumerate(batch_texts): + row_idx = batch_start + i + text_len = len(text) if isinstance(text, str) and text else 0 + if text_len == 0: + continue + # +1 for newline separator (except the first text in a pack) + accum += text_len + (1 if accum > 0 else 0) + if accum >= characters_per_sample: + pack_ranges.append((start, row_idx + 1)) + start = row_idx + 1 + accum = 0 + + # Flush remaining rows as a partial pack + if start < n and accum > 0: + pack_ranges.append((start, n)) + + return pack_ranges diff --git a/nemo_rl/data/processors.py b/nemo_rl/data/processors.py index 091c5ad1c5..ed3ac29465 100644 --- a/nemo_rl/data/processors.py +++ b/nemo_rl/data/processors.py @@ -187,6 +187,72 @@ def sft_processor( return output +def kd_data_processor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer: TokenizerType, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a datum for knowledge-distillation training on raw text. + + Cross-tokenizer KD requires both student and teacher to see the same + raw text (no chat template, no instruction formatting), with loss + applied to every token. The raw text is preserved in + ``extra_env_info`` so the teacher worker can retokenize it with its + own tokenizer. + + Args: + datum_dict: Sample dict with a ``messages`` list whose ``content`` + fields are concatenated with newlines into the raw text. + task_data_spec: Task spec (unused; present for protocol parity). + tokenizer: Student tokenizer used for the student-side token ids. + max_seq_length: Truncation cap; samples longer than this are + zero-weighted via ``loss_multiplier``. + idx: Index of the sample within the dataset. + + Returns: + DatumSpec with a single assistant-role ``message_log`` entry, + ``token_loss_mask`` set to all-ones (loss on every token), and + ``extra_env_info["raw_text"]`` carrying the joined text. + """ + raw_text = "\n".join( + msg["content"] + for msg in datum_dict["messages"] + if isinstance(msg.get("content"), str) + ) + + token_ids = tokenizer( + raw_text, + return_tensors="pt", + add_special_tokens=True, + max_length=max_seq_length, + truncation=True, + )["input_ids"][0] + + length = len(token_ids) + loss_multiplier = 1.0 + if length > max_seq_length: + loss_multiplier = 0.0 + + message_log = [ + { + "role": "assistant", + "content": raw_text, + "token_ids": token_ids, + "token_loss_mask": torch.ones_like(token_ids), + } + ] + + return { + "message_log": message_log, + "length": length, + "extra_env_info": {"raw_text": raw_text}, + "loss_multiplier": loss_multiplier, + "idx": idx, + } + + def preference_preprocessor( datum_dict: dict[str, Any], task_data_spec: TaskDataSpec, @@ -718,6 +784,7 @@ def nemo_gym_data_processor( { "default": math_hf_data_processor, "helpsteer3_data_processor": helpsteer3_data_processor, + "kd_data_processor": kd_data_processor, "math_data_processor": math_data_processor, "math_hf_data_processor": math_hf_data_processor, "multichoice_qa_processor": multichoice_qa_processor, diff --git a/nemo_rl/data/utils.py b/nemo_rl/data/utils.py index 2819e27582..b53adc6522 100644 --- a/nemo_rl/data/utils.py +++ b/nemo_rl/data/utils.py @@ -134,7 +134,10 @@ def setup_response_data( } else: # merge datasets into a single dataset - merged_data = concatenate_datasets([data.dataset for data in data_list]) + if len(data_list) == 1: + merged_data = data_list[0].dataset + else: + merged_data = concatenate_datasets([data.dataset for data in data_list]) dataset = AllTaskProcessedDataset( merged_data, tokenizer, @@ -199,7 +202,10 @@ def setup_response_data( # merge datasets val_dataset = None if len(val_data_list) > 0: - merged_val_data = concatenate_datasets(val_data_list) + if len(val_data_list) == 1: + merged_val_data = val_data_list[0] + else: + merged_val_data = concatenate_datasets(val_data_list) val_dataset = AllTaskProcessedDataset( merged_val_data, tokenizer, diff --git a/nemo_rl/distributed/ipc_utils.py b/nemo_rl/distributed/ipc_utils.py new file mode 100644 index 0000000000..46e9b206c3 --- /dev/null +++ b/nemo_rl/distributed/ipc_utils.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""IPC helpers for sharing CUDA tensors across processes. + +Used by cross-tokenizer off-policy distillation to ship teacher logits +from the teacher policy worker to the student worker without an extra +host roundtrip. +""" + +from typing import Any + +import torch +from torch.multiprocessing.reductions import rebuild_cuda_tensor + + +def get_handle_from_tensor(tensor: torch.Tensor) -> tuple[Any]: + """Get IPC handle from a tensor.""" + from torch.multiprocessing.reductions import reduce_tensor + + # Skip serializing the function for better refit performance. + return reduce_tensor(tensor.detach())[1:] + + +def rebuild_cuda_tensor_from_ipc( + cuda_ipc_handle: tuple, device_id: int +) -> torch.Tensor: + """Rebuild a CUDA tensor from an IPC handle on ``device_id``.""" + args = cuda_ipc_handle[0] + list_args = list(args) + list_args[6] = device_id + return rebuild_cuda_tensor(*list_args) diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index d2b5979400..f885a8611a 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -45,7 +45,12 @@ from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.algorithms.utils import mask_out_neg_inf_logprobs from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.ipc_utils import ( + get_handle_from_tensor, + rebuild_cuda_tensor_from_ipc, +) from nemo_rl.distributed.model_utils import ( + _compute_distributed_log_softmax, allgather_cp_sharded_tensor, distributed_vocab_topk, get_logprobs_from_vocab_parallel_logits, @@ -59,6 +64,8 @@ "LogprobsPostProcessor", "TopkLogitsPostProcessor", "ScorePostProcessor", + "XTokenTeacherIPCLossPostProcessor", + "XTokenTeacherIPCExportPostProcessor", ] @@ -588,6 +595,307 @@ def __call__( return loss, loss_metrics + + +class XTokenTeacherIPCLossPostProcessor(LossPostProcessor): + """Loss post-processor that injects teacher logits via CUDA IPC handles.""" + + def __init__( + self, *args: Any, teacher_result: Optional[dict[str, Any]] = None, **kwargs: Any + ): + super().__init__(*args, **kwargs) + self._teacher_result = teacher_result + self._microbatch_idx = 0 + + def set_microbatch_index(self, mb_idx: int) -> None: + self._microbatch_idx = mb_idx + + def __call__( + self, + logits: torch.Tensor, + data_dict: BatchedDataDict[Any], + processed_inputs: ProcessedInputs, + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + sequence_dim: int = 1, + ) -> tuple[torch.Tensor, dict[str, Any]]: + if self.cp_size > 1: + _, data_dict = prepare_data_for_cp( + data_dict, processed_inputs, self.cp_mesh, sequence_dim + ) + logits = redistribute_logits_for_cp( + logits, self.device_mesh, self.cp_mesh, sequence_dim + ) + + if self.enable_seq_packing: + loss_fn_ = SequencePackingLossWrapper( + loss_fn=self.loss_fn, + cu_seqlens_q=processed_inputs.flash_attn_kwargs.cu_seqlens_q, + cu_seqlens_q_padded=processed_inputs.flash_attn_kwargs.cu_seqlens_q, + ) + else: + loss_fn_ = self.loss_fn + + def _extract_teacher_ipc_payload(teacher_result_obj: dict[str, Any]) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + if "microbatch_handles" not in teacher_result_obj: + return None, None + handles = teacher_result_obj["microbatch_handles"] + if self._microbatch_idx >= len(handles): + return None, None + rank = torch.distributed.get_rank() + current_device_id = torch.cuda.current_device() + handle = handles[self._microbatch_idx] + aB, aS, aK = handle["actual_shape"] + + # View into the teacher's pre-allocated IPC buffer — no clone. + # The teacher's per-microbatch buffer slot is only reused on the + # next step's teacher forward, which is strictly ordered AFTER + # this student forward+backward by the driver-side ray.get + # sequencing. Teacher tensor is detached and doesn't require + # grad, so autograd never captures it past the loss op. Avoids a + # full-vocab fp32-or-bf16 memcpy per microbatch (~3.3 GiB at + # mbs=2, V≈200K, bf16 — was the dominant peak after fix A). + teacher_logits_tensor = rebuild_cuda_tensor_from_ipc( + handle[rank], current_device_id + ).detach()[:aB, :aS, :aK] + + teacher_topk_indices_tensor = None + is_topk = teacher_result_obj.get("is_topk", False) + if is_topk and "topk_indices_ipc" in handle: + teacher_topk_indices_tensor = rebuild_cuda_tensor_from_ipc( + handle["topk_indices_ipc"], current_device_id + ).detach()[:aB, :aS, :aK] + return teacher_logits_tensor, teacher_topk_indices_tensor + + loss_kwargs = {} + if self._teacher_result is not None and not self.enable_seq_packing: + if isinstance(self._teacher_result, list): + teacher_logits_list: list[torch.Tensor] = [] + teacher_topk_indices_list: list[Optional[torch.Tensor]] = [] + for teacher_result_obj in self._teacher_result: + t_logits, t_topk = _extract_teacher_ipc_payload(teacher_result_obj) + if t_logits is None: + continue + teacher_logits_list.append(t_logits) + teacher_topk_indices_list.append(t_topk) + if teacher_logits_list: + loss_kwargs["teacher_logits_list"] = teacher_logits_list + loss_kwargs["teacher_topk_indices_list"] = teacher_topk_indices_list + else: + t_logits, t_topk = _extract_teacher_ipc_payload(self._teacher_result) + if t_logits is not None: + loss_kwargs["teacher_logits"] = t_logits + if t_topk is not None: + loss_kwargs["teacher_topk_indices_ipc"] = t_topk + + loss, loss_metrics = loss_fn_( + logits, + data_dict, + global_valid_seqs, + global_valid_toks, + **loss_kwargs, + ) + return loss, loss_metrics + + +class XTokenTeacherIPCExportPostProcessor(LossPostProcessor): + """Teacher-side post-processor that exports per-microbatch logits via CUDA IPC.""" + + def __init__( + self, + *args: Any, + tp_mesh: Any, + topk_logits: Optional[int], + is_mdlm: bool = False, + **kwargs: Any, + ): + # Keep explicit tp_mesh for local use and also forward it to base init. + super().__init__(*args, tp_mesh=tp_mesh, **kwargs) + self.tp_mesh = tp_mesh + self.topk_logits = topk_logits + self.is_mdlm = is_mdlm + self.microbatch_handles: list[dict[str, Any]] = [] + self._microbatch_idx = 0 + self._mb_vals_buffers: list[torch.Tensor] = [] + self._mb_vals_ipcs: list[tuple[Any]] = [] + self._mb_idx_buffers: list[torch.Tensor] = [] + self._mb_idx_ipcs: list[tuple[Any]] = [] + self._mb_logits_buffers: list[torch.Tensor] = [] + self._mb_logits_ipcs: list[tuple[Any]] = [] + + def set_microbatch_index(self, mb_idx: int) -> None: + self._microbatch_idx = mb_idx + + def _ensure_topk_buffer( + self, + buf_idx: int, + B: int, + S: int, + K: int, + vals_dtype: torch.dtype, + idx_dtype: torch.dtype, + device: torch.device, + ) -> None: + while len(self._mb_vals_buffers) <= buf_idx: + vals_buf = torch.empty((B, S, K), dtype=vals_dtype, device=device) + idx_buf = torch.empty((B, S, K), dtype=idx_dtype, device=device) + self._mb_vals_buffers.append(vals_buf) + self._mb_vals_ipcs.append(get_handle_from_tensor(vals_buf)) + self._mb_idx_buffers.append(idx_buf) + self._mb_idx_ipcs.append(get_handle_from_tensor(idx_buf)) + vals_buf = self._mb_vals_buffers[buf_idx] + idx_buf = self._mb_idx_buffers[buf_idx] + needs_realloc = ( + vals_buf.shape[0] < B + or vals_buf.shape[1] < S + or vals_buf.shape[2] < K + or vals_buf.dtype != vals_dtype + or vals_buf.device != device + or idx_buf.shape[0] < B + or idx_buf.shape[1] < S + or idx_buf.shape[2] < K + or idx_buf.dtype != idx_dtype + or idx_buf.device != device + ) + if needs_realloc: + vals_buf = torch.empty((B, S, K), dtype=vals_dtype, device=device) + idx_buf = torch.empty((B, S, K), dtype=idx_dtype, device=device) + self._mb_vals_buffers[buf_idx] = vals_buf + self._mb_vals_ipcs[buf_idx] = get_handle_from_tensor(vals_buf) + self._mb_idx_buffers[buf_idx] = idx_buf + self._mb_idx_ipcs[buf_idx] = get_handle_from_tensor(idx_buf) + + def _ensure_logits_buffer( + self, + buf_idx: int, + B: int, + S: int, + V: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + while len(self._mb_logits_buffers) <= buf_idx: + buf = torch.empty((B, S, V), dtype=dtype, device=device) + self._mb_logits_buffers.append(buf) + self._mb_logits_ipcs.append(get_handle_from_tensor(buf)) + buf = self._mb_logits_buffers[buf_idx] + needs_realloc = ( + buf.shape[0] < B + or buf.shape[1] < S + or buf.shape[2] < V + or buf.dtype != dtype + or buf.device != device + ) + if needs_realloc: + buf = torch.empty((B, S, V), dtype=dtype, device=device) + self._mb_logits_buffers[buf_idx] = buf + self._mb_logits_ipcs[buf_idx] = get_handle_from_tensor(buf) + + def __call__( + self, + logits: torch.Tensor, + data_dict: BatchedDataDict[Any], # noqa: ARG002 + processed_inputs: ProcessedInputs, # noqa: ARG002 + global_valid_seqs: torch.Tensor, # noqa: ARG002 + global_valid_toks: torch.Tensor, # noqa: ARG002 + sequence_dim: int = 1, # noqa: ARG002 + ) -> tuple[torch.Tensor, dict[str, Any]]: + if isinstance(logits, DTensor): + mb_logits_local = logits.to_local() + else: + mb_logits_local = logits + + tp_group = self.tp_mesh.get_group() + tp_rank = torch.distributed.get_rank(tp_group) + V_local = int(mb_logits_local.shape[-1]) + vocab_start_index = tp_rank * V_local + vocab_end_index = (tp_rank + 1) * V_local + + if self.topk_logits is not None: + # Top-k path (same-tokenizer distillation): the student loss + # consumes IPC payload entries as log-probs (it `.exp()`s them + # to recover probs), so do the fp32 cast + distributed + # log-softmax here before extracting the top-k. + mb_logits_fp32 = mb_logits_local.to(torch.float32) + mb_log_prob = _compute_distributed_log_softmax( + mb_logits_fp32, group=tp_group + ) + del mb_logits_fp32, mb_logits_local + + if isinstance(mb_log_prob, DTensor): + mb_log_prob = mb_log_prob.to_local() + + if self.is_mdlm: + shared_seq_len = int(mb_log_prob.shape[1] / 2) + mb_log_prob = mb_log_prob[:, shared_seq_len:, :] + + mb_topk_vals, mb_topk_idx = distributed_vocab_topk( + mb_log_prob, + k=self.topk_logits, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_end_index, + ) + del mb_log_prob + + B_mb, S_mb, K_mb = mb_topk_vals.shape + buf_idx = self._microbatch_idx + self._ensure_topk_buffer( + buf_idx, + B_mb, + S_mb, + K_mb, + mb_topk_vals.dtype, + mb_topk_idx.dtype, + mb_topk_vals.device, + ) + self._mb_vals_buffers[buf_idx][:B_mb, :S_mb, :K_mb].copy_(mb_topk_vals) + self._mb_idx_buffers[buf_idx][:B_mb, :S_mb, :K_mb].copy_(mb_topk_idx) + del mb_topk_vals, mb_topk_idx + + rank = torch.distributed.get_rank() + handle = { + rank: self._mb_vals_ipcs[buf_idx], + "actual_shape": (B_mb, S_mb, K_mb), + "topk_indices_ipc": self._mb_idx_ipcs[buf_idx], + } + else: + # Full-vocab path (cross-tokenizer distillation): ship raw logits + # in their native dtype (typically bf16). The student-side loss + # applies log_softmax exactly once, on raw logits, matching the + # reference train_distillation_ddp.py. This avoids the teacher- + # side fp32 cast + full-vocab log_softmax kernel and removes the + # redundant student-side log_softmax of already-log-softmaxed + # values (idempotent at T=1, mathematically wrong at T!=1). + if self.is_mdlm: + shared_seq_len = int(mb_logits_local.shape[1] / 2) + mb_logits_local = mb_logits_local[:, shared_seq_len:, :] + + B_mb, S_mb, V_mb = mb_logits_local.shape + buf_idx = self._microbatch_idx + self._ensure_logits_buffer( + buf_idx, + B_mb, + S_mb, + V_mb, + mb_logits_local.dtype, + mb_logits_local.device, + ) + self._mb_logits_buffers[buf_idx][:B_mb, :S_mb, :V_mb].copy_(mb_logits_local) + del mb_logits_local + + rank = torch.distributed.get_rank() + handle = { + rank: self._mb_logits_ipcs[buf_idx], + "actual_shape": (B_mb, S_mb, V_mb), + } + + self.microbatch_handles.append(handle) + + dummy_loss = torch.zeros((), device="cuda", dtype=torch.float32) + return dummy_loss, {"num_valid_samples": 1.0} + + class LogprobsPostProcessor: """Post-processor for computing log probabilities from model outputs.""" diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 8acd808b11..513e4b6fae 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -582,6 +582,69 @@ def get_topk_logits( return stacked + def init_cross_tokenizer_loss_fn( + self, loss_config: Any, token_aligner_config: Any + ) -> None: + """Have each worker build its own cross-tokenizer loss function. + + Each worker materializes a ``MultiTeacherLossAggregator`` (with + ``N=1`` for the single-teacher case) from ``loss_config`` plus + the shared filesystem (projection matrices, embeddings) and + caches it as ``self._cached_loss_fn`` so the per-step + ``update_cross_tokenizer_data`` and the training forward path + can find it. + """ + futures = self.worker_group.run_all_workers_single_data( + "init_cross_tokenizer_loss_fn", + loss_config=loss_config, + token_aligner_config=token_aligner_config, + ) + ray.get(futures) + + def update_cross_tokenizer_data( + self, + teacher_input_ids: torch.Tensor, + aligned_pairs: Any, + teacher_idx: Optional[int] = None, + chunk_indices: Optional[dict[str, Any]] = None, + ) -> None: + """Push per-step cross-tokenizer data to all workers' cached loss functions. + + Shards ``teacher_input_ids``, ``aligned_pairs``, and the optional + ``chunk_indices`` along the DP axis so each worker only receives + its own slice. ``chunk_indices`` carries the per-sample COO chunk + masks precomputed by ``CrossTokenizerCollator``. + """ + dp_size = self.sharding_annotations.get_axis_size("data_parallel") + batch_size = teacher_input_ids.shape[0] + shard_size = batch_size // dp_size + + if chunk_indices is not None: + chunk_indices_shards: list[Optional[dict[str, Any]]] = [ + { + k: chunk_indices[k][i * shard_size : (i + 1) * shard_size] + for k in chunk_indices + } + for i in range(dp_size) + ] + else: + chunk_indices_shards = [None for _ in range(dp_size)] + + futures = self.worker_group.run_all_workers_multiple_data( + "update_cross_tokenizer_data", + teacher_input_ids=[ + teacher_input_ids[i * shard_size : (i + 1) * shard_size] + for i in range(dp_size) + ], + aligned_pairs=[ + aligned_pairs[i * shard_size : (i + 1) * shard_size] + for i in range(dp_size) + ], + teacher_idx=[teacher_idx for _ in range(dp_size)], + chunk_indices=chunk_indices_shards, + ) + ray.get(futures) + def train( self, data: BatchedDataDict[Any], @@ -729,6 +792,98 @@ def generate( return result + def train_off_policy_distillation( + self, + data: BatchedDataDict[Any], + teacher_logits: Optional[Any] = None, + loss_fn: Optional[LossFunction] = None, + eval_mode: bool = False, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> dict[str, Any]: + """Run cross-tokenizer off-policy distillation on the student worker group. + + Sibling of :meth:`train` that dispatches to + ``DTensorPolicyWorkerV2Impl.train_off_policy_distillation`` instead of + the regular ``train`` so the on-policy / GRPO / SFT path stays + unaffected. + """ + batch_size = gbs or self.cfg["train_global_batch_size"] + micro_batch_size = mbs or self.cfg["train_micro_batch_size"] + dp_size = self.sharding_annotations.get_axis_size("data_parallel") + sharded_data = data.shard_by_batch_size(dp_size, batch_size=batch_size) + + futures = self.worker_group.run_all_workers_sharded_data( + "train_off_policy_distillation", + data=sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + common_kwargs={ + "teacher_logits": teacher_logits, + "loss_fn": loss_fn, + "eval_mode": eval_mode, + "gbs": batch_size, + "mbs": micro_batch_size, + }, + ) + results = self.worker_group.get_all_worker_results(futures) + return { + "loss": results[0]["global_loss"], + "grad_norm": results[0]["grad_norm"], + **{ + k: v for k, v in results[0].items() + if k not in ("global_loss", "grad_norm") + }, + } + + def compute_teacher_logits_ipc( + self, + data: BatchedDataDict[Any], + topk_logits: Optional[int] = None, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> list[dict[str, Any]]: + """Run the teacher forward pass and return per-rank IPC handle dicts. + + The teacher is always run forward-only; the student consumes the + returned per-rank IPC handles in :meth:`train_off_policy_distillation`. + """ + batch_size = gbs or self.cfg["train_global_batch_size"] + micro_batch_size = mbs or self.cfg["train_micro_batch_size"] + dp_size = self.sharding_annotations.get_axis_size("data_parallel") + sharded_data = data.shard_by_batch_size(dp_size, batch_size=batch_size) + + futures = self.worker_group.run_all_workers_sharded_data( + "compute_teacher_logits_ipc", + data=sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + common_kwargs={ + "topk_logits": topk_logits, + "gbs": batch_size, + "mbs": micro_batch_size, + }, + ) + return self.worker_group.get_all_worker_results(futures) + def score( self, data: BatchedDataDict[GenerationDatumSpec] ) -> BatchedDataDict[ScoreOutputSpec]: @@ -783,6 +938,18 @@ def prepare_for_training(self, *args: Any, **kwargs: Any) -> None: futures = self.worker_group.run_all_workers_single_data("prepare_for_training") ray.get(futures) + def move_optimizer_to_cuda(self) -> None: + """Move optimizer state to CUDA on all workers. + + Used by off-policy distillation when ``keep_models_resident=False`` to + restore the optimizer state that ``offload_after_refit`` moves to CPU + between teacher inference and student training. + """ + futures = self.worker_group.run_all_workers_single_data( + "move_optimizer_to_cuda" + ) + ray.get(futures) + def prepare_for_lp_inference(self, *args: Any, **kwargs: Any) -> None: futures = self.worker_group.run_all_workers_single_data( "prepare_for_lp_inference" diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 2fa8a8e604..c20390a853 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -55,6 +55,8 @@ LossPostProcessor, ScorePostProcessor, TopkLogitsPostProcessor, + XTokenTeacherIPCExportPostProcessor, + XTokenTeacherIPCLossPostProcessor, aggregate_training_statistics, automodel_forward_backward, forward_with_post_processing_fn, @@ -246,6 +248,10 @@ def __init__( rank=0, # Temporary, will be updated after distributed init ) + self._apply_phi_inference_dtype_override( + runtime_config=runtime_config, init_optimizer=init_optimizer + ) + # Set up distributed environment (returns DistributedContext) distributed_context = setup_distributed( config=config, @@ -302,6 +308,8 @@ def __init__( self.autocast_enabled, ) = model_and_optimizer_state + self._fix_phi_rope_meta_buffers() + # Initialize reference model if requested self.reference_model_state_dict = None if init_reference_model: @@ -324,6 +332,214 @@ def __init__( _runtime_is_reward_model, # Duplicate, already set as _is_reward_model ) = runtime_config + def _is_phi_style_model( + self, architectures: Optional[list[Any]] = None + ) -> bool: + """Return True for Phi-4 / Phi-3 model families. + + Shared gate for the Phi-specific patches in this worker. Accepts an + optional ``architectures`` list so callers running before + ``self.model_config`` is populated (pre-load hooks) can pass the + list straight from ``runtime_config.model_config``. + """ + model_name = str(self.cfg.get("model_name", "")).lower() + if architectures is None: + architectures = getattr( + getattr(self, "model_config", None), "architectures", [] + ) + arch_blob = " ".join( + a.lower() for a in architectures if isinstance(a, str) + ) + return ( + "phi-4" in model_name + or "phi4" in model_name + or "phi3" in arch_blob + ) + + def _apply_phi_inference_dtype_override( + self, runtime_config: Any, init_optimizer: bool + ) -> None: + """Drop FP32 master weights for inference-only Phi loads. + + Automodel's load-before-shard path materializes the full model on + every rank before FSDP sharding. ``validate_and_prepare_config`` + always pins ``model_config.torch_dtype = float32`` for FSDP + mixed-precision master weights, but FP32 master weights are only + consumed by the optimizer (AdamW state, grad-norm reductions). For + Phi-4 14B that's an extra ~28 GB on each GPU before the BF16 state + dict is even staged, which OOMs an 80 GB H100. + + Inference-only paths (``init_optimizer=False``, e.g. teacher + policies) never touch the optimizer code paths, so we can safely + load directly in the compute dtype and halve the load-time + footprint. Scoped to Phi-style models so other families keep the + existing behavior; no-op when an optimizer is requested. + """ + if init_optimizer: + return + if not self._is_phi_style_model( + architectures=getattr(runtime_config.model_config, "architectures", []) + ): + return + runtime_config.model_config.torch_dtype = runtime_config.dtype + if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0: + return + print( + f"[Phi inference dtype] loading {self.cfg.get('model_name')} in " + f"{runtime_config.dtype} (no optimizer requested) instead of FP32 " + f"master weights to halve load-before-shard memory." + ) + + def _fix_phi_rope_meta_buffers(self) -> None: + """Repair Phi RoPE inv_freq buffers when loaded from meta init. + + Some Phi-4 / Phi-3 ``trust_remote_code`` revisions register + ``inv_freq`` (and ``original_inv_freq``) as buffers that get left + on the ``meta`` device after FSDP/DTensor materialization. Their + forward later calls ``.to(device)`` on those buffers and crashes + with ``NotImplementedError: Cannot copy out of meta tensor``. We + re-run each module's ``rope_init_fn`` against CUDA to produce a + healthy buffer and overwrite the meta one in place. + + Scoped to Phi-style models via the model_name / architecture + string so other model families short-circuit immediately. + """ + if not self._is_phi_style_model(): + return + + fixed_count = 0 + for module in self.model.modules(): + if not hasattr(module, "inv_freq"): + continue + inv_freq = getattr(module, "inv_freq") + original_inv_freq = getattr(module, "original_inv_freq", None) + + inv_needs_repair = torch.is_tensor(inv_freq) and ( + getattr(inv_freq, "is_meta", False) + or (not torch.isfinite(inv_freq).all()) + ) + original_needs_repair = torch.is_tensor(original_inv_freq) and ( + getattr(original_inv_freq, "is_meta", False) + or (not torch.isfinite(original_inv_freq).all()) + ) + + if not inv_needs_repair and not original_needs_repair: + continue + + if not (hasattr(module, "rope_init_fn") and hasattr(module, "config")): + continue + + try: + repaired_inv_freq, _ = module.rope_init_fn( + module.config, torch.device("cuda") + ) + except (TypeError, RuntimeError, AttributeError) as exc: + if self.rank == 0: + print( + f"[Phi RoPE repair] skipping module " + f"{type(module).__name__}: rope_init_fn raised " + f"{type(exc).__name__}: {exc}" + ) + continue + + if hasattr(module, "_buffers") and "inv_freq" in module._buffers: + module._buffers["inv_freq"] = repaired_inv_freq + else: + module.inv_freq = repaired_inv_freq + if hasattr(module, "original_inv_freq"): + module.original_inv_freq = repaired_inv_freq.detach().clone() + fixed_count += 1 + + if fixed_count > 0 and self.rank == 0: + print( + f"[Phi RoPE repair] re-materialized {fixed_count} meta " + f"buffer(s) after model setup." + ) + + def init_cross_tokenizer_loss_fn( + self, + loss_config: Any, + token_aligner_config: Any, + ) -> None: + """Build and cache a cross-tokenizer loss function on this worker. + + Always materializes a ``MultiTeacherLossAggregator`` so the rest of + the off-policy distillation path is uniform between single- and + multi-teacher setups. Accepts either: + + * ``loss_config = list[(teacher_loss_cfg, aligner_cfg|None, weight)]`` + and ``token_aligner_config = None`` (multi-teacher shape), or + * ``loss_config = dict`` (single teacher's loss cfg) and + ``token_aligner_config = dict`` (single teacher's aligner cfg). + """ + from nemo_rl.algorithms.loss.loss_functions import ( + CrossTokenizerDistillationLossFn, + MultiTeacherLossAggregator, + ) + from nemo_rl.algorithms.x_token.tokenalign import TokenAligner + + if isinstance(loss_config, list) and token_aligner_config is None: + entries = loss_config + else: + assert token_aligner_config is not None, ( + "single-teacher init_cross_tokenizer_loss_fn requires " + "token_aligner_config to be provided" + ) + entries = [(loss_config, token_aligner_config, 1.0)] + + loss_fns: list[Optional[CrossTokenizerDistillationLossFn]] = [] + weights: list[float] = [] + cfg_override: Optional[dict[str, Any]] = None + for teacher_loss_cfg, aligner_cfg, teacher_weight in entries: + cfg_override = teacher_loss_cfg + weights.append(float(teacher_weight)) + if aligner_cfg is None: + loss_fns.append(None) + continue + aligner = TokenAligner( + teacher_tokenizer_name=aligner_cfg["teacher_model"], + student_tokenizer_name=aligner_cfg["student_model"], + max_comb_len=aligner_cfg.get("max_comb_len", 4), + projection_matrix_multiplier=aligner_cfg.get( + "projection_matrix_multiplier", 1.0 + ), + ) + aligner._load_logits_projection_map( + file_path=aligner_cfg["projection_matrix_path"], + use_sparse_format=aligner_cfg.get("use_sparse_format", True), + learnable=aligner_cfg.get("learnable", False), + device="cpu", + ) + if aligner_cfg.get("project_teacher_to_student", False): + aligner.create_reverse_projection_matrix(device="cpu") + loss_fns.append(CrossTokenizerDistillationLossFn(teacher_loss_cfg, aligner)) + + self._cached_loss_fn = MultiTeacherLossAggregator( + loss_fns, + weights, + normalize_by_vocab=bool( + (cfg_override or {}).get("normalize_by_vocab", False) + ), + cfg=cfg_override, + ) + + def update_cross_tokenizer_data( + self, + teacher_input_ids: Any, + aligned_pairs: Any, + teacher_idx: Optional[int] = None, + chunk_indices: Optional[dict] = None, + ) -> None: + """Update per-step cross-tokenizer data on the cached loss function.""" + cached = getattr(self, "_cached_loss_fn", None) + if cached is not None: + cached.set_cross_tokenizer_data( + teacher_input_ids=teacher_input_ids, + aligned_pairs=aligned_pairs, + teacher_idx=teacher_idx, + chunk_indices=chunk_indices, + ) + @wrap_with_nvtx_name("dtensor_policy_worker_v2/train") def train( self, @@ -507,6 +723,316 @@ def on_microbatch_start(mb_idx): return metrics + @wrap_with_nvtx_name("dtensor_policy_worker_v2/train_off_policy_distillation") + def train_off_policy_distillation( + self, + data: BatchedDataDict[Any], + teacher_logits: Optional[Any] = None, + loss_fn: Optional[LossFunction] = None, + eval_mode: bool = False, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> dict[str, Any]: + """Train the student with cross-tokenizer off-policy distillation. + + Self-contained sibling of ``train()`` — does not modify the shared + train path used by GRPO / SFT / on-policy distillation. + + Teacher logits arrive as CUDA-IPC handles produced by + ``compute_teacher_logits_ipc`` on the teacher worker (one entry per + teacher rank); the student rebuilds the per-rank handle locally and + feeds it to ``XTokenTeacherIPCLossPostProcessor``. The cross-tokenizer + loss function is the one cached via ``init_cross_tokenizer_loss_fn``; + per-step teacher data must be set via ``update_cross_tokenizer_data`` + before this call. + """ + if loss_fn is None: + loss_fn = getattr(self, "_cached_loss_fn", None) + assert loss_fn is not None, ( + "train_off_policy_distillation requires either an explicit loss_fn " + "or a cached one set via init_cross_tokenizer_loss_fn" + ) + if gbs is None: + gbs = self.cfg["train_global_batch_size"] + if mbs is None: + mbs = self.cfg["train_micro_batch_size"] + local_gbs = gbs // self.dp_size + total_dataset_size = torch.tensor(data.size, device="cuda") + torch.distributed.all_reduce( + total_dataset_size, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_mesh.get_group(), + ) + num_global_batches = int(total_dataset_size.item()) // gbs + + sequence_dim, _ = check_sequence_dim(data) + if eval_mode: + ctx: AbstractContextManager[Any] = torch.no_grad() + self.model.eval() + else: + ctx = nullcontext() + self.model.train() + + teacher_worker_result = None + if teacher_logits is not None: + # Both shapes pre-shard the per-rank payload before reaching here: + # - single teacher: ``list[rank]`` of per-rank dicts + # - multi-teacher: ``dict[rank, list[T_payloads]]`` (built by + # ``_group_teacher_logits_by_rank``) + # Indexing by rank yields the right per-rank entry in both cases. + rank = torch.distributed.get_rank() + teacher_worker_result = teacher_logits[rank] + + def train_context_fn(processed_inputs): + return get_train_context( + cp_size=self.cp_size, + cp_mesh=self.cp_mesh, + cp_buffers=processed_inputs.cp_buffers, + sequence_dim=sequence_dim, + dtype=self.dtype, + autocast_enabled=self.autocast_enabled, + ) + + empty_cache_steps = self.cfg.get("dtensor_cfg", {}).get( + "clear_cache_every_n_steps" + ) + if empty_cache_steps: + warnings.warn( + f"Emptying cache every {empty_cache_steps} microbatches; doing so unnecessarily would incur a large performance overhead.", + ) + + loss_post_processor = XTokenTeacherIPCLossPostProcessor( + loss_fn=loss_fn, + cfg=self.cfg, + device_mesh=self.device_mesh, + cp_mesh=self.cp_mesh, + tp_mesh=self.tp_mesh, + cp_size=self.cp_size, + dp_size=self.dp_size, + enable_seq_packing=self.enable_seq_packing, + sampling_params=None, + teacher_result=teacher_worker_result, + ) + + def on_microbatch_start(mb_idx): + loss_post_processor.set_microbatch_index(mb_idx) + if empty_cache_steps and mb_idx % empty_cache_steps == 0: + torch.cuda.empty_cache() + + with ctx: + data = data.to("cuda") + losses: list[float] = [] + all_mb_metrics: list[dict[str, Any]] = [] + grad_norm: Optional[float | torch.Tensor] = None + + for gb_idx in range(num_global_batches): + gb_result = process_global_batch( + data, + loss_fn, + self.dp_mesh.get_group(), + batch_idx=gb_idx, + batch_size=local_gbs, + ) + batch = gb_result["batch"] + global_valid_seqs = gb_result["global_valid_seqs"] + global_valid_toks = gb_result["global_valid_toks"] + + self.optimizer.zero_grad() + processed_iterator, iterator_len = get_microbatch_iterator( + batch, + self.cfg, + mbs, + self.dp_mesh, + tokenizer=self.tokenizer, + cp_size=self.cp_size, + ) + + mb_results = automodel_forward_backward( + model=self.model, + data_iterator=processed_iterator, + post_processing_fn=loss_post_processor, + forward_only=eval_mode, + is_reward_model=self._is_reward_model, + allow_flash_attn_args=self.allow_flash_attn_args, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + sampling_params=None, + sequence_dim=sequence_dim, + dp_size=self.dp_size, + cp_size=self.cp_size, + num_global_batches=num_global_batches, + train_context_fn=train_context_fn, + num_valid_microbatches=iterator_len, + on_microbatch_start=on_microbatch_start, + ) + + mb_losses: list[float] = [] + for mb_idx, (loss, loss_metrics) in enumerate(mb_results): + if mb_idx < iterator_len: + num_valid_samples = loss_metrics["num_valid_samples"] + loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] + loss_metrics["global_valid_seqs"] = global_valid_seqs.item() + loss_metrics["global_valid_toks"] = global_valid_toks.item() + if num_valid_samples > 0: + mb_losses.append(loss.item()) + all_mb_metrics.append(loss_metrics) + + if not eval_mode: + grad_norm = scale_grads_and_clip_grad_norm( + self.max_grad_norm, + [self.model], + norm_type=2.0, + pp_enabled=False, + device_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + ep_axis_name="ep" + if self.moe_mesh is not None + and "ep" in self.moe_mesh.mesh_dim_names + else None, + pp_axis_name=None, + foreach=True, + num_label_tokens=1, + dp_group_size=self.dp_size * self.cp_size, + ) + grad_norm = torch.tensor( + grad_norm, device="cpu", dtype=torch.float32 + ) + self.optimizer.step() + + losses.append(torch.tensor(mb_losses).sum().item()) + + self.optimizer.zero_grad() + if not eval_mode: + self.scheduler.step() + torch.cuda.empty_cache() + + return aggregate_training_statistics( + losses=losses, + all_mb_metrics=all_mb_metrics, + grad_norm=grad_norm, + dp_group=self.dp_mesh.get_group(), + dtype=self.dtype, + ) + + @wrap_with_nvtx_name("dtensor_policy_worker_v2/compute_teacher_logits_ipc") + def compute_teacher_logits_ipc( + self, + data: BatchedDataDict[Any], + topk_logits: Optional[int] = None, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> dict[str, Any]: + """Run the teacher forward and export per-microbatch logits via CUDA IPC. + + Returns a dict with ``microbatch_handles`` (list of per-microbatch + IPC handle dicts) and ``is_topk`` (bool, indicates the handles carry + top-k values+indices instead of full vocab logits). Consumed by the + student's ``train_off_policy_distillation``. + """ + if gbs is None: + gbs = self.cfg["train_global_batch_size"] + if mbs is None: + mbs = self.cfg["train_micro_batch_size"] + local_gbs = gbs // self.dp_size + total_dataset_size = torch.tensor(data.size, device="cuda") + torch.distributed.all_reduce( + total_dataset_size, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_mesh.get_group(), + ) + num_global_batches = int(total_dataset_size.item()) // gbs + + sequence_dim, _ = check_sequence_dim(data) + ctx: AbstractContextManager[Any] = torch.no_grad() + self.model.eval() + + def train_context_fn(processed_inputs): + return get_train_context( + cp_size=self.cp_size, + cp_mesh=self.cp_mesh, + cp_buffers=processed_inputs.cp_buffers, + sequence_dim=sequence_dim, + dtype=self.dtype, + autocast_enabled=self.autocast_enabled, + ) + + empty_cache_steps = self.cfg.get("dtensor_cfg", {}).get( + "clear_cache_every_n_steps" + ) + if empty_cache_steps: + warnings.warn( + f"Emptying cache every {empty_cache_steps} microbatches; doing so unnecessarily would incur a large performance overhead.", + ) + + teacher_post_processor = XTokenTeacherIPCExportPostProcessor( + loss_fn=getattr(self, "_cached_loss_fn", None), + cfg=self.cfg, + device_mesh=self.device_mesh, + cp_mesh=self.cp_mesh, + tp_mesh=self.tp_mesh, + cp_size=self.cp_size, + dp_size=self.dp_size, + enable_seq_packing=self.enable_seq_packing, + sampling_params=None, + topk_logits=topk_logits, + is_mdlm=self.cfg.get("is_mdlm", False), + ) + + def on_microbatch_start(mb_idx): + teacher_post_processor.set_microbatch_index(mb_idx) + if empty_cache_steps and mb_idx % empty_cache_steps == 0: + torch.cuda.empty_cache() + + with ctx: + data = data.to("cuda") + for gb_idx in range(num_global_batches): + gb_result = process_global_batch( + data, + getattr(self, "_cached_loss_fn", None), + self.dp_mesh.get_group(), + batch_idx=gb_idx, + batch_size=local_gbs, + ) + batch = gb_result["batch"] + global_valid_seqs = gb_result["global_valid_seqs"] + global_valid_toks = gb_result["global_valid_toks"] + + processed_iterator, iterator_len = get_microbatch_iterator( + batch, + self.cfg, + mbs, + self.dp_mesh, + tokenizer=self.tokenizer, + cp_size=self.cp_size, + ) + + automodel_forward_backward( + model=self.model, + data_iterator=processed_iterator, + post_processing_fn=teacher_post_processor, + forward_only=True, + is_reward_model=self._is_reward_model, + allow_flash_attn_args=self.allow_flash_attn_args, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + sampling_params=None, + sequence_dim=sequence_dim, + dp_size=self.dp_size, + cp_size=self.cp_size, + num_global_batches=num_global_batches, + train_context_fn=train_context_fn, + num_valid_microbatches=iterator_len, + on_microbatch_start=on_microbatch_start, + ) + break + + # Ensure writes to IPC-exported buffers are complete before returning handles. + torch.cuda.current_stream().synchronize() + return { + "microbatch_handles": teacher_post_processor.microbatch_handles, + "is_topk": topk_logits is not None, + } + @wrap_with_nvtx_name("dtensor_policy_worker_v2/get_logprobs") def get_logprobs( self, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None @@ -1005,6 +1531,21 @@ def prepare_for_lp_inference(self) -> None: gc.collect() torch.cuda.empty_cache() + @wrap_with_nvtx_name("dtensor_policy_worker_v2/move_optimizer_to_cuda") + def move_optimizer_to_cuda(self) -> None: + """Move optimizer state back to CUDA. + + Used by off-policy distillation when ``keep_models_resident=False``: + ``offload_after_refit`` moves optimizer state to CPU between teacher + inference and student training, but ``prepare_for_training`` only + re-onloads the optimizer for the logprob/colocated-generation + offload cases. This is the explicit onload for the off-policy + distillation offload path. No-op if no optimizer or if cpu offload + is in effect. + """ + if self.optimizer is not None and not self.cpu_offload: + self.move_optimizer_to_device("cuda") + @wrap_with_nvtx_name("dtensor_policy_worker_v2/prepare_for_training") def prepare_for_training(self, *args, **kwargs) -> None: # onload models and optimizer state to cuda diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 8df5e1f15c..1b04931ecb 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -445,6 +445,58 @@ def train( metrics["moe_metrics"] = moe_metrics return metrics + def train_off_policy_distillation( + self, + data: BatchedDataDict[Any], + teacher_logits: Optional[Any] = None, + loss_fn: Optional[LossFunction] = None, + eval_mode: bool = False, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> dict[str, Any]: + raise NotImplementedError( + "MegatronPolicyWorker does not support off-policy distillation; " + "use DTensorPolicyWorkerV2 (set policy.dtensor_cfg._v2=true)." + ) + + def compute_teacher_logits_ipc( + self, + data: BatchedDataDict[Any], + topk_logits: Optional[int] = None, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> dict[str, Any]: + raise NotImplementedError( + "MegatronPolicyWorker does not support compute_teacher_logits_ipc; " + "use DTensorPolicyWorkerV2 (set policy.dtensor_cfg._v2=true)." + ) + + def init_cross_tokenizer_loss_fn( + self, loss_config: Any, token_aligner_config: Any + ) -> None: + raise NotImplementedError( + "MegatronPolicyWorker does not support cross-tokenizer distillation; " + "use DTensorPolicyWorkerV2 (set policy.dtensor_cfg._v2=true)." + ) + + def update_cross_tokenizer_data( + self, + teacher_input_ids: Any, + aligned_pairs: Any, + teacher_idx: Optional[int] = None, + chunk_indices: Optional[dict[str, Any]] = None, + ) -> None: + raise NotImplementedError( + "MegatronPolicyWorker does not support cross-tokenizer distillation; " + "use DTensorPolicyWorkerV2 (set policy.dtensor_cfg._v2=true)." + ) + + def move_optimizer_to_cuda(self) -> None: + raise NotImplementedError( + "MegatronPolicyWorker does not support move_optimizer_to_cuda; " + "use DTensorPolicyWorkerV2 (set policy.dtensor_cfg._v2=true)." + ) + @wrap_with_nvtx_name("megatron_policy_worker/get_logprobs") def get_logprobs( self, *, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None diff --git a/nemo_rl/utils/x_token/__init__.py b/nemo_rl/utils/x_token/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/nemo_rl/utils/x_token/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_rl/utils/x_token/minimal_projection_generator.py b/nemo_rl/utils/x_token/minimal_projection_generator.py new file mode 100644 index 0000000000..4c71dd1490 --- /dev/null +++ b/nemo_rl/utils/x_token/minimal_projection_generator.py @@ -0,0 +1,585 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import os +import argparse +from transformers import AutoTokenizer, AutoModel, AutoConfig +# from sentence_transformers import SentenceTransformer +from tqdm.auto import tqdm +import re +import pdb + + +##### verify KL and top5 with this matrix + + +###### use config vocab size, not tokenizer + +EXACT_MATCH_ONLY = False + +# --- Configuration and Setup --- +parser = argparse.ArgumentParser(description="Generate a sparse projection map between two tokenizers.") +parser.add_argument("--model_a_index", type=int, default=1, help="Index of the source model (Model A / Student).") +parser.add_argument("--model_b_index", type=int, default=0, help="Index of the target model (Model B / Teacher).") +parser.add_argument("--model_a_name", type=str, default=None, help="HuggingFace model name for source model (Model A / Student). If provided, overrides model_a_index.") +parser.add_argument("--model_b_name", type=str, default=None, help="HuggingFace model name for target model (Model B / Teacher). If provided, overrides model_b_index.") +parser.add_argument("--keep_top_tokens", type=int, default=-1, help="Number of top tokens to keep for each vocabulary. -1 means all.") +parser.add_argument("--data_dir", type=str, default="cross_tokenizer_data/", help="Directory for importance scores and cached data.") +parser.add_argument("--top_k", type=int, default=10, help="Number of top projections to keep for each token.") +parser.add_argument("--weight_threshold", type=float, default=0.0, help="Minimum weight threshold to keep a projection. Values below this will be filtered out.") +parser.add_argument("--force_recompute", action='store_true', help="Force recomputation of embeddings even if cached files exist.") +parser.add_argument("--skip_exact_enforcement", action='store_true', help="Skip enforcing exact matches between tokens.") +parser.add_argument("--use_canonicalization", action='store_true', help="Apply token canonicalization before generating embeddings to normalize different tokenizer representations (e.g., Ġ vs ▁ prefixes, Ċ vs \\n).") +args = parser.parse_args() + +args.skip_exact_enforcement = True + +MODEL_LIST = [ + "nvidia/Mistral-NeMo-Minitron-8B-Base", + "Qwen/Qwen3-8B-Base", + "meta-llama/Llama-3.2-1B", + "meta-llama/Llama-3.1-8B", + "google/gemma-3-4b-it", + "google/gemma-2b", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "openai/gpt-oss-20b", + "microsoft/phi-4", + "google/gemma-3-12b-pt", +] +EMBEDDING_MODEL_CHOICES = [ + {"name": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", "type": "sbert"}, + {"name": "sentence-transformers/all-mpnet-base-v2", "type": "sbert"}, + {"name": "sentence-transformers/all-MiniLM-L6-v2", "type": "sbert"}, + {"name": "Qwen/Qwen3-Embedding-4B", "type": "llm_first_layer"}, + {"name": "Qwen/Qwen3-Embedding-0.6B", "type": "llm_first_layer"}, +] + +MAX_SEQ_LENGTH_EMBEDDING = 64 +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def sinkhorn(A, n_iters=10): + for _ in range(n_iters): + if _ % 2 == 0: + # A = A / (A.sum(dim=0, keepdim=True) + 1e-6) + col_sums = A.sum(dim=0, keepdim=True) + safe_col_sums = torch.where(col_sums == 0, torch.ones_like(col_sums), col_sums) + A = A / safe_col_sums + else: + #0, 2, 4, 6 + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + + return A + +def sinkhorn_one_dim(A, n_iters=1): + for _ in range(n_iters): + + + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + + return A + +# --- Helper Functions --- + +def clean_model_name_for_filename(name: str) -> str: + """Removes parameter counts and common suffixes from model names for cleaner filenames.""" + # Removes patterns like -8B, -1.5B, -4b, -125m etc. + cleaned_name = re.sub(r'-?[0-9\.]+[bBmB]', '', name, flags=re.IGNORECASE) + # Remove common suffixes + cleaned_name = cleaned_name.replace('-Base', '').replace('-it', '').replace('-Instruct', '') + # Clean up any leading/trailing hyphens that might result + cleaned_name = cleaned_name.strip('-_') + if 'mini' in name: + cleaned_name += "_mini" + return cleaned_name + +def load_tokenizer(model_id_or_path): + """Loads a HuggingFace tokenizer, setting a pad token if necessary.""" + try: + tok = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok + except Exception as e: + print(f"Error loading tokenizer for model '{model_id_or_path}': {e}") + print(f"Available models in MODEL_LIST (indices 0-{len(MODEL_LIST)-1}):") + for i, model in enumerate(MODEL_LIST): + print(f" {i}: {model}") + raise + +def validate_model_selection(args): + """Validates that the model selection arguments are valid.""" + # Check if both name and index are provided for the same model + if args.model_a_name is not None and args.model_a_index != 1: # 1 is the default + print("Warning: Both --model_a_name and --model_a_index provided. Using --model_a_name.") + + if args.model_b_name is not None and args.model_b_index != 0: # 0 is the default + print("Warning: Both --model_b_name and --model_b_index provided. Using --model_b_name.") + + # Validate indices if names are not provided + if args.model_a_name is None: + if args.model_a_index < 0 or args.model_a_index >= len(MODEL_LIST): + raise ValueError(f"model_a_index {args.model_a_index} is out of range. Available models: 0-{len(MODEL_LIST)-1}") + + if args.model_b_name is None: + if args.model_b_index < 0 or args.model_b_index >= len(MODEL_LIST): + raise ValueError(f"model_b_index {args.model_b_index} is out of range. Available models: 0-{len(MODEL_LIST)-1}") + + # Check if the same model is selected for both A and B + model_a_id = args.model_a_name if args.model_a_name is not None else MODEL_LIST[args.model_a_index] + model_b_id = args.model_b_name if args.model_b_name is not None else MODEL_LIST[args.model_b_index] + + if model_a_id == model_b_id: + raise ValueError(f"Cannot use the same model for both A and B: {model_a_id}") + +def save_data(data, filename): + """Saves data to a torch file.""" + os.makedirs(os.path.dirname(filename), exist_ok=True) + torch.save(data.cpu(), filename) + print(f"Data saved to {filename}") + +def load_data(filename): + """Loads data from a torch file.""" + return torch.load(filename) + +def get_llm_first_layer_embeddings(decoded_tokens_list, llm_embedding_tokenizer, llm_embedding_model, max_seq_length_embedding, device, batch_size=32): + """Generates embeddings using the first layer of a given LLM.""" + all_embeddings = [] + llm_embedding_model.eval() + embedding_dim = llm_embedding_model.config.hidden_size + + for i in tqdm(range(0, len(decoded_tokens_list), batch_size), desc="Encoding tokens with LLM"): + batch_tokens = decoded_tokens_list[i:i + batch_size] + inputs = llm_embedding_tokenizer( + batch_tokens, return_tensors="pt", padding=True, truncation=True, + max_length=max_seq_length_embedding, add_special_tokens=False, + ).to(device) + + with torch.no_grad(): + outputs = llm_embedding_model(**inputs, output_hidden_states=True) + first_layer_output = outputs.hidden_states[0] + + for k in range(first_layer_output.shape[0]): + valid_token_mask = inputs['attention_mask'][k] == 1 + if valid_token_mask.sum() > 0: + pooled_embedding = first_layer_output[k, valid_token_mask].mean(dim=0) + all_embeddings.append(pooled_embedding) + else: + all_embeddings.append(torch.zeros(embedding_dim, device=device)) + + return torch.stack(all_embeddings).to(device) + + +def compute_chunked_projection_map(embeddings_query, embeddings_corpus, args, device, chunk_size=1000): + """Computes projection map in chunks to save memory.""" + num_queries = embeddings_query.shape[0] + target_vocab_size = embeddings_corpus.shape[0] + + # Pre-allocate result tensors + all_top_k_indices = torch.zeros((num_queries, args.top_k), dtype=torch.long) + all_top_k_likelihoods = torch.zeros((num_queries, args.top_k), dtype=torch.float32) + + # Normalize corpus embeddings once + embeddings_corpus_norm = torch.nn.functional.normalize(embeddings_corpus.to(device).float(), p=2, dim=1) + + for chunk_start in tqdm(range(0, num_queries, chunk_size), desc="Processing chunks"): + chunk_end = min(chunk_start + chunk_size, num_queries) + chunk_query = embeddings_query[chunk_start:chunk_end].to(device).float() + + with torch.no_grad(): + # Compute similarities for this chunk + chunk_query_norm = torch.nn.functional.normalize(chunk_query, p=2, dim=1) + similarities = torch.matmul(chunk_query_norm, embeddings_corpus_norm.t()) + + # Generate projection map for this chunk + chunk_top_k_indices, chunk_top_k_likelihoods = generate_projection_map_chunk(similarities, args) + + # Store results + all_top_k_indices[chunk_start:chunk_end] = chunk_top_k_indices.cpu() + all_top_k_likelihoods[chunk_start:chunk_end] = chunk_top_k_likelihoods.cpu() + + # Clear GPU memory + del similarities, chunk_query_norm, chunk_top_k_indices, chunk_top_k_likelihoods + torch.cuda.empty_cache() + + return all_top_k_indices, all_top_k_likelihoods + +def generate_projection_map_chunk(similarities, args): + """Calculates the sparse likelihood map from a similarity matrix chunk.""" + similarities = similarities.abs() + similarities[similarities > 0.999999999] = 1.0 + max_similarities = torch.max(similarities, dim=1, keepdim=True)[0] + sharpness = 10.0 * max_similarities + likelihood = similarities ** sharpness + + # Normalize rows + likelihood = sinkhorn_one_dim(likelihood) + + # Extract final top-k values from the normalized sparse likelihood matrix + top_k_likelihood, top_k_indices = likelihood.topk(args.top_k, dim=1) + + # Apply weight threshold filtering if specified + if args.weight_threshold > 0.0: + threshold_mask = top_k_likelihood >= args.weight_threshold + top_k_indices = top_k_indices.where(threshold_mask, torch.full_like(top_k_indices, -1)) + + return top_k_indices, top_k_likelihood + +def project_token_likelihoods(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device): + """Projects token likelihoods from a source to a target vocabulary using a sparse map.""" + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + if source_vocab_size != projection_map_indices.shape[0]: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches projection map size ({projection_map_indices.shape[0]})") + + top_k = projection_map_indices.shape[1] + input_likelihoods = input_likelihoods.to(device) + projection_map_indices = projection_map_indices.to(device) + projection_map_values = projection_map_values.to(device) + + crow_indices = torch.arange(0, (source_vocab_size + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = projection_map_indices.flatten() + values = projection_map_values.flatten() + + sparse_projection_matrix = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(source_vocab_size, target_vocab_size), device=device + ) + + reshaped_input = input_likelihoods.reshape(batch_size * seq_len, source_vocab_size) + projected_likelihoods_reshaped = torch.matmul(reshaped_input, sparse_projection_matrix) + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, target_vocab_size) + +def debug_projection_map(top_k_indices, top_k_likelihood, source_tokenizer, target_tokenizer, direction="", N=2000): + """Debug function to show first N rows with decoded tokens and weights.""" + N = min(N, top_k_indices.shape[0]) # Show first N rows or less + print(f"\n--- Debugging projection map {direction} (first {N} rows) ---") + + for row_idx in range(N): + # for row_idx in range(-N,-1): + # Decode source token + try: + token_id = row_idx if row_idx >= 0 else top_k_indices.shape[0] + row_idx + source_token = source_tokenizer.decode([token_id]) + # source_token = source_tokenizer.convert_ids_to_tokens([token_id])[0] + source_token_str = repr(source_token) # Use repr to show special chars + except: + source_token_str = f"" + + # Build the target tokens with weights string + row_indices = top_k_indices[row_idx].cpu().numpy() + row_weights = top_k_likelihood[row_idx].float().cpu().numpy() + + weight_total = 0 + target_parts = [] + + if row_weights.max() != row_weights[-1]: + continue + + + for target_idx, weight in zip(row_indices, row_weights): + try: + target_token = target_tokenizer.decode([target_idx]) + target_token_str = repr(target_token) + except: + target_token_str = f"" + + + target_parts.append(f"{target_token_str}({weight:.4f})") + weight_total += weight + + target_string = " ".join(target_parts) + # print(f"Weight total: {weight_total:.4f}") + print(f"{source_token_str} -> {target_string}") + +def generate_projection_map(similarities, args): + """Calculates the sparse likelihood map from a similarity matrix.""" + similarities = similarities.abs() + similarities[similarities > 0.999999999] = 1.0 + max_similarities = torch.max(similarities, dim=1, keepdim=True)[0] + sharpness = 10.0 * max_similarities + likelihood = similarities ** sharpness + + # Create a sparse representation by keeping only top-k values + # top_k_likelihood_pre_norm, _ = likelihood.topk(args.top_k, dim=1) + # likelihood = likelihood.where(likelihood >= top_k_likelihood_pre_norm[:, -1:], torch.zeros_like(likelihood)) + + # Normalize the row to sum to 1, handling rows that are all zero + # row_sums = likelihood.sum(dim=1, keepdim=True) + # safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + # likelihood = likelihood / safe_row_sums + # pdb.set_trace() + # likelihood = sinkhorn_one_dim(likelihood) + + # Get the final top-k values and their indices from the sparse, normalized likelihood matrix + top_k_likelihood, top_k_indices = likelihood.topk(args.top_k, dim=1) + + # Store top-k values before zeroing (to avoid losing them) + row_indices = torch.arange(likelihood.shape[0]).unsqueeze(1).expand(-1, args.top_k) + top_k_values = likelihood[row_indices, top_k_indices].clone() + + # Zero out entire likelihood matrix in-place, then restore only top-k elements + likelihood.zero_() + likelihood[row_indices, top_k_indices] = top_k_values + + # likelihood = sinkhorn(likelihood, n_iters=1) + # likelihood = sinkhorn(likelihood, n_iters=1) works the best + + + likelihood = sinkhorn_one_dim(likelihood) + + # Extract final top-k values from the normalized sparse likelihood matrix + top_k_likelihood, top_k_indices = likelihood.topk(args.top_k, dim=1) + + # Apply weight threshold filtering if specified + if args.weight_threshold > 0.0: + print(f"Applying weight threshold filter: {args.weight_threshold}") + # Create mask for values above threshold + # pdb.set_trace() + threshold_mask = top_k_likelihood >= args.weight_threshold + + #set indices to -1 where threshold is not met + top_k_indices = top_k_indices.where(threshold_mask, torch.full_like(top_k_indices, -1)) + + # # Count how many values per row are above threshold + # valid_counts = threshold_mask.sum(dim=1) + # total_filtered = (valid_counts == 0).sum().item() + # total_kept = threshold_mask.sum().item() + # total_possible = top_k_likelihood.numel() + + # print(f"Kept {total_kept}/{total_possible} ({100*total_kept/total_possible:.1f}%) projections above threshold") + + # if total_filtered > 0: + # print(f"Warning: {total_filtered} tokens have no projections above threshold {args.weight_threshold}") + + # # Zero out values below threshold + # filtered_likelihood = top_k_likelihood * threshold_mask.to(top_k_likelihood.dtype) + # filtered_indices = top_k_indices.clone() + + # # For rows with no values above threshold, keep the top value to avoid empty rows + # empty_rows = valid_counts == 0 + # if empty_rows.any(): + # print(f"Keeping top projection for {empty_rows.sum().item()} tokens with no values above threshold") + # filtered_likelihood[empty_rows, 0] = top_k_likelihood[empty_rows, 0] + + # top_k_likelihood = filtered_likelihood + # top_k_indices = filtered_indices + + # pdb.set_trace() + + return top_k_indices, top_k_likelihood + +# --- Main Execution --- +if __name__ == "__main__": + # Validate model selection arguments + validate_model_selection(args) + + # 1. Load Tokenizers and deterministically assign A and B + # Use model names if provided, otherwise use indices + if args.model_a_name is not None: + model_1 = {'id': args.model_a_name} + print(f"Using provided model A name: {args.model_a_name}") + else: + model_1 = {'id': MODEL_LIST[args.model_a_index]} + print(f"Using model A from index {args.model_a_index}: {model_1['id']}") + + model_1['name'] = model_1['id'].split("/")[-1] + print(f"Loading first tokenizer: {model_1['name']}") + model_1['tokenizer'] = load_tokenizer(model_1['id']) + + if args.model_b_name is not None: + model_2 = {'id': args.model_b_name} + print(f"Using provided model B name: {args.model_b_name}") + else: + model_2 = {'id': MODEL_LIST[args.model_b_index]} + print(f"Using model B from index {args.model_b_index}: {model_2['id']}") + + model_2['name'] = model_2['id'].split("/")[-1] + print(f"Loading second tokenizer: {model_2['name']}") + model_2['tokenizer'] = load_tokenizer(model_2['id']) + + # Deterministically assign model_A and model_B based on alphabetical order of names + if model_1['name'] > model_2['name']: + model_A, model_B = model_2, model_1 + else: + model_A, model_B = model_1, model_2 + + print(f"\nAssigned Source (A): {model_A['name']}") + print(f"Assigned Target (B): {model_B['name']}") + + source_vocab_size = model_A['tokenizer'].vocab_size + target_vocab_size = model_B['tokenizer'].vocab_size + # get the top k tokens from the source and target vocab from model config file + model_A_config = AutoConfig.from_pretrained(model_A['id'], trust_remote_code=True if 'nvidia' in model_A['id'] else False) + model_B_config = AutoConfig.from_pretrained(model_B['id'], trust_remote_code=True if 'nvidia' in model_B['id'] else False) + # pdb.set_trace() + source_vocab_size = model_A_config.vocab_size + if "gemma" not in model_B['id']: + target_vocab_size = model_B_config.vocab_size + else: + target_vocab_size = model_B_config.text_config.vocab_size + # print(f"Source top k tokens: {model_A_top_k_tokens}") + # print(f"Target top k tokens: {model_B_top_k_tokens}") + + print(f"Source vocab size (full): {source_vocab_size}") + print(f"Target vocab size (full): {target_vocab_size}") + # exit() + + + + if 0: + # just debugging learned projection map + # learned_projection_map = torch.load("models/runs/s4_l1q4b_lr0_kl1_ce0_k1_emb_top10_transformation_matrices/learned_projection_map_latest.pt") + # learned_projection_map = torch.load("cross_tokenizer_data/projection_map_Llama-3.2_to_Qwen3_multitoken_top_64_double.pt") + learned_projection_map = torch.load("cross_tokenizer_data/projection_matrix_learned_llama_qwen_top5.pt") + top_k_indices_A_to_B = learned_projection_map["indices"] + top_k_likelihood_A_to_B = learned_projection_map["likelihoods"] + debug_projection_map(top_k_indices_A_to_B, top_k_likelihood_A_to_B, model_A['tokenizer'], model_B['tokenizer'], "A -> B", N=150000) + exit() + + + + + + + + # 2. Select and Load Embedding Model + embedding_model_index = 3 # Default to a good LLM embedder + selected_model_info = EMBEDDING_MODEL_CHOICES[embedding_model_index] + embedding_model_name = selected_model_info["name"] + embedding_model_type = selected_model_info["type"] + print(f"\nUsing embedding model: {embedding_model_name} ({embedding_model_type})") + + # 3. Generate or Load Embeddings + canonicalization_suffix = "_canonical" if args.use_canonicalization else "_raw" + embeddings_path_A = os.path.join(args.data_dir, f"embeddings_{model_A['name']}_{embedding_model_name.replace('/', '_')}_full{canonicalization_suffix}.pt") + embeddings_path_B = os.path.join(args.data_dir, f"embeddings_{model_B['name']}_{embedding_model_name.replace('/', '_')}_full{canonicalization_suffix}.pt") + + if not args.force_recompute and os.path.exists(embeddings_path_A) and os.path.exists(embeddings_path_B): + print("Loading cached embeddings...") + model_A['embeddings'] = load_data(embeddings_path_A).to(DEVICE) + model_B['embeddings'] = load_data(embeddings_path_B).to(DEVICE) + else: + print("Generating new embeddings...") + + # Generate raw decoded tokens + raw_tokens_A = [model_A['tokenizer'].decode([idx]) for idx in range(model_A['tokenizer'].vocab_size)] + raw_tokens_B = [model_B['tokenizer'].decode([idx]) for idx in range(model_B['tokenizer'].vocab_size)] + + # Apply canonicalization if requested + if args.use_canonicalization: + # Import canonicalization function + import sys + sys.path.append('.') + from tokenalign import TokenAligner + + print("Applying token canonicalization before embedding generation...") + decoded_tokens_A = [TokenAligner._canonical_token(token) for token in raw_tokens_A] + decoded_tokens_B = [TokenAligner._canonical_token(token) for token in raw_tokens_B] + + # Show some examples of canonicalization + print("Canonicalization examples:") + for i in range(min(10, len(raw_tokens_A))): + if raw_tokens_A[i] != decoded_tokens_A[i]: + print(f" Model A: '{raw_tokens_A[i]}' -> '{decoded_tokens_A[i]}'") + for i in range(min(10, len(raw_tokens_B))): + if raw_tokens_B[i] != decoded_tokens_B[i]: + print(f" Model B: '{raw_tokens_B[i]}' -> '{decoded_tokens_B[i]}'") + + print(f"Applied canonicalization to {len(decoded_tokens_A)} tokens for model A and {len(decoded_tokens_B)} tokens for model B") + else: + print("Using raw decoded tokens without canonicalization") + decoded_tokens_A = raw_tokens_A + decoded_tokens_B = raw_tokens_B + + if embedding_model_type == "sbert": + sbert_model = SentenceTransformer(embedding_model_name, device=DEVICE) + model_A['embeddings'] = sbert_model.encode(decoded_tokens_A, convert_to_tensor=True, show_progress_bar=True) + model_B['embeddings'] = sbert_model.encode(decoded_tokens_B, convert_to_tensor=True, show_progress_bar=True) + elif embedding_model_type == "llm_first_layer": + llm_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, trust_remote_code=True) + if llm_tokenizer.pad_token is None: llm_tokenizer.pad_token = llm_tokenizer.eos_token + llm_model = AutoModel.from_pretrained( + embedding_model_name, torch_dtype=torch.bfloat16, trust_remote_code=True + ).to(DEVICE) + model_A['embeddings'] = get_llm_first_layer_embeddings(decoded_tokens_A, llm_tokenizer, llm_model, MAX_SEQ_LENGTH_EMBEDDING, DEVICE) + model_B['embeddings'] = get_llm_first_layer_embeddings(decoded_tokens_B, llm_tokenizer, llm_model, MAX_SEQ_LENGTH_EMBEDDING, DEVICE) + + save_data(model_A['embeddings'], embeddings_path_A) + save_data(model_B['embeddings'], embeddings_path_B) + + # 4. Compute Similarity and Generate Projection Maps (chunked to save memory) + print("\nComputing projection map in chunks to save memory...") + chunk_size = 500 # Process 500 tokens at a time to avoid OOM + top_k_indices_A_to_B, top_k_likelihood_A_to_B = compute_chunked_projection_map( + model_A['embeddings'], model_B['embeddings'], args, DEVICE, chunk_size=chunk_size) + + # Note: Exact match enforcement is skipped in chunked mode for simplicity + # The chunked approach processes similarities in small batches to avoid OOM + if 0: + debug_projection_map(top_k_indices_A_to_B, top_k_likelihood_A_to_B, model_A['tokenizer'], model_B['tokenizer'], "A -> B") + + # print("Generating B -> A projection map...") + # top_k_indices_B_to_A, top_k_likelihood_B_to_A = generate_projection_map(similarities.T, args) + # debug_projection_map(top_k_indices_B_to_A, top_k_likelihood_B_to_A, model_B['tokenizer'], model_A['tokenizer'], "B -> A") + + # 5. Save the Combined Projection Map + print("\nSaving combined projection map...") + model_a_clean_name = clean_model_name_for_filename(model_A['name']) + model_b_clean_name = clean_model_name_for_filename(model_B['name']) + # output_filename = f"temp_projection_map_{model_a_clean_name}_to_{model_b_clean_name}_bidirectional_top_{args.top_k}.pt" + output_filename = f"temp_projection_map_{model_a_clean_name}_to_{model_b_clean_name}_top_{args.top_k}" + # if args.skip_exact_enforcement: + # output_filename += "_no_exact" + output_filename += f".pt" + if args.weight_threshold > 0.0: + output_filename = output_filename.replace(".pt", f"_thresh_{args.weight_threshold:.3f}.pt") + output_path = os.path.join(args.data_dir, output_filename) + + torch.save({ + "indices": top_k_indices_A_to_B.cpu(), + "likelihoods": top_k_likelihood_A_to_B.cpu(), + "model_A_id": model_A['id'], + "model_B_id": model_B['id'], + }, output_path) + + # torch.save({ + # "A_to_B": { + # "indices": top_k_indices_A_to_B.cpu(), + # "likelihoods": top_k_likelihood_A_to_B.cpu() + # }, + # "B_to_A": { + # "indices": top_k_indices_B_to_A.cpu(), + # "likelihoods": top_k_likelihood_B_to_A.cpu() + # }, + # "model_A_id": model_A['id'], + # "model_B_id": model_B['id'], + # }, output_path) + print(f"Saved combined projection map to: {output_path}") + + # 6. Example Usage of the Projection Function + print("\n--- Testing projection function (A -> B) ---") + # Create a dummy likelihood tensor: [BATCH, SEQ, vocab_size_A] + source_vocab_size_A = model_A['embeddings'].shape[0] + target_vocab_size_B = model_B['embeddings'].shape[0] + dummy_tensor = torch.randn(1, 4096, source_vocab_size_A, device=DEVICE, dtype=torch.bfloat16) + + # Transform this tensor using the projection map (convert to float32 for compatibility) + projected_tensor = project_token_likelihoods(dummy_tensor.float(), top_k_indices_A_to_B, top_k_likelihood_A_to_B, target_vocab_size_B, DEVICE) + print(f"Input tensor shape: {dummy_tensor.shape}") + print(f"Projected tensor shape: {projected_tensor.shape}") + print("Projection test successful.") diff --git a/nemo_rl/utils/x_token/minimal_projection_via_multitoken.py b/nemo_rl/utils/x_token/minimal_projection_via_multitoken.py new file mode 100644 index 0000000000..ec67e4cc57 --- /dev/null +++ b/nemo_rl/utils/x_token/minimal_projection_via_multitoken.py @@ -0,0 +1,942 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +import torch.nn as nn +from nemo_rl.algorithms.x_token.tokenalign import TokenAligner +import gc +from collections import defaultdict +from datasets import load_dataset, get_dataset_config_names +import random +import time +import numpy as np +import tqdm +import pdb +import difflib +import re +import argparse +import os + + + +###### save as dense format and set indices to -1 where not used + + +#remove all special tokens that start with <| and end with |> + + +# compare 3 ways to estimate likelihood matrix: +# 1. using embeddings from another model, like was done in minimal_projection_generator.py +# 2. using text analysis like in tokenalign_likelihood_estimate.py +# 3. use one token to multiple and assign those as transformation matrix + +# this file implements 3rd way + +def sinkhorn_one_dim(A, n_iters=1): + for _ in range(n_iters): + + + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + + return A + +def apply_canonicalization_if_enabled(token_str, use_canonicalization): + """Apply canonicalization to token string if enabled.""" + if use_canonicalization: + return TokenAligner._canonical_token(token_str) + return token_str + +def create_weight_distribution(num_tokens): + """Create weight distribution for multi-token mappings""" + # if num_tokens == 1: + # return [1.0] + # elif num_tokens == 2: + # return [0.7, 0.3] + # elif num_tokens == 3: + # return [0.6, 0.3, 0.1] + # else: + if 1: + # For more tokens, use exponential decay + weights = [] + base = 0.9 + for i in range(num_tokens): + if i == 0: + weights.append(base) + else: + weights.append(base * (0.1 ** i)) + + # Normalize to sum to 1 + total = sum(weights) + weights = [w / total for w in weights] + return weights + +def find_similar_special_tokens(tokenizer_a, tokenizer_b, similarity_threshold=0.4, top_k_matches=3): + """Find similar special tokens between two tokenizers using string similarity.""" + + def is_special_token(token_str): + """Check if a token looks like a special token""" + return (token_str.startswith('<|') and token_str.endswith('|>')) or \ + (token_str.startswith('<') and token_str.endswith('>')) or \ + token_str in ['', '', '', '', '', ''] + + def extract_special_tokens(tokenizer): + """Extract all special tokens from a tokenizer with their IDs""" + special_tokens = {} + vocab = tokenizer.get_vocab() + for token_str, token_id in vocab.items(): + if is_special_token(token_str): + special_tokens[token_id] = token_str + return special_tokens + + def calculate_similarity(token_a, token_b): + """Calculate similarity between two token strings""" + # Use difflib for sequence similarity + seq_similarity = difflib.SequenceMatcher(None, token_a, token_b).ratio() + + # Extract key words from special tokens for semantic matching + def extract_keywords(token): + # Remove special token markers and split by common separators + cleaned = re.sub(r'[<>|_]', ' ', token.lower()) + words = [w for w in cleaned.split() if len(w) > 2] # Filter short words + return set(words) + + keywords_a = extract_keywords(token_a) + keywords_b = extract_keywords(token_b) + + # Jaccard similarity for keywords + if keywords_a or keywords_b: + keyword_similarity = len(keywords_a.intersection(keywords_b)) / len(keywords_a.union(keywords_b)) + else: + keyword_similarity = 0.0 + + # Combined similarity (weighted average) + return 0.6 * seq_similarity + 0.4 * keyword_similarity + + print("Extracting special tokens...") + special_tokens_a = extract_special_tokens(tokenizer_a) # student + special_tokens_b = extract_special_tokens(tokenizer_b) # teacher + + print(f"Found {len(special_tokens_a)} special tokens in student tokenizer") + print(f"Found {len(special_tokens_b)} special tokens in teacher tokenizer") + + # Find matches + special_token_mappings = [] + + print("Finding similar special tokens...") + for token_id_a, token_str_a in special_tokens_a.items(): + similarities = [] + for token_id_b, token_str_b in special_tokens_b.items(): + similarity = calculate_similarity(token_str_a, token_str_b) + if similarity >= similarity_threshold: + similarities.append((token_id_b, token_str_b, similarity)) + + # Sort by similarity and take top-k + similarities.sort(key=lambda x: x[2], reverse=True) + for token_id_b, token_str_b, similarity in similarities[:top_k_matches]: + special_token_mappings.append({ + 'student_id': token_id_a, + 'student_token': token_str_a, + 'teacher_id': token_id_b, + 'teacher_token': token_str_b, + 'similarity': similarity + }) + + return special_token_mappings + + +def parse_arguments(): + """Parse command line arguments for the multi-token projection script.""" + parser = argparse.ArgumentParser( + description="Generate multi-token projection mappings between tokenizers", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + # Model selection arguments + parser.add_argument( + "--student-model", + type=str, + default="meta-llama/Llama-3.2-1B", + help="Student model name or path" + ) + parser.add_argument( + "--teacher-model", + type=str, + default="microsoft/phi-4", + help="Teacher model name or path" + ) + + # Boolean flags + parser.add_argument( + "--enable-scale-trick", + action="store_true", + default=True, + help="Enable scale trick (set last column likelihood to 0.2)" + ) + parser.add_argument( + "--disable-scale-trick", + action="store_false", + dest="enable_scale_trick", + help="Disable scale trick" + ) + parser.add_argument( + "--enable-reverse-pass", + action="store_true", + default=True, + help="Enable second pass: student tokens -> teacher tokens" + ) + parser.add_argument( + "--disable-reverse-pass", + action="store_false", + dest="enable_reverse_pass", + help="Disable reverse pass" + ) + parser.add_argument( + "--enable-exact-match", + action="store_true", + default=False, + help="Enable exact match enforcement for identical tokens" + ) + parser.add_argument( + "--use-raw-tokens", + action="store_true", + default=False, + help="Use convert_ids_to_tokens instead of decode, should be False" + ) + parser.add_argument( + "--enable-special-token-mapping", + action="store_true", + default=True, + help="Enable mapping of similar special tokens" + ) + parser.add_argument( + "--disable-special-token-mapping", + action="store_false", + dest="enable_special_token_mapping", + help="Disable special token mapping" + ) + parser.add_argument( + "--use-canonicalization", + action="store_true", + default=False, + help="Apply token canonicalization before processing to normalize different tokenizer representations (e.g., Ġ vs ▁ prefixes, Ċ vs \\n)" + ) + + # Numeric parameters + parser.add_argument( + "--tokens-to-cut", + type=int, + default=4, + help="Maximum number of tokens to consider for multi-token mappings" + ) + parser.add_argument( + "--top-k", + type=int, + default=32, + help="Number of top projections to keep for each token" + ) + parser.add_argument( + "--special-token-similarity-threshold", + type=float, + default=0.3, + help="Minimum similarity threshold for special token matching" + ) + parser.add_argument( + "--special-token-top-k", + type=int, + default=None, + help="Top K matches for each special token (defaults to --top-k value)" + ) + + # File paths + parser.add_argument( + "--initial-projection-path", + type=str, + default=None, + help="Path to initial projection map to load and extend" + ) + parser.add_argument( + "--output-dir", + type=str, + default="cross_tokenizer_data", + help="Output directory for saving projection maps" + ) + + return parser.parse_args() + + +if __name__ == "__main__": + # Parse command line arguments + args = parse_arguments() + + # Configuration from arguments + ENABLE_SCALE_TRICK = args.enable_scale_trick + ENABLE_REVERSE_PASS = args.enable_reverse_pass + ENABLE_EXACT_MATCH = args.enable_exact_match + + TOKENS_TO_CUT = args.tokens_to_cut + TOP_K = args.top_k + USE_RAW_TOKENS = args.use_raw_tokens + INITIAL_PROJECTION_PATH = args.initial_projection_path + ENABLE_SPECIAL_TOKEN_MAPPING = args.enable_special_token_mapping + SPECIAL_TOKEN_SIMILARITY_THRESHOLD = args.special_token_similarity_threshold + SPECIAL_TOKEN_TOP_K = args.special_token_top_k if args.special_token_top_k is not None else TOP_K + USE_CANONICALIZATION = args.use_canonicalization + + # Model names from arguments + teacher_model_name = args.teacher_model + student_model_name = args.student_model + + # Print configuration + print("=== Configuration ===") + print(f"Student model: {student_model_name}") + print(f"Teacher model: {teacher_model_name}") + print(f"Enable scale trick: {ENABLE_SCALE_TRICK}") + print(f"Enable reverse pass: {ENABLE_REVERSE_PASS}") + print(f"Enable exact match: {ENABLE_EXACT_MATCH}") + print(f"Use raw tokens: {USE_RAW_TOKENS}") + print(f"Use canonicalization: {USE_CANONICALIZATION}") + print(f"Tokens to cut: {TOKENS_TO_CUT}") + print(f"Top K: {TOP_K}") + print(f"Enable special token mapping: {ENABLE_SPECIAL_TOKEN_MAPPING}") + if ENABLE_SPECIAL_TOKEN_MAPPING: + print(f"Special token similarity threshold: {SPECIAL_TOKEN_SIMILARITY_THRESHOLD}") + print(f"Special token top K: {SPECIAL_TOKEN_TOP_K}") + print(f"Initial projection path: {INITIAL_PROJECTION_PATH}") + print(f"Output directory: {args.output_dir}") + print("=" * 25) + + tokenizer_student = AutoTokenizer.from_pretrained(student_model_name) + tokenizer_teacher = AutoTokenizer.from_pretrained(teacher_model_name) + + tokenizer_student_total_vocab_size = len(tokenizer_student) + tokenizer_teacher_total_vocab_size = len(tokenizer_teacher) + model_A_config = AutoConfig.from_pretrained(student_model_name) + model_B_config = AutoConfig.from_pretrained(teacher_model_name) + # pdb.set_trace() + if "gemma" not in student_model_name.lower(): + source_vocab_size = model_A_config.vocab_size + else: + source_vocab_size = model_A_config.text_config.vocab_size + + if "gemma" not in teacher_model_name.lower(): + target_vocab_size = model_B_config.vocab_size + else: + target_vocab_size = model_B_config.text_config.vocab_size + + tokenizer_student_total_vocab_size = source_vocab_size + tokenizer_teacher_total_vocab_size = target_vocab_size + # print(f"Source top k tokens: {model_A_top_k_tokens}") + # print(f"Target top k tokens: {model_B_top_k_tokens}") + + print(f"Student tokenizer total vocab size: {tokenizer_student_total_vocab_size}") + print(f"Teacher tokenizer total vocab size: {tokenizer_teacher_total_vocab_size}") + + # Print token processing mode + if USE_RAW_TOKENS: + print("Using raw token representation (convert_ids_to_tokens)") + else: + print("Using decoded token representation (decode)") + + + transformation_counts = defaultdict(float) + import os + if INITIAL_PROJECTION_PATH and os.path.exists(INITIAL_PROJECTION_PATH): + print(f"Loading initial projection from: {INITIAL_PROJECTION_PATH}") + initial_projection_map = torch.load(INITIAL_PROJECTION_PATH, map_location='cpu') + + if isinstance(initial_projection_map, dict) and 'indices' in initial_projection_map and 'likelihoods' in initial_projection_map: + print("Loading from sparse top-k format and converting to transformation_counts.") + indices = initial_projection_map['indices'] + likelihoods = initial_projection_map['likelihoods'] + + loaded_student_model = initial_projection_map.get('model_A_id') + loaded_teacher_model = initial_projection_map.get('model_B_id') + + if loaded_student_model and loaded_student_model != student_model_name: + print(f"Warning: Student model mismatch. Loaded: {loaded_student_model}, Current: {student_model_name}") + if loaded_teacher_model and loaded_teacher_model != teacher_model_name: + print(f"Warning: Teacher model mismatch. Loaded: {loaded_teacher_model}, Current: {teacher_model_name}") + + num_student_tokens = indices.shape[0] + top_k = indices.shape[1] + + for student_id in tqdm.tqdm(range(num_student_tokens), desc="Converting initial projection to counts"): + for k in range(top_k): + teacher_id = indices[student_id, k].item() + if teacher_id != -1: + likelihood = likelihoods[student_id, k].item() + if likelihood > 0: + transformation_counts[(student_id, teacher_id)] = likelihood + + elif torch.is_tensor(initial_projection_map): + if initial_projection_map.is_sparse: + print("Loading from sparse tensor and converting to transformation_counts.") + sparse_matrix = initial_projection_map.coalesce() + map_indices = sparse_matrix.indices() + map_values = sparse_matrix.values() + for i in tqdm.tqdm(range(map_indices.shape[1]), desc="Converting sparse tensor to counts"): + student_id = map_indices[0, i].item() + teacher_id = map_indices[1, i].item() + weight = map_values[i].item() + if weight > 0: + transformation_counts[(student_id, teacher_id)] = weight + else: + print("Loading from dense matrix and converting to transformation_counts.") + dense_matrix = initial_projection_map + non_zero_indices = torch.nonzero(dense_matrix, as_tuple=False) + for idx in tqdm.tqdm(range(non_zero_indices.shape[0]), desc="Converting dense projection to counts"): + student_id = non_zero_indices[idx, 0].item() + teacher_id = non_zero_indices[idx, 1].item() + weight = dense_matrix[student_id, teacher_id].item() + if weight > 0: + transformation_counts[(student_id, teacher_id)] = weight + else: + print(f"Warning: Unrecognized format for initial projection map at {INITIAL_PROJECTION_PATH}. Skipping.") + + print(f"Initialized transformation_counts with {len(transformation_counts)} entries.") + + # pdb.set_trace() + + ignore_tokens = ['<|endoftext|>', '', ] + ignore_student_ids = {tokenizer_student.convert_tokens_to_ids(token) for token in ignore_tokens if token in tokenizer_student.get_vocab()} + ignore_teacher_ids = {tokenizer_teacher.convert_tokens_to_ids(token) for token in ignore_tokens if token in tokenizer_teacher.get_vocab()} + + # Get all teacher tokens and decode them + teacher_vocab = tokenizer_teacher.get_vocab() + teacher_tokens_decoded = {} + + print("Decoding teacher tokens...") + for token_id in tqdm.tqdm(range(tokenizer_teacher_total_vocab_size), desc="Decoding teacher tokens"): + if token_id in ignore_teacher_ids: + continue + try: + # Get token representation based on configuration + if USE_RAW_TOKENS: + decoded = tokenizer_teacher.convert_ids_to_tokens([token_id])[0] + else: + decoded = tokenizer_teacher.decode([token_id]) + + # Apply canonicalization if enabled + decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION) + teacher_tokens_decoded[token_id] = decoded + except: + # Skip tokens that can't be processed + continue + + print(f"Successfully decoded {len(teacher_tokens_decoded)} teacher tokens") + + # Find multi-token mappings + multi_token_examples = [] + + + print("=== FIRST PASS: Teacher tokens -> Student tokens ===") + print("Finding multi-token mappings...") + + + # First pass: Teacher tokens -> Student tokens (reverse direction) + if 1: + print("\n=== First PASS: Student tokens -> Teacher tokens ===") + + # Get all student tokens and decode them + student_vocab = tokenizer_student.get_vocab() + + student_tokens_decoded = {} + + print("Decoding student tokens...") + for token_id in tqdm.tqdm(range(tokenizer_student_total_vocab_size), desc="Decoding student tokens"): + if token_id in ignore_student_ids: + continue + try: + # Get token representation based on configuration + if USE_RAW_TOKENS: + decoded = tokenizer_student.convert_ids_to_tokens([token_id])[0] + else: + decoded = tokenizer_student.decode([token_id]) + + if decoded.startswith("<|") and decoded.endswith("|>"): + print(f"Skipping special token: {decoded}") + continue + + # Apply canonicalization if enabled + decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION) + student_tokens_decoded[token_id] = decoded + except: + # Skip tokens that can't be processed + continue + + print(f"Successfully decoded {len(student_tokens_decoded)} student tokens") + + reverse_multi_token_examples = [] + print("Finding reverse multi-token mappings...") + for student_token_id, student_token_str in tqdm.tqdm(student_tokens_decoded.items(), desc="Processing student tokens"): + # Tokenize the student token string using teacher tokenizer + teacher_encoding = tokenizer_teacher(student_token_str, add_special_tokens=False, return_attention_mask=False) + teacher_token_ids = teacher_encoding['input_ids'] + + # Skip if any teacher token is in ignore list + if any(tid in ignore_teacher_ids for tid in teacher_token_ids): + continue + + # Cut to only first 4 tokens + teacher_token_ids = teacher_token_ids[:TOKENS_TO_CUT] + + # Get weight distribution based on number of teacher tokens + weights = create_weight_distribution(len(teacher_token_ids)) + + # Add to transformation matrix (reverse direction: teacher_token_id -> student_token_id) + if 1: + for teacher_token_id, weight in zip(teacher_token_ids, weights): + transformation_counts[(student_token_id, teacher_token_id)] += weight + + # Collect examples for analysis + if len(teacher_token_ids) >= 2: + teacher_tokens_decoded_reverse = [tokenizer_teacher.decode([tid]) for tid in teacher_token_ids] + reverse_multi_token_examples.append({ + 'student_token': student_token_str, + 'student_id': student_token_id, + 'teacher_tokens': teacher_tokens_decoded_reverse, + 'teacher_ids': teacher_token_ids, + 'weights': weights + }) + + # second pass: Teacher tokens -> Student tokens (opposite direction) + if ENABLE_REVERSE_PASS: + print("\n=== secod PASS: Teacher tokens -> Student tokens ===") + + # Get all teacher tokens and decode them + teacher_vocab = tokenizer_teacher.get_vocab() + teacher_tokens_decoded = {} + + print("Decoding teacher tokens...") + for token_id in tqdm.tqdm(range(tokenizer_teacher_total_vocab_size), desc="Decoding teacher tokens"): + if token_id in ignore_teacher_ids: + continue + try: + # Get token representation based on configuration + if USE_RAW_TOKENS: + decoded = tokenizer_teacher.convert_ids_to_tokens([token_id])[0] + else: + decoded = tokenizer_teacher.decode([token_id]) + + if decoded.startswith("<|") and decoded.endswith("|>"): + print(f"Skipping special token: {decoded}") + continue + + # Apply canonicalization if enabled + decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION) + teacher_tokens_decoded[token_id] = decoded + except: + # Skip tokens that can't be processed + continue + + print(f"Successfully decoded {len(teacher_tokens_decoded)} teacher tokens") + + teacher_to_student_multi_token_examples = [] + print("Finding teacher->student multi-token mappings...") + for teacher_token_id, teacher_token_str in tqdm.tqdm(teacher_tokens_decoded.items(), desc="Processing teacher tokens"): + # Tokenize the teacher token string using student tokenizer + student_encoding = tokenizer_student(teacher_token_str, add_special_tokens=False, return_attention_mask=False) + student_token_ids = student_encoding['input_ids'] + + # Skip if any student token is in ignore list + if any(sid in ignore_student_ids for sid in student_token_ids): + continue + + # Cut to only first 4 tokens + student_token_ids = student_token_ids[:TOKENS_TO_CUT] + + # Get weight distribution based on number of student tokens + weights = create_weight_distribution(len(student_token_ids)) + + # Add to transformation matrix (student_token_id -> teacher_token_id mapping) + if 1: + for student_token_id, weight in zip(student_token_ids, weights): + transformation_counts[(student_token_id, teacher_token_id)] += weight + + # Collect examples for analysis + if len(student_token_ids) >= 2: + student_tokens_decoded_reverse = [tokenizer_student.decode([sid]) for sid in student_token_ids] + teacher_to_student_multi_token_examples.append({ + 'teacher_token': teacher_token_str, + 'teacher_id': teacher_token_id, + 'student_tokens': student_tokens_decoded_reverse, + 'student_ids': student_token_ids, + 'weights': weights + }) + + print(f"\n=== ADDING SPECIAL TOKEN MAPPINGS ===") + + # Find and add special token mappings (if enabled) + special_token_mappings = [] + if ENABLE_SPECIAL_TOKEN_MAPPING: + special_token_mappings = find_similar_special_tokens( + tokenizer_student, + tokenizer_teacher, + similarity_threshold=SPECIAL_TOKEN_SIMILARITY_THRESHOLD, + top_k_matches=SPECIAL_TOKEN_TOP_K + ) + else: + print("Special token mapping disabled") + + if special_token_mappings: + print(f"\nFound {len(special_token_mappings)} special token mappings:") + initial_transformation_count = len(transformation_counts) + + # Add ALL mappings to transformation matrix + for mapping in special_token_mappings: + student_id = mapping['student_id'] + teacher_id = mapping['teacher_id'] + similarity = mapping['similarity'] + + # Add mapping with weight based on similarity + weight = similarity * 0.8 # Scale similarity to reasonable weight + transformation_counts[(student_id, teacher_id)] += weight + + # Group mappings by student token and show top 2 matches per student token + from collections import defaultdict + student_mappings = defaultdict(list) + for mapping in special_token_mappings: + student_mappings[mapping['student_id']].append(mapping) + + # Sort each student's mappings by similarity and show top 2 + print("Top 2 matches per student special token:") + shown_count = 0 + for student_id, mappings in student_mappings.items(): + # Sort by similarity (highest first) + sorted_mappings = sorted(mappings, key=lambda x: x['similarity'], reverse=True) + + # Show top 2 for this student token + student_token = sorted_mappings[0]['student_token'] # Get student token name + print(f" {student_token}:") + + for mapping in sorted_mappings[:2]: + similarity = mapping['similarity'] + weight = similarity * 0.8 + print(f" -> '{mapping['teacher_token']}' (similarity: {similarity:.3f}, weight: {weight:.3f})") + shown_count += 1 + + if len(sorted_mappings) > 2: + print(f" ... and {len(sorted_mappings) - 2} more matches") + + total_hidden = len(special_token_mappings) - shown_count + if total_hidden > 0: + print(f"Total mappings not shown: {total_hidden}") + + added_count = len(transformation_counts) - initial_transformation_count + print(f"Added {added_count} new special token transformation entries") + else: + print("No similar special tokens found") + + print(f"\n=== SUMMARY ===") + print(f"Found {len(multi_token_examples)} teacher tokens that map to multiple student tokens") + # exit() + # Show some examples + if multi_token_examples: + print("\nExamples of multi-token mappings:") + for i, example in enumerate(multi_token_examples[:10]): + print(f" Teacher '{example['teacher_token']}' -> Student {example['student_tokens']} (weights: {example['weights']})") + if len(multi_token_examples) > 10: + print(f" ... and {len(multi_token_examples) - 10} more.") + + if ENABLE_REVERSE_PASS: + print(f"\nReverse pass enabled - added bidirectional mappings") + + print(f"\nTotal transformation entries: {len(transformation_counts)}") + + if ENABLE_EXACT_MATCH: + + print("Checking for exact token matches and setting exact mappings...") + # check exact match between student and teacher tokens and set those as perfect 1-to-1 mappings + # Convert all tokens to strings at once for vectorized comparison + # pdb.set_trace() + tokens_student = [apply_canonicalization_if_enabled(tokenizer_student.convert_ids_to_tokens([i])[0], USE_CANONICALIZATION) for i in range(tokenizer_student_total_vocab_size)] + tokens_teacher = [apply_canonicalization_if_enabled(tokenizer_teacher.convert_ids_to_tokens([j])[0], USE_CANONICALIZATION) for j in range(tokenizer_teacher_total_vocab_size)] + + map_teacher_token_to_idx = {token: j for j, token in enumerate(tokens_teacher)} + + # Find indices in student and teacher where the tokens are identical + match_indices_student = [] + match_indices_teacher = [] + for i, token_student in enumerate(tokens_student): + if token_student in map_teacher_token_to_idx: + j = map_teacher_token_to_idx[token_student] + match_indices_student.append(i) + match_indices_teacher.append(j) + + if match_indices_student: + print(f"Found {len(match_indices_student)} exact matches. Setting perfect 1-to-1 mappings.") + + # For tokens that match exactly, we want their mapping to be 1.0 + # and they should not be mapped to any other token. + # First, remove all existing mappings for these student tokens + match_indices_student_set = set(match_indices_student) + keys_to_remove = [] + for key in transformation_counts.keys(): + student_id, teacher_id = key + if student_id in match_indices_student_set: + keys_to_remove.append(key) + for key in keys_to_remove: + del transformation_counts[key] + # Then, set the perfect 1-to-1 mappings for exact matches + for student_id, teacher_id in zip(match_indices_student, match_indices_teacher): + transformation_counts[(student_id, teacher_id)] = 1.0 + + + def debug_projection_map(transformation_counts, source_tokenizer, target_tokenizer, direction="", N=50): + """Debug function to show projection mappings with decoded tokens and weights.""" + print(f"\n--- Debugging projection map {direction} (showing {N} examples) ---") + + # Group transformation_counts by source token (student token) + source_to_targets = defaultdict(list) + for (source_id, target_id), weight in transformation_counts.items(): + source_to_targets[source_id].append((target_id, weight)) + + # Sort by source token ID and take first N + # sorted_sources = sorted(source_to_targets.keys())[:N] + sorted_sources = sorted(source_to_targets.keys())[-N:] + + for source_id in sorted_sources: + # Decode source token + try: + if USE_RAW_TOKENS: + source_token = source_tokenizer.convert_ids_to_tokens([source_id])[0] + else: + source_token = source_tokenizer.decode([source_id]) + source_token = apply_canonicalization_if_enabled(source_token, USE_CANONICALIZATION) + source_token_str = repr(source_token) # Use repr to show special chars + except: + source_token_str = f"" + + # Sort targets by weight (descending) and build target string + targets_weights = sorted(source_to_targets[source_id], key=lambda x: x[1], reverse=True) + + target_parts = [] + for target_id, weight in targets_weights: + try: + if USE_RAW_TOKENS: + target_token = target_tokenizer.convert_ids_to_tokens([target_id])[0] + else: + target_token = target_tokenizer.decode([target_id]) + target_token = apply_canonicalization_if_enabled(target_token, USE_CANONICALIZATION) + target_token_str = repr(target_token) + except: + target_token_str = f"" + target_parts.append(f"{target_token_str}({weight:.4f})") + + target_string = " ".join(target_parts) + print(f"{source_token_str} -> {target_string}") + + # debug_projection_map(transformation_counts, tokenizer_student, tokenizer_teacher, + # direction="student->teacher", N=1000) + + # Create transformation matrix (student -> teacher projection) + indices = list(transformation_counts.keys()) + values = list(transformation_counts.values()) + + teacher_indices = [idx[1] for idx in indices] + student_indices = [idx[0] for idx in indices] + + # Create sparse tensor with student tokens as rows, teacher tokens as columns + # This creates a student -> teacher projection matrix + indices_tensor = torch.LongTensor([student_indices, teacher_indices]) + values_tensor = torch.FloatTensor(values) + + transformation_matrix_sparse = torch.sparse_coo_tensor( + indices_tensor, + values_tensor, + (tokenizer_student_total_vocab_size, tokenizer_teacher_total_vocab_size), + device="cuda" if torch.cuda.is_available() else "cpu", + dtype=torch.bfloat16 + ) + + # indices, values = torch.topk(transformation_matrix_sparse, k=1000, dim=1) + + print(f"Created sparse student->teacher projection matrix with shape: {transformation_matrix_sparse.shape}") + print(f"Non-zero elements: {transformation_matrix_sparse._nnz()}") + + if 0: + # cant fit to the memory + # Calculate mapping statistics from sparse matrix + print("\nCalculating mapping statistics from projection matrix...") + + # Count non-zero elements in each row (each row = student token) + dense_matrix = transformation_matrix_sparse.to_dense() + non_zero_counts_per_row = (dense_matrix != 0).sum(dim=1) # Count non-zeros per row + + # Create statistics + mapping_stats = defaultdict(int) + for count in non_zero_counts_per_row: + mapping_stats[count.item()] += 1 + + # Print mapping statistics + print("\nMapping statistics (student tokens -> teacher tokens):") + for i in range(1, 5): # 1, 2, 3, 4 teacher tokens + count = mapping_stats.get(i, 0) + print(f"Student tokens mapping to {i} teacher tokens: {count}") + + total_mapped = sum(mapping_stats.values()) + print(f"Total student tokens mapped: {total_mapped}") + + # Convert sparse matrix to same format as minimal_projection_generator.py + os.makedirs(args.output_dir, exist_ok=True) + + # Convert defaultdict to regular dict for saving + transformation_counts_dict = dict(transformation_counts) + + # Show some examples of the projection mappings + debug_projection_map(transformation_counts_dict, tokenizer_student, tokenizer_teacher, + direction="student->teacher", N=1000) + + # exit() + + print(f"\nConverting sparse matrix to top-{TOP_K} dense format...") + + # Convert sparse matrix to dense and get top-k values per row + print("Converting to dense matrix on CPU to avoid memory issues...") + dense_matrix = transformation_matrix_sparse.cpu().to_dense() # Move to CPU to handle memory + print(f"Dense matrix shape: {dense_matrix.shape}") + + # Get top-k values and indices for each row (each source token) + print(f"Extracting top-{TOP_K} values per token...") + + # Apply sinkhorn normalization on CPU + if 1: + print("Applying Sinkhorn normalization on CPU...") + dense_matrix = sinkhorn_one_dim(dense_matrix, n_iters=1) + + # Extract top-k on CPU + top_k_likelihoods, top_k_indices = torch.topk(dense_matrix, k=min(TOP_K, dense_matrix.shape[1]), dim=1) + # exit() + # Handle case where vocabulary has fewer tokens than TOP_K + actual_k = top_k_indices.shape[1] + if actual_k < TOP_K: + print(f"Warning: Target vocabulary size ({dense_matrix.shape[1]}) is smaller than TOP_K ({TOP_K}). Using k={actual_k}") + # Pad with -1 indices and 0.0 likelihoods to maintain consistent shape + pad_size = TOP_K - actual_k + top_k_indices = torch.cat([top_k_indices, torch.full((top_k_indices.shape[0], pad_size), -1, dtype=top_k_indices.dtype)], dim=1) + top_k_likelihoods = torch.cat([top_k_likelihoods, torch.zeros((top_k_likelihoods.shape[0], pad_size), dtype=top_k_likelihoods.dtype)], dim=1) + + if 0: + threshold_mask = top_k_likelihoods >= 0.0000000000000000001 + top_k_indices = top_k_indices.where(threshold_mask, torch.full_like(top_k_indices, -1)) + + # Apply SCALE_TRICK: set last column to -4 if enabled + if ENABLE_SCALE_TRICK: + print("ENABLE_SCALE_TRICK is True: Setting last column of likelihoods to -4.0") + top_k_likelihoods[:, -1] = 0.2 + if ENABLE_EXACT_MATCH: + for indices in match_indices_student: + top_k_likelihoods[indices, -1] = 0.0 + print(f"Set last column of likelihoods to 0.0 for {len(match_indices_student)} exact matches as exact match is enabled") + # Apply sinkhorn normalization on CPU + if 1: + print("Applying Sinkhorn normalization on CPU...") + top_k_likelihoods = sinkhorn_one_dim(top_k_likelihoods, n_iters=1) + + # pdb.set_trace() + #set indices to -1 where likelihood is 0 + + # Create filename in same format as minimal_projection_generator.py + def clean_model_name_for_filename(name: str) -> str: + """Removes parameter counts and common suffixes from model names for cleaner filenames.""" + import re + # Removes patterns like -8B, -1.5B, -4b, -125m etc. + cleaned_name = re.sub(r'-?[0-9\.]+[bBmB]', '', name, flags=re.IGNORECASE) + # Remove common suffixes + cleaned_name = cleaned_name.replace('-Base', '').replace('-it', '').replace('-Instruct', '') + # Clean up any leading/trailing hyphens that might result + cleaned_name = cleaned_name.strip('-_') + return cleaned_name + + student_clean_name = clean_model_name_for_filename(student_model_name.split("/")[-1]) + teacher_clean_name = clean_model_name_for_filename(teacher_model_name.split("/")[-1]) + + output_filename = f"projection_map_{student_clean_name}_to_{teacher_clean_name}_multitoken_top_{TOP_K}_double" + # if USE_RAW_TOKENS: + # output_filename += "_raw_tokens" + if ENABLE_SPECIAL_TOKEN_MAPPING: + output_filename += "_special" + output_filename += ".pt" + # if ENABLE_REVERSE_PASS: + # output_filename = output_filename.replace(".pt", "_bidirectional.pt") + output_path = os.path.join(args.output_dir, output_filename) + + # Save in same format as minimal_projection_generator.py + torch.save({ + "indices": top_k_indices, + "likelihoods": top_k_likelihoods, + "model_A_id": student_model_name, # source model (student) + "model_B_id": teacher_model_name, # target model (teacher) + }, output_path) + + print(f"Saved projection map to: {output_path}") + print(f"Format: indices shape {top_k_indices.shape}, likelihoods shape {top_k_likelihoods.shape}") + print(f"Compatible with minimal_projection_generator.py format") + print(f"Token processing mode: {'Raw tokens (convert_ids_to_tokens)' if USE_RAW_TOKENS else 'Decoded tokens (decode)'}") + if ENABLE_REVERSE_PASS: + print("File includes bidirectional mappings (teacher->student and student->teacher)") + if ENABLE_SPECIAL_TOKEN_MAPPING: + print(f"File includes special token mappings (similarity_threshold={SPECIAL_TOKEN_SIMILARITY_THRESHOLD}, top_k={SPECIAL_TOKEN_TOP_K})") + # exit() + + + + + # Test projection function compatibility (same as minimal_projection_generator.py) + print("\n--- Testing projection function compatibility ---") + + def project_token_likelihoods(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device): + """Projects token likelihoods from a source to a target vocabulary using a sparse map.""" + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + if source_vocab_size != projection_map_indices.shape[0]: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches projection map size ({projection_map_indices.shape[0]})") + + top_k = projection_map_indices.shape[1] + input_likelihoods = input_likelihoods.to(device) + projection_map_indices = projection_map_indices.to(device) + projection_map_values = projection_map_values.to(device) + + crow_indices = torch.arange(0, (source_vocab_size + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = projection_map_indices.flatten() + values = projection_map_values.flatten() + + sparse_projection_matrix = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(source_vocab_size, target_vocab_size), device=device + ) + + reshaped_input = input_likelihoods.reshape(batch_size * seq_len, source_vocab_size) + projected_likelihoods_reshaped = torch.matmul(reshaped_input, sparse_projection_matrix) + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, target_vocab_size) + + # Create a dummy likelihood tensor: [BATCH, SEQ, source_vocab_size] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dummy_tensor = torch.randn(1, 4096, tokenizer_student_total_vocab_size, device=device, dtype=torch.bfloat16) + + # Transform this tensor using the projection map + projected_tensor = project_token_likelihoods( + dummy_tensor, + top_k_indices.to(device), + top_k_likelihoods.to(device), + tokenizer_teacher_total_vocab_size, + device + ) + print(f"Input tensor shape: {dummy_tensor.shape}") + print(f"Projected tensor shape: {projected_tensor.shape}") + print("Projection test successful - format is fully compatible!") + + # pdb.set_trace() + \ No newline at end of file diff --git a/nemo_rl/utils/x_token/reapply_exact_map.py b/nemo_rl/utils/x_token/reapply_exact_map.py new file mode 100644 index 0000000000..b37f994466 --- /dev/null +++ b/nemo_rl/utils/x_token/reapply_exact_map.py @@ -0,0 +1,246 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +import torch.nn as nn +from tokenalign import TokenAligner +import gc +from collections import defaultdict +from datasets import load_dataset, get_dataset_config_names +import random +import time +import numpy as np +import tqdm +import pdb +import difflib +import re +import argparse +import os + +def apply_canonicalization_if_enabled(token_str, use_canonicalization): + """Apply canonicalization to token string if enabled.""" + if use_canonicalization: + return TokenAligner._canonical_token(token_str) + return token_str + +def parse_arguments(): + """Parse command line arguments for the multi-token projection script.""" + parser = argparse.ArgumentParser( + description="Generate multi-token projection mappings between tokenizers", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + # Model selection arguments + parser.add_argument( + "--student-model", + type=str, + default="meta-llama/Llama-3.2-1B", + help="Student model name or path" + ) + parser.add_argument( + "--teacher-model", + type=str, + default="microsoft/phi-4", + help="Teacher model name or path" + ) + + # Boolean flags + parser.add_argument( + "--enable-scale-trick", + action="store_true", + default=True, + help="Enable scale trick (set last column likelihood to 0.2)" + ) + parser.add_argument( + "--disable-scale-trick", + action="store_false", + dest="enable_scale_trick", + help="Disable scale trick" + ) + parser.add_argument( + "--enable-reverse-pass", + action="store_true", + default=True, + help="Enable second pass: student tokens -> teacher tokens" + ) + parser.add_argument( + "--disable-reverse-pass", + action="store_false", + dest="enable_reverse_pass", + help="Disable reverse pass" + ) + parser.add_argument( + "--enable-exact-match", + action="store_true", + default=False, + help="Enable exact match enforcement for identical tokens" + ) + parser.add_argument( + "--use-raw-tokens", + action="store_true", + default=False, + help="Use convert_ids_to_tokens instead of decode, should be False" + ) + parser.add_argument( + "--enable-special-token-mapping", + action="store_true", + default=True, + help="Enable mapping of similar special tokens" + ) + parser.add_argument( + "--disable-special-token-mapping", + action="store_false", + dest="enable_special_token_mapping", + help="Disable special token mapping" + ) + parser.add_argument( + "--use-canonicalization", + action="store_true", + default=False, + help="Apply token canonicalization before processing to normalize different tokenizer representations (e.g., Ġ vs ▁ prefixes, Ċ vs \\n)" + ) + + # Numeric parameters + parser.add_argument( + "--tokens-to-cut", + type=int, + default=4, + help="Maximum number of tokens to consider for multi-token mappings" + ) + parser.add_argument( + "--top-k", + type=int, + default=32, + help="Number of top projections to keep for each token" + ) + parser.add_argument( + "--special-token-similarity-threshold", + type=float, + default=0.3, + help="Minimum similarity threshold for special token matching" + ) + parser.add_argument( + "--special-token-top-k", + type=int, + default=None, + help="Top K matches for each special token (defaults to --top-k value)" + ) + + # File paths + parser.add_argument( + "--initial-projection-path", + type=str, + default=None, + help="Path to initial projection map to load and extend" + ) + parser.add_argument( + "--output-dir", + type=str, + default="cross_tokenizer_data", + help="Output directory for saving projection maps" + ) + + return parser.parse_args() + + +if __name__ == "__main__": + # Parse command line arguments + args = parse_arguments() + # Model names from arguments + teacher_model_name = args.teacher_model + student_model_name = args.student_model + USE_CANONICALIZATION = args.use_canonicalization + + + tokenizer_student = AutoTokenizer.from_pretrained(student_model_name) + tokenizer_teacher = AutoTokenizer.from_pretrained(teacher_model_name) + + tokenizer_student_total_vocab_size = len(tokenizer_student) + tokenizer_teacher_total_vocab_size = len(tokenizer_teacher) + model_A_config = AutoConfig.from_pretrained(student_model_name) + model_B_config = AutoConfig.from_pretrained(teacher_model_name) + + tokens_student = [apply_canonicalization_if_enabled(tokenizer_student.convert_ids_to_tokens([i])[0], USE_CANONICALIZATION) for i in range(tokenizer_student_total_vocab_size)] + tokens_teacher = [apply_canonicalization_if_enabled(tokenizer_teacher.convert_ids_to_tokens([j])[0], USE_CANONICALIZATION) for j in range(tokenizer_teacher_total_vocab_size)] + + map_teacher_token_to_idx = {token: j for j, token in enumerate(tokens_teacher)} + + # Find indices in student and teacher where the tokens are identical + match_indices_student = [] + match_indices_teacher = [] + for i, token_student in enumerate(tokens_student): + if token_student in map_teacher_token_to_idx: + j = map_teacher_token_to_idx[token_student] + match_indices_student.append(i) + match_indices_teacher.append(j) + + if match_indices_student: + print(f"Found {len(match_indices_student)} exact matches. Setting perfect 1-to-1 mappings.") + + # load intial projection map + initial_projection_path = args.initial_projection_path + if initial_projection_path is not None: + initial_projection_map = torch.load(initial_projection_path) + else: + initial_projection_map = None + + # go through token in projection map. For each token present in match_indices_student, set it's likelihoods and incices to 1.0 and the exact match teacher token + non_exact_map_tokens = list(range(len(initial_projection_map["likelihoods"]))) + all_student_token_ids = list(range(len(initial_projection_map["likelihoods"]))) + + show_remapping = 5 + if show_remapping > 0: + print(f"Showing remapping for the last {show_remapping} exact matches.") + else: + print(f"Not showing remapping.") + + for i, exact_token_student in enumerate(match_indices_student): + exact_token_teacher = match_indices_teacher[i] + + index_ = all_student_token_ids.index(exact_token_student) + likelihoods = initial_projection_map["likelihoods"][index_] + indices = initial_projection_map["indices"][index_] + + if len(match_indices_student) - i <= show_remapping: + print(f"prior to remapping: likelihoods {likelihoods} indices {indices}") + + topk = indices.shape[0] + + remapped_indices = torch.ones_like(indices) * -1 + remapped_likelihoods = torch.zeros_like(likelihoods) + + remapped_likelihoods[0] = 1.0 + remapped_indices[0] = exact_token_teacher + + # if exact_token_student == 5159: + # import pdb + # pdb.set_trace() + + initial_projection_map["likelihoods"][index_] = remapped_likelihoods + initial_projection_map["indices"][index_] = remapped_indices + + + if len(match_indices_student) - i <= show_remapping: + print(f'after remapping {tokens_student[exact_token_student]}:{exact_token_student} -> {tokens_teacher[exact_token_teacher]}:{exact_token_teacher}: likelihoods {initial_projection_map["likelihoods"][index_]} indices {initial_projection_map["indices"][index_]}') + non_exact_map_tokens.remove(index_) + + + # import pdb + # pdb.set_trace() + # print(f"non exact map tokens: {non_exact_map_tokens}") + # pdb.set_trace() + save_path = args.initial_projection_path.split(".")[0] + "_exact_map_remapped.pt" + torch.save(initial_projection_map, save_path) + print(f"Saved remapped projection map to: {save_path}") + print(f"remapped {len(match_indices_student)} tokens. Retained remaining {len(non_exact_map_tokens)} tokens as is.") \ No newline at end of file diff --git a/nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py b/nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py new file mode 100644 index 0000000000..caba5b2738 --- /dev/null +++ b/nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py @@ -0,0 +1,451 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import os +import argparse +import tqdm +from transformers import AutoTokenizer, AutoConfig +import pdb; + +def sinkhorn_one_dim(A, n_iters=1): + """Apply Sinkhorn normalization to make each row sum to 1.""" + for _ in range(n_iters): + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + return A + +def clean_model_name_for_filename(name: str) -> str: + """Removes parameter counts and common suffixes from model names for cleaner filenames.""" + import re + # Removes patterns like -8B, -1.5B, -4b, -125m etc. + cleaned_name = re.sub(r'-?[0-9\.]+[bBmB]', '', name, flags=re.IGNORECASE) + # Remove common suffixes + cleaned_name = cleaned_name.replace('-Base', '').replace('-it', '').replace('-Instruct', '') + # Clean up any leading/trailing hyphens that might result + cleaned_name = cleaned_name.strip('-_') + return cleaned_name + +def sort_and_cut_projection_matrix(input_path, output_path, new_top_k, preserve_last=False, verbose=True): + """ + Load a projection matrix, sort each row by weight values, and save with new top_k cutoff. + + Args: + input_path: Path to input projection matrix file + output_path: Path to save the new projection matrix + new_top_k: New top_k value for cutoff + preserve_last: If True, always preserve the last column as the final element + verbose: Whether to print progress information + """ + if verbose: + print(f"Loading projection matrix from: {input_path}") + + # Load the projection matrix + projection_data = torch.load(input_path, map_location='cpu', weights_only=False) + + if not isinstance(projection_data, dict) or 'indices' not in projection_data or 'likelihoods' not in projection_data: + raise ValueError("Input file must contain a dictionary with 'indices' and 'likelihoods' keys") + + original_indices = projection_data['indices'] # Shape: [vocab_size, original_top_k] + original_likelihoods = projection_data['likelihoods'] # Shape: [vocab_size, original_top_k] + + vocab_size, original_top_k = original_indices.shape + + if verbose: + print(f"Original matrix shape: {original_indices.shape}") + print(f"Original top_k: {original_top_k}") + print(f"New top_k: {new_top_k}") + print(f"Preserve last column: {preserve_last}") + # pdb.set_trace() + + if new_top_k > original_top_k: + print(f"Warning: New top_k ({new_top_k}) is larger than original top_k ({original_top_k})") + print(f"Will pad with -1 indices and 0.0 likelihoods") + effective_top_k = original_top_k + else: + effective_top_k = new_top_k + + # Initialize new tensors + new_indices = torch.full((vocab_size, new_top_k), -1, dtype=original_indices.dtype) + new_likelihoods = torch.zeros((vocab_size, new_top_k), dtype=original_likelihoods.dtype) + + # Statistics tracking + rows_with_order_change = 0 + significant_components_count = [0] * min(new_top_k, 10) # Track up to 10 components + threshold_for_significance = 0.2 # Threshold for considering a component "significant" + # Track position of maximum element in original ordering + max_element_positions = {} # position -> count + # Track preserve_last statistics + rows_with_preserved_last = 0 + # Track specifically when max element is in the last column + rows_with_max_in_last_column = 0 + # Track position of maximum element in final sorted and trimmed matrix + final_max_element_positions = {} # position -> count + + # threshold_for_significance = 0.05 # Threshold for considering a component "significant" + # threshold_for_significance = 0.05 # Threshold for considering a component "significant" + + if verbose: + print("Sorting and cutting each row...") + + # Process each row (each source token) + last_element_trick_count = 0 + for row_idx in tqdm.tqdm(range(vocab_size), desc="Processing rows", disable=not verbose): + row_indices = original_indices[row_idx] # [original_top_k] + row_likelihoods = original_likelihoods[row_idx] # [original_top_k] + + # Filter out invalid indices (-1) and zero likelihoods + valid_mask = (row_indices != -1) & (row_likelihoods > 0) + + if valid_mask.any(): + valid_indices = row_indices[valid_mask] + valid_likelihoods = row_likelihoods[valid_mask] + + # Track position of maximum element in original ordering + max_pos = torch.argmax(valid_likelihoods).item() + if max_pos not in max_element_positions: + max_element_positions[max_pos] = 0 + max_element_positions[max_pos] += 1 + + # Check if max element is specifically in the last column + # Find the actual maximum value in the original row (including invalid entries) + original_max_pos = torch.argmax(row_likelihoods).item() + if original_max_pos == original_top_k - 1: + # Only count if the last position actually has valid data + last_index = row_indices[original_top_k - 1] + last_likelihood = row_likelihoods[original_top_k - 1] + if last_index != -1 and last_likelihood > 0: + rows_with_max_in_last_column += 1 + + if preserve_last and new_top_k >= 1: + # Handle preserve_last case + last_index = original_indices[row_idx, original_top_k - 1] + last_likelihood = original_likelihoods[row_idx, original_top_k - 1] + + if new_top_k == 1: + # Special case: only keep the last element + if last_index != -1 and last_likelihood > 0: + new_indices[row_idx, 0] = last_index + new_likelihoods[row_idx, 0] = last_likelihood + rows_with_preserved_last += 1 + + # Count significant components + if last_likelihood >= threshold_for_significance: + significant_components_count[0] += 1 + else: + # General case: sort first (original_top_k-1) elements, then add last element + elements_to_sort = min(len(valid_likelihoods), original_top_k - 1) + if elements_to_sort > 0: + # Get elements excluding the last position in original matrix + sort_mask = torch.arange(len(valid_likelihoods)) < elements_to_sort + if sort_mask.any(): + sortable_indices = valid_indices[sort_mask] + sortable_likelihoods = valid_likelihoods[sort_mask] + + # Sort the non-last elements + sorted_likelihoods, sort_order = torch.sort(sortable_likelihoods, descending=True) + sorted_indices = sortable_indices[sort_order] + + # Check if order changed in the sortable portion + original_order = torch.arange(len(sortable_likelihoods)) + if not torch.equal(sort_order, original_order): + rows_with_order_change += 1 + + # Take top (new_top_k - 1) elements from sorted portion + num_from_sorted = min(len(sorted_indices), new_top_k - 1) + + new_indices[row_idx, :num_from_sorted] = sorted_indices[:num_from_sorted] + new_likelihoods[row_idx, :num_from_sorted] = sorted_likelihoods[:num_from_sorted] + + # Count significant components from sorted portion + for comp_idx in range(min(num_from_sorted, len(significant_components_count) - 1)): + if sorted_likelihoods[comp_idx] >= threshold_for_significance: + significant_components_count[comp_idx] += 1 + + # Always put the last element at the end (if valid) + + if last_index != -1 and last_likelihood > 0: + last_element_trick_count += 1 + new_indices[row_idx, new_top_k - 1] = last_index + new_likelihoods[row_idx, new_top_k - 1] = last_likelihood + rows_with_preserved_last += 1 + + # Count significant component for the preserved last element + if new_top_k - 1 < len(significant_components_count): + if last_likelihood >= threshold_for_significance: + significant_components_count[new_top_k - 1] += 1 + + else: + # Original logic: sort all elements normally + # Check if order changed by comparing original vs sorted order + original_order = torch.arange(len(valid_likelihoods)) + sorted_likelihoods, sort_order = torch.sort(valid_likelihoods, descending=True) + + # Check if the order changed (not just sorted, but actually different) + if not torch.equal(sort_order, original_order): + rows_with_order_change += 1 + + sorted_indices = valid_indices[sort_order] + + # Take top effective_top_k elements + num_to_take = min(len(sorted_indices), effective_top_k) + + new_indices[row_idx, :num_to_take] = sorted_indices[:num_to_take] + new_likelihoods[row_idx, :num_to_take] = sorted_likelihoods[:num_to_take] + # pdb.set_trace() + # Count significant components (components above threshold) + for comp_idx in range(min(num_to_take, len(significant_components_count))): + if sorted_likelihoods[comp_idx] >= threshold_for_significance: + significant_components_count[comp_idx] += 1 + # if significant_components_count[1] > 0.0: + # pdb.set_trace() + + # If new_top_k > original_top_k, the tensors are already padded with -1 and 0.0 + + # Apply Sinkhorn normalization to the final matrix + print(f"last element trick count: {last_element_trick_count}") + if verbose: + print("Applying Sinkhorn normalization...") + + # Apply normalization only to non-zero values to preserve sparsity structure + normalized_likelihoods = sinkhorn_one_dim(new_likelihoods.clone(), n_iters=1) + + # Calculate final maximum element position statistics after sorting and normalization + if verbose: + print("Calculating final maximum element position statistics...") + + for row_idx in range(vocab_size): + row_likelihoods = normalized_likelihoods[row_idx] + # Filter out zero likelihoods + valid_mask = row_likelihoods > 0 + if valid_mask.any(): + valid_likelihoods = row_likelihoods[valid_mask] + # Find position of maximum element in the final matrix + max_pos_in_valid = torch.argmax(valid_likelihoods).item() + # Convert back to original position in the row + valid_positions = torch.nonzero(valid_mask).squeeze(-1) + actual_max_pos = valid_positions[max_pos_in_valid].item() + + if actual_max_pos not in final_max_element_positions: + final_max_element_positions[actual_max_pos] = 0 + final_max_element_positions[actual_max_pos] += 1 + + # Create output dictionary with same format as input + output_data = { + 'indices': new_indices, + 'likelihoods': normalized_likelihoods, + } + + # Copy over any additional metadata + for key in projection_data: + if key not in ['indices', 'likelihoods']: + output_data[key] = projection_data[key] + + # Save the new projection matrix + torch.save(output_data, output_path) + + if verbose: + print(f"Saved sorted and cut projection matrix to: {output_path}") + print(f"New matrix shape: {new_indices.shape}") + + # Show basic statistics + non_zero_counts = (new_likelihoods > 0).sum(dim=1) + avg_non_zero = non_zero_counts.float().mean().item() + print(f"Average non-zero entries per row: {avg_non_zero:.2f}") + print(f"Rows with max entries ({new_top_k}): {(non_zero_counts == new_top_k).sum().item()}") + + # Show ordering statistics + print(f"\n=== Ordering Statistics ===") + print(f"Rows with changed order after sorting: {rows_with_order_change:,} / {vocab_size:,} ({100*rows_with_order_change/vocab_size:.1f}%)") + if preserve_last: + print(f"Rows with preserved last element: {rows_with_preserved_last:,} / {vocab_size:,} ({100*rows_with_preserved_last/vocab_size:.1f}%)") + + # Show last column maximum element statistics + print(f"\n=== Last Column Maximum Element Statistics ===") + total_rows_with_data = sum(max_element_positions.values()) + if total_rows_with_data > 0: + percentage_last_max = 100 * rows_with_max_in_last_column / total_rows_with_data + print(f"Rows with maximum element in LAST column: {rows_with_max_in_last_column:,} / {total_rows_with_data:,} ({percentage_last_max:.1f}%)") + print(f"Rows with maximum element in NON-LAST columns: {total_rows_with_data - rows_with_max_in_last_column:,} / {total_rows_with_data:,} ({100 - percentage_last_max:.1f}%)") + else: + print(f"No valid data found to analyze last column statistics") + + # Show maximum element position distribution + print(f"\n=== Maximum Element Position Distribution (Original Ordering) ===") + total_rows_with_data = sum(max_element_positions.values()) + print(f"Total rows with valid data: {total_rows_with_data:,}") + + # Sort positions for ordered display + sorted_positions = sorted(max_element_positions.keys()) + for pos in sorted_positions[:20]: # Show up to first 20 positions + count = max_element_positions[pos] + percentage = 100 * count / total_rows_with_data if total_rows_with_data > 0 else 0 + ordinal = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"][pos] if pos < 10 else f"{pos+1}th" + print(f"Rows with max element in {ordinal} position: {count:,} / {total_rows_with_data:,} ({percentage:.1f}%)") + + if len(sorted_positions) > 20: + remaining_count = sum(max_element_positions[pos] for pos in sorted_positions[20:]) + remaining_percentage = 100 * remaining_count / total_rows_with_data if total_rows_with_data > 0 else 0 + print(f"Rows with max element in positions 21+: {remaining_count:,} / {total_rows_with_data:,} ({remaining_percentage:.1f}%)") + + # Show final maximum element position distribution (after sorting and normalization) + print(f"\n=== Maximum Element Position Distribution (Final Sorted & Normalized Matrix) ===") + total_final_rows_with_data = sum(final_max_element_positions.values()) + print(f"Total rows with valid data: {total_final_rows_with_data:,}") + + if total_final_rows_with_data > 0: + # Sort positions for ordered display + sorted_final_positions = sorted(final_max_element_positions.keys()) + for pos in sorted_final_positions[:min(new_top_k, 20)]: # Show up to new_top_k or 20 positions + count = final_max_element_positions[pos] + percentage = 100 * count / total_final_rows_with_data + ordinal = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"][pos] if pos < 10 else f"{pos+1}th" + print(f"Rows with max element in {ordinal} position: {count:,} / {total_final_rows_with_data:,} ({percentage:.1f}%)") + + if len(sorted_final_positions) > min(new_top_k, 20): + remaining_count = sum(final_max_element_positions[pos] for pos in sorted_final_positions[min(new_top_k, 20):]) + remaining_percentage = 100 * remaining_count / total_final_rows_with_data + print(f"Rows with max element in positions {min(new_top_k, 20)+1}+: {remaining_count:,} / {total_final_rows_with_data:,} ({remaining_percentage:.1f}%)") + + # Show significant components statistics + print(f"\n=== Significant Components Statistics (threshold >= {threshold_for_significance}) ===") + component_names = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"] + for i, count in enumerate(significant_components_count): + percentage = 100 * count / vocab_size if vocab_size > 0 else 0 + print(f"Rows with significant {component_names[i]} component: {count:,} / {vocab_size:,} ({percentage:.1f}%)") + + # Additional analysis: distribution of likelihood values (after normalization) + all_likelihoods = normalized_likelihoods[normalized_likelihoods > 0] + if len(all_likelihoods) > 0: + print(f"\n=== Likelihood Distribution ===") + print(f"Total non-zero likelihoods: {len(all_likelihoods):,}") + print(f"Mean likelihood: {all_likelihoods.mean().item():.4f}") + print(f"Median likelihood: {all_likelihoods.median().item():.4f}") + print(f"Min likelihood: {all_likelihoods.min().item():.4f}") + print(f"Max likelihood: {all_likelihoods.max().item():.4f}") + + # Show percentiles - convert to float for quantile calculation + percentiles = [90, 95, 99] + all_likelihoods_float = all_likelihoods.float() + for p in percentiles: + val = torch.quantile(all_likelihoods_float, p/100.0).item() + print(f"{p}th percentile: {val:.4f}") + + # Show how many rows have multiple significant components + print(f"\n=== Multi-Component Analysis ===") + rows_with_multiple_significant = 0 + for row_idx in range(vocab_size): + significant_in_row = (normalized_likelihoods[row_idx] >= threshold_for_significance).sum().item() + if significant_in_row >= 2: + rows_with_multiple_significant += 1 + + percentage_multi = 100 * rows_with_multiple_significant / vocab_size if vocab_size > 0 else 0 + print(f"Rows with 2+ significant components: {rows_with_multiple_significant:,} / {vocab_size:,} ({percentage_multi:.1f}%)") + + # Show normalization effect + print(f"\n=== Normalization Effect ===") + # Calculate row sums for ALL rows (including zero rows) + all_row_sums = normalized_likelihoods.sum(dim=1) + non_zero_rows = (normalized_likelihoods > 0).any(dim=1) + zero_rows = ~non_zero_rows + + print(f"Total rows: {vocab_size:,}") + print(f"Rows with non-zero entries: {non_zero_rows.sum().item():,}") + print(f"Rows with all zeros: {zero_rows.sum().item():,}") + + if non_zero_rows.any(): + row_sums_nonzero = all_row_sums[non_zero_rows] + print(f"\nNon-zero rows statistics:") + print(f" Mean sum: {row_sums_nonzero.mean().item():.6f}") + print(f" Std sum: {row_sums_nonzero.std().item():.6f}") + print(f" Min sum: {row_sums_nonzero.min().item():.6f}") + print(f" Max sum: {row_sums_nonzero.max().item():.6f}") + + # Check how many rows don't sum to 1 (with different tolerance levels) + tolerances = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2] + for tol in tolerances: + perfect_rows = (torch.abs(row_sums_nonzero - 1.0) < tol).sum().item() + imperfect_rows = len(row_sums_nonzero) - perfect_rows + percentage_imperfect = 100 * imperfect_rows / len(row_sums_nonzero) + print(f" Rows NOT summing to 1.0 (tol={tol}): {imperfect_rows:,}/{len(row_sums_nonzero):,} ({percentage_imperfect:.2f}%)") + + # Show distribution of row sums that deviate from 1.0 + if non_zero_rows.any(): + row_sums_nonzero = all_row_sums[non_zero_rows] + deviations = torch.abs(row_sums_nonzero - 1.0) + significant_deviations = deviations > 1e-3 + + if significant_deviations.any(): + print(f"\nRows with significant deviations from 1.0 (>0.001): {significant_deviations.sum().item():,}") + worst_deviations = deviations[significant_deviations] + print(f" Mean deviation: {worst_deviations.mean().item():.6f}") + print(f" Max deviation: {worst_deviations.max().item():.6f}") + + # Show some examples of problematic rows + worst_indices = torch.topk(deviations, k=min(5, len(deviations)))[1] + print(f" Worst {min(5, len(worst_indices))} row examples:") + for i, idx in enumerate(worst_indices): + actual_row_idx = torch.nonzero(non_zero_rows)[idx].item() + sum_val = row_sums_nonzero[idx].item() + deviation = deviations[idx].item() + non_zero_count = (normalized_likelihoods[actual_row_idx] > 0).sum().item() + print(f" Row {actual_row_idx}: sum={sum_val:.6f}, deviation={deviation:.6f}, non_zeros={non_zero_count}") + else: + print(f"\nAll non-zero rows sum very close to 1.0 (deviation < 0.001)") + +def main(): + parser = argparse.ArgumentParser(description="Sort and cut projection matrix by top_k") + parser.add_argument("input_path", help="Path to input projection matrix file") + parser.add_argument("--top_k", type=int, required=True, help="New top_k value for cutoff") + parser.add_argument("--output_path", help="Output path (auto-generated if not specified)") + parser.add_argument("--preserve_last", action="store_true", help="Always preserve the last column as the final element") + parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output") + # python sort_and_cut_projection_matrix.py /lustre/fsw/portfolios/nvr/projects/nvr_lpr_llm/users/pmolchanov/xtoken/models/runs/s4_l1q4b_lr0_kl1_ce0_k1_emb_top10_3_learn_qa2_transformation_matrices/learned_projection_map_latest.pt --top_k 8 --output_path cross_tokenizer_data/projection_matrix_learned_llama_qwen_top8.pt --preserve_last + #s4_l1q4b_lr0_kl1_ce0_k1_emb_top10_3_learn_qa2_transformation_matrices + args = parser.parse_args() + + # Auto-generate output path if not specified + if args.output_path is None: + input_dir = os.path.dirname(args.input_path) + input_filename = os.path.basename(args.input_path) + + # Extract base name and extension + base_name, ext = os.path.splitext(input_filename) + + # Remove old top_k info if present + import re + base_name = re.sub(r'_top_\d+', '', base_name) + + # Add new top_k info and preserve_last flag + suffix = "_sorted" + if args.preserve_last: + suffix += "_preservelast" + output_filename = f"{base_name}_top_{args.top_k}{suffix}{ext}" + args.output_path = os.path.join(input_dir, output_filename) + + # Ensure output directory exists + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + + # Process the matrix + sort_and_cut_projection_matrix( + args.input_path, + args.output_path, + args.top_k, + preserve_last=args.preserve_last, + verbose=not args.quiet + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/unit/data/test_data_processor.py b/tests/unit/data/test_data_processor.py index 4e20788aaa..fe77d7926d 100644 --- a/tests/unit/data/test_data_processor.py +++ b/tests/unit/data/test_data_processor.py @@ -39,6 +39,7 @@ from nemo_rl.data.interfaces import TaskDataProcessFnCallable, TaskDataSpec from nemo_rl.data.processors import ( helpsteer3_data_processor, + kd_data_processor, math_data_processor, math_hf_data_processor, ) @@ -60,10 +61,19 @@ def apply_chat_template( content += "assistant:" return content - def __call__(self, text, return_tensors=None, add_special_tokens=False): + def __call__( + self, + text, + return_tensors=None, + add_special_tokens=False, + max_length=None, + truncation=False, + ): if isinstance(text, list): text = "".join(text) encoded = list(range(len(text))) + if truncation and max_length is not None: + encoded = encoded[:max_length] if return_tensors == "pt": return {"input_ids": torch.tensor([encoded], dtype=torch.long)} return {"input_ids": encoded} @@ -348,3 +358,98 @@ def test_helpsteer3_data_processor(): # Length equals sum of token lengths assert out["length"] == sum(int(m["token_ids"].numel()) for m in msg_log) + + +def test_kd_data_processor(): + datum_dict = { + "messages": [ + {"role": "assistant", "content": "Hello world"}, + {"role": "assistant", "content": "Goodbye"}, + ], + } + tokenizer = DummyTokenizer() + task_spec = TaskDataSpec( + task_name="arrow_text_dataset", + prompt_file=None, + system_prompt_file=None, + ) + + out = kd_data_processor( + datum_dict=datum_dict, + task_data_spec=task_spec, + tokenizer=tokenizer, + max_seq_length=128, + idx=7, + ) + + expected_raw_text = "Hello world\nGoodbye" + assert out["extra_env_info"]["raw_text"] == expected_raw_text + assert out["loss_multiplier"] == 1.0 + assert out["idx"] == 7 + + # Single assistant-role entry; raw text is mirrored back into the message. + assert len(out["message_log"]) == 1 + msg = out["message_log"][0] + assert msg["role"] == "assistant" + assert msg["content"] == expected_raw_text + + # KD invariant: loss applies to every token. + assert isinstance(msg["token_ids"], torch.Tensor) + assert isinstance(msg["token_loss_mask"], torch.Tensor) + assert msg["token_loss_mask"].shape == msg["token_ids"].shape + assert torch.all(msg["token_loss_mask"] == 1) + + # Reported length matches the tokenized output. + assert out["length"] == int(msg["token_ids"].numel()) + + +def test_kd_data_processor_truncates_to_max_seq_length(): + long_text = "x" * 200 + datum_dict = {"messages": [{"role": "assistant", "content": long_text}]} + tokenizer = DummyTokenizer() + task_spec = TaskDataSpec( + task_name="arrow_text_dataset", + prompt_file=None, + system_prompt_file=None, + ) + + out = kd_data_processor( + datum_dict=datum_dict, + task_data_spec=task_spec, + tokenizer=tokenizer, + max_seq_length=64, + idx=0, + ) + + msg = out["message_log"][0] + assert int(msg["token_ids"].numel()) == 64 + assert out["length"] == 64 + # Truncated length equals the cap (not strictly greater) — loss stays on. + assert out["loss_multiplier"] == 1.0 + + +def test_kd_data_processor_skips_non_string_content(): + datum_dict = { + "messages": [ + {"role": "assistant", "content": "kept"}, + {"role": "assistant", "content": None}, + {"role": "assistant"}, # no content key + {"role": "assistant", "content": "also kept"}, + ], + } + tokenizer = DummyTokenizer() + task_spec = TaskDataSpec( + task_name="arrow_text_dataset", + prompt_file=None, + system_prompt_file=None, + ) + + out = kd_data_processor( + datum_dict=datum_dict, + task_data_spec=task_spec, + tokenizer=tokenizer, + max_seq_length=128, + idx=0, + ) + + assert out["extra_env_info"]["raw_text"] == "kept\nalso kept"