From 51a01d9ecedb157d326ce68a834b3d6f9f92f0ef Mon Sep 17 00:00:00 2001 From: Adithya Hanasoge Date: Mon, 11 May 2026 10:44:37 -0700 Subject: [PATCH 1/6] feat(xtoken): add cross-tokenizer off-policy distillation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a single-teacher cross-tokenizer distillation pipeline running on DTensor V2 with FSDP2. The student and teacher use different tokenizers; a learned projection matrix maps student-vocab probabilities into the teacher's vocab so chunk-averaged P-KL can be computed across the alignment between the two token streams. New entrypoints - examples/run_xtoken_distillation.py and examples/configs/xtoken_distillation.yaml exemplar. - nemo_rl/algorithms/xtoken_distillation.py training loop. New data path - nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py: ArrowTextDataset for raw arrow shards. - nemo_rl/data/processors.py: kd_data_processor (raw-text path). - nemo_rl/data/cross_tokenizer_collate.py: collator that tokenizes the same source twice and runs TokenAligner.align inside DataLoader workers. New algorithm package - nemo_rl/algorithms/x_token/{__init__.py, tokenalign.py}: TokenAligner with projection-matrix loader and chunk alignment. Loss - nemo_rl/algorithms/loss/loss_functions.py: CrossTokenizerDistillationLossFn implements chunk-mean P-KL with PT-matched dynamic loss scaling. Sparse projection wrapped in _Fp32SparseMM (custom autograd.Function with amp custom_fwd/custom_bwd) to bypass the missing CUDA BF16 addmm_sparse_cuda kernel on both forward and backward. _load_M filters -1 sentinels and sizes V_t to the teacher's full vocab. Per-step accuracy / projection-accuracy diagnostics and an env-gated per-microbatch loss dump (NRL_XTOKEN_LOSS_DUMP_DIR) for PT-vs-NRL parity comparison. Worker / dispatcher plumbing - nemo_rl/models/automodel/data.py: check_sequence_dim(skip_keys=โ€ฆ) so CT auxiliary tensors that don't follow [B, student_seq, โ€ฆ] can ride on the same BatchedDataDict without tripping the seq-dim guard. - nemo_rl/models/automodel/train.py: register GlobalTopkLogitsPostProcessor in forward_with_post_processing_fn and add XTokenTeacherIPC{Loss,Export}PostProcessor subclasses. - nemo_rl/models/policy/lm_policy.py + workers/dtensor_policy_worker_v2.py: plumb skip_keys through Policy.train and the worker. Signed-off-by: Adithya Hanasoge --- examples/configs/xtoken_distillation.yaml | 176 +++ examples/run_xtoken_distillation.py | 130 +++ nemo_rl/algorithms/loss/loss_functions.py | 519 ++++++++ nemo_rl/algorithms/x_token/__init__.py | 17 + nemo_rl/algorithms/x_token/tokenalign.py | 1039 +++++++++++++++++ nemo_rl/algorithms/xtoken_distillation.py | 714 +++++++++++ nemo_rl/data/cross_tokenizer_collate.py | 166 +++ .../datasets/response_datasets/__init__.py | 3 + .../response_datasets/arrow_text_dataset.py | 101 ++ nemo_rl/data/processors.py | 30 + nemo_rl/models/automodel/data.py | 17 +- nemo_rl/models/automodel/train.py | 93 +- nemo_rl/models/policy/lm_policy.py | 70 +- .../workers/dtensor_policy_worker_v2.py | 106 +- 14 files changed, 3172 insertions(+), 9 deletions(-) create mode 100644 examples/configs/xtoken_distillation.yaml create mode 100644 examples/run_xtoken_distillation.py create mode 100644 nemo_rl/algorithms/x_token/__init__.py create mode 100644 nemo_rl/algorithms/x_token/tokenalign.py create mode 100644 nemo_rl/algorithms/xtoken_distillation.py create mode 100644 nemo_rl/data/cross_tokenizer_collate.py create mode 100644 nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py diff --git a/examples/configs/xtoken_distillation.yaml b/examples/configs/xtoken_distillation.yaml new file mode 100644 index 0000000000..d6a74f0b44 --- /dev/null +++ b/examples/configs/xtoken_distillation.yaml @@ -0,0 +1,176 @@ +# Single-teacher cross-tokenizer off-policy distillation. +# +# Defaults: Llama-3.2-1B (student) <- Qwen3-4B (teacher), arrow-text corpus, +# P-KL loss mode (gold_loss=false, xtoken_loss=false). Override the model +# names, tokenizers, and arrow_files via Hydra CLI for other pairs. + +distillation: + num_prompts_per_step: 64 + max_num_steps: 5000 + max_num_epochs: 1 + topk_logits_k: 64 # must equal loss_fn.vocab_topk + seed: 42 + val_period: 0 # validation disabled by default + val_at_start: false + val_at_end: false + +loss_fn: + projection_matrix_path: "/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_genai/users/avenkateshha/nemo_rl/RL/cross_tokenizer_data/llama_qwen_best_special_exact_map_remapped.pt" + # Loss-mode selection (mutually exclusive): + # gold_loss=false, xtoken_loss=false -> P-KL (full-vocab projection KL) + # gold_loss=true, xtoken_loss=false -> gold-loss (NotImplementedError in v0) + # gold_loss=false, xtoken_loss=true -> xtoken-loss (NotImplementedError in v0) + gold_loss: false + xtoken_loss: false + temperature: 1.0 + vocab_topk: 64 + reverse_kl: false + exact_token_match_only: false + project_teacher_to_student: false + kl_loss_weight: 1.0 + ce_loss_scale: 0.1 + dynamic_loss_scaling: true + +checkpointing: + enabled: false + checkpoint_dir: "checkpoints/xtoken-distillation" + metric_name: "train:loss" + higher_is_better: false + keep_top_k: 3 + save_period: 600 + checkpoint_must_save_by: null + save_optimizer: true + model_save_format: "safetensors" + save_consolidated: false + +policy: + model_name: "meta-llama/Llama-3.2-1B" + tokenizer: + name: "meta-llama/Llama-3.2-1B" + chat_template: null + train_global_batch_size: 64 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 2048 + make_sequence_length_divisible_by: 1 + precision: "bfloat16" + logprob_chunk_size: null + offload_optimizer_for_logprob: true + dtensor_cfg: + enabled: true + _v2: true + cpu_offload: false + sequence_parallel: false + activation_checkpointing: true + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + max_grad_norm: 1.0 + dynamic_batching: + enabled: false + train_mb_tokens: 2048 + logprob_mb_tokens: 2048 + sequence_length_round: 64 + sequence_packing: + enabled: false + train_mb_tokens: 2048 + logprob_mb_tokens: 2048 + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-5 + weight_decay: 0.001 + betas: [0.9, 0.95] + eps: 1e-8 + foreach: false + fused: false + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.02 + end_factor: 1.0 + total_iters: 250 + - name: "torch.optim.lr_scheduler.CosineAnnealingLR" + kwargs: + T_max: 4750 + eta_min: 0.0 + - milestones: [250] + generation: null + +teacher: + model_name: "Qwen/Qwen3-4B" + tokenizer: + name: "Qwen/Qwen3-4B" + chat_template: null + train_global_batch_size: 64 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 2048 + make_sequence_length_divisible_by: 1 + precision: "bfloat16" + logprob_chunk_size: null + offload_optimizer_for_logprob: false + dtensor_cfg: + enabled: true + _v2: true + cpu_offload: false + sequence_parallel: false + activation_checkpointing: true + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + max_grad_norm: 1.0 + dynamic_batching: + enabled: false + train_mb_tokens: 2048 + logprob_mb_tokens: 2048 + sequence_length_round: 64 + sequence_packing: + enabled: false + train_mb_tokens: 2048 + logprob_mb_tokens: 2048 + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-5 + weight_decay: 0.001 + betas: [0.9, 0.95] + eps: 1e-8 + foreach: false + fused: false + generation: null + +data: + max_input_seq_length: 2048 + shuffle: true + num_workers: 4 + train: + dataset_name: "arrow_text" + processor: "kd_data_processor" + arrow_files: null # set via Hydra CLI, e.g. data.train.arrow_files=/path/glob/*.arrow + text_key: "text" + characters_per_sample: 16384 + validation: null + +logger: + log_dir: "logs/xtoken-distillation" + num_val_samples_to_print: 0 + wandb_enabled: false + swanlab_enabled: false + mlflow_enabled: false + tensorboard_enabled: false + monitor_gpus: true + wandb: + project: "x_token" + name: "xtoken-distillation" + gpu_monitoring: + collection_interval: 10 + flush_interval: 10 + +cluster: + gpus_per_node: 8 + num_nodes: 1 diff --git a/examples/run_xtoken_distillation.py b/examples/run_xtoken_distillation.py new file mode 100644 index 0000000000..c3432a4c54 --- /dev/null +++ b/examples/run_xtoken_distillation.py @@ -0,0 +1,130 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Single-teacher cross-tokenizer off-policy distillation entrypoint.""" + +from __future__ import annotations + +import argparse +import os +import pprint + +from omegaconf import OmegaConf + +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.algorithms.xtoken_distillation import ( + MasterConfig, + setup, + xtoken_distillation_train, +) +from nemo_rl.data.datasets import AllTaskProcessedDataset, load_response_dataset +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.utils.config import ( + load_config, + parse_hydra_overrides, + register_omegaconf_resolvers, +) +from nemo_rl.utils.logger import get_next_experiment_dir + + +def parse_args() -> tuple[argparse.Namespace, list[str]]: + """Parse CLI args; unknown args become Hydra overrides.""" + parser = argparse.ArgumentParser( + description="Run single-teacher cross-tokenizer off-policy distillation" + ) + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + args, overrides = parser.parse_known_args() + return args, overrides + + +def main() -> None: + """Main entry point.""" + register_omegaconf_resolvers() + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "xtoken_distillation.yaml" + ) + + config = load_config(args.config) + if overrides: + config = parse_hydra_overrides(config, overrides) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + print("Final config:") + pprint.pprint(config) + + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + if config["checkpointing"]["enabled"]: + print( + f"๐Ÿ“Š Using checkpoint directory: " + f"{config['checkpointing']['checkpoint_dir']}", + flush=True, + ) + + init_ray() + + # Two tokenizers โ€” one each for student and teacher. + student_tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + teacher_tokenizer = get_tokenizer(config["teacher"]["tokenizer"]) + + # Load arrow_text dataset directly (no env / no rollout path). + train_data = load_response_dataset(config["data"]["train"]) + train_dataset = AllTaskProcessedDataset( + train_data.dataset, + student_tokenizer, + train_data.task_spec, + train_data.processor, + max_seq_length=config["data"]["max_input_seq_length"], + ) + val_dataset = None + if config["data"].get("validation") is not None: + val_data = load_response_dataset(config["data"]["validation"]) + val_dataset = AllTaskProcessedDataset( + val_data.dataset, + student_tokenizer, + val_data.task_spec, + val_data.processor, + max_seq_length=config["data"]["max_input_seq_length"], + ) + + ( + student_policy, + teacher_policy, + train_dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + save_state, + master_config, + ) = setup(config, student_tokenizer, teacher_tokenizer, train_dataset, val_dataset) + + xtoken_distillation_train( + student_policy, + teacher_policy, + train_dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + save_state, + master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index df6ff6bc54..a7def7ee13 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Any, NotRequired, Optional, TypedDict, TypeVar import torch @@ -1035,3 +1036,521 @@ def __call__( } return kl_loss, metrics + + +# ===================================================================== +# Cross-tokenizer distillation +# ===================================================================== + + +class CrossTokenizerDistillationLossConfig(TypedDict): + """Config for cross-tokenizer distillation loss. + + Attributes: + projection_matrix_path: Filesystem path to the .pt file containing + either the dense top-k projection (dict with 'indices' and + 'likelihoods' tensors of shape [V_student, top_k]) or the sparse + multi-token format (dict[(student_id, teacher_id)] -> count). + Loaded lazily on first call by each worker process. + gold_loss: If True, switch to gold-loss formulation (1-1 exact-match + partition uses CE; rest uses ULD). v0 stub: raises + NotImplementedError. + xtoken_loss: If True, switch to x-token (multi-token chunk) + formulation. v0 stub: raises NotImplementedError. + temperature: Softmax temperature applied symmetrically to student + and teacher logits before KL. + vocab_topk: Top-k size used for teacher logits. Should equal + distillation.topk_logits_k. + reverse_kl: If True, compute KL(student || teacher) instead of + KL(teacher || student). + exact_token_match_only: If True, only aligned pairs flagged as + 'is_correct' contribute to KL; mismatched pairs are masked out. + project_teacher_to_student: If True, project the teacher distribution + into student vocab via M.T instead of projecting student into + teacher vocab via M. + kl_loss_weight: Scalar multiplier on the KL term. + ce_loss_scale: Scalar multiplier on the CE term. + dynamic_loss_scaling: If True, rescale KL each step so its detached + magnitude matches CE. + """ + + projection_matrix_path: str + gold_loss: bool + xtoken_loss: bool + temperature: float + vocab_topk: int + reverse_kl: bool + exact_token_match_only: bool + project_teacher_to_student: bool + kl_loss_weight: float + ce_loss_scale: float + dynamic_loss_scaling: bool + teacher_vocab_size: int + + +class CrossTokenizerDistillationLossDataDict(TypedDict): + input_ids: torch.Tensor + input_lengths: torch.Tensor + token_mask: torch.Tensor + sample_mask: torch.Tensor + # Per-sample global top-k teacher logits (same vocab columns at every + # teacher position) so chunk-averaged KL has a stable vocab axis. + teacher_topk_logits: torch.Tensor # [B, T_t, k] + teacher_topk_indices: torch.Tensor # [B, k] in teacher vocab + alignment_student_spans: torch.Tensor # [B, max_pairs, 2] + alignment_teacher_spans: torch.Tensor # [B, max_pairs, 2] + alignment_pair_valid: torch.Tensor # [B, max_pairs] + alignment_pair_is_correct: torch.Tensor # [B, max_pairs] + alignment_student_exact_partition_mask: torch.Tensor + alignment_teacher_exact_partition_mask: torch.Tensor + alignment_student_chunk_id: torch.Tensor # [B, T_s], -1 = no chunk + alignment_teacher_chunk_id: torch.Tensor # [B, T_t] + alignment_num_chunks: torch.Tensor + + +class _Fp32SparseMM(torch.autograd.Function): + """FP32 ``M.t() @ dense`` (sparse-dense matmul) ignoring surrounding autocast. + + ``addmm_sparse_cuda`` has no BF16 kernel on either forward or backward. + The worker wraps forward + loss + backward in ``autocast(BF16)``, so a + plain ``with autocast(enabled=False):`` around the forward call is not + enough โ€” ``loss.backward()`` runs inside the outer autocast and the + sparse-mm backward kernel is still dispatched as BF16. The + ``custom_fwd(cast_inputs=torch.float32)`` / ``custom_bwd`` decorators + are PyTorch's official escape: they force FP32 inputs on forward and + run the backward as if autocast were disabled. + + Math matches PT reference ``project_token_likelihoods_ultra_fast``: + autograd's builtin sparse-mm backward computes the same + ``M @ grad_out``. The gradient w.r.t. the sparse argument isn't + needed (the projection matrix is frozen), so it's returned as ``None``. + """ + + @staticmethod + @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) + def forward( + ctx: Any, sparse_M: torch.Tensor, dense: torch.Tensor + ) -> torch.Tensor: + ctx.sparse_M = sparse_M + return torch.sparse.mm(sparse_M.t(), dense) + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward( + ctx: Any, grad_out: torch.Tensor + ) -> tuple[None, torch.Tensor]: + sparse_M = ctx.sparse_M + # out = sparse_M.t() @ dense, so d/d_dense = sparse_M @ grad_out. + grad_dense = torch.sparse.mm(sparse_M, grad_out) + return None, grad_dense + + +class CrossTokenizerDistillationLossFn(LossFunction): + """Cross-tokenizer distillation loss with three configurable modes. + + Mode is selected by ``(gold_loss, xtoken_loss)`` flags: + + - ``(False, False)`` -> P-KL: full-vocab projection KL using projection + matrix M. Implemented in v0. + - ``(True, False)`` -> gold-loss: exact-match partition uses CE on + paired tokens; non-partition uses ULD. **NotImplementedError in v0.** + - ``(False, True)`` -> xtoken-loss: chunk-aggregated KL using + multi-token spans. **NotImplementedError in v0.** + + Inputs (via ``LossInputType.LOGIT``): + logits: ``[B, T_s, V_s]`` raw student logits from the worker forward. + + Inputs (via ``data: BatchedDataDict``): + See :class:`CrossTokenizerDistillationLossDataDict`. + + Returns: + ``(loss, metrics)`` where ``metrics`` contains ``loss``, ``kl_loss``, + ``ce_loss``, ``kl_loss_scale``, ``num_valid_samples``, + ``num_valid_pairs``. + """ + + loss_type = LossType.TOKEN_LEVEL + input_type = LossInputType.LOGIT + + def __init__(self, cfg: CrossTokenizerDistillationLossConfig): + if cfg["gold_loss"] and cfg["xtoken_loss"]: + raise ValueError( + "gold_loss and xtoken_loss are mutually exclusive; set at " + "most one to True." + ) + self.cfg = cfg + self.projection_matrix_path = cfg["projection_matrix_path"] + # Lazy projection-matrix caches; populated on the first call inside + # each worker process. Keyed by device because the worker may run on + # multiple CUDA devices over its lifetime (rare but possible). + self._M_per_device: dict[torch.device, torch.Tensor] = {} + # Optional per-microbatch loss dump for PT-vs-NRL parity comparison. + # Activated by setting NRL_XTOKEN_LOSS_DUMP_DIR. Each rank appends a + # record per call to {dir}/rank{R}.pt. Records are raw floats from + # the loss-compute site, no scaling/aggregation โ€” matches the dump + # protocol in feedback_sanity_loss_dump. + self._loss_dump_dir = os.environ.get("NRL_XTOKEN_LOSS_DUMP_DIR") + self._loss_dump_records: list[dict[str, Any]] = [] + self._loss_dump_call_idx = 0 + + def _load_M(self, device: torch.device) -> torch.Tensor: + """Load and cache the sparse projection matrix on ``device``. + + File format detection is delegated to :class:`TokenAligner` โ€” + importing the loader directly here would couple the loss to the + aligner; instead we re-implement the small loader. The tokenizers + are not needed since we only require the matrix tensor. + """ + if device in self._M_per_device: + return self._M_per_device[device] + + if not os.path.exists(self.projection_matrix_path): + raise FileNotFoundError( + f"Projection matrix file not found: {self.projection_matrix_path}" + ) + data = torch.load( + self.projection_matrix_path, map_location="cpu", weights_only=False + ) + if isinstance(data, dict) and "indices" in data and "likelihoods" in data: + top_indices = data["indices"].long() + top_likelihoods = data["likelihoods"].float() + v_student, top_k = top_indices.shape + student_idx = ( + torch.arange(v_student).unsqueeze(1).expand(-1, top_k).reshape(-1) + ) + teacher_idx = top_indices.reshape(-1) + values = top_likelihoods.reshape(-1) + # `_exact_map_remapped` projection files use -1 as a padding + # sentinel for student rows that have fewer than top_k teacher + # mappings. A negative column index is illegal in a sparse + # tensor and causes CUDA illegal-memory-access in sparse.mm + # (forward and backward). PT's tokenalign clamps to col 0 and + # zeros the value; we drop those entries entirely (COO can + # carry a variable nnz, no need to keep them). + valid_mask = teacher_idx >= 0 + student_idx = student_idx[valid_mask] + teacher_idx = teacher_idx[valid_mask] + values = values[valid_mask] + # Use the teacher's full vocab size as V_t โ€” not max(teacher_idx)+1. + # GlobalTopkLogitsPostProcessor picks top-k over the teacher's + # full vocab, including ids the projection doesn't cover. Sizing + # projected_full to the full teacher vocab makes those columns + # all-zero (correct semantics: unmapped teacher tokens get zero + # projected probability) and keeps the gather in bounds. + projection_max_teacher = int(teacher_idx.max().item()) + 1 + v_teacher = max(self.cfg["teacher_vocab_size"], projection_max_teacher) + indices = torch.stack([student_idx, teacher_idx], dim=0) + shape = (v_student, v_teacher) + elif isinstance(data, dict) and all( + isinstance(k, tuple) and len(k) == 2 for k in data.keys() + ): + keys = list(data.keys()) + values_list = list(data.values()) + student_idx = torch.tensor([k[0] for k in keys], dtype=torch.long) + teacher_idx = torch.tensor([k[1] for k in keys], dtype=torch.long) + indices = torch.stack([student_idx, teacher_idx], dim=0) + values = torch.tensor(values_list, dtype=torch.float32) + v_student = int(student_idx.max().item()) + 1 + v_teacher = int(teacher_idx.max().item()) + 1 + shape = (v_student, v_teacher) + else: + raise ValueError( + f"Unrecognized projection matrix format at " + f"{self.projection_matrix_path}" + ) + + # Coalesced COO. CSR was tried and PyTorch's beta CSR/CSC backward + # path on CUDA hits unrelated illegal-memory-access errors in this + # torch version. COO sparse-dense mm has a stable FP32 backward + # kernel (the original error here was specifically that BF16 was + # missing, which `_Fp32SparseMM.custom_fwd/custom_bwd` now handles). + sparse = torch.sparse_coo_tensor( + indices, values, shape, device=device, dtype=torch.float32 + ).coalesce() + self._M_per_device[device] = sparse + return sparse + + def __call__( + self, + data: BatchedDataDict[CrossTokenizerDistillationLossDataDict], + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + logits: torch.Tensor, + ) -> tuple[torch.Tensor, dict[str, Any]]: + """Compute the cross-tokenizer distillation loss for one microbatch.""" + cfg = self.cfg + + if cfg["gold_loss"]: + raise NotImplementedError( + "gold_loss mode is not implemented in v0. The exact-match " + "partition CE + ULD math from the PT reference still needs " + "to be ported. Run with gold_loss=false in the meantime." + ) + if cfg["xtoken_loss"]: + raise NotImplementedError( + "xtoken_loss mode is not implemented in v0. The chunk-" + "aggregated multi-token KL from the PT reference still " + "needs to be ported. Run with xtoken_loss=false in the " + "meantime." + ) + if cfg["project_teacher_to_student"]: + raise NotImplementedError( + "project_teacher_to_student=True is not implemented in v0. " + "It would invert the projection direction (teacher distribution " + "projected into student vocab via M.T) and isn't on the " + "smoke-test path." + ) + + kl_loss, num_valid_pairs, proj_acc = self._compute_p_kl(logits, data) + ce_loss = self._compute_ce(logits, data, global_valid_toks) + + # Next-token accuracy on the student side, masked to valid tokens. + # Mirrors PT reference at train_distillation_ddp.py:1956 โ€” gives a + # quick per-step signal that's directly comparable to PT's `Acc:` + # log column. + with torch.no_grad(): + student_argmax = logits[:, :-1].argmax(dim=-1) + shift_labels = data["input_ids"][:, 1:] + acc_mask = ( + data["token_mask"][:, 1:].float() + * data["sample_mask"].unsqueeze(-1).float() + ) + denom = acc_mask.sum().clamp(min=1.0) + accuracy = ( + ((student_argmax == shift_labels).float() * acc_mask).sum() + / denom + ) + + if cfg["dynamic_loss_scaling"]: + # Match PT reference exactly (train_distillation_ddp.py:1745-1747): + # dls_scale = ce_loss.item() / kl_loss.item() + # loss = kl_loss * dls_scale + ce_loss + # User-supplied `kl_loss_weight` / `ce_loss_scale` are + # intentionally ignored in this branch โ€” PT does the same. + kl_detached = kl_loss.detach().abs() + ce_detached = ce_loss.detach().abs() + kl_scale = torch.where( + kl_detached > 0, + ce_detached / kl_detached, + torch.ones_like(kl_detached), + ) + loss = kl_scale * kl_loss + ce_loss + else: + kl_scale = torch.tensor( + 1.0, device=kl_loss.device, dtype=kl_loss.dtype + ) + loss = ( + cfg["kl_loss_weight"] * kl_loss + + cfg["ce_loss_scale"] * ce_loss + ) + + metrics = { + "loss": loss.item(), + "kl_loss": kl_loss.item(), + "ce_loss": ce_loss.item(), + "kl_loss_scale": kl_scale.item(), + "accuracy": accuracy.item(), + "proj_accuracy": proj_acc.item(), + "num_valid_samples": data["input_ids"].shape[0], + "num_valid_pairs": int(num_valid_pairs.item()), + } + self._maybe_dump_loss(metrics) + return loss, metrics + + def _maybe_dump_loss(self, metrics: dict[str, Any]) -> None: + """Append per-call raw loss values to a per-rank dump file. + + Activated by ``NRL_XTOKEN_LOSS_DUMP_DIR``. One file per rank, + rewritten on each call with the full record list. Records are raw + ``loss.item()`` values from the loss-compute site โ€” not scaled, + aggregated, or DP-summed โ€” matching the dump protocol used for + PT-vs-NRL parity comparisons (cf. ``feedback_sanity_loss_dump``). + """ + if not self._loss_dump_dir: + return + rank = ( + torch.distributed.get_rank() + if torch.distributed.is_initialized() + else 0 + ) + self._loss_dump_records.append( + { + "call_idx": self._loss_dump_call_idx, + "loss": metrics["loss"], + "kl_loss": metrics["kl_loss"], + "ce_loss": metrics["ce_loss"], + "kl_loss_scale": metrics["kl_loss_scale"], + "num_valid_pairs": metrics["num_valid_pairs"], + } + ) + self._loss_dump_call_idx += 1 + os.makedirs(self._loss_dump_dir, exist_ok=True) + torch.save( + self._loss_dump_records, + os.path.join(self._loss_dump_dir, f"rank{rank}.pt"), + ) + + # ------------------------------------------------------------------ # + # Loss-mode implementations + # ------------------------------------------------------------------ # + def _compute_p_kl( + self, + logits: torch.Tensor, + data: BatchedDataDict[CrossTokenizerDistillationLossDataDict], + ) -> tuple[torch.Tensor, torch.Tensor]: + """P-KL: chunk-averaged KL over the projected teacher-vocab subset. + + Mirrors the PT reference ``compute_KL_loss_optimized`` non-exact-match + branch: chunk-averages student-projected probs over each aligned + student span, chunk-averages teacher log-probs over the paired + teacher span, and KLs the resulting chunk distributions. + + Steps: + + 1. Compute student log-probs at ``T``, exponentiate to probs. + 2. Project full-vocab student probs through ``M`` to teacher vocab. + 3. Slice projection to the per-sample global top-k teacher columns + (carried in ``teacher_topk_indices [B, k]``). + 4. Build per-token chunk masks (one-hot from ``chunk_id``) for both + sides, then ``bmm`` to chunk-sum and divide by chunk size. + 5. Renormalize student chunk distributions inside the top-k subset + (PT convention: avg-then-renormalize, log). + 6. Compute teacher chunk log-probs by chunk-averaging + ``log_softmax(teacher_topk_logits / T)`` directly (same as PT). + 7. Forward (or reverse) KL between chunk distributions. + """ + cfg = self.cfg + T = cfg["temperature"] + device = logits.device + eps = 1e-10 + + b, t_s, v_s = logits.shape + student_log_probs = torch.log_softmax(logits.float() / T, dim=-1) + student_probs = student_log_probs.exp() # [B, T_s, V_s] + + # Project to full teacher vocab. Sparse matmul via M.T trick. + # `_Fp32SparseMM` keeps the op in FP32 on both forward and backward; + # `torch.sparse.mm` has no BF16 kernel and the worker's autocast(BF16) + # context wraps loss.backward(), so a plain `.float()` cast isn't + # enough โ€” the backward kernel is still dispatched as BF16. + M = self._load_M(device) # [V_s, V_t] sparse CSR, fp32 + flat = student_probs.reshape(b * t_s, v_s) + # _Fp32SparseMM internally computes M.t() @ dense; passing M (not + # M.t()) avoids a sparse `.t()` on a saved tensor in backward. + projected_full = _Fp32SparseMM.apply(M, flat.t()).t() # [B*T_s, V_t] + v_t = projected_full.shape[-1] + projected_full = projected_full.reshape(b, t_s, v_t) # [B, T_s, V_t] + + # Per-sample slice to global top-k teacher columns. + teacher_topk_indices = data["teacher_topk_indices"] # [B, k] + teacher_topk_logits = data["teacher_topk_logits"].float() # [B, T_t, k] + _, k = teacher_topk_indices.shape + t_t = teacher_topk_logits.shape[1] + idx_for_proj = teacher_topk_indices.unsqueeze(1).expand(-1, t_s, -1) + projected_topk = torch.gather( + projected_full, dim=-1, index=idx_for_proj + ) # [B, T_s, k] + + # Teacher target log-probs over the top-k subset. PT renormalizes + # softmax over only the kept columns; we follow the same convention. + target_log_probs = torch.log_softmax( + teacher_topk_logits / T, dim=-1 + ) # [B, T_t, k] + + # Build chunk masks via one-hot from the chunk_id tensors. -1 + # entries (no chunk) compare false everywhere and stay out. + student_chunk_id = data["alignment_student_chunk_id"] # [B, T_s] long + teacher_chunk_id = data["alignment_teacher_chunk_id"] # [B, T_t] long + pair_valid = data["alignment_pair_valid"] # [B, max_pairs] + if cfg["exact_token_match_only"]: + pair_valid = pair_valid & data["alignment_pair_is_correct"] + max_chunks = pair_valid.shape[1] + chunk_arange = torch.arange(max_chunks, device=device).view(1, 1, -1) + proj_mask = student_chunk_id.unsqueeze(-1) == chunk_arange # [B, T_s, C] + tgt_mask = teacher_chunk_id.unsqueeze(-1) == chunk_arange # [B, T_t, C] + + # Chunk-aggregate via bmm: sum over positions in each chunk. + proj_mask_f = proj_mask.transpose(1, 2).to(projected_topk.dtype) + tgt_mask_f = tgt_mask.transpose(1, 2).to(target_log_probs.dtype) + proj_chunks = torch.bmm(proj_mask_f, projected_topk) # [B, C, k] + tgt_log_chunks = torch.bmm(tgt_mask_f, target_log_probs) # [B, C, k] + + proj_sizes = proj_mask.sum(dim=1).float() # [B, C] + tgt_sizes = tgt_mask.sum(dim=1).float() # [B, C] + proj_chunks = proj_chunks / (proj_sizes.unsqueeze(-1) + eps) + tgt_log_chunks = tgt_log_chunks / (tgt_sizes.unsqueeze(-1) + eps) + + # PT: renormalize projected chunk distribution within the top-k + # subset, then take log. Teacher side is already log-probs (avg of + # log_softmaxes is what PT computes; not a true log of mean). + proj_chunks = proj_chunks / (proj_chunks.sum(dim=-1, keepdim=True) + eps) + proj_log_chunks = (proj_chunks + eps).log() + + chunk_mask = ( + (proj_sizes > 0) & (tgt_sizes > 0) & pair_valid + ) # [B, C] + if not chunk_mask.any(): + zero = torch.zeros((), device=device, dtype=proj_log_chunks.dtype) + return ( + zero, + torch.zeros((), device=device, dtype=torch.long), + zero.detach(), + ) + + # Projection top-1 accuracy: per-chunk argmax of the student-side + # projected distribution vs the teacher's argmax over the same + # top-k subset. Mirrors PT reference at + # tokenalign.py:4097-4104 โ€” gives a KD-specific accuracy signal. + with torch.no_grad(): + proj_top1 = proj_chunks.argmax(dim=-1) # [B, C] + tgt_top1 = torch.exp(tgt_log_chunks).argmax(dim=-1) # [B, C] + proj_matches = (proj_top1 == tgt_top1) & chunk_mask + proj_acc = proj_matches.sum().float() / chunk_mask.sum().float().clamp( + min=1.0 + ) + + # KL between chunk-averaged distributions. + if cfg["reverse_kl"]: + # KL(student || teacher) + per_chunk_kl = torch.nn.functional.kl_div( + tgt_log_chunks, proj_log_chunks, reduction="none", log_target=True + ).sum(dim=-1) + else: + # Forward KL(teacher || student) + per_chunk_kl = torch.nn.functional.kl_div( + proj_log_chunks, tgt_log_chunks, reduction="none", log_target=True + ).sum(dim=-1) + + sample_mask = data["sample_mask"].to(per_chunk_kl.dtype) # [B] + valid = chunk_mask.to(per_chunk_kl.dtype) * sample_mask.unsqueeze(-1) + denom = valid.sum().clamp(min=1.0) + kl_loss = (per_chunk_kl * valid).sum() / denom * (T * T) + return kl_loss, valid.sum().detach(), proj_acc.detach() + + def _compute_ce( + self, + logits: torch.Tensor, + data: BatchedDataDict[CrossTokenizerDistillationLossDataDict], + global_valid_toks: torch.Tensor, + ) -> torch.Tensor: + """Standard next-token CE on the student side. + + Uses ``token_mask[:, 1:]`` so padded tokens don't contribute. + """ + input_ids = data["input_ids"] + token_mask = data["token_mask"][:, 1:] + sample_mask = data["sample_mask"] + + shift_logits = logits[:, :-1].contiguous() + shift_labels = input_ids[:, 1:].contiguous() + + per_token_ce = torch.nn.functional.cross_entropy( + shift_logits.reshape(-1, shift_logits.shape[-1]).float(), + shift_labels.reshape(-1), + reduction="none", + ).reshape(shift_labels.shape) + + mask = token_mask.float() * sample_mask.unsqueeze(-1).float() + return masked_mean( + per_token_ce, mask, global_normalization_factor=global_valid_toks + ) diff --git a/nemo_rl/algorithms/x_token/__init__.py b/nemo_rl/algorithms/x_token/__init__.py new file mode 100644 index 0000000000..620414c843 --- /dev/null +++ b/nemo_rl/algorithms/x_token/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_rl.algorithms.x_token.tokenalign import AlignmentBatch, TokenAligner + +__all__ = ["AlignmentBatch", "TokenAligner"] diff --git a/nemo_rl/algorithms/x_token/tokenalign.py b/nemo_rl/algorithms/x_token/tokenalign.py new file mode 100644 index 0000000000..7ce1813561 --- /dev/null +++ b/nemo_rl/algorithms/x_token/tokenalign.py @@ -0,0 +1,1039 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Cross-tokenizer token alignment. + +Ports the alignment algorithm from the PyTorch tokenalign reference +(``train_distillation_ddp.py`` companion ``tokenalign.py``). The DP kernel, +canonicalization helpers, anchor optimization, and post-processing are kept +faithful so loss-side parity with that reference is preserved. Anything +unrelated to running cross-tokenizer alignment for off-policy distillation +(rule tracking, learnable projection, MSE loss, multiple compute_loss +variants, accuracy, translation) is dropped. + +Public surface: + - :class:`AlignmentBatch` โ€” dense-padded per-batch alignment payload that + covers all three loss modes (P-KL, gold_loss, xtoken_loss). + - :class:`TokenAligner` โ€” owns the two tokenizers and the projection + matrix, exposes :meth:`align` for the collator. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import Any, List, Tuple + +import numpy as np +import torch + +# Visual byte representations used by some BPE tokenizers (especially for +# emojis / non-ASCII bytes). Copied from the PT reference verbatim โ€” these +# constants are content-coupled to the tokenizers we align across. +VISUAL_BYTE_MAP = { + "รฐ": 240, "ฦ–": 241, "ฦ—": 242, "ฦ˜": 243, "ฦ™": 244, "ฦš": 245, "ฦ›": 246, "ฦœ": 247, + "ฦ": 248, "ฦž": 249, "ฦŸ": 250, "ฦ ": 251, "ฦก": 252, "ฦข": 253, "ฦฃ": 254, "ฦค": 255, + "ล": 156, "ล‚": 157, "ลƒ": 158, "ล„": 159, "ฤบ": 149, "ฤป": 150, "ฤผ": 151, "ฤฝ": 152, + "ฤพ": 153, "ฤฟ": 154, "ล€": 155, "ฤฌ": 135, "ฤญ": 136, "ฤฎ": 137, "ฤฏ": 138, "ฤฐ": 139, + "ฤฑ": 140, "ฤฒ": 141, "ฤณ": 142, "ฤด": 143, "ฤต": 144, "ฤถ": 145, "ฤท": 146, "ฤธ": 147, + "ฤน": 148, "ฤฅ": 128, "ฤฆ": 129, "ฤง": 130, "ฤจ": 131, "ฤฉ": 132, "ฤช": 133, "ฤซ": 134, + "ฤข": 162, "ฤฃ": 163, "ฤœ": 28, "ฤ": 29, "ฤž": 30, "ฤŸ": 31, +} + +# Multi-token encoding artifacts (mojibake patterns) where the broken byte +# sequence spans tokens. Patterns are checked left-to-right with the first +# match wins. From PT reference; trimmed to the high-frequency entries. +_MULTI_TOKEN_ARTIFACT_FIXES = [ + (["ฤ รขฤช", "ฤณ"], ["ฤ โˆ‘"]), (["รขฤช", "ฤณ"], ["โˆ‘"]), + (["ฤ รขฤช", "ฤฑ"], ["ฤ โˆ"]), (["รขฤช", "ฤฑ"], ["โˆ"]), + (["ฤ รขฤช", "ฤค"], ["ฤ โˆ‚"]), (["รขฤช", "ฤค"], ["โˆ‚"]), + (["ฤ รขฤช", "ฤฉ"], ["ฤ โˆ‡"]), (["รขฤช", "ฤฉ"], ["โˆ‡"]), + (["ฤ รขฤช", "ล€"], ["ฤ โˆž"]), (["รขฤช", "ล€"], ["โˆž"]), + (["ฤ รขฤช", "ฤผ"], ["ฤ โˆš"]), (["รขฤช", "ฤผ"], ["โˆš"]), + (["ฤ รขฤช", "ยซ"], ["ฤ โˆซ"]), (["รขฤช", "ยซ"], ["โˆซ"]), + (["ฤ รขฤซ", "ล‚"], ["ฤ โ‰ "]), (["รขฤซ", "ล‚"], ["โ‰ "]), + (["ฤ รคยธ", "ลƒ"], ["ฤ ไธญ"]), (["รคยธ", "ลƒ"], ["ไธญ"]), + (["รฆฤธ", "ฤฉ"], ["ๆ–‡"]), (["ฤ รฆฤธ", "ฤฉ"], ["ฤ ๆ–‡"]), +] + +# Per-token canonicalizations applied after multi-token artifact fixes. +_UNICODE_FIXES = { + "รƒยฑ": "รฑ", "รƒยก": "รก", "รƒยฉ": "รฉ", "รƒยญ": "รญ", "รƒยณ": "รณ", "รƒยบ": "รบ", + "รƒ": "ร€", "รƒยข": "รข", "รƒยง": "รง", + "รƒยจ": "รจ", "รƒยซ": "รซ", "รƒยฎ": "รฎ", "รƒยด": "รด", + "รƒยน": "รน", "รƒยป": "รป", "รƒยฟ": "รฟ", + "รคยธลƒ": "ไธญ", "รฆฤธฤฉ": "ๆ–‡", "รฆฤนยฅรฆฤพยฌ": "ๆ—ฅๆœฌ", "รจยชล€": "่ชž", + "รล‚ร‘ฤฅร‘ฤฃ": "ะ ัƒั", "ร‘ฤฃรยบรยธรยน": "ัะบะธะน", + "ร˜ยงร™ฤฆร˜ยนร˜ยฑร˜ยจร™ฤฌร˜ยฉ": "ุงู„ุนุฑุจูŠุฉ", + "ร ยคยน": "เคน", "ร ยคยฟร ยคฤค": "เคนเคฟเค‚", "ร ยคยฆร ยฅฤข": "เคฆเฅ€", + "รขฤชฤณ": "โˆ‘", "รขฤชฤฑ": "โˆ", "รขฤชฤค": "โˆ‚", "รขฤชฤฉ": "โˆ‡", + "รขฤชล€": "โˆž", "รขฤชฤผ": "โˆš", "รขฤชยซ": "โˆซ", "รขฤซฤช": "โ‰ˆ", + "รขฤซล‚": "โ‰ ", "รขฤซยค": "โ‰ค", "รขฤซยฅ": "โ‰ฅ", +} + +_SPECIAL_TOKEN_MAP = { + "<|begin_of_text|>": "", + "": "", + "": "", +} + + +@dataclass +class AlignmentBatch: + """Per-batch alignment payload covering all three loss modes. + + The collator hands this dataclass directly to the loss fn alongside the + tokenized batch. Tensors are dense-padded to the batch maximum so DTensor + V2 can shard on dim 0 without knowing about cross-tokenizer specifics. + + Attributes: + student_spans: ``[B, max_pairs, 2]`` long tensor with + ``(start, end)`` indices into the student-tokenized sequence. + Empty side of an alignment pair (insertion/deletion) gets + ``(-1, -1)``. + teacher_spans: Like ``student_spans`` but for teacher positions. + pair_valid: ``[B, max_pairs]`` bool. False on padding entries. + pair_is_correct: ``[B, max_pairs]`` bool. True when canonicalized + student span text matches canonicalized teacher span text. + student_exact_partition_mask: ``[B, T_s]`` bool. True at student + tokens that sit on a 1-1 exact-match pair (gold_loss partition). + teacher_exact_partition_mask: ``[B, T_t]`` bool. Counterpart. + student_chunk_id: ``[B, T_s]`` long. Chunk index (= pair index) the + student token belongs to; ``-1`` if not in any chunk + (insertion-only pair on student side). + teacher_chunk_id: ``[B, T_t]`` long. Counterpart. + num_chunks: ``[B]`` long. Number of valid chunks in each sample. + """ + + student_spans: torch.Tensor + teacher_spans: torch.Tensor + pair_valid: torch.Tensor + pair_is_correct: torch.Tensor + student_exact_partition_mask: torch.Tensor + teacher_exact_partition_mask: torch.Tensor + student_chunk_id: torch.Tensor + teacher_chunk_id: torch.Tensor + num_chunks: torch.Tensor + + +class TokenAligner: + """Aligns student and teacher tokenizations of the same source text. + + The alignment algorithm is a Needleman-Wunsch DP over canonicalized token + strings, augmented with multi-token combination scoring (one student + token can match a span of teacher tokens and vice versa, up to + ``max_comb_len``) and anchor-based segmentation for long sequences. + + Construction loads the projection matrix from disk into memory. The + raw COO components (indices, values, shape) are stored picklable so the + aligner can be pickled into DataLoader worker processes; the loss fn is + expected to materialize the actual ``torch.sparse_coo_tensor`` on the + training device on first use. + + Args: + student_tokenizer: HF tokenizer for the student model. + teacher_tokenizer: HF tokenizer for the teacher model. + projection_matrix_path: Path to a ``.pt`` file holding either the + sparse multi-token format (dict[(student_id, teacher_id)] -> count) + or the dense top-k format (dict with ``indices`` and + ``likelihoods`` tensors, shape ``[V_student, top_k]``). + max_comb_len: Maximum span length considered when matching one token + on one side against multiple tokens on the other. + """ + + def __init__( + self, + student_tokenizer, + teacher_tokenizer, + projection_matrix_path: str, + max_comb_len: int = 4, + ): + self.student_tokenizer = student_tokenizer + self.teacher_tokenizer = teacher_tokenizer + self.max_combination_len = max_comb_len + self.projection_matrix_path = projection_matrix_path + + # Loaded lazily by load_projection_matrix(); the loss fn calls that + # explicitly on its training device. We store the raw COO components + # so this aligner remains picklable across DataLoader workers (no + # CUDA tensors, no nn.Parameter). + self._projection_indices: torch.Tensor | None = None + self._projection_values: torch.Tensor | None = None + self._projection_shape: Tuple[int, int] | None = None + + # ------------------------------------------------------------------ # + # Projection matrix loading + # ------------------------------------------------------------------ # + def load_projection_matrix( + self, device: torch.device | str = "cpu" + ) -> torch.Tensor: + """Materialize the projection matrix as a sparse COO tensor. + + Args: + device: Device to place the tensor on. + + Returns: + A ``torch.sparse_coo_tensor`` of shape + ``(V_student, V_teacher)``. + """ + if self._projection_indices is None: + self._load_projection_components() + assert self._projection_indices is not None + assert self._projection_values is not None + assert self._projection_shape is not None + sparse = torch.sparse_coo_tensor( + self._projection_indices, + self._projection_values, + self._projection_shape, + device=device, + dtype=torch.float32, + ).coalesce() + return sparse + + def _load_projection_components(self) -> None: + """Load and normalize the projection matrix into COO components.""" + if not os.path.exists(self.projection_matrix_path): + raise FileNotFoundError( + f"Projection matrix file not found: {self.projection_matrix_path}" + ) + data = torch.load( + self.projection_matrix_path, map_location="cpu", weights_only=False + ) + v_student = len(self.student_tokenizer) + v_teacher = len(self.teacher_tokenizer) + + if isinstance(data, dict) and "indices" in data and "likelihoods" in data: + # Dense top-k format: indices [V_student, top_k] holds teacher + # token ids; likelihoods [V_student, top_k] holds the projection + # weights. We unfold to COO so the loss fn can use a uniform + # sparse-matmul path regardless of file format. + top_indices: torch.Tensor = data["indices"].long() + top_likelihoods: torch.Tensor = data["likelihoods"].float() + assert top_indices.shape == top_likelihoods.shape, ( + f"indices/likelihoods shape mismatch: " + f"{top_indices.shape} vs {top_likelihoods.shape}" + ) + v_student_file, top_k = top_indices.shape + assert v_student_file == v_student, ( + f"projection rows ({v_student_file}) != student vocab " + f"({v_student}) for tokenizer " + f"{getattr(self.student_tokenizer, 'name_or_path', '?')}" + ) + student_idx = ( + torch.arange(v_student).unsqueeze(1).expand(-1, top_k).reshape(-1) + ) + teacher_idx = top_indices.reshape(-1) + values = top_likelihoods.reshape(-1) + # Drop entries that point at out-of-range teacher ids; the dense + # format pads with 0 in some checkpoints, so we keep only valid + # rows. + keep = teacher_idx < v_teacher + self._projection_indices = torch.stack( + [student_idx[keep], teacher_idx[keep]], dim=0 + ) + self._projection_values = values[keep] + self._projection_shape = (v_student, v_teacher) + elif isinstance(data, dict) and all( + isinstance(k, tuple) and len(k) == 2 for k in data.keys() + ): + # Sparse multi-token format: dict[(student_id, teacher_id)] -> count. + keys = list(data.keys()) + values = list(data.values()) + student_idx = torch.tensor([k[0] for k in keys], dtype=torch.long) + teacher_idx = torch.tensor([k[1] for k in keys], dtype=torch.long) + self._projection_indices = torch.stack([student_idx, teacher_idx], dim=0) + self._projection_values = torch.tensor(values, dtype=torch.float32) + self._projection_shape = (v_student, v_teacher) + else: + raise ValueError( + f"Unrecognized projection matrix format at " + f"{self.projection_matrix_path}; expected dict with " + f"'indices'/'likelihoods' tensors or " + f"dict[(student_id, teacher_id)] -> count." + ) + + # ------------------------------------------------------------------ # + # Public alignment API + # ------------------------------------------------------------------ # + def align( + self, + student_ids: torch.Tensor, + teacher_ids: torch.Tensor, + ) -> AlignmentBatch: + """Align a batch of student/teacher token id tensors. + + Args: + student_ids: ``[B, T_s]`` long tensor. + teacher_ids: ``[B, T_t]`` long tensor. + + Returns: + An :class:`AlignmentBatch` with all fields populated for the + three loss modes. + """ + assert student_ids.dim() == 2 and teacher_ids.dim() == 2 + assert student_ids.shape[0] == teacher_ids.shape[0], ( + f"student/teacher batch size mismatch: " + f"{student_ids.shape[0]} vs {teacher_ids.shape[0]}" + ) + b, t_s = student_ids.shape + _, t_t = teacher_ids.shape + + student_token_lists: List[List[str]] = [ + self.student_tokenizer.convert_ids_to_tokens(student_ids[i].tolist()) + for i in range(b) + ] + teacher_token_lists: List[List[str]] = [ + self.teacher_tokenizer.convert_ids_to_tokens(teacher_ids[i].tolist()) + for i in range(b) + ] + + per_sample_pairs: List[List[Tuple[Any, ...]]] = [] + for s_toks, t_toks in zip(student_token_lists, teacher_token_lists): + pairs = self._align_single(s_toks, t_toks) + per_sample_pairs.append(pairs) + + return self._pairs_to_batch(per_sample_pairs, b=b, t_s=t_s, t_t=t_t) + + @staticmethod + def _pairs_to_batch( + per_sample_pairs: List[List[Tuple[Any, ...]]], + *, + b: int, + t_s: int, + t_t: int, + ) -> AlignmentBatch: + """Pack per-sample alignment lists into dense-padded tensors.""" + max_pairs = max((len(p) for p in per_sample_pairs), default=0) + # Guarantee at least one slot so downstream tensor shapes stay sane. + max_pairs = max(max_pairs, 1) + + student_spans = torch.full((b, max_pairs, 2), -1, dtype=torch.long) + teacher_spans = torch.full((b, max_pairs, 2), -1, dtype=torch.long) + pair_valid = torch.zeros((b, max_pairs), dtype=torch.bool) + pair_is_correct = torch.zeros((b, max_pairs), dtype=torch.bool) + student_partition = torch.zeros((b, t_s), dtype=torch.bool) + teacher_partition = torch.zeros((b, t_t), dtype=torch.bool) + student_chunk_id = torch.full((b, t_s), -1, dtype=torch.long) + teacher_chunk_id = torch.full((b, t_t), -1, dtype=torch.long) + num_chunks = torch.zeros((b,), dtype=torch.long) + + for batch_i, pairs in enumerate(per_sample_pairs): + num_chunks[batch_i] = len(pairs) + for pair_i, pair in enumerate(pairs): + # Pair shape: (s_tokens, t_tokens, s_start, s_end, t_start, + # t_end, is_correct). + _, _, s_start, s_end, t_start, t_end, is_correct = pair + if s_start != -1 and s_end != -1: + student_spans[batch_i, pair_i, 0] = s_start + student_spans[batch_i, pair_i, 1] = s_end + if 0 <= s_start < t_s and 0 < s_end <= t_s: + student_chunk_id[batch_i, s_start:s_end] = pair_i + if t_start != -1 and t_end != -1: + teacher_spans[batch_i, pair_i, 0] = t_start + teacher_spans[batch_i, pair_i, 1] = t_end + if 0 <= t_start < t_t and 0 < t_end <= t_t: + teacher_chunk_id[batch_i, t_start:t_end] = pair_i + pair_valid[batch_i, pair_i] = True + pair_is_correct[batch_i, pair_i] = bool(is_correct) + # gold_loss partition: tokens on a 1-1 exact-match pair. + if ( + is_correct + and s_start != -1 + and t_start != -1 + and (s_end - s_start) == 1 + and (t_end - t_start) == 1 + ): + if 0 <= s_start < t_s: + student_partition[batch_i, s_start] = True + if 0 <= t_start < t_t: + teacher_partition[batch_i, t_start] = True + + return AlignmentBatch( + student_spans=student_spans, + teacher_spans=teacher_spans, + pair_valid=pair_valid, + pair_is_correct=pair_is_correct, + student_exact_partition_mask=student_partition, + teacher_exact_partition_mask=teacher_partition, + student_chunk_id=student_chunk_id, + teacher_chunk_id=teacher_chunk_id, + num_chunks=num_chunks, + ) + + # ------------------------------------------------------------------ # + # Per-sample alignment pipeline + # ------------------------------------------------------------------ # + def _align_single( + self, + seq1: List[str], + seq2: List[str], + exact_match_score: float = 3.0, + combination_score_multiplier: float = 1.5, + gap_penalty: float = -1.5, + anchor_lengths: Tuple[int, ...] = (3,), + ) -> List[Tuple[Any, ...]]: + """Run canonicalize -> anchor-DP -> post-process for one sample. + + Returns: + A list of pairs ``(s_tokens, t_tokens, s_start, s_end, t_start, + t_end, is_correct)``. Insertions/deletions use ``-1`` for the + empty side's start/end. + """ + seq1_canon = _canonicalize_sequence(seq1) + seq2_canon = _canonicalize_sequence(seq2) + + kwargs = dict( + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=self.max_combination_len, + ignore_leading_char_diff=False, + ) + + aligned, _ = self._align_with_anchors( + seq1_canon, seq2_canon, anchor_lengths=anchor_lengths, **kwargs + ) + aligned = _post_process_alignment( + aligned, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=self.max_combination_len, + ) + # Attach is_correct mask using canonicalized comparison so that + # ignore_leading_char_diff=True semantics are baked in upstream. + is_correct = _alignment_mask(aligned) + return [(s_t, t_t, s0, s1, t0, t1, m) for (s_t, t_t, s0, s1, t0, t1), m in zip(aligned, is_correct)] + + # ------------------------------------------------------------------ # + # Anchor-based segmentation + # ------------------------------------------------------------------ # + def _align_with_anchors( + self, + seq1: List[str], + seq2: List[str], + anchor_lengths: Tuple[int, ...] = (3,), + **kwargs, + ) -> Tuple[List[Tuple[Any, ...]], float]: + """Optimize long alignments by pinning unique n-gram matches as anchors. + + Falls back to plain DP when no anchors exist or when + ``anchor_lengths`` is empty. + """ + if not anchor_lengths: + return _align_dp(seq1, seq2, **kwargs) + + # Find unique n-gram matches in both sequences. + all_potential_anchors: List[Tuple[int, int, int]] = [] + for anchor_len in anchor_lengths: + if anchor_len == 1: + s1_counts: dict[str, List[int]] = {} + s2_counts: dict[str, List[int]] = {} + for i, t in enumerate(seq1): + s1_counts.setdefault(t, []).append(i) + for j, t in enumerate(seq2): + s2_counts.setdefault(t, []).append(j) + for token in s1_counts.keys() & s2_counts.keys(): + if len(s1_counts[token]) == 1 and len(s2_counts[token]) == 1: + all_potential_anchors.append( + (s1_counts[token][0], s2_counts[token][0], 1) + ) + else: + s1_ngrams: dict[Tuple[str, ...], List[int]] = {} + s2_ngrams: dict[Tuple[str, ...], List[int]] = {} + for i in range(len(seq1) - anchor_len + 1): + s1_ngrams.setdefault(tuple(seq1[i : i + anchor_len]), []).append(i) + for j in range(len(seq2) - anchor_len + 1): + s2_ngrams.setdefault(tuple(seq2[j : j + anchor_len]), []).append(j) + for ngram in s1_ngrams.keys() & s2_ngrams.keys(): + if len(s1_ngrams[ngram]) == 1 and len(s2_ngrams[ngram]) == 1: + i = s1_ngrams[ngram][0] + j = s2_ngrams[ngram][0] + if ( + i + anchor_len <= len(seq1) + and j + anchor_len <= len(seq2) + and seq1[i : i + anchor_len] == seq2[j : j + anchor_len] + ): + all_potential_anchors.append((i, j, anchor_len)) + + # Greedy non-conflicting selection, preferring longer anchors. + all_potential_anchors.sort(key=lambda x: (-x[2], x[0], x[1])) + used_seq1: set[int] = set() + used_seq2: set[int] = set() + selected: List[Tuple[int, int, int]] = [] + for i, j, k in all_potential_anchors: + r1 = set(range(i, i + k)) + r2 = set(range(j, j + k)) + if not (r1 & used_seq1) and not (r2 & used_seq2): + selected.append((i, j, k)) + used_seq1.update(r1) + used_seq2.update(r2) + selected.sort() + + # Validate monotonic ordering. + validated: List[Tuple[int, int, int]] = [] + last_j = -1 + for i, j, k in selected: + if j > last_j and seq1[i : i + k] == seq2[j : j + k]: + validated.append((i, j, k)) + last_j = j + k - 1 + + if not validated: + return _align_dp(seq1, seq2, **kwargs) + + full_alignment: List[Tuple[Any, ...]] = [] + last_i, last_j = 0, 0 + for i, j, k in validated: + seg1, seg2 = seq1[last_i:i], seq2[last_j:j] + if seg1 or seg2: + aligned_segment, _ = _align_dp(seg1, seg2, **kwargs) + full_alignment.extend( + _shift_pairs(aligned_segment, last_i, last_j) + ) + # Anchor itself splits to 1-1 matches. + for kk in range(k): + full_alignment.append( + ( + [seq1[i + kk]], + [seq2[j + kk]], + i + kk, + i + kk + 1, + j + kk, + j + kk + 1, + ) + ) + last_i, last_j = i + k, j + k + + seg1, seg2 = seq1[last_i:], seq2[last_j:] + if seg1 or seg2: + aligned_segment, _ = _align_dp(seg1, seg2, **kwargs) + full_alignment.extend(_shift_pairs(aligned_segment, last_i, last_j)) + + return full_alignment, 0.0 + + +# ===================================================================== +# Module-level helpers (canonicalization, DP kernel, post-process). +# Kept module-level so they pickle cleanly into DataLoader worker processes. +# ===================================================================== + + +def _canonical_token(token: str) -> str: + """Return a canonical representation of a tokenizer token.""" + if not token: + return token + + # Normalize space prefixes. + if token.startswith(" "): + token = "ฤ " + token[1:] + elif token.startswith("_"): + token = "ฤ " + token[1:] + elif token.startswith("โ–"): + token = "ฤ " + token[1:] + + # Newline and whitespace normalization. + if token == "ฤŠ": + token = "\n" + elif token == "\\n": + token = "\n" + elif token == "ฤ‰": + token = "\n" + elif token == "ฤ \n": + token = "\n" + elif "ฤŠ" in token: + token = token.replace("ฤŠ", "\n") + elif "\\n" in token: + token = token.replace("\\n", "\n") + + if token == "ฤ ,": + token = "," + elif token == "ฤ .": + token = "." + elif token == "ฤ ;": + token = ";" + elif token == "ฤ :": + token = ":" + + # SentencePiece byte fallback like <0x20>. + if token.startswith("<0x") and token.endswith(">") and len(token) == 6: + try: + byte_val = int(token[3:5], 16) + if 0 <= byte_val <= 255: + return chr(byte_val) + except ValueError: + pass + + for broken, fixed in _UNICODE_FIXES.items(): + if broken in token: + token = token.replace(broken, fixed) + + if token in _SPECIAL_TOKEN_MAP: + return _SPECIAL_TOKEN_MAP[token] + + return token + + +def _canonicalize_sequence(seq: List[str]) -> List[str]: + """Canonicalize every token in a sequence, including byte-merging.""" + merged = _merge_encoding_artifacts(seq) + canon = [_canonical_token(t) for t in merged] + return _merge_consecutive_bytes(canon) + + +def _merge_encoding_artifacts(tokens: List[str]) -> List[str]: + """Merge known multi-token mojibake patterns into single tokens.""" + if not tokens: + return tokens + result: List[str] = [] + i = 0 + while i < len(tokens): + matched = False + for pattern, replacement in _MULTI_TOKEN_ARTIFACT_FIXES: + pl = len(pattern) + if i + pl <= len(tokens) and tokens[i : i + pl] == pattern: + result.extend(replacement) + i += pl + matched = True + break + if not matched: + result.append(tokens[i]) + i += 1 + return result + + +def _get_byte_value(token_char: str) -> int | None: + """Return the byte value (0..255) for a single character, or None.""" + if len(token_char) != 1: + return None + char_ord = ord(token_char) + if char_ord < 256: + return char_ord + return VISUAL_BYTE_MAP.get(token_char) + + +def _merge_consecutive_bytes(tokens: List[str]) -> List[str]: + """Merge consecutive byte-fallback tokens back into Unicode characters.""" + if not tokens: + return tokens + result: List[str] = [] + byte_buffer: List[str] = [] + for token in tokens: + clean = token.lstrip("ฤ ") + if not clean: + all_bytes = False + else: + all_bytes = all(_get_byte_value(c) is not None for c in clean) + if all_bytes: + byte_buffer.append(token) + else: + if byte_buffer: + result.extend(_try_merge_byte_buffer(byte_buffer)) + byte_buffer = [] + result.append(token) + if byte_buffer: + result.extend(_try_merge_byte_buffer(byte_buffer)) + return result + + +def _try_merge_byte_buffer(byte_tokens: List[str]) -> List[str]: + """Decode 2-4 buffered byte tokens as a single UTF-8 character.""" + if not byte_tokens: + return [] + if len(byte_tokens) == 1: + token = byte_tokens[0] + clean = token.lstrip("ฤ ") + if len(clean) <= 1: + return byte_tokens + + space_prefix = "ฤ " if byte_tokens[0].startswith("ฤ ") else "" + raw_bytes: List[int] = [] + for token in byte_tokens: + clean = token.lstrip("ฤ ") + for c in clean: + v = _get_byte_value(c) + if v is None: + return byte_tokens + raw_bytes.append(v) + + if len(raw_bytes) < 2 or len(raw_bytes) > 4: + return byte_tokens + try: + decoded = bytes(raw_bytes).decode("utf-8") + if len(decoded) == 1 and ord(decoded) > 127: + return [space_prefix + decoded] + return byte_tokens + except UnicodeDecodeError: + return byte_tokens + + +def _strings_equal_flexible(s1: str, s2: str, ignore_leading_char_diff: bool) -> bool: + """Compare two strings, optionally after canonicalization.""" + if not ignore_leading_char_diff: + return s1 == s2 + return _canonical_token(s1) == _canonical_token(s2) + + +def _align_dp( + seq1: List[str], + seq2: List[str], + *, + exact_match_score: float, + combination_score_multiplier: float, + gap_penalty: float, + max_combination_len: int, + ignore_leading_char_diff: bool, +) -> Tuple[List[Tuple[Any, ...]], float]: + """Needleman-Wunsch DP with up-to-``max_combination_len`` token spans.""" + n1, n2 = len(seq1), len(seq2) + dp = np.zeros((n1 + 1, n2 + 1), dtype=np.float32) + trace = np.full((n1 + 1, n2 + 1), "", dtype=object) + + for i in range(1, n1 + 1): + dp[i, 0] = dp[i - 1, 0] + gap_penalty + trace[i, 0] = "up" + for j in range(1, n2 + 1): + dp[0, j] = dp[0, j - 1] + gap_penalty + trace[0, j] = "left" + + joined_seq1 = { + (i - k, i): "".join(seq1[i - k : i]) + for i in range(n1 + 1) + for k in range(1, min(i, max_combination_len) + 1) + } + joined_seq2 = { + (j - k, j): "".join(seq2[j - k : j]) + for j in range(n2 + 1) + for k in range(1, min(j, max_combination_len) + 1) + } + + for i in range(1, n1 + 1): + for j in range(1, n2 + 1): + s1_val, s2_val = seq1[i - 1], seq2[j - 1] + match_score = ( + exact_match_score + if _strings_equal_flexible(s1_val, s2_val, ignore_leading_char_diff) + else -exact_match_score + ) + score_diag = dp[i - 1, j - 1] + match_score + score_up = dp[i - 1, j] + gap_penalty + score_left = dp[i, j - 1] + gap_penalty + + max_score = score_diag + best_move = "diag" + if score_up > max_score: + max_score = score_up + best_move = "up" + if score_left > max_score: + max_score = score_left + best_move = "left" + + for k in range(2, min(j + 1, max_combination_len + 1)): + key = (j - k, j) + if key in joined_seq2 and _strings_equal_flexible( + s1_val, joined_seq2[key], ignore_leading_char_diff + ): + cand = dp[i - 1, j - k] + combination_score_multiplier * k + if cand > max_score: + max_score = cand + best_move = f"comb_s1_over_s2_{k}" + + for k in range(2, min(i + 1, max_combination_len + 1)): + key = (i - k, i) + if key in joined_seq1 and _strings_equal_flexible( + s2_val, joined_seq1[key], ignore_leading_char_diff + ): + cand = dp[i - k, j - 1] + combination_score_multiplier * k + if cand > max_score: + max_score = cand + best_move = f"comb_s2_over_s1_{k}" + + dp[i, j] = max_score + trace[i, j] = best_move + + aligned: List[Tuple[Any, ...]] = [] + i, j = n1, n2 + while i > 0 or j > 0: + move = trace[i, j] + if move == "diag": + aligned.append(([seq1[i - 1]], [seq2[j - 1]], i - 1, i, j - 1, j)) + i -= 1 + j -= 1 + elif move == "up": + aligned.append(([seq1[i - 1]], [], i - 1, i, -1, -1)) + i -= 1 + elif move == "left": + aligned.append(([], [seq2[j - 1]], -1, -1, j - 1, j)) + j -= 1 + elif move.startswith("comb_s1_over_s2_"): + k = int(move.rsplit("_", 1)[-1]) + aligned.append(([seq1[i - 1]], seq2[j - k : j], i - 1, i, j - k, j)) + i -= 1 + j -= k + elif move.startswith("comb_s2_over_s1_"): + k = int(move.rsplit("_", 1)[-1]) + aligned.append((seq1[i - k : i], [seq2[j - 1]], i - k, i, j - 1, j)) + i -= k + j -= 1 + else: + break + aligned.reverse() + return aligned, float(dp[n1, n2]) + + +def _shift_pairs( + pairs: List[Tuple[Any, ...]], shift_s: int, shift_t: int +) -> List[Tuple[Any, ...]]: + """Offset start/end indices of pairs after segment-level alignment.""" + out = [] + for s_toks, t_toks, s_start, s_end, t_start, t_end in pairs: + ns = s_start + shift_s if s_start != -1 else -1 + ne = s_end + shift_s if s_end != -1 else -1 + nts = t_start + shift_t if t_start != -1 else -1 + nte = t_end + shift_t if t_end != -1 else -1 + # Split coarse same-token spans into 1-1 matches. + if ( + len(s_toks) > 1 + and len(s_toks) == len(t_toks) + and s_toks == t_toks + and ns >= 0 + and nts >= 0 + ): + for k in range(len(s_toks)): + out.append( + ( + [s_toks[k]], + [t_toks[k]], + ns + k, + ns + k + 1, + nts + k, + nts + k + 1, + ) + ) + else: + out.append((s_toks, t_toks, ns, ne, nts, nte)) + return out + + +def _alignment_mask(aligned_pairs: List[Tuple[Any, ...]]) -> List[bool]: + """Compute is_correct for each pair using canonicalized text comparison.""" + out: List[bool] = [] + for s_toks, t_toks, *_rest in aligned_pairs: + s_canon = "".join(_canonical_token(tk) for tk in s_toks) if s_toks else "" + t_canon = "".join(_canonical_token(tk) for tk in t_toks) if t_toks else "" + out.append(_strings_equal_flexible(s_canon, t_canon, ignore_leading_char_diff=False)) + return out + + +def _post_process_alignment( + aligned_pairs: List[Tuple[Any, ...]], + *, + exact_match_score: float, + combination_score_multiplier: float, + gap_penalty: float, + max_combination_len: int, + end_mismatch_threshold: float = 0.2, +) -> List[Tuple[Any, ...]]: + """Post-process: combine misaligned consecutive pairs and re-align bad spans. + + Mirrors ``post_process_alignment_optimized`` from the PT reference but + inlined as a module-level function so the aligner stays simple. + """ + if not aligned_pairs: + return [] + + # Step 1: combine consecutive misaligned pairs (away from sequence end). + pair_strings = _build_pair_strings(aligned_pairs) + aligned_pairs = _combine_consecutive_misaligned( + aligned_pairs, pair_strings, end_mismatch_threshold + ) + pair_strings = _build_pair_strings(aligned_pairs) + + # Step 2: split exact-token coarse alignments and re-align small bad spans. + processed: List[Tuple[Any, ...]] = [] + align_cache: dict[Tuple[Tuple[str, ...], Tuple[str, ...]], Tuple[List[Tuple[Any, ...]], bool]] = {} + i = 0 + while i < len(aligned_pairs): + s1_toks, s2_toks, *_rest = aligned_pairs[i] + if len(s1_toks) > 1 and len(s1_toks) == len(s2_toks) and s1_toks == s2_toks: + s1_start, s1_end, s2_start, s2_end = aligned_pairs[i][2:6] + for k in range(len(s1_toks)): + processed.append( + ( + [s1_toks[k]], + [s2_toks[k]], + s1_start + k, + s1_start + k + 1, + s2_start + k, + s2_start + k + 1, + ) + ) + i += 1 + continue + + bad_start = -1 + for j in range(i, len(aligned_pairs)): + if not pair_strings[j][2]: + bad_start = j + break + if bad_start == -1: + processed.extend(aligned_pairs[i:]) + break + processed.extend(aligned_pairs[i:bad_start]) + + found = False + max_chunk = min(10, len(aligned_pairs) - bad_start) + for chunk_size in range(2, max_chunk + 1): + chunk = aligned_pairs[bad_start : bad_start + chunk_size] + chunk_s1, chunk_s2, s1_idx, s2_idx = _flatten_chunk(chunk) + chunk_s1_str = "".join(_canonical_token(t) for t in chunk_s1) + chunk_s2_str = "".join(_canonical_token(t) for t in chunk_s2) + if not _strings_equal_flexible( + chunk_s1_str, chunk_s2_str, ignore_leading_char_diff=False + ): + continue + cache_key = (tuple(chunk_s1), tuple(chunk_s2)) + if cache_key in align_cache: + sub_pairs, perfect = align_cache[cache_key] + else: + sub_pairs, _ = _align_dp( + chunk_s1, + chunk_s2, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=max_combination_len, + ignore_leading_char_diff=False, + ) + perfect = all( + _strings_equal_flexible( + "".join(_canonical_token(t) for t in p[0]), + "".join(_canonical_token(t) for t in p[1]), + ignore_leading_char_diff=False, + ) + for p in sub_pairs + ) + align_cache[cache_key] = (sub_pairs, perfect) + + s1_chunk_start = min(s1_idx[::2]) if s1_idx else -1 + s2_chunk_start = min(s2_idx[::2]) if s2_idx else -1 + if perfect: + for s1_t, s2_t, ss, se, ts, te in sub_pairs: + ns = s1_chunk_start + ss if ss != -1 else -1 + ne = s1_chunk_start + se if se != -1 else -1 + nts = s2_chunk_start + ts if ts != -1 else -1 + nte = s2_chunk_start + te if te != -1 else -1 + processed.append((s1_t, s2_t, ns, ne, nts, nte)) + else: + s1_chunk_end = max(s1_idx[1::2]) if s1_idx else -1 + s2_chunk_end = max(s2_idx[1::2]) if s2_idx else -1 + processed.append( + ( + chunk_s1, + chunk_s2, + s1_chunk_start, + s1_chunk_end, + s2_chunk_start, + s2_chunk_end, + ) + ) + i = bad_start + chunk_size + found = True + break + if not found: + processed.append(aligned_pairs[bad_start]) + i = bad_start + 1 + return processed + + +def _build_pair_strings( + aligned_pairs: List[Tuple[Any, ...]], +) -> List[Tuple[str, str, bool]]: + """Precompute (s_str, t_str, is_match) for each pair.""" + out: List[Tuple[str, str, bool]] = [] + for s_toks, t_toks, *_rest in aligned_pairs: + s_canon = "".join(_canonical_token(t) for t in s_toks) if s_toks else "" + t_canon = "".join(_canonical_token(t) for t in t_toks) if t_toks else "" + is_match = _strings_equal_flexible(s_canon, t_canon, ignore_leading_char_diff=False) + out.append((s_canon, t_canon, is_match)) + return out + + +def _combine_consecutive_misaligned( + aligned_pairs: List[Tuple[Any, ...]], + pair_strings: List[Tuple[str, str, bool]], + end_mismatch_threshold: float, +) -> List[Tuple[Any, ...]]: + """Combine consecutive misaligned pairs into single multi-token chunks.""" + if not aligned_pairs or len(aligned_pairs) < 2: + return aligned_pairs + end_boundary = int(len(aligned_pairs) * (1 - end_mismatch_threshold)) + out: List[Tuple[Any, ...]] = [] + i = 0 + while i < len(aligned_pairs): + if ( + i < end_boundary + and not pair_strings[i][2] + and i + 1 < len(aligned_pairs) + ): + run = [i] + j = i + 1 + while ( + j < end_boundary + and j < len(aligned_pairs) + and not pair_strings[j][2] + ): + run.append(j) + j += 1 + if len(run) >= 2: + combined_s1: List[str] = [] + combined_s2: List[str] = [] + s1_idx: List[int] = [] + s2_idx: List[int] = [] + for idx in run: + s1_t, s2_t, s1s, s1e, s2s, s2e, *_rest = aligned_pairs[idx] + combined_s1.extend(s1_t) + combined_s2.extend(s2_t) + if s1_t and s1s != -1: + s1_idx.extend([s1s, s1e]) + if s2_t and s2s != -1: + s2_idx.extend([s2s, s2e]) + cs1_start = min(s1_idx[::2]) if s1_idx else -1 + cs1_end = max(s1_idx[1::2]) if s1_idx else -1 + cs2_start = min(s2_idx[::2]) if s2_idx else -1 + cs2_end = max(s2_idx[1::2]) if s2_idx else -1 + out.append( + (combined_s1, combined_s2, cs1_start, cs1_end, cs2_start, cs2_end) + ) + i = j + continue + out.append(aligned_pairs[i]) + i += 1 + return out + + +def _flatten_chunk( + chunk: List[Tuple[Any, ...]], +) -> Tuple[List[str], List[str], List[int], List[int]]: + """Concatenate tokens and collect span indices across a chunk of pairs.""" + chunk_s1: List[str] = [] + chunk_s2: List[str] = [] + s1_idx: List[int] = [] + s2_idx: List[int] = [] + for s1_t, s2_t, s1s, s1e, s2s, s2e, *_rest in chunk: + chunk_s1.extend(s1_t) + chunk_s2.extend(s2_t) + if s1_t: + s1_idx.extend([s1s, s1e]) + if s2_t: + s2_idx.extend([s2s, s2e]) + return chunk_s1, chunk_s2, s1_idx, s2_idx diff --git a/nemo_rl/algorithms/xtoken_distillation.py b/nemo_rl/algorithms/xtoken_distillation.py new file mode 100644 index 0000000000..281d00eb70 --- /dev/null +++ b/nemo_rl/algorithms/xtoken_distillation.py @@ -0,0 +1,714 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Single-teacher cross-tokenizer off-policy distillation. + +Training-loop layout mirrors ``run_distillation.py`` / +``nemo_rl/algorithms/distillation.py`` minus the on-policy bits (no env, no +rollout, no generation). Per step: + + 1. Pull a collated batch (student & teacher token ids + alignment). + 2. Run teacher forward via ``Policy.get_topk_logits`` on TEACHER token + ids โ€” gives top-k teacher logits at teacher positions. + 3. Pack alignment payload + teacher topk into a student-side + ``train_data`` dict. + 4. ``student_policy.train(train_data, loss_fn)`` โ€” student forward + + loss + backward + optimizer step happens inside the dtensor v2 + worker. + +The collator and aligner do all the CPU-side cross-tokenizer work; the +loss function does only loss math; this module is just plumbing. +""" + +from __future__ import annotations + +import os +from typing import Any, NotRequired, Optional, TypedDict, cast + +import numpy as np +import torch +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from nemo_rl.algorithms.loss.loss_functions import ( + CrossTokenizerDistillationLossConfig, + CrossTokenizerDistillationLossFn, +) +from nemo_rl.algorithms.utils import set_seed +from nemo_rl.algorithms.x_token import TokenAligner +from nemo_rl.data import DataConfig +from nemo_rl.data.cross_tokenizer_collate import CrossTokenizerCollator +from nemo_rl.data.datasets import AllTaskProcessedDataset +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster +from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy.lm_policy import Policy +from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager +from nemo_rl.utils.logger import Logger, LoggerConfig +from nemo_rl.utils.nsys import maybe_gpu_profile_step +from nemo_rl.utils.timer import TimeoutChecker, Timer + +# Keys packed into the student-side `train_data` BatchedDataDict whose dim 1 +# is NOT the student sequence axis. They ride along on the dict so the loss +# fn can index them per-microbatch, but the worker's `check_sequence_dim` +# pre-flight (which assumes [B, student_seq, ...] for every 2+D tensor) must +# skip them. Sources: +# - teacher_topk_*: produced by GlobalTopkLogitsPostProcessor in +# dtensor_policy_worker_v2.get_global_topk_logits. +# - teacher_input_ids/teacher_token_mask + alignment_*: produced by +# CrossTokenizerCollator (in DataLoader workers). +# alignment_student_chunk_id and alignment_student_exact_partition_mask are +# [B, T_s] and DO follow the student-seq invariant, so they are NOT listed. +XTOKEN_NON_STUDENT_SEQ_KEYS: frozenset[str] = frozenset( + { + "teacher_topk_logits", + "teacher_topk_indices", + "teacher_input_ids", + "teacher_token_mask", + "alignment_student_spans", + "alignment_teacher_spans", + "alignment_pair_valid", + "alignment_pair_is_correct", + "alignment_teacher_exact_partition_mask", + "alignment_teacher_chunk_id", + } +) + +# =============================================================================== +# Configuration +# =============================================================================== + + +class XTokenDistillationConfig(TypedDict): + """Top-level distillation algo config. + + Attributes: + num_prompts_per_step: Global batch size at the dataloader level. + max_num_steps: Max training steps before early stop. + max_num_epochs: Max passes over the training dataset. + topk_logits_k: ``k`` passed to ``teacher_policy.get_topk_logits``. + Should equal ``loss_fn.vocab_topk``. + seed: RNG seed. + val_period: Validation cadence in steps. ``0`` disables validation. + val_at_start: Run validation before training begins. + val_at_end: Run validation on the final step. + """ + + num_prompts_per_step: int + max_num_steps: int + max_num_epochs: int + topk_logits_k: int + seed: int + val_period: int + val_at_start: bool + val_at_end: bool + + +class XTokenDistillationSaveState(TypedDict): + current_epoch: int + current_step: int + total_steps: int + consumed_samples: int + total_valid_tokens: int + val_loss: NotRequired[float] + + +def _default_save_state() -> XTokenDistillationSaveState: + return { + "current_epoch": 0, + "current_step": 0, + "total_steps": 0, + "consumed_samples": 0, + "total_valid_tokens": 0, + } + + +class MasterConfig(TypedDict): + policy: PolicyConfig # student + teacher: PolicyConfig + loss_fn: CrossTokenizerDistillationLossConfig + data: DataConfig + distillation: XTokenDistillationConfig + logger: LoggerConfig + cluster: ClusterConfig + checkpointing: CheckpointingConfig + + +# =============================================================================== +# Setup +# =============================================================================== + + +def setup( + master_config: MasterConfig, + student_tokenizer: PreTrainedTokenizerBase, + teacher_tokenizer: PreTrainedTokenizerBase, + train_dataset: AllTaskProcessedDataset, + val_dataset: Optional[AllTaskProcessedDataset], +) -> tuple[ + Policy, # student + Policy, # teacher + StatefulDataLoader, + Optional[StatefulDataLoader], + CrossTokenizerDistillationLossFn, + Logger, + CheckpointManager, + XTokenDistillationSaveState, + MasterConfig, +]: + """Construct cluster, dataloaders, policies, and loss fn for the run.""" + policy_config = master_config["policy"] + teacher_config = master_config["teacher"] + loss_config = master_config["loss_fn"] + distill_config = master_config["distillation"] + data_config = master_config["data"] + logger_config = master_config["logger"] + cluster_config = master_config["cluster"] + + # Parity check that catches misconfigured topk values early. + assert loss_config["vocab_topk"] == distill_config["topk_logits_k"], ( + f"loss_fn.vocab_topk ({loss_config['vocab_topk']}) must equal " + f"distillation.topk_logits_k ({distill_config['topk_logits_k']}) โ€” " + f"the teacher returns top-k in teacher vocab and the loss fn must " + f"use the same k." + ) + + # Backend gate: this code path is DTensor V2 only by design. + assert policy_config["dtensor_cfg"]["enabled"] and policy_config["dtensor_cfg"].get( + "_v2", False + ), "xtoken distillation requires policy.dtensor_cfg.enabled=true and _v2=true." + assert teacher_config["dtensor_cfg"]["enabled"] and teacher_config["dtensor_cfg"].get( + "_v2", False + ), "xtoken distillation requires teacher.dtensor_cfg.enabled=true and _v2=true." + + set_seed(distill_config["seed"]) + + # ========================== + # Logger + # ========================== + logger = Logger(logger_config) + logger.log_hyperparams(master_config) + + # ========================== + # Checkpointing + # ========================== + checkpointer = CheckpointManager(master_config["checkpointing"]) + last_checkpoint_path = checkpointer.get_latest_checkpoint_path() + save_state: Optional[XTokenDistillationSaveState] = cast( + Optional[XTokenDistillationSaveState], + checkpointer.load_training_info(last_checkpoint_path), + ) + if save_state is None: + save_state = _default_save_state() + + # ========================== + # Aligner + Collator + # ========================== + print("\nโ–ถ Building token aligner and cross-tokenizer collator...", flush=True) + aligner = TokenAligner( + student_tokenizer=student_tokenizer, + teacher_tokenizer=teacher_tokenizer, + projection_matrix_path=loss_config["projection_matrix_path"], + ) + + collator = CrossTokenizerCollator( + student_tokenizer=student_tokenizer, + teacher_tokenizer=teacher_tokenizer, + aligner=aligner, + ctx_length_student=policy_config["max_total_sequence_length"], + ctx_length_teacher=teacher_config["max_total_sequence_length"], + make_seq_div_by_student=policy_config["make_sequence_length_divisible_by"], + make_seq_div_by_teacher=teacher_config["make_sequence_length_divisible_by"], + ) + + # ========================== + # Data + # ========================== + train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=distill_config["num_prompts_per_step"], + shuffle=data_config["shuffle"], + collate_fn=collator, + drop_last=True, + num_workers=data_config["num_workers"], + ) + if last_checkpoint_path is not None: + dataloader_state = torch.load( + os.path.join(last_checkpoint_path, "train_dataloader.pt") + ) + train_dataloader.load_state_dict(dataloader_state) + print( + f" โœ“ Training dataloader loaded with {len(train_dataset)} samples", + flush=True, + ) + + val_dataloader: Optional[StatefulDataLoader] = None + if val_dataset is not None and ( + distill_config["val_period"] > 0 + or distill_config["val_at_start"] + or distill_config["val_at_end"] + ): + val_dataloader = StatefulDataLoader( + val_dataset, + batch_size=distill_config["num_prompts_per_step"], + shuffle=False, + collate_fn=collator, + drop_last=False, + num_workers=data_config["num_workers"], + ) + print( + f" โœ“ Validation dataloader loaded with {len(val_dataset)} samples", + flush=True, + ) + + # ========================== + # Cluster + # ========================== + print("\nโ–ถ Setting up compute cluster...", flush=True) + cluster = RayVirtualCluster( + name="xtoken_distillation_cluster", + bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] + * cluster_config["num_nodes"], + use_gpus=True, + num_gpus_per_node=cluster_config["gpus_per_node"], + max_colocated_worker_groups=2, + ) + + # ========================== + # Teacher Policy + # ========================== + print("\nโ–ถ Setting up teacher policy...", flush=True) + teacher_policy = Policy( + name_prefix="teacher", + cluster=cluster, + config=teacher_config, + tokenizer=teacher_tokenizer, + weights_path=None, + optimizer_path=None, + init_optimizer=False, + init_reference_model=False, + ) + teacher_policy.offload_after_refit() + + # ========================== + # Student Policy + # ========================== + print("\nโ–ถ Setting up student policy...", flush=True) + weights_path, optimizer_path = checkpointer.get_resume_paths(last_checkpoint_path) + student_policy = Policy( + name_prefix="student", + cluster=cluster, + config=policy_config, + tokenizer=student_tokenizer, + weights_path=weights_path, + optimizer_path=optimizer_path, + init_optimizer=True, + init_reference_model=False, + ) + + # ========================== + # Loss + # ========================== + # Inject the teacher's full vocab size so the projection matrix's V_t + # axis covers every teacher id GlobalTopkLogitsPostProcessor can pick. + # `len(tokenizer)` is what HF treats as the embedding/lm_head dim. + loss_config = {**loss_config, "teacher_vocab_size": len(teacher_tokenizer)} + loss_fn = CrossTokenizerDistillationLossFn(loss_config) + + print("\n" + "=" * 60) + print(" " * 18 + "SETUP COMPLETE") + print("=" * 60 + "\n", flush=True) + + return ( + student_policy, + teacher_policy, + train_dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + save_state, + master_config, + ) + + +# =============================================================================== +# Train loop +# =============================================================================== + + +def xtoken_distillation_train( + student_policy: Policy, + teacher_policy: Policy, + dataloader: StatefulDataLoader, + val_dataloader: Optional[StatefulDataLoader], + loss_fn: CrossTokenizerDistillationLossFn, + logger: Logger, + checkpointer: CheckpointManager, + save_state: XTokenDistillationSaveState, + master_config: MasterConfig, +) -> None: + """Off-policy CT distillation training loop.""" + timer = Timer() + timeout = TimeoutChecker( + timeout=master_config["checkpointing"]["checkpoint_must_save_by"], + fit_last_save_time=True, + ) + timeout.start_iterations() + + distill_cfg = master_config["distillation"] + current_epoch = save_state["current_epoch"] + current_step = save_state["current_step"] + total_steps = save_state["total_steps"] + consumed_samples = save_state["consumed_samples"] + total_valid_tokens = save_state["total_valid_tokens"] + val_period = distill_cfg["val_period"] + val_at_start = distill_cfg["val_at_start"] + val_at_end = distill_cfg["val_at_end"] + max_epochs = distill_cfg["max_num_epochs"] + max_steps = distill_cfg["max_num_steps"] + topk_logits_k = distill_cfg["topk_logits_k"] + + if val_at_start and total_steps == 0 and val_dataloader is not None: + val_metrics, val_timings = validate( + student_policy, + teacher_policy, + val_dataloader, + loss_fn, + master_config, + timer=timer, + ) + logger.log_metrics(val_metrics, total_steps, prefix="validation") + logger.log_metrics(val_timings, total_steps, prefix="timing/validation") + + while total_steps < max_steps and current_epoch < max_epochs: + print( + f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_epochs} {'=' * 25}", + flush=True, + ) + for batch in dataloader: + print( + f"\n{'=' * 25} Step {current_step + 1}/" + f"{min(len(dataloader), max_steps)} {'=' * 25}", + flush=True, + ) + maybe_gpu_profile_step(student_policy, total_steps + 1) + + with timer.time("total_step_time"): + with timer.time("teacher_forward_prep"): + teacher_policy.prepare_for_lp_inference() + + with timer.time("teacher_forward"): + teacher_data = BatchedDataDict( + input_ids=batch["teacher_input_ids"], + input_lengths=batch["teacher_input_lengths"], + token_mask=batch["teacher_token_mask"], + sample_mask=batch["sample_mask"], + ) + # Per-sample global top-k: same vocab columns at every + # teacher position, so chunk-averaged KL has a stable + # vocab axis. teacher_topk_logits: [B, T_t, k]; + # teacher_topk_indices: [B, k]. + teacher_topk = teacher_policy.get_global_topk_logits( + teacher_data, k=topk_logits_k, timer=timer + ) + teacher_policy.offload_after_refit() + + # Pack student-side training data with teacher topk and the + # alignment payload the loss fn will index into. + train_data: BatchedDataDict[Any] = BatchedDataDict( + input_ids=batch["input_ids"], + input_lengths=batch["input_lengths"], + token_mask=batch["token_mask"], + sample_mask=batch["sample_mask"], + teacher_topk_logits=teacher_topk["topk_logits"], + teacher_topk_indices=teacher_topk["topk_indices"], + alignment_student_spans=batch["alignment_student_spans"], + alignment_teacher_spans=batch["alignment_teacher_spans"], + alignment_pair_valid=batch["alignment_pair_valid"], + alignment_pair_is_correct=batch["alignment_pair_is_correct"], + alignment_student_exact_partition_mask=( + batch["alignment_student_exact_partition_mask"] + ), + alignment_teacher_exact_partition_mask=( + batch["alignment_teacher_exact_partition_mask"] + ), + alignment_student_chunk_id=batch["alignment_student_chunk_id"], + alignment_teacher_chunk_id=batch["alignment_teacher_chunk_id"], + alignment_num_chunks=batch["alignment_num_chunks"], + ) + train_data.to("cpu") + + with timer.time("training_prep"): + student_policy.prepare_for_training() + + with timer.time("policy_training"): + train_results = student_policy.train( + train_data, + loss_fn, + timer=timer, + skip_keys=XTOKEN_NON_STUDENT_SEQ_KEYS, + ) + + is_last_step = (total_steps + 1 >= max_steps) or ( + (current_epoch + 1 == max_epochs) + and (current_step + 1 == len(dataloader)) + ) + + val_metrics: dict[str, Any] | None = None + if val_dataloader is not None and ( + (val_period > 0 and (total_steps + 1) % val_period == 0) + or (val_at_end and is_last_step) + ): + val_metrics, val_timings = validate( + student_policy, + teacher_policy, + val_dataloader, + loss_fn, + master_config, + timer=timer, + ) + logger.log_metrics( + val_metrics, total_steps + 1, prefix="validation" + ) + logger.log_metrics( + val_timings, total_steps + 1, prefix="timing/validation" + ) + + metrics: dict[str, Any] = {} + metrics.update(train_results["all_mb_metrics"]) + # Reduce per-microbatch metrics to per-step scalars. + for k, v in metrics.items(): + if k in { + "lr", + "wd", + "global_valid_seqs", + "global_valid_toks", + "accuracy", + "proj_accuracy", + "kl_loss_scale", + }: + metrics[k] = float(np.mean(v)) + else: + metrics[k] = float(np.sum(v)) + metrics["loss"] = float(train_results["loss"].numpy()) + metrics["grad_norm"] = float(train_results["grad_norm"].numpy()) + if "global_valid_toks" in metrics: + total_valid_tokens += int(metrics["global_valid_toks"]) + + consumed_samples += distill_cfg["num_prompts_per_step"] + timeout.mark_iteration() + + # ===== Checkpointing ===== + should_save_by_step = ( + is_last_step + or (total_steps + 1) + % master_config["checkpointing"]["save_period"] + == 0 + ) + should_save_by_timeout = timeout.check_save() + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout + ): + student_policy.prepare_for_training() + save_state["current_epoch"] = current_epoch + save_state["current_step"] = current_step + 1 + save_state["total_steps"] = total_steps + 1 + save_state["total_valid_tokens"] = total_valid_tokens + save_state["consumed_samples"] = consumed_samples + if val_metrics is not None and "loss" in val_metrics: + save_state["val_loss"] = float(val_metrics["loss"]) + elif "val_loss" in save_state: + del save_state["val_loss"] + + full_metric_name = master_config["checkpointing"]["metric_name"] + if full_metric_name is not None: + prefix, metric_name = full_metric_name.split(":", 1) + source = metrics if prefix == "train" else (val_metrics or {}) + if metric_name in source: + save_state[full_metric_name] = float(source[metric_name]) + + with timer.time("checkpointing"): + ckpt_path = checkpointer.init_tmp_checkpoint( + total_steps + 1, save_state, master_config + ) + student_policy.save_checkpoint( + weights_path=os.path.join( + ckpt_path, "policy", "weights" + ), + optimizer_path=os.path.join( + ckpt_path, "policy", "optimizer" + ) + if checkpointer.save_optimizer + else None, + tokenizer_path=os.path.join( + ckpt_path, "policy", "tokenizer" + ), + checkpointing_cfg=master_config["checkpointing"], + ) + torch.save( + dataloader.state_dict(), + os.path.join(ckpt_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(ckpt_path) + + # ===== Logging ===== + timing_metrics: dict[str, float] = timer.get_timing_metrics( + reduction_op="sum" + ) # type: ignore + # `metrics["loss"]/kl_loss/ce_loss` are SUM across all DP ranks + # AND microbatches (= dp_size * local_mbs values summed). PT + # logs rank-0 per-microbatch raw, so to compare apples-to-apples + # to PT, also print per-MB-mean. n_mb = len of the flat list of + # per-MB metrics across all ranks. + n_mb = max(len(train_results["all_mb_metrics"].get("loss", [])), 1) + print( + f" โ€ข Loss: {metrics['loss']:.4f} " + f"(per-MB-mean: {metrics['loss'] / n_mb:.4f})", + flush=True, + ) + kl_sum = float(metrics.get("kl_loss", 0.0)) + ce_sum = float(metrics.get("ce_loss", 0.0)) + print( + f" โ€ข KL: {kl_sum:.4f} " + f"(per-MB-mean: {kl_sum / n_mb:.4f})", + flush=True, + ) + print( + f" โ€ข CE: {ce_sum:.4f} " + f"(per-MB-mean: {ce_sum / n_mb:.4f})", + flush=True, + ) + # Next-token accuracy is already a mean-across-MB-ranks (we + # put it in the np.mean branch above), directly comparable to + # PT's `Acc:` log column. + if "accuracy" in metrics: + print( + f" โ€ข Acc: {metrics['accuracy'] * 100:.2f}%", + flush=True, + ) + if "proj_accuracy" in metrics: + print( + f" โ€ข ProjAcc: {metrics['proj_accuracy'] * 100:.2f}%", + flush=True, + ) + print(f" โ€ข Total step time: {timing_metrics.get('total_step_time', 0):.2f}s", flush=True) + for k, v in sorted( + timing_metrics.items(), key=lambda kv: kv[1], reverse=True + ): + if k != "total_step_time": + print(f" โ€ข {k}: {v:.2f}s", flush=True) + + logger.log_metrics(metrics, total_steps + 1, prefix="train") + logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train") + + timer.reset() + current_step += 1 + total_steps += 1 + + if should_save_by_timeout: + print("Timeout reached, stopping training early.", flush=True) + return + if total_steps >= max_steps: + print("Max steps reached, stopping training.", flush=True) + return + + current_epoch += 1 + current_step = 0 + + +# =============================================================================== +# Validation +# =============================================================================== + + +def validate( + student_policy: Policy, + teacher_policy: Policy, + val_dataloader: StatefulDataLoader, + loss_fn: CrossTokenizerDistillationLossFn, + master_config: MasterConfig, + timer: Optional[Timer] = None, +) -> tuple[dict[str, Any], dict[str, Any]]: + """Held-out KL/CE on a validation dataloader. + + Reuses the same per-step path as training, but in eval mode so no + backward / optimizer step runs. Returns mean train-style metrics. + """ + distill_cfg = master_config["distillation"] + topk_logits_k = distill_cfg["topk_logits_k"] + timer = timer if timer is not None else Timer() + + losses: list[float] = [] + kl_losses: list[float] = [] + ce_losses: list[float] = [] + + with timer.time("validation_total"): + teacher_policy.prepare_for_lp_inference() + for batch in val_dataloader: + teacher_data = BatchedDataDict( + input_ids=batch["teacher_input_ids"], + input_lengths=batch["teacher_input_lengths"], + token_mask=batch["teacher_token_mask"], + sample_mask=batch["sample_mask"], + ) + teacher_topk = teacher_policy.get_global_topk_logits( + teacher_data, k=topk_logits_k + ) + + train_data: BatchedDataDict[Any] = BatchedDataDict( + input_ids=batch["input_ids"], + input_lengths=batch["input_lengths"], + token_mask=batch["token_mask"], + sample_mask=batch["sample_mask"], + teacher_topk_logits=teacher_topk["topk_logits"], + teacher_topk_indices=teacher_topk["topk_indices"], + alignment_student_spans=batch["alignment_student_spans"], + alignment_teacher_spans=batch["alignment_teacher_spans"], + alignment_pair_valid=batch["alignment_pair_valid"], + alignment_pair_is_correct=batch["alignment_pair_is_correct"], + alignment_student_exact_partition_mask=( + batch["alignment_student_exact_partition_mask"] + ), + alignment_teacher_exact_partition_mask=( + batch["alignment_teacher_exact_partition_mask"] + ), + alignment_student_chunk_id=batch["alignment_student_chunk_id"], + alignment_teacher_chunk_id=batch["alignment_teacher_chunk_id"], + alignment_num_chunks=batch["alignment_num_chunks"], + ) + train_data.to("cpu") + student_policy.prepare_for_training() + results = student_policy.train( + train_data, + loss_fn, + eval_mode=True, + skip_keys=XTOKEN_NON_STUDENT_SEQ_KEYS, + ) + losses.append(float(results["loss"].numpy())) + mb_metrics = results.get("all_mb_metrics", {}) + if "kl_loss" in mb_metrics: + kl_losses.append(float(np.mean(mb_metrics["kl_loss"]))) + if "ce_loss" in mb_metrics: + ce_losses.append(float(np.mean(mb_metrics["ce_loss"]))) + teacher_policy.offload_after_refit() + + metrics: dict[str, Any] = { + "loss": float(np.mean(losses)) if losses else 0.0, + } + if kl_losses: + metrics["kl_loss"] = float(np.mean(kl_losses)) + if ce_losses: + metrics["ce_loss"] = float(np.mean(ce_losses)) + + return metrics, timer.get_timing_metrics(reduction_op="sum") # type: ignore diff --git a/nemo_rl/data/cross_tokenizer_collate.py b/nemo_rl/data/cross_tokenizer_collate.py new file mode 100644 index 0000000000..7be382c311 --- /dev/null +++ b/nemo_rl/data/cross_tokenizer_collate.py @@ -0,0 +1,166 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Collator that tokenizes raw text twice (student + teacher) and aligns. + +The collator runs inside DataLoader worker processes. It does: + +1. Tokenizes the same source text once with the student tokenizer and once + with the teacher tokenizer (no chat template, no special handling). +2. Calls :class:`TokenAligner.align` to produce a dense-padded + :class:`AlignmentBatch` covering all three loss modes (P-KL, gold_loss, + xtoken_loss). +3. Returns a :class:`BatchedDataDict` with the keys + :class:`Policy.train` expects (``input_ids``, ``input_lengths``, + ``token_mask``, ``sample_mask``) plus teacher tensors and alignment + tensors. + +Loss-side projection-matrix work happens inside the loss fn; nothing related +to KL/CE math runs here. +""" + +from __future__ import annotations + +from typing import Any, List + +import torch +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from nemo_rl.algorithms.x_token.tokenalign import TokenAligner +from nemo_rl.data.interfaces import DatumSpec +from nemo_rl.distributed.batched_data_dict import BatchedDataDict + + +class CrossTokenizerCollator: + """Tokenize twice, align once, return a flat tensor batch. + + Args: + student_tokenizer: HF tokenizer matching the student model. + teacher_tokenizer: HF tokenizer matching the teacher model. + aligner: Pre-constructed :class:`TokenAligner`. + ctx_length_student: Hard tokenization length cap on the student + side (also the padded sequence length of the student tensor). + ctx_length_teacher: Same on the teacher side. + make_seq_div_by_student: Round student sequence length up to a + multiple of this value (typically TP * CP * 2 for DTensor V2). + make_seq_div_by_teacher: Same for the teacher side. + text_key: Field on :class:`DatumSpec` that holds the raw text. Set + by :func:`kd_data_processor` to ``"raw_text"``. + """ + + def __init__( + self, + *, + student_tokenizer: PreTrainedTokenizerBase, + teacher_tokenizer: PreTrainedTokenizerBase, + aligner: TokenAligner, + ctx_length_student: int, + ctx_length_teacher: int, + make_seq_div_by_student: int = 1, + make_seq_div_by_teacher: int = 1, + text_key: str = "raw_text", + ): + self.student_tokenizer = student_tokenizer + self.teacher_tokenizer = teacher_tokenizer + self.aligner = aligner + self.ctx_length_student = ctx_length_student + self.ctx_length_teacher = ctx_length_teacher + self.make_seq_div_by_student = make_seq_div_by_student + self.make_seq_div_by_teacher = make_seq_div_by_teacher + self.text_key = text_key + # Defensive: HF tokenizers without a pad token can't pad batches. + if self.student_tokenizer.pad_token_id is None: + self.student_tokenizer.pad_token = self.student_tokenizer.eos_token + if self.teacher_tokenizer.pad_token_id is None: + self.teacher_tokenizer.pad_token = self.teacher_tokenizer.eos_token + + def __call__(self, batch: List[DatumSpec]) -> BatchedDataDict[Any]: + texts = [datum[self.text_key] for datum in batch] + student_input_ids, student_attention_mask = self._tokenize_batch( + texts, + self.student_tokenizer, + self.ctx_length_student, + self.make_seq_div_by_student, + ) + teacher_input_ids, teacher_attention_mask = self._tokenize_batch( + texts, + self.teacher_tokenizer, + self.ctx_length_teacher, + self.make_seq_div_by_teacher, + ) + alignment = self.aligner.align(student_input_ids, teacher_input_ids) + + sample_mask = torch.tensor( + [datum["loss_multiplier"] for datum in batch], dtype=torch.float32 + ) + idx = [datum["idx"] for datum in batch] + + return BatchedDataDict( + # Student-side keys map onto Policy.train's expected names. + input_ids=student_input_ids, + input_lengths=student_attention_mask.sum(dim=-1).long(), + token_mask=student_attention_mask.long(), + sample_mask=sample_mask, + # Teacher-side keys travel with the batch for the teacher + # forward pass in the trainer. + teacher_input_ids=teacher_input_ids, + teacher_input_lengths=teacher_attention_mask.sum(dim=-1).long(), + teacher_token_mask=teacher_attention_mask.long(), + # Alignment payload, dense-padded so DTensor V2 can shard on dim 0. + alignment_student_spans=alignment.student_spans, + alignment_teacher_spans=alignment.teacher_spans, + alignment_pair_valid=alignment.pair_valid, + alignment_pair_is_correct=alignment.pair_is_correct, + alignment_student_exact_partition_mask=( + alignment.student_exact_partition_mask + ), + alignment_teacher_exact_partition_mask=( + alignment.teacher_exact_partition_mask + ), + alignment_student_chunk_id=alignment.student_chunk_id, + alignment_teacher_chunk_id=alignment.teacher_chunk_id, + alignment_num_chunks=alignment.num_chunks, + idx=idx, + ) + + @staticmethod + def _tokenize_batch( + texts: List[str], + tokenizer: PreTrainedTokenizerBase, + ctx_length: int, + make_seq_div_by: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Tokenize a batch and pad to a multiple of ``make_seq_div_by``.""" + encoded = tokenizer( + texts, + padding="max_length", + truncation=True, + max_length=ctx_length, + return_tensors="pt", + ) + input_ids: torch.Tensor = encoded["input_ids"] + attention_mask: torch.Tensor = encoded["attention_mask"] + + b, t = input_ids.shape + pad = (make_seq_div_by - (t % make_seq_div_by)) % make_seq_div_by + if pad > 0: + pad_ids = torch.full( + (b, pad), + tokenizer.pad_token_id, + dtype=input_ids.dtype, + ) + pad_mask = torch.zeros((b, pad), dtype=attention_mask.dtype) + input_ids = torch.cat([input_ids, pad_ids], dim=1) + attention_mask = torch.cat([attention_mask, pad_mask], dim=1) + + return input_ids, attention_mask diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py index a8928a16a9..b53301ab7a 100644 --- a/nemo_rl/data/datasets/response_datasets/__init__.py +++ b/nemo_rl/data/datasets/response_datasets/__init__.py @@ -14,6 +14,7 @@ from nemo_rl.data import ResponseDatasetConfig from nemo_rl.data.datasets.response_datasets.aime24 import AIME2024Dataset +from nemo_rl.data.datasets.response_datasets.arrow_text_dataset import ArrowTextDataset from nemo_rl.data.datasets.response_datasets.avqa import AVQADataset from nemo_rl.data.datasets.response_datasets.clevr import CLEVRCoGenTDataset from nemo_rl.data.datasets.response_datasets.daily_omni import DailyOmniDataset @@ -46,6 +47,7 @@ DATASET_REGISTRY = { # built-in datasets + "arrow_text": ArrowTextDataset, "avqa": AVQADataset, "AIME2024": AIME2024Dataset, "clevr-cogent": CLEVRCoGenTDataset, @@ -96,6 +98,7 @@ def load_response_dataset(data_config: ResponseDatasetConfig): __all__ = [ + "ArrowTextDataset", "AVQADataset", "AIME2024Dataset", "CLEVRCoGenTDataset", diff --git a/nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py b/nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py new file mode 100644 index 0000000000..c7382b1481 --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py @@ -0,0 +1,101 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Arrow-shard raw-text dataset for cross-tokenizer distillation. + +A minimal dataset that loads an arrow file (or a glob of arrow files), takes +one column of raw text, and optionally packs consecutive rows together into +larger samples by character count. Tokenization is intentionally NOT done +here โ€” the cross-tokenizer collator tokenizes both student and teacher +copies of each text on the fly. +""" + +from __future__ import annotations + +from typing import Any, Iterable + +from datasets import Dataset, load_dataset + +from nemo_rl.data.datasets.raw_dataset import RawDataset + + +class ArrowTextDataset(RawDataset): + """Load arrow shards as a stream of raw text strings. + + Args: + arrow_files: Path or glob to one or more ``.arrow`` files. Forwarded + to ``datasets.load_dataset("arrow", data_files=...)``. + text_key: Column on the loaded dataset that contains the raw text. + characters_per_sample: If set, pack consecutive rows together until + the running character count reaches this threshold; emit a packed + sample and start a fresh one. If ``None``, every input row is + one sample. + split_validation_size: Optional held-out fraction. + seed: Seed for the train/validation split. + """ + + def __init__( + self, + arrow_files: str | list[str], + text_key: str = "text", + characters_per_sample: int | None = None, + split_validation_size: float = 0.0, + seed: int = 42, + **kwargs: Any, + ): + self.text_key = text_key + self.task_name = "x_token" + + raw = load_dataset("arrow", data_files=arrow_files, split="train") + + if characters_per_sample is None or characters_per_sample <= 0: + self.dataset = raw.map( + lambda d: {"text": d[text_key], "task_name": self.task_name}, + remove_columns=raw.column_names, + ) + else: + self.dataset = Dataset.from_generator( + _pack_generator, + gen_kwargs={ + "raw": raw, + "text_key": text_key, + "characters_per_sample": characters_per_sample, + "task_name": self.task_name, + }, + ) + + self.val_dataset = None + self.split_train_validation(split_validation_size, seed) + + +def _pack_generator( + raw: Dataset, + text_key: str, + characters_per_sample: int, + task_name: str, +) -> Iterable[dict[str, Any]]: + """Pack consecutive rows until each pack hits ``characters_per_sample``.""" + buf: list[str] = [] + n = 0 + for row in raw: + text = row[text_key] + if not isinstance(text, str) or not text: + continue + buf.append(text) + n += len(text) + if n >= characters_per_sample: + yield {"text": "\n".join(buf), "task_name": task_name} + buf = [] + n = 0 + if buf: + yield {"text": "\n".join(buf), "task_name": task_name} diff --git a/nemo_rl/data/processors.py b/nemo_rl/data/processors.py index 091c5ad1c5..f3289acb5a 100644 --- a/nemo_rl/data/processors.py +++ b/nemo_rl/data/processors.py @@ -707,6 +707,35 @@ def nemo_gym_data_processor( return output +def kd_data_processor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer: TokenizerType, + max_seq_length: int | None, + idx: int, +) -> DatumSpec: + """Process a raw-text datum for cross-tokenizer distillation. + + The cross-tokenizer collator does the actual tokenization (twice โ€” once + with the student tokenizer, once with the teacher tokenizer), so this + processor only carries the raw text forward. ``message_log`` is left + empty for compatibility with :class:`DatumSpec`; the collator reads + ``raw_text`` instead. + """ + text = datum_dict["text"] + output: DatumSpec = { + "message_log": [], + "length": len(text), + "extra_env_info": None, + "loss_multiplier": 1.0, + "idx": idx, + "raw_text": text, # consumed by CrossTokenizerCollator + } + if "task_name" in datum_dict: + output["task_name"] = datum_dict["task_name"] + return output + + # Processor registry. Key is the processor name, value is the processor function. # Note: We cast the literal dict to Dict[str, TaskDataProcessFnCallable] because # type checkers see each concrete function's signature as a distinct callable type. @@ -718,6 +747,7 @@ def nemo_gym_data_processor( { "default": math_hf_data_processor, "helpsteer3_data_processor": helpsteer3_data_processor, + "kd_data_processor": kd_data_processor, "math_data_processor": math_data_processor, "math_hf_data_processor": math_hf_data_processor, "multichoice_qa_processor": multichoice_qa_processor, diff --git a/nemo_rl/models/automodel/data.py b/nemo_rl/models/automodel/data.py index 98eed48d4f..58c2ff1b01 100644 --- a/nemo_rl/models/automodel/data.py +++ b/nemo_rl/models/automodel/data.py @@ -16,7 +16,7 @@ import itertools from dataclasses import dataclass, field -from typing import Any, Iterator, Optional, Tuple +from typing import Any, Iterable, Iterator, Optional, Tuple import torch from transformers import AutoTokenizer @@ -348,7 +348,11 @@ def process_global_batch( } -def check_sequence_dim(data: BatchedDataDict[Any]) -> Tuple[int, int]: +def check_sequence_dim( + data: BatchedDataDict[Any], + *, + skip_keys: Optional[Iterable[str]] = None, +) -> Tuple[int, int]: """Check and validate sequence dimension across all tensors. Verifies that dimension 1 is the sequence dimension for all tensors @@ -356,6 +360,10 @@ def check_sequence_dim(data: BatchedDataDict[Any]) -> Tuple[int, int]: Args: data: BatchedDataDict to validate + skip_keys: Keys whose tensors are not student-sequence-aligned at + dim 1 and should be excluded from the check (e.g. cross-tokenizer + teacher and alignment auxiliaries that ride along on the same + BatchedDataDict). Returns: Tuple of (sequence_dim, seq_dim_size) @@ -363,9 +371,12 @@ def check_sequence_dim(data: BatchedDataDict[Any]) -> Tuple[int, int]: Raises: AssertionError: If any tensor has inconsistent sequence dimension """ + skip_set = set(skip_keys) if skip_keys is not None else set() sequence_dim = 1 seq_dim_size = data.get("input_ids").shape[sequence_dim] - for _, v in data.items(): + for k, v in data.items(): + if k in skip_set: + continue if torch.is_tensor(v) and len(v.shape) > 1: assert v.shape[sequence_dim] == seq_dim_size, ( f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}" diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index d2b5979400..5aedd9b8bd 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -58,6 +58,7 @@ "LossPostProcessor", "LogprobsPostProcessor", "TopkLogitsPostProcessor", + "GlobalTopkLogitsPostProcessor", "ScorePostProcessor", ] @@ -323,7 +324,12 @@ def forward_with_post_processing_fn( # Score computations should use unscaled logits if isinstance( post_processing_fn, - (LossPostProcessor, LogprobsPostProcessor, TopkLogitsPostProcessor), + ( + LossPostProcessor, + LogprobsPostProcessor, + TopkLogitsPostProcessor, + GlobalTopkLogitsPostProcessor, + ), ): # Temperature scaling is element-wise, directly applying it here. # Other sampling parameters like top-k and top-p need the logits from whole vocabulary, @@ -341,7 +347,8 @@ def forward_with_post_processing_fn( sequence_dim=sequence_dim, ) elif isinstance( - post_processing_fn, (LogprobsPostProcessor, TopkLogitsPostProcessor) + post_processing_fn, + (LogprobsPostProcessor, TopkLogitsPostProcessor, GlobalTopkLogitsPostProcessor), ): result = post_processing_fn( logits=logits, @@ -929,6 +936,88 @@ def __call__( return vals, idx +class GlobalTopkLogitsPostProcessor: + """Post-processor for global top-k logits (single vocab subset per sample). + + Used by cross-tokenizer distillation: the teacher's vocab axis is reduced + to one ``vocab_topk`` subset *per sample* (max over positions, then top-k + over vocab) so the same vocab columns are kept across every position. + This makes chunk-averaged KL well-defined when student and teacher chunks + span multiple positions โ€” every position carries the same vocab axis. + + Output: + vals: ``[B, S, k]`` teacher logits at the selected ``k`` columns. + idx: ``[B, k]`` teacher vocab ids selected per sample. + + v0 limitation: only the no-TP, no-CP, no-seq-packing path is implemented. + The post-processor asserts on the unsupported configurations because + distributed global top-k requires TP-aware reduction that isn't on the + smoke-test path. + """ + + def __init__( + self, + cfg: PolicyConfig, + device_mesh: Any, + cp_mesh: Any, + tp_mesh: Any, + cp_size: int, + k: int, + enable_seq_packing: bool = False, + ): + self.cfg = cfg + self.device_mesh = device_mesh + self.cp_mesh = cp_mesh + self.tp_mesh = tp_mesh + self.cp_size = cp_size + self.k = k + self.enable_seq_packing = enable_seq_packing + + def __call__( + self, + logits: torch.Tensor, + data_dict: BatchedDataDict[Any], + processed_inputs: Any, + original_batch_size: int, + original_seq_len: int, + sequence_dim: int = 1, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.cp_size > 1: + raise NotImplementedError( + "GlobalTopkLogitsPostProcessor: context_parallel_size > 1 is " + "not supported in v0." + ) + if self.enable_seq_packing: + raise NotImplementedError( + "GlobalTopkLogitsPostProcessor: sequence packing is not " + "supported in v0." + ) + if isinstance(logits, DTensor): + tp_group = self.tp_mesh.get_group() if self.tp_mesh is not None else None + tp_size = ( + torch.distributed.get_world_size(tp_group) + if tp_group is not None + else 1 + ) + if tp_size > 1: + raise NotImplementedError( + "GlobalTopkLogitsPostProcessor: tensor_parallel_size > 1 " + "is not supported in v0." + ) + logits = logits.to_local() + + full_logits = logits.to(torch.float32) # [B, S, V] + # Per-sample max over positions, then top-k over vocab. + per_vocab_max = full_logits.max(dim=1).values # [B, V] + _, idx = torch.topk(per_vocab_max, k=self.k, dim=-1) # [B, k] + # Gather the selected k columns at every position for each sample. + # full_logits: [B, S, V], idx: [B, k] -> vals: [B, S, k]. + b, s, _ = full_logits.shape + idx_expanded = idx.unsqueeze(1).expand(-1, s, -1) # [B, S, k] + vals = torch.gather(full_logits, dim=-1, index=idx_expanded) + return vals, idx + + class ScorePostProcessor: """Post-processor for computing reward model scores from model outputs.""" diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index c3f7772c42..8d027d1e57 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -15,7 +15,7 @@ import warnings from collections import defaultdict from contextlib import nullcontext -from typing import Any, Optional, Union +from typing import Any, Iterable, Optional, Union import numpy as np import ray @@ -591,6 +591,62 @@ def get_topk_logits( return stacked + def get_global_topk_logits( + self, + data: BatchedDataDict[GenerationDatumSpec], + k: int, + micro_batch_size: Optional[int] = None, + timer: Optional[Timer] = None, + ) -> BatchedDataDict[Any]: + """Dispatch get_global_topk_logits to workers. + + Returns a BatchedDataDict with: + - ``topk_logits``: ``[B, S, k]`` + - ``topk_indices``: ``[B, k]`` per-sample global vocab subset. + + Used by cross-tokenizer distillation; intentionally doesn't support + dynamic batching or sequence packing in v0. + """ + if self.use_dynamic_batches or self.use_sequence_packing: + raise NotImplementedError( + "get_global_topk_logits does not support dynamic batching or " + "sequence packing in v0." + ) + dp_size = self.sharding_annotations.get_axis_size("data_parallel") + with timer.time("get_global_topk_logits/shard_data") if timer else nullcontext(): + sharded_data = data.shard_by_batch_size( # type: ignore + dp_size, + batch_size=None, + ) + with ( + timer.time("get_global_topk_logits/submit") + if timer + else nullcontext() + ): + futures = self.worker_group.run_all_workers_sharded_data( + "get_global_topk_logits", + data=sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + common_kwargs={"k": k, "micro_batch_size": micro_batch_size}, + ) + worker_batches = self.worker_group.get_all_worker_results(futures) + all_vals = [wb["topk_logits"] for wb in worker_batches] + all_idx = [wb["topk_indices"] for wb in worker_batches] + stacked: BatchedDataDict[Any] = BatchedDataDict() + stacked["topk_logits"] = torch.cat(all_vals, dim=0) + stacked["topk_indices"] = torch.cat(all_idx, dim=0) + return stacked + def train( self, data: BatchedDataDict[Any], @@ -599,8 +655,17 @@ def train( gbs: Optional[int] = None, mbs: Optional[int] = None, timer: Optional[Timer] = None, + skip_keys: Optional[Iterable[str]] = None, ) -> dict[str, Any]: - """Train the policy on a batch of data with a given loss function.""" + """Train the policy on a batch of data with a given loss function. + + Args: + skip_keys: Keys whose tensors are not student-sequence-aligned at + dim 1 and must be excluded from the worker's sequence-dim + pre-flight check. Used by cross-tokenizer distillation to + pass through teacher / alignment auxiliaries that ride on + the same data dict. + """ batch_size = gbs or self.cfg["train_global_batch_size"] micro_batch_size = mbs or self.cfg["train_micro_batch_size"] # Shard and replicate the batch @@ -661,6 +726,7 @@ def train( "eval_mode": eval_mode, "gbs": batch_size, "mbs": micro_batch_size, + "skip_keys": skip_keys, }, ) results = self.worker_group.get_all_worker_results(futures) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 2fa8a8e604..9d7769cbf3 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -16,7 +16,7 @@ import gc import warnings from contextlib import AbstractContextManager, contextmanager, nullcontext -from typing import Any, Generator, Optional +from typing import Any, Generator, Iterable, Optional import ray import torch @@ -54,6 +54,7 @@ LogprobsPostProcessor, LossPostProcessor, ScorePostProcessor, + GlobalTopkLogitsPostProcessor, TopkLogitsPostProcessor, aggregate_training_statistics, automodel_forward_backward, @@ -332,6 +333,7 @@ def train( eval_mode: bool = False, gbs: Optional[int] = None, mbs: Optional[int] = None, + skip_keys: Optional[Iterable[str]] = None, ) -> dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" if gbs is None: @@ -348,7 +350,7 @@ def train( num_global_batches = int(total_dataset_size.item()) // gbs # Validate sequence dimension - sequence_dim, _ = check_sequence_dim(data) + sequence_dim, _ = check_sequence_dim(data, skip_keys=skip_keys) if eval_mode: ctx: AbstractContextManager[Any] = torch.no_grad() @@ -778,6 +780,106 @@ def get_topk_logits( ).cpu() return ret + def get_global_topk_logits( + self, + data: BatchedDataDict[Any], + k: int, + micro_batch_size: Optional[int] = None, + ) -> BatchedDataDict[Any]: + """Return per-sample global top-k teacher logits. + + Computes one ``vocab_topk`` subset per sample (max over positions, + then top-k over vocab) and returns the full sequence's logits at + those columns. Used by cross-tokenizer distillation to enable + chunk-averaged KL with a stable vocab axis. + + Returns: + BatchedDataDict with: + - ``topk_logits``: [B, S, k] + - ``topk_indices``: [B, k] (per-sample, position-independent) + + v0: TP=1, CP=1, no sequence packing. + """ + topk_batch_size = ( + micro_batch_size + if micro_batch_size is not None + else self.cfg["logprob_batch_size"] + ) + sequence_dim, seq_dim_size = check_sequence_dim(data) + + out_topk_vals: list[torch.Tensor] = [] + out_topk_idx: list[torch.Tensor] = [] + self.model.eval() + + post_processor = GlobalTopkLogitsPostProcessor( + cfg=self.cfg, + device_mesh=self.device_mesh, + cp_mesh=self.cp_mesh, + tp_mesh=self.tp_mesh, + cp_size=self.cp_size, + k=k, + enable_seq_packing=self.enable_seq_packing, + ) + + with torch.no_grad(): + data.to("cuda") + processed_iterator, iterator_len = get_microbatch_iterator( + data, + self.cfg, + topk_batch_size, + self.dp_mesh, + tokenizer=self.tokenizer, + cp_size=self.cp_size, + ) + for batch_idx, processed_mb in enumerate(processed_iterator): + processed_inputs = processed_mb.processed_inputs + with get_train_context( + cp_size=self.cp_size, + cp_mesh=self.cp_mesh, + cp_buffers=processed_inputs.cp_buffers, + sequence_dim=sequence_dim, + dtype=self.dtype, + autocast_enabled=self.autocast_enabled, + ): + (vals, idx), _metrics, _ = forward_with_post_processing_fn( + model=self.model, + post_processing_fn=post_processor, + processed_mb=processed_mb, + is_reward_model=False, + allow_flash_attn_args=self.allow_flash_attn_args, + sampling_params=self.sampling_params, + sequence_dim=sequence_dim, + ) + if batch_idx >= iterator_len: + continue + out_topk_vals.append(vals.cpu()) + out_topk_idx.append(idx.cpu()) + + ret = BatchedDataDict[Any]() + # Pad each microbatch's vals along the sequence dim to the common + # length so concatenation on dim 0 matches the input batch shape. + # idx has no sequence dim and concatenates directly. + all_vals_padded = [] + target_seq_len = seq_dim_size + for vals in out_topk_vals: + pad_needed = target_seq_len - vals.shape[1] + if pad_needed > 0: + vals = torch.nn.functional.pad( + vals, (0, 0, 0, pad_needed, 0, 0), mode="constant", value=0.0 + ) + all_vals_padded.append(vals) + ret["topk_logits"] = ( + torch.cat(all_vals_padded, dim=0) + if len(all_vals_padded) > 1 + else all_vals_padded[0] + ).cpu() + ret["topk_indices"] = ( + torch.cat(out_topk_idx, dim=0) + if len(out_topk_idx) > 1 + else out_topk_idx[0] + ).cpu() + return ret + @contextmanager def use_reference_model(self) -> Generator[None, None, None]: """Context manager that temporarily swaps the reference model and active model. From 40dd5305818f64b7b0ce5f069d3d99465451498c Mon Sep 17 00:00:00 2001 From: Adithya Hanasoge Date: Mon, 11 May 2026 12:09:30 -0700 Subject: [PATCH 2/6] feat(xtoken): IPC teacher-logits transport + topk=8192 - xtoken_distillation runner now calls teacher_policy.get_global_topk_logits_ipc to receive teacher per-sample top-k logits via CUDA IPC handles instead of materializing them on the driver. Driver passes the handles into student_policy.train via teacher_topk_logits_ipc, and workers rebuild the tensors locally with rebuild_cuda_tensor_from_ipc. - lm_policy: new dispatchers (get_global_topk_logits_ipc and release_ipc_buffer) plus an ipc-handle-to-batch pairing assert. Fix release_ipc_buffer to use ray.get(futures) (matches every other run_all_workers_single_data caller). - dtensor_policy_worker_v2: worker-side get_global_topk_logits_ipc, IPC tensor stash, release_ipc_buffer, and matching train-side unpack of teacher_topk_logits_ipc. - Loss fn unpacks the IPC handle list lazily inside _compute_p_kl. - YAML default: topk_logits_k / vocab_topk lifted to 8192 (was 64). Signed-off-by: Adithya Hanasoge --- examples/configs/xtoken_distillation.yaml | 4 +- nemo_rl/algorithms/loss/loss_functions.py | 30 +++- nemo_rl/algorithms/xtoken_distillation.py | 68 ++++++--- nemo_rl/models/policy/lm_policy.py | 76 ++++++++++ .../workers/dtensor_policy_worker_v2.py | 139 +++++++++++++++++- 5 files changed, 287 insertions(+), 30 deletions(-) diff --git a/examples/configs/xtoken_distillation.yaml b/examples/configs/xtoken_distillation.yaml index d6a74f0b44..055b1d9656 100644 --- a/examples/configs/xtoken_distillation.yaml +++ b/examples/configs/xtoken_distillation.yaml @@ -8,7 +8,7 @@ distillation: num_prompts_per_step: 64 max_num_steps: 5000 max_num_epochs: 1 - topk_logits_k: 64 # must equal loss_fn.vocab_topk + topk_logits_k: 8192 # must equal loss_fn.vocab_topk seed: 42 val_period: 0 # validation disabled by default val_at_start: false @@ -23,7 +23,7 @@ loss_fn: gold_loss: false xtoken_loss: false temperature: 1.0 - vocab_topk: 64 + vocab_topk: 8192 reverse_kl: false exact_token_match_only: false project_teacher_to_student: false diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index a7def7ee13..003cfffd84 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -1095,7 +1095,12 @@ class CrossTokenizerDistillationLossDataDict(TypedDict): sample_mask: torch.Tensor # Per-sample global top-k teacher logits (same vocab columns at every # teacher position) so chunk-averaged KL has a stable vocab axis. - teacher_topk_logits: torch.Tensor # [B, T_t, k] + # Either teacher_topk_logits OR teacher_topk_logits_ipc must be present + # (the trainer chooses which transport by calling get_global_topk_logits + # vs get_global_topk_logits_ipc); indices always travel on the CPU/Ray + # path since they're tiny. + teacher_topk_logits: NotRequired[torch.Tensor] # [B, T_t, k] + teacher_topk_logits_ipc: NotRequired[list[dict[str, Any]]] # list[B] of handle dicts teacher_topk_indices: torch.Tensor # [B, k] in teacher vocab alignment_student_spans: torch.Tensor # [B, max_pairs, 2] alignment_teacher_spans: torch.Tensor # [B, max_pairs, 2] @@ -1441,9 +1446,28 @@ def _compute_p_kl( v_t = projected_full.shape[-1] projected_full = projected_full.reshape(b, t_s, v_t) # [B, T_s, V_t] - # Per-sample slice to global top-k teacher columns. + # Per-sample slice to global top-k teacher columns. Teacher logits + # either arrive as a dense [B, T_t, k] tensor (CPU/Ray transport, + # for k=64) or as a list[B] of per-sample CUDA IPC handles (for + # k=8192 where the CPU round-trip would be ~6 GB/step). teacher_topk_indices = data["teacher_topk_indices"] # [B, k] - teacher_topk_logits = data["teacher_topk_logits"].float() # [B, T_t, k] + if "teacher_topk_logits_ipc" in data: + handles = data["teacher_topk_logits_ipc"] # list[mbs] of dicts + assert len(handles) == teacher_topk_indices.shape[0], ( + f"IPC handle list length ({len(handles)}) must match " + f"teacher_topk_indices batch dim " + f"({teacher_topk_indices.shape[0]}). Sharding pairing has " + f"diverged โ€” investigate before trusting the loss." + ) + from nemo_rl.models.policy.utils import rebuild_cuda_tensor_from_ipc + consumer_device = torch.cuda.current_device() + vals_per_sample = [ + rebuild_cuda_tensor_from_ipc(h["logits_ipc"], consumer_device) + for h in handles + ] + teacher_topk_logits = torch.stack(vals_per_sample, dim=0).float() + else: + teacher_topk_logits = data["teacher_topk_logits"].float() # [B, T_t, k] _, k = teacher_topk_indices.shape t_t = teacher_topk_logits.shape[1] idx_for_proj = teacher_topk_indices.unsqueeze(1).expand(-1, t_s, -1) diff --git a/nemo_rl/algorithms/xtoken_distillation.py b/nemo_rl/algorithms/xtoken_distillation.py index 281d00eb70..6bb7eb6de5 100644 --- a/nemo_rl/algorithms/xtoken_distillation.py +++ b/nemo_rl/algorithms/xtoken_distillation.py @@ -72,6 +72,7 @@ XTOKEN_NON_STUDENT_SEQ_KEYS: frozenset[str] = frozenset( { "teacher_topk_logits", + "teacher_topk_logits_ipc", "teacher_topk_indices", "teacher_input_ids", "teacher_token_mask", @@ -415,13 +416,18 @@ def xtoken_distillation_train( token_mask=batch["teacher_token_mask"], sample_mask=batch["sample_mask"], ) - # Per-sample global top-k: same vocab columns at every - # teacher position, so chunk-averaged KL has a stable - # vocab axis. teacher_topk_logits: [B, T_t, k]; - # teacher_topk_indices: [B, k]. - teacher_topk = teacher_policy.get_global_topk_logits( - teacher_data, k=topk_logits_k, timer=timer + # IPC transport: per-sample [T_t, k] logit views are + # exported as CUDA IPC handles and consumed by the + # student in-process. Indices are tiny and go through + # the standard CPU/Ray path. + teacher_handles, teacher_indices = ( + teacher_policy.get_global_topk_logits_ipc( + teacher_data, k=topk_logits_k, timer=timer + ) ) + # Model offload frees the teacher's PARAMS to CPU; the + # IPC-stashed logit tensors live in worker Python state + # and survive this call. teacher_policy.offload_after_refit() # Pack student-side training data with teacher topk and the @@ -431,8 +437,8 @@ def xtoken_distillation_train( input_lengths=batch["input_lengths"], token_mask=batch["token_mask"], sample_mask=batch["sample_mask"], - teacher_topk_logits=teacher_topk["topk_logits"], - teacher_topk_indices=teacher_topk["topk_indices"], + teacher_topk_logits_ipc=teacher_handles, + teacher_topk_indices=teacher_indices, alignment_student_spans=batch["alignment_student_spans"], alignment_teacher_spans=batch["alignment_teacher_spans"], alignment_pair_valid=batch["alignment_pair_valid"], @@ -447,18 +453,27 @@ def xtoken_distillation_train( alignment_teacher_chunk_id=batch["alignment_teacher_chunk_id"], alignment_num_chunks=batch["alignment_num_chunks"], ) + # `.to("cpu")` is a no-op on the IPC handle list (lists are + # not tensors) and on the CPU indices tensor. train_data.to("cpu") with timer.time("training_prep"): student_policy.prepare_for_training() with timer.time("policy_training"): - train_results = student_policy.train( - train_data, - loss_fn, - timer=timer, - skip_keys=XTOKEN_NON_STUDENT_SEQ_KEYS, - ) + try: + train_results = student_policy.train( + train_data, + loss_fn, + timer=timer, + skip_keys=XTOKEN_NON_STUDENT_SEQ_KEYS, + ) + finally: + # Producer-side CUDA tensors must be freed before + # the next teacher forward โ€” otherwise memory grows + # unboundedly. Always release, even on student + # failure. + teacher_policy.release_ipc_buffer() is_last_step = (total_steps + 1 >= max_steps) or ( (current_epoch + 1 == max_epochs) @@ -662,8 +677,10 @@ def validate( token_mask=batch["teacher_token_mask"], sample_mask=batch["sample_mask"], ) - teacher_topk = teacher_policy.get_global_topk_logits( - teacher_data, k=topk_logits_k + teacher_handles, teacher_indices = ( + teacher_policy.get_global_topk_logits_ipc( + teacher_data, k=topk_logits_k + ) ) train_data: BatchedDataDict[Any] = BatchedDataDict( @@ -671,8 +688,8 @@ def validate( input_lengths=batch["input_lengths"], token_mask=batch["token_mask"], sample_mask=batch["sample_mask"], - teacher_topk_logits=teacher_topk["topk_logits"], - teacher_topk_indices=teacher_topk["topk_indices"], + teacher_topk_logits_ipc=teacher_handles, + teacher_topk_indices=teacher_indices, alignment_student_spans=batch["alignment_student_spans"], alignment_teacher_spans=batch["alignment_teacher_spans"], alignment_pair_valid=batch["alignment_pair_valid"], @@ -689,12 +706,15 @@ def validate( ) train_data.to("cpu") student_policy.prepare_for_training() - results = student_policy.train( - train_data, - loss_fn, - eval_mode=True, - skip_keys=XTOKEN_NON_STUDENT_SEQ_KEYS, - ) + try: + results = student_policy.train( + train_data, + loss_fn, + eval_mode=True, + skip_keys=XTOKEN_NON_STUDENT_SEQ_KEYS, + ) + finally: + teacher_policy.release_ipc_buffer() losses.append(float(results["loss"].numpy())) mb_metrics = results.get("all_mb_metrics", {}) if "kl_loss" in mb_metrics: diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 8d027d1e57..9f802d227a 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -647,6 +647,82 @@ def get_global_topk_logits( stacked["topk_indices"] = torch.cat(all_idx, dim=0) return stacked + def get_global_topk_logits_ipc( + self, + data: BatchedDataDict[GenerationDatumSpec], + k: int, + micro_batch_size: Optional[int] = None, + timer: Optional[Timer] = None, + ) -> tuple[list[dict[str, Any]], torch.Tensor]: + """Like get_global_topk_logits but ships the logits via CUDA IPC. + + Logits are returned as a flat ``list[B]`` of per-sample IPC handle + dicts. Each entry's ``logits_ipc`` reconstructs a ``[T_t, k]`` + CUDA tensor on the consumer device. Top-k *indices* are tiny (a + few MB at k=8192) so they're shipped via the standard CPU/Ray + path as a single ``[B, k]`` tensor. + + Caller must invoke :meth:`release_ipc_buffer` after the student + finishes consuming the handles โ€” otherwise the producer-side + CUDA tensors leak. + """ + if self.use_dynamic_batches or self.use_sequence_packing: + raise NotImplementedError( + "get_global_topk_logits_ipc does not support dynamic batching " + "or sequence packing in v0." + ) + dp_size = self.sharding_annotations.get_axis_size("data_parallel") + with ( + timer.time("get_global_topk_logits_ipc/shard_data") + if timer + else nullcontext() + ): + sharded_data = data.shard_by_batch_size( # type: ignore + dp_size, + batch_size=None, + ) + with ( + timer.time("get_global_topk_logits_ipc/submit") + if timer + else nullcontext() + ): + futures = self.worker_group.run_all_workers_sharded_data( + "get_global_topk_logits_ipc", + data=sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + common_kwargs={"k": k, "micro_batch_size": micro_batch_size}, + ) + worker_results = self.worker_group.get_all_worker_results(futures) + all_handles: list[dict[str, Any]] = [] + all_idx: list[torch.Tensor] = [] + for wr in worker_results: + all_handles.extend(wr["per_sample_handles"]) + all_idx.append(wr["topk_indices"]) + indices = torch.cat(all_idx, dim=0) + assert len(all_handles) == indices.shape[0], ( + f"IPC handle list length ({len(all_handles)}) must match the " + f"top-k indices batch dim ({indices.shape[0]}). One of the two " + f"transports lost a sample โ€” sharding pairing is broken." + ) + return all_handles, indices + + def release_ipc_buffer(self) -> None: + """Tell all workers to drop their stashed IPC tensors.""" + futures = self.worker_group.run_all_workers_single_data( + "release_ipc_buffer" + ) + ray.get(futures) + def train( self, data: BatchedDataDict[Any], diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 9d7769cbf3..02ada237ce 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -66,7 +66,10 @@ LogprobOutputSpec, ScoreOutputSpec, ) -from nemo_rl.models.policy.utils import get_runtime_env_for_policy_worker +from nemo_rl.models.policy.utils import ( + get_handle_from_tensor, + get_runtime_env_for_policy_worker, +) from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker from nemo_rl.models.policy.workers.patches import ( apply_transformer_engine_patch, @@ -240,6 +243,12 @@ def __init__( # Initialize checkpoint manager self.checkpoint_manager: Optional[AutomodelCheckpointManager] = None + # Per-step stash for cross-tokenizer teacher logits exported via + # CUDA IPC. Holds references to the [B_r, T_t, k] tensors so they + # survive across the student's train call. Released by + # release_ipc_buffer(). + self._teacher_ipc_buffer: Optional[list[torch.Tensor]] = None + # Validate configuration and prepare runtime settings runtime_config = validate_and_prepare_config( config=config, @@ -880,6 +889,134 @@ def get_global_topk_logits( ).cpu() return ret + def get_global_topk_logits_ipc( + self, + data: BatchedDataDict[Any], + k: int, + micro_batch_size: Optional[int] = None, + ) -> dict[str, Any]: + """Cross-tokenizer teacher forward; logits leave via CUDA IPC. + + Same per-sample global top-k math as :meth:`get_global_topk_logits`, + but the per-sample ``[T_t, k]`` logits views are exported as CUDA IPC + handles instead of moved to CPU. Top-k *indices* are tiny + (``[B_r, k]`` int64 ~= 6 MB for k=8192, B_r=96/dp_size) and go back + through Ray on CPU like before โ€” cheap, simpler, and bookkeeping + between handles and indices is preserved by returning both from this + single call. + + Lifetime: the source ``[B_r, T_t, k]`` CUDA tensor is stashed in + ``self._teacher_ipc_buffer`` so it outlives the consumer's + ``rebuild_cuda_tensor_from_ipc`` call. The driver releases it via + :meth:`release_ipc_buffer` after the student training step finishes. + + Returns: + dict with: + - ``per_sample_handles``: ``list[B_r]`` of dicts each carrying + a single ``logits_ipc`` handle tuple plus shape/dtype. + - ``topk_indices``: CPU tensor ``[B_r, k]``. + """ + topk_batch_size = ( + micro_batch_size + if micro_batch_size is not None + else self.cfg["logprob_batch_size"] + ) + sequence_dim, seq_dim_size = check_sequence_dim(data) + + out_vals: list[torch.Tensor] = [] + out_idx: list[torch.Tensor] = [] + self.model.eval() + + post_processor = GlobalTopkLogitsPostProcessor( + cfg=self.cfg, + device_mesh=self.device_mesh, + cp_mesh=self.cp_mesh, + tp_mesh=self.tp_mesh, + cp_size=self.cp_size, + k=k, + enable_seq_packing=self.enable_seq_packing, + ) + + with torch.no_grad(): + data.to("cuda") + processed_iterator, iterator_len = get_microbatch_iterator( + data, + self.cfg, + topk_batch_size, + self.dp_mesh, + tokenizer=self.tokenizer, + cp_size=self.cp_size, + ) + for batch_idx, processed_mb in enumerate(processed_iterator): + processed_inputs = processed_mb.processed_inputs + with get_train_context( + cp_size=self.cp_size, + cp_mesh=self.cp_mesh, + cp_buffers=processed_inputs.cp_buffers, + sequence_dim=sequence_dim, + dtype=self.dtype, + autocast_enabled=self.autocast_enabled, + ): + (vals, idx), _metrics, _ = forward_with_post_processing_fn( + model=self.model, + post_processing_fn=post_processor, + processed_mb=processed_mb, + is_reward_model=False, + allow_flash_attn_args=self.allow_flash_attn_args, + sampling_params=self.sampling_params, + sequence_dim=sequence_dim, + ) + if batch_idx >= iterator_len: + continue + # Keep vals on CUDA for IPC; pad seq dim now so the stash + # tensor matches the canonical [B_r, T_t, k] shape. + pad_needed = seq_dim_size - vals.shape[1] + if pad_needed > 0: + vals = torch.nn.functional.pad( + vals, (0, 0, 0, pad_needed, 0, 0), mode="constant", value=0.0 + ) + out_vals.append(vals.contiguous()) + out_idx.append(idx.cpu()) + + final_vals = ( + torch.cat(out_vals, dim=0) if len(out_vals) > 1 else out_vals[0] + ) # CUDA [B_r, T_t, k] + final_idx = ( + torch.cat(out_idx, dim=0) if len(out_idx) > 1 else out_idx[0] + ) # CPU [B_r, k] + + # Stash the full tensor so per-sample views remain valid across the + # student's train call. Cleared by release_ipc_buffer(). + if self._teacher_ipc_buffer is None: + self._teacher_ipc_buffer = [] + self._teacher_ipc_buffer.append(final_vals) + + per_sample_handles: list[dict[str, Any]] = [] + for i in range(final_vals.shape[0]): + view_i = final_vals[i] # [T_t, k] view; dim-0 slice of row-major + per_sample_handles.append( + { + "logits_ipc": get_handle_from_tensor(view_i), + "shape": tuple(view_i.shape), + "dtype": view_i.dtype, + } + ) + return { + "per_sample_handles": per_sample_handles, + "topk_indices": final_idx, + } + + def release_ipc_buffer(self) -> None: + """Drop the stashed teacher logits and reclaim GPU memory. + + Called by the driver after the student finishes consuming the + previous step's IPC handles. Must precede the next IPC step or + memory grows unboundedly. + """ + self._teacher_ipc_buffer = None + gc.collect() + torch.cuda.empty_cache() + @contextmanager def use_reference_model(self) -> Generator[None, None, None]: """Context manager that temporarily swaps the reference model and active model. From a5f07865590cd8c23f265c60a4e16c94884a2530 Mon Sep 17 00:00:00 2001 From: Adithya Hanasoge Date: Wed, 13 May 2026 23:12:29 -0700 Subject: [PATCH 3/6] feat(xtoken): port PT gold-loss + xtoken-loss; unify full-vocab IPC Ports the PT reference `compute_KL_loss_optimized` gold path (tokenalign.py:3494-3829) into CrossTokenizerDistillationLossFn: - _build_exact_token_map: vectorized port of PT 3493-3594. Partitions the student/teacher vocabs into common (exact-1-1 mapped) and uncommon sets. Strict mode keeps the lowest-s_idx winner per teacher id; xtoken_loss=true relaxes the exact-map threshold from `== 1.0` to `>= 0.6` and elects the highest-projection-weight student per teacher id via scatter_reduce(amax) with a min-s_idx tiebreaker. Cached per-device for the run since the map depends only on the immutable projection file + xtoken_loss flag. - _compute_gold: log_softmax + chunk-average on full teacher vocab both sides, KL on the common slice (forward or reverse), L1 on sorted uncommon probs (topk-capped at cfg["uncommon_topk"], default 8192, matching PT line 3727). Combined as (kl_common + l1_uncommon) * T**2. Top-1 accuracy on the common slice over valid chunks. - Shared helpers _rebuild_teacher_full_logits / _chunk_average_log_probs / _valid_chunk_mask reused by both _compute_p_kl and _compute_gold. - __init__ now requires gold_loss=True when xtoken_loss=True; the older "<=1 True" mutex was wrong (PT's canonical CT run sets both, and the flag is a modifier inside the gold path). Unifies the teacher-logits transport. Replaces the per-sample top-k GlobalTopkLogitsPostProcessor + get_global_topk_logits[_ipc] pair with a single FullLogitsPostProcessor + get_full_logits_ipc that ships the entire [T_t, V_t] tensor per sample via CUDA IPC. The legacy P-KL path stays correct by computing its microbatch-global top-k inline (max over flat (B*T_t) -> topk over V_t -- matches PT `global_top_indices` at tokenalign.py:3877) instead of at the worker. Closer to PT shape (full vocab in-process) and removes worker-side top-k plumbing. Handles the common HF case where lm_head pads out_features beyond len(tokenizer) (Qwen3-4B: 151936 vs 151669): both _compute_p_kl and _compute_gold slice teacher_full_logits to v_t (projection's V_t = real tokenizer vocab) before any downstream math. Padded columns are non-real tokens; dropping them matches the de-facto NRL contract. Config: adds CrossTokenizerDistillationLossConfig.uncommon_topk; drops XTokenDistillationConfig.topk_logits_k (derived locally from loss_fn.vocab_topk). examples/configs/xtoken_distillation.yaml updated. Runner: XTOKEN_NON_STUDENT_SEQ_KEYS swaps the top-k keys for teacher_full_logits_ipc. Both training and validate() switch transport call sites. Per-MB metric reduction + logging handle both metric shapes (kl_loss/ce_loss vs kl_common/l1_uncommon). _maybe_dump_loss records whichever keys are present so the same dump file format serves both paths -- preserves the NRL_XTOKEN_LOSS_DUMP_DIR parity-check protocol. Verified end-to-end with 100-step 2-node interactive smokes: - P-KL (gold_loss=false): job 11752927 COMPLETED in 11:43, train:loss=3.10, ckpt at step 100. - Gold (gold_loss=true, xtoken_loss=true): job 11754078 COMPLETED in 12:39, train:loss=45.69, KL(common)~0.47, L1(uncommon)~0.002, top-1 accuracy on common ~76.7%, ckpt at step 100. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Adithya Hanasoge --- examples/configs/xtoken_distillation.yaml | 29 +- nemo_rl/algorithms/loss/loss_functions.py | 661 ++++++++++++++---- nemo_rl/algorithms/xtoken_distillation.py | 138 ++-- nemo_rl/models/automodel/train.py | 63 +- nemo_rl/models/policy/lm_policy.py | 96 +-- .../workers/dtensor_policy_worker_v2.py | 162 +---- 6 files changed, 707 insertions(+), 442 deletions(-) diff --git a/examples/configs/xtoken_distillation.yaml b/examples/configs/xtoken_distillation.yaml index 055b1d9656..ea043202d2 100644 --- a/examples/configs/xtoken_distillation.yaml +++ b/examples/configs/xtoken_distillation.yaml @@ -8,7 +8,6 @@ distillation: num_prompts_per_step: 64 max_num_steps: 5000 max_num_epochs: 1 - topk_logits_k: 8192 # must equal loss_fn.vocab_topk seed: 42 val_period: 0 # validation disabled by default val_at_start: false @@ -16,20 +15,32 @@ distillation: loss_fn: projection_matrix_path: "/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_genai/users/avenkateshha/nemo_rl/RL/cross_tokenizer_data/llama_qwen_best_special_exact_map_remapped.pt" - # Loss-mode selection (mutually exclusive): - # gold_loss=false, xtoken_loss=false -> P-KL (full-vocab projection KL) - # gold_loss=true, xtoken_loss=false -> gold-loss (NotImplementedError in v0) - # gold_loss=false, xtoken_loss=true -> xtoken-loss (NotImplementedError in v0) + # Loss-mode selection: + # gold_loss=false -> P-KL: full-vocab teacher logits + # transported via IPC; the loss derives a microbatch-global top-k + # (size `vocab_topk`) inside, projects student to teacher vocab via + # M, and chunk-averages KL on the [k] subset. CE term is added. + # gold_loss=true, xtoken_loss=false -> gold-loss (PT-faithful): + # KL on the exact-mapped common vocab (top-1 weight == 1.0) + L1 + # on the uncommon vocab tail (sorted, capped at `uncommon_topk`). + # gold_loss=true, xtoken_loss=true -> relaxed gold: same partition + # math but exact-map threshold is >= 0.6 and collisions are resolved + # by keeping the higher-weight student. Matches the PT canonical CT + # run. + # `xtoken_loss=true` requires `gold_loss=true`. gold_loss: false xtoken_loss: false temperature: 1.0 - vocab_topk: 8192 + vocab_topk: 8192 # P-KL only; size of the microbatch-global + # top-k derived inside the loss + uncommon_topk: 8192 # gold-loss only; caps the L1 uncommon tail + # (matches PT default at tokenalign.py:3727) reverse_kl: false exact_token_match_only: false project_teacher_to_student: false - kl_loss_weight: 1.0 - ce_loss_scale: 0.1 - dynamic_loss_scaling: true + kl_loss_weight: 1.0 # P-KL only + ce_loss_scale: 0.1 # P-KL only + dynamic_loss_scaling: true # P-KL only checkpointing: enabled: false diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index 003cfffd84..14172c7fd9 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -1052,26 +1052,34 @@ class CrossTokenizerDistillationLossConfig(TypedDict): 'likelihoods' tensors of shape [V_student, top_k]) or the sparse multi-token format (dict[(student_id, teacher_id)] -> count). Loaded lazily on first call by each worker process. - gold_loss: If True, switch to gold-loss formulation (1-1 exact-match - partition uses CE; rest uses ULD). v0 stub: raises - NotImplementedError. - xtoken_loss: If True, switch to x-token (multi-token chunk) - formulation. v0 stub: raises NotImplementedError. + gold_loss: If True, switch to the gold-loss formulation: split the + vocab into an exact-token-mapped *common* set (KL) and an + *uncommon* set (sorted L1). Matches PT + ``compute_KL_loss_optimized`` lines 3494โ€“3829. + xtoken_loss: Modifier inside the gold-loss path. If True, relaxes + the exact-map threshold to ``>= 0.6`` (vs ``== 1.0``) and adds + a collision-replacement rule so multi-token projections can + still contribute exact maps. Requires ``gold_loss=True``. temperature: Softmax temperature applied symmetrically to student and teacher logits before KL. - vocab_topk: Top-k size used for teacher logits. Should equal - distillation.topk_logits_k. + vocab_topk: Microbatch-global top-k size used by the P-KL path + (``gold_loss=False``). Computed inside the loss fn from full + teacher logits, mirroring PT ``global_top_indices``. Inert when + ``gold_loss=True``. + uncommon_topk: Cap on the L1 uncommon-tail sort in the gold path. + Matches PT's hardcoded 8192. Inert when ``gold_loss=False``. reverse_kl: If True, compute KL(student || teacher) instead of KL(teacher || student). exact_token_match_only: If True, only aligned pairs flagged as 'is_correct' contribute to KL; mismatched pairs are masked out. + Used by the P-KL path only. project_teacher_to_student: If True, project the teacher distribution into student vocab via M.T instead of projecting student into teacher vocab via M. - kl_loss_weight: Scalar multiplier on the KL term. - ce_loss_scale: Scalar multiplier on the CE term. + kl_loss_weight: Scalar multiplier on the KL term (P-KL path). + ce_loss_scale: Scalar multiplier on the CE term (P-KL path). dynamic_loss_scaling: If True, rescale KL each step so its detached - magnitude matches CE. + magnitude matches CE (P-KL path). """ projection_matrix_path: str @@ -1079,6 +1087,7 @@ class CrossTokenizerDistillationLossConfig(TypedDict): xtoken_loss: bool temperature: float vocab_topk: int + uncommon_topk: int reverse_kl: bool exact_token_match_only: bool project_teacher_to_student: bool @@ -1093,15 +1102,13 @@ class CrossTokenizerDistillationLossDataDict(TypedDict): input_lengths: torch.Tensor token_mask: torch.Tensor sample_mask: torch.Tensor - # Per-sample global top-k teacher logits (same vocab columns at every - # teacher position) so chunk-averaged KL has a stable vocab axis. - # Either teacher_topk_logits OR teacher_topk_logits_ipc must be present - # (the trainer chooses which transport by calling get_global_topk_logits - # vs get_global_topk_logits_ipc); indices always travel on the CPU/Ray - # path since they're tiny. - teacher_topk_logits: NotRequired[torch.Tensor] # [B, T_t, k] - teacher_topk_logits_ipc: NotRequired[list[dict[str, Any]]] # list[B] of handle dicts - teacher_topk_indices: torch.Tensor # [B, k] in teacher vocab + # Full-vocab teacher logits shipped via CUDA IPC. List[B] of dicts each + # carrying a single ``logits_ipc`` handle (rebuilt to ``[T_t, V_t]`` on + # the consumer side) plus shape/dtype. Produced by + # ``Policy.get_full_logits_ipc``. The loss fn either derives a + # microbatch-global top-k subset internally (P-KL path) or uses full + # vocab end-to-end (gold-loss path). + teacher_full_logits_ipc: list[dict[str, Any]] alignment_student_spans: torch.Tensor # [B, max_pairs, 2] alignment_teacher_spans: torch.Tensor # [B, max_pairs, 2] alignment_pair_valid: torch.Tensor # [B, max_pairs] @@ -1178,10 +1185,12 @@ class CrossTokenizerDistillationLossFn(LossFunction): input_type = LossInputType.LOGIT def __init__(self, cfg: CrossTokenizerDistillationLossConfig): - if cfg["gold_loss"] and cfg["xtoken_loss"]: + if cfg["xtoken_loss"] and not cfg["gold_loss"]: raise ValueError( - "gold_loss and xtoken_loss are mutually exclusive; set at " - "most one to True." + "xtoken_loss=True requires gold_loss=True; xtoken_loss is " + "a modifier inside the gold path (relaxes the exact-map " + "threshold and adds collision resolution) and is undefined " + "in the P-KL path." ) self.cfg = cfg self.projection_matrix_path = cfg["projection_matrix_path"] @@ -1189,6 +1198,17 @@ def __init__(self, cfg: CrossTokenizerDistillationLossConfig): # each worker process. Keyed by device because the worker may run on # multiple CUDA devices over its lifetime (rare but possible). self._M_per_device: dict[torch.device, torch.Tensor] = {} + # Lazy cache for the gold-path dense projection (indices, + # likelihoods) and the derived common/uncommon vocab partition. + # Both depend only on the immutable projection file and the static + # xtoken_loss flag, so we build once per device and reuse for the + # rest of the run. + self._dense_proj_per_device: dict[ + torch.device, tuple[torch.Tensor, torch.Tensor] + ] = {} + self._exact_map_cache: dict[ + torch.device, dict[str, torch.Tensor] + ] = {} # Optional per-microbatch loss dump for PT-vs-NRL parity comparison. # Activated by setting NRL_XTOKEN_LOSS_DUMP_DIR. Each rank appends a # record per call to {dir}/rank{R}.pt. Records are raw floats from @@ -1237,11 +1257,11 @@ def _load_M(self, device: torch.device) -> torch.Tensor: teacher_idx = teacher_idx[valid_mask] values = values[valid_mask] # Use the teacher's full vocab size as V_t โ€” not max(teacher_idx)+1. - # GlobalTopkLogitsPostProcessor picks top-k over the teacher's - # full vocab, including ids the projection doesn't cover. Sizing - # projected_full to the full teacher vocab makes those columns - # all-zero (correct semantics: unmapped teacher tokens get zero - # projected probability) and keeps the gather in bounds. + # The P-KL global top-k pick happens over the teacher's full vocab, + # including ids the projection doesn't cover. Sizing projected_full + # to the full teacher vocab makes those columns all-zero (correct + # semantics: unmapped teacher tokens get zero projected probability) + # and keeps the gather in bounds. projection_max_teacher = int(teacher_idx.max().item()) + 1 v_teacher = max(self.cfg["teacher_vocab_size"], projection_max_teacher) indices = torch.stack([student_idx, teacher_idx], dim=0) @@ -1275,6 +1295,232 @@ def _load_M(self, device: torch.device) -> torch.Tensor: self._M_per_device[device] = sparse return sparse + def _load_dense_projection( + self, device: torch.device + ) -> tuple[torch.Tensor, torch.Tensor]: + """Load the dense ``(indices, likelihoods)`` projection on ``device``. + + Returns the raw ``[V_s, top_k]`` arrays the gold path needs (vs + :meth:`_load_M` which builds a sparse COO over the projected + teacher vocab). Cached per device for the run. + + Only the dense file format is supported here โ€” the sparse + ``dict[(s, t)] -> count`` format used by some legacy projection + files doesn't carry the per-row top-k weights the gold-path + exact-map builder reads. + """ + if device in self._dense_proj_per_device: + return self._dense_proj_per_device[device] + + if not os.path.exists(self.projection_matrix_path): + raise FileNotFoundError( + f"Projection matrix file not found: {self.projection_matrix_path}" + ) + data = torch.load( + self.projection_matrix_path, map_location="cpu", weights_only=False + ) + if not ( + isinstance(data, dict) + and "indices" in data + and "likelihoods" in data + ): + raise ValueError( + f"gold_loss requires the dense projection-matrix format " + f"(dict with 'indices' and 'likelihoods' tensors). File " + f"{self.projection_matrix_path} uses an unsupported format." + ) + indices = data["indices"].long().to(device) + likelihoods = data["likelihoods"].float().to(device) + self._dense_proj_per_device[device] = (indices, likelihoods) + return indices, likelihoods + + def _build_exact_token_map( + self, device: torch.device + ) -> dict[str, torch.Tensor]: + """Build the common/uncommon vocab partition for the gold path. + + Ports PT ``compute_KL_loss_optimized`` lines 3493โ€“3594. Reads the + dense projection arrays, sorts each student row's projection + weights descending, then picks an exact-token map per the + ``xtoken_loss`` flag: + + - ``xtoken_loss=False`` (strict): ``has_exact_map = (sorted_values[:, 0] == 1.0) & (projection_indices[:, 1] == -1)``. + On collision (multiple students mapping to the same teacher id), + the earliest (lowest) student index wins โ€” matches PT's + first-come-first-served loop. + - ``xtoken_loss=True`` (relaxed): ``has_exact_map = sorted_values[:, 0] >= 0.6``. + On collision, the student with the highest first-projection + weight wins; ties are broken by lowest student index (matches + PT's ``prev_prob >= new_prob`` skip rule under iteration order). + + Both branches are vectorized via ``scatter_reduce`` so the build + is O(V_s) and happens once per device for the run. + + Returns: + Dict with keys ``common_student``, ``common_teacher`` (paired), + ``uncommon_student``, ``uncommon_teacher`` (each independently + sorted). All ``[long]`` tensors on ``device``. + """ + if device in self._exact_map_cache: + return self._exact_map_cache[device] + + indices, likelihoods = self._load_dense_projection(device) + v_student = indices.shape[0] + v_teacher = self.cfg["teacher_vocab_size"] + xtoken_loss = self.cfg["xtoken_loss"] + + sorted_values, sorted_in_topk = torch.sort( + likelihoods, dim=-1, descending=True + ) + if xtoken_loss: + has_exact_map = sorted_values[:, 0] >= 0.6 + else: + # Strict: exactly one top-k entry with weight 1.0, no second + # mapping. `indices[:, 1] == -1` is the sentinel used by the + # `_exact_map_remapped` projection files for "no second + # mapping" โ€” matches the PT check at tokenalign.py:3517. + has_exact_map = (sorted_values[:, 0] == 1.0) & ( + indices[:, 1] == -1 + ) + + # Gather (s_idx, t_idx, prob) for each exact-map candidate. + s_candidates = torch.where(has_exact_map)[0] + if s_candidates.numel() == 0: + empty = torch.empty(0, dtype=torch.long, device=device) + cache_entry = { + "common_student": empty, + "common_teacher": empty, + "uncommon_student": torch.arange(v_student, device=device), + "uncommon_teacher": torch.arange(v_teacher, device=device), + } + self._exact_map_cache[device] = cache_entry + return cache_entry + + t_candidates = indices[ + s_candidates, sorted_in_topk[s_candidates, 0] + ] + prob_candidates = sorted_values[s_candidates, 0] + + in_bounds = (t_candidates >= 0) & (t_candidates < v_teacher) + s_vec = s_candidates[in_bounds] + t_vec = t_candidates[in_bounds] + prob_vec = prob_candidates[in_bounds] + + # Strict mode: any candidate is eligible (first one wins). + # Relaxed mode: only candidates whose prob ties the per-teacher max. + if xtoken_loss: + max_prob_per_t = torch.full( + (v_teacher,), + float("-inf"), + device=device, + dtype=prob_vec.dtype, + ) + max_prob_per_t.scatter_reduce_( + 0, t_vec, prob_vec, reduce="amax", include_self=True + ) + eligible = prob_vec >= max_prob_per_t[t_vec] + else: + eligible = torch.ones_like(t_vec, dtype=torch.bool) + + # For each teacher id, pick the smallest student index among the + # eligible candidates. Sentinel = v_student so non-eligible rows + # lose the amin reduction. + sentinel = torch.tensor(v_student, dtype=s_vec.dtype, device=device) + eligible_s = torch.where(eligible, s_vec, sentinel.expand_as(s_vec)) + min_s_per_t = torch.full( + (v_teacher,), v_student, device=device, dtype=s_vec.dtype + ) + min_s_per_t.scatter_reduce_( + 0, t_vec, eligible_s, reduce="amin", include_self=True + ) + winner_mask = eligible & (s_vec == min_s_per_t[t_vec]) + + common_student = s_vec[winner_mask] + common_teacher = t_vec[winner_mask] + # Sort by student index so the paired arrays match. + sort_perm = torch.argsort(common_student) + common_student = common_student[sort_perm] + common_teacher = common_teacher[sort_perm] + + common_s_mask = torch.zeros(v_student, dtype=torch.bool, device=device) + common_s_mask[common_student] = True + common_t_mask = torch.zeros(v_teacher, dtype=torch.bool, device=device) + common_t_mask[common_teacher] = True + uncommon_student = (~common_s_mask).nonzero(as_tuple=True)[0] + uncommon_teacher = (~common_t_mask).nonzero(as_tuple=True)[0] + + cache_entry = { + "common_student": common_student, + "common_teacher": common_teacher, + "uncommon_student": uncommon_student, + "uncommon_teacher": uncommon_teacher, + } + self._exact_map_cache[device] = cache_entry + return cache_entry + + @staticmethod + def _rebuild_teacher_full_logits( + data: BatchedDataDict[CrossTokenizerDistillationLossDataDict], + ) -> torch.Tensor: + """Unpack ``teacher_full_logits_ipc`` to a stacked ``[B, T_t, V_t]`` CUDA tensor. + + The IPC handles point at views the teacher worker stashed in its + ``_teacher_ipc_buffer``; rebuilding does not allocate new memory + on the producer side. Casts to ``float32`` to match the loss math + (the producer also writes FP32 via :class:`FullLogitsPostProcessor`). + """ + from nemo_rl.models.policy.utils import rebuild_cuda_tensor_from_ipc + + handles = data["teacher_full_logits_ipc"] + consumer_device = torch.cuda.current_device() + per_sample = [ + rebuild_cuda_tensor_from_ipc(h["logits_ipc"], consumer_device) + for h in handles + ] + return torch.stack(per_sample, dim=0).float() + + @staticmethod + def _chunk_average_log_probs( + log_probs: torch.Tensor, + chunk_id: torch.Tensor, + max_chunks: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Average ``log_probs`` over the chunks defined by ``chunk_id``. + + Builds a one-hot chunk mask from ``chunk_id`` (``-1`` means "no + chunk", contributes to no bucket), then ``bmm``-aggregates and + divides by chunk sizes. Both inputs and outputs match PT's + chunk-averaging math at ``tokenalign.py:3617โ€“3637``. + + Args: + log_probs: ``[B, T, V]`` log-probabilities. + chunk_id: ``[B, T]`` long tensor, values in ``[-1, max_chunks)``. + max_chunks: number of chunk buckets. + + Returns: + chunk_log_probs: ``[B, max_chunks, V]`` averaged log-probs. + chunk_sizes: ``[B, max_chunks]`` float tensor of bucket sizes. + """ + eps = 1e-10 + device = log_probs.device + chunk_arange = torch.arange(max_chunks, device=device).view(1, 1, -1) + # [B, T, max_chunks] โ€” -1 entries compare false everywhere. + chunk_mask = chunk_id.unsqueeze(-1) == chunk_arange + chunk_mask_f = chunk_mask.transpose(1, 2).to(log_probs.dtype) + chunk_sums = torch.bmm(chunk_mask_f, log_probs) # [B, C, V] + chunk_sizes = chunk_mask.sum(dim=1).float() # [B, C] + chunk_log_probs = chunk_sums / (chunk_sizes.unsqueeze(-1) + eps) + return chunk_log_probs, chunk_sizes + + @staticmethod + def _valid_chunk_mask( + s_sizes: torch.Tensor, + t_sizes: torch.Tensor, + pair_valid: torch.Tensor, + ) -> torch.Tensor: + """Per-chunk validity gate: both sides non-empty and pair is valid.""" + return (s_sizes > 0) & (t_sizes > 0) & pair_valid + def __call__( self, data: BatchedDataDict[CrossTokenizerDistillationLossDataDict], @@ -1285,19 +1531,6 @@ def __call__( """Compute the cross-tokenizer distillation loss for one microbatch.""" cfg = self.cfg - if cfg["gold_loss"]: - raise NotImplementedError( - "gold_loss mode is not implemented in v0. The exact-match " - "partition CE + ULD math from the PT reference still needs " - "to be ported. Run with gold_loss=false in the meantime." - ) - if cfg["xtoken_loss"]: - raise NotImplementedError( - "xtoken_loss mode is not implemented in v0. The chunk-" - "aggregated multi-token KL from the PT reference still " - "needs to be ported. Run with xtoken_loss=false in the " - "meantime." - ) if cfg["project_teacher_to_student"]: raise NotImplementedError( "project_teacher_to_student=True is not implemented in v0. " @@ -1306,6 +1539,21 @@ def __call__( "smoke-test path." ) + if cfg["gold_loss"]: + loss, kl_common, l1_uncommon, num_valid_chunks, top1_acc = ( + self._compute_gold(logits, data) + ) + metrics = { + "loss": loss.item(), + "kl_common": kl_common.item(), + "l1_uncommon": l1_uncommon.item(), + "accuracy": top1_acc.item(), + "num_valid_samples": data["input_ids"].shape[0], + "num_valid_chunks": int(num_valid_chunks.item()), + } + self._maybe_dump_loss(metrics) + return loss, metrics + kl_loss, num_valid_pairs, proj_acc = self._compute_p_kl(logits, data) ce_loss = self._compute_ce(logits, data, global_valid_toks) @@ -1378,16 +1626,26 @@ def _maybe_dump_loss(self, metrics: dict[str, Any]) -> None: if torch.distributed.is_initialized() else 0 ) - self._loss_dump_records.append( - { - "call_idx": self._loss_dump_call_idx, - "loss": metrics["loss"], - "kl_loss": metrics["kl_loss"], - "ce_loss": metrics["ce_loss"], - "kl_loss_scale": metrics["kl_loss_scale"], - "num_valid_pairs": metrics["num_valid_pairs"], - } - ) + # The P-KL path emits kl_loss/ce_loss/kl_loss_scale/num_valid_pairs; + # the gold-loss path emits kl_common/l1_uncommon/num_valid_chunks. + # Record everything that's present so the same dump file format + # serves both โ€” downstream comparison scripts read by key. + record: dict[str, Any] = { + "call_idx": self._loss_dump_call_idx, + "loss": metrics["loss"], + } + for k in ( + "kl_loss", + "ce_loss", + "kl_loss_scale", + "num_valid_pairs", + "kl_common", + "l1_uncommon", + "num_valid_chunks", + ): + if k in metrics: + record[k] = metrics[k] + self._loss_dump_records.append(record) self._loss_dump_call_idx += 1 os.makedirs(self._loss_dump_dir, exist_ok=True) torch.save( @@ -1402,26 +1660,24 @@ def _compute_p_kl( self, logits: torch.Tensor, data: BatchedDataDict[CrossTokenizerDistillationLossDataDict], - ) -> tuple[torch.Tensor, torch.Tensor]: - """P-KL: chunk-averaged KL over the projected teacher-vocab subset. - - Mirrors the PT reference ``compute_KL_loss_optimized`` non-exact-match - branch: chunk-averages student-projected probs over each aligned - student span, chunk-averages teacher log-probs over the paired - teacher span, and KLs the resulting chunk distributions. - - Steps: - - 1. Compute student log-probs at ``T``, exponentiate to probs. - 2. Project full-vocab student probs through ``M`` to teacher vocab. - 3. Slice projection to the per-sample global top-k teacher columns - (carried in ``teacher_topk_indices [B, k]``). - 4. Build per-token chunk masks (one-hot from ``chunk_id``) for both - sides, then ``bmm`` to chunk-sum and divide by chunk size. - 5. Renormalize student chunk distributions inside the top-k subset + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """P-KL: chunk-averaged KL over a microbatch-global top-k teacher subset. + + Mirrors the PT non-gold forward-projection path at + ``tokenalign.py:3901โ€“4100``: + + 1. Project full-vocab student probs through ``M`` to teacher vocab. + 2. Rebuild full teacher logits from the IPC handles. + 3. Compute one ``global_top_indices [k]`` per microbatch from the + teacher's importance: ``max`` over flat ``(B*T_t)``, ``topk`` + over ``V_t``. Same vocab subset across every sample/position โ€” + keeps chunk-averaged KL well-defined. + 4. Slice both the projected student probs and the teacher logits + to those ``k`` columns. + 5. Build per-token chunk masks from ``alignment_*_chunk_id`` and + chunk-average via ``bmm`` (shared helper). + 6. Renormalize student chunk distributions inside the top-k subset (PT convention: avg-then-renormalize, log). - 6. Compute teacher chunk log-probs by chunk-averaging - ``log_softmax(teacher_topk_logits / T)`` directly (same as PT). 7. Forward (or reverse) KL between chunk distributions. """ cfg = self.cfg @@ -1438,7 +1694,7 @@ def _compute_p_kl( # `torch.sparse.mm` has no BF16 kernel and the worker's autocast(BF16) # context wraps loss.backward(), so a plain `.float()` cast isn't # enough โ€” the backward kernel is still dispatched as BF16. - M = self._load_M(device) # [V_s, V_t] sparse CSR, fp32 + M = self._load_M(device) # [V_s, V_t] sparse COO, fp32 flat = student_probs.reshape(b * t_s, v_s) # _Fp32SparseMM internally computes M.t() @ dense; passing M (not # M.t()) avoids a sparse `.t()` on a saved tensor in backward. @@ -1446,73 +1702,58 @@ def _compute_p_kl( v_t = projected_full.shape[-1] projected_full = projected_full.reshape(b, t_s, v_t) # [B, T_s, V_t] - # Per-sample slice to global top-k teacher columns. Teacher logits - # either arrive as a dense [B, T_t, k] tensor (CPU/Ray transport, - # for k=64) or as a list[B] of per-sample CUDA IPC handles (for - # k=8192 where the CPU round-trip would be ~6 GB/step). - teacher_topk_indices = data["teacher_topk_indices"] # [B, k] - if "teacher_topk_logits_ipc" in data: - handles = data["teacher_topk_logits_ipc"] # list[mbs] of dicts - assert len(handles) == teacher_topk_indices.shape[0], ( - f"IPC handle list length ({len(handles)}) must match " - f"teacher_topk_indices batch dim " - f"({teacher_topk_indices.shape[0]}). Sharding pairing has " - f"diverged โ€” investigate before trusting the loss." - ) - from nemo_rl.models.policy.utils import rebuild_cuda_tensor_from_ipc - consumer_device = torch.cuda.current_device() - vals_per_sample = [ - rebuild_cuda_tensor_from_ipc(h["logits_ipc"], consumer_device) - for h in handles - ] - teacher_topk_logits = torch.stack(vals_per_sample, dim=0).float() - else: - teacher_topk_logits = data["teacher_topk_logits"].float() # [B, T_t, k] - _, k = teacher_topk_indices.shape - t_t = teacher_topk_logits.shape[1] - idx_for_proj = teacher_topk_indices.unsqueeze(1).expand(-1, t_s, -1) - projected_topk = torch.gather( - projected_full, dim=-1, index=idx_for_proj - ) # [B, T_s, k] - - # Teacher target log-probs over the top-k subset. PT renormalizes - # softmax over only the kept columns; we follow the same convention. + # Rebuild full teacher logits from the IPC handles. Same transport + # as the gold path consumes; here we additionally compute a + # microbatch-global top-k inline to match PT. + teacher_full_logits = self._rebuild_teacher_full_logits(data) # [B, T_t, V_t_model] + # HF models commonly pad lm_head out_features beyond len(tokenizer) + # for embedding/FFN alignment (e.g. Qwen3: tokenizer 151669, + # lm_head 151936). The projection matrix is sized to the real + # tokenizer vocab (`cfg["teacher_vocab_size"]`); the padded + # columns aren't real tokens and the projection has no entries + # there. Slice to the projection's V_t to keep the projected + # student probs and the teacher logits on the same vocab axis. + if teacher_full_logits.shape[-1] > v_t: + teacher_full_logits = teacher_full_logits[..., :v_t] + + # PT global_top_indices: max over flat (B*T_t) โ†’ [V_t] โ†’ topk โ†’ [k]. + vocab_topk = min(cfg["vocab_topk"], v_t) + with torch.no_grad(): + teacher_flat = teacher_full_logits.view(-1, v_t) + global_importance = teacher_flat.max(dim=0).values + global_top_indices = torch.topk( + global_importance, k=vocab_topk, dim=-1 + ).indices + global_top_indices = global_top_indices.sort().values # [k] + + # Slice both sides to the shared [k] columns. + projected_topk = projected_full[..., global_top_indices] # [B, T_s, k] + teacher_topk_logits = teacher_full_logits[..., global_top_indices] # [B, T_t, k] target_log_probs = torch.log_softmax( teacher_topk_logits / T, dim=-1 - ) # [B, T_t, k] + ) # [B, T_t, k] (renormalized within the [k] subset, matching PT). - # Build chunk masks via one-hot from the chunk_id tensors. -1 - # entries (no chunk) compare false everywhere and stay out. + # Chunk-average both sides via the shared helper. student_chunk_id = data["alignment_student_chunk_id"] # [B, T_s] long teacher_chunk_id = data["alignment_teacher_chunk_id"] # [B, T_t] long pair_valid = data["alignment_pair_valid"] # [B, max_pairs] if cfg["exact_token_match_only"]: pair_valid = pair_valid & data["alignment_pair_is_correct"] max_chunks = pair_valid.shape[1] - chunk_arange = torch.arange(max_chunks, device=device).view(1, 1, -1) - proj_mask = student_chunk_id.unsqueeze(-1) == chunk_arange # [B, T_s, C] - tgt_mask = teacher_chunk_id.unsqueeze(-1) == chunk_arange # [B, T_t, C] - - # Chunk-aggregate via bmm: sum over positions in each chunk. - proj_mask_f = proj_mask.transpose(1, 2).to(projected_topk.dtype) - tgt_mask_f = tgt_mask.transpose(1, 2).to(target_log_probs.dtype) - proj_chunks = torch.bmm(proj_mask_f, projected_topk) # [B, C, k] - tgt_log_chunks = torch.bmm(tgt_mask_f, target_log_probs) # [B, C, k] - - proj_sizes = proj_mask.sum(dim=1).float() # [B, C] - tgt_sizes = tgt_mask.sum(dim=1).float() # [B, C] - proj_chunks = proj_chunks / (proj_sizes.unsqueeze(-1) + eps) - tgt_log_chunks = tgt_log_chunks / (tgt_sizes.unsqueeze(-1) + eps) - - # PT: renormalize projected chunk distribution within the top-k + proj_chunks, proj_sizes = self._chunk_average_log_probs( + projected_topk, student_chunk_id, max_chunks + ) # [B, C, k] / [B, C] + tgt_log_chunks, tgt_sizes = self._chunk_average_log_probs( + target_log_probs, teacher_chunk_id, max_chunks + ) # [B, C, k] / [B, C] + + # PT: renormalize the projected chunk distribution within the top-k # subset, then take log. Teacher side is already log-probs (avg of - # log_softmaxes is what PT computes; not a true log of mean). + # log_softmaxes; not a true log of mean โ€” matches PT). proj_chunks = proj_chunks / (proj_chunks.sum(dim=-1, keepdim=True) + eps) proj_log_chunks = (proj_chunks + eps).log() - chunk_mask = ( - (proj_sizes > 0) & (tgt_sizes > 0) & pair_valid - ) # [B, C] + chunk_mask = self._valid_chunk_mask(proj_sizes, tgt_sizes, pair_valid) if not chunk_mask.any(): zero = torch.zeros((), device=device, dtype=proj_log_chunks.dtype) return ( @@ -1523,8 +1764,7 @@ def _compute_p_kl( # Projection top-1 accuracy: per-chunk argmax of the student-side # projected distribution vs the teacher's argmax over the same - # top-k subset. Mirrors PT reference at - # tokenalign.py:4097-4104 โ€” gives a KD-specific accuracy signal. + # top-k subset. Mirrors PT reference at tokenalign.py:4097โ€“4104. with torch.no_grad(): proj_top1 = proj_chunks.argmax(dim=-1) # [B, C] tgt_top1 = torch.exp(tgt_log_chunks).argmax(dim=-1) # [B, C] @@ -1551,6 +1791,175 @@ def _compute_p_kl( kl_loss = (per_chunk_kl * valid).sum() / denom * (T * T) return kl_loss, valid.sum().detach(), proj_acc.detach() + def _compute_gold( + self, + logits: torch.Tensor, + data: BatchedDataDict[CrossTokenizerDistillationLossDataDict], + ) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor + ]: + """Gold-loss path: KL on common (exact-mapped) vocab + L1 on uncommon. + + Ports PT ``compute_KL_loss_optimized`` lines 3494โ€“3829. + + 1. Lazy-build the exact-token map (cached per device). + 2. Rebuild full teacher logits from the IPC handles. + 3. ``log_softmax`` on full vocab both sides; chunk-average via the + shared helper. + 4. Slice each chunk-averaged tensor to ``common_*`` indices and + compute (forward or reverse) KL, reduced as + ``sum / chunk_mask.sum()``. + 5. Slice to ``uncommon_*`` indices, ``.exp()`` to probs, sort/topk + descending (capped at ``cfg['uncommon_topk']``), truncate to + ``min(student_len, teacher_len)``, L1 with ``reduction="none"`` + summed over vocab and meaned across valid chunks. + 6. Combine: ``loss = (kl_common + l1_uncommon) * T**2``. + 7. Top-1 accuracy on the common slice over valid chunks. + + Returns ``(loss, kl_common, l1_uncommon, num_valid_chunks, top1_acc)``. + Components other than ``loss`` are detached. + """ + cfg = self.cfg + T = cfg["temperature"] + device = logits.device + + exact_map = self._build_exact_token_map(device) + common_s = exact_map["common_student"] + common_t = exact_map["common_teacher"] + uncommon_s = exact_map["uncommon_student"] + uncommon_t = exact_map["uncommon_teacher"] + v_teacher = self.cfg["teacher_vocab_size"] + + teacher_full_logits = self._rebuild_teacher_full_logits(data) # [B, T_t, V_t_model] + # Drop any padded lm_head vocab beyond the real tokenizer vocab โ€” + # the exact-token map's t-axis is bounded by `teacher_vocab_size`, + # so chunked teacher log-probs must use the same axis. See the + # matching note in `_compute_p_kl` for why the model vocab can + # exceed `len(tokenizer)`. + if teacher_full_logits.shape[-1] > v_teacher: + teacher_full_logits = teacher_full_logits[..., :v_teacher] + + student_log_probs = torch.log_softmax(logits.float() / T, dim=-1) # [B, T_s, V_s] + teacher_log_probs = torch.log_softmax(teacher_full_logits / T, dim=-1) # [B, T_t, V_t] + + student_chunk_id = data["alignment_student_chunk_id"] + teacher_chunk_id = data["alignment_teacher_chunk_id"] + pair_valid = data["alignment_pair_valid"] + max_chunks = pair_valid.shape[1] + student_chunks, s_sizes = self._chunk_average_log_probs( + student_log_probs, student_chunk_id, max_chunks + ) # [B, C, V_s] / [B, C] + teacher_chunks, t_sizes = self._chunk_average_log_probs( + teacher_log_probs, teacher_chunk_id, max_chunks + ) # [B, C, V_t] / [B, C] + + chunk_mask = self._valid_chunk_mask(s_sizes, t_sizes, pair_valid) + zero_dtype = student_log_probs.dtype + if not chunk_mask.any(): + zero = torch.zeros((), device=device, dtype=zero_dtype) + return ( + zero, + zero.detach(), + zero.detach(), + torch.zeros((), device=device, dtype=torch.long), + zero.detach(), + ) + + # ---------------------- KL on common ---------------------- + if common_s.numel() > 0: + student_common = student_chunks[:, :, common_s] # [B, C, N_common] + teacher_common = teacher_chunks[:, :, common_t] # [B, C, N_common] + if cfg["reverse_kl"]: + kl_per_elem = torch.nn.functional.kl_div( + teacher_common, student_common, + reduction="none", log_target=True, + ) + else: + kl_per_elem = torch.nn.functional.kl_div( + student_common, teacher_common, + reduction="none", log_target=True, + ) + kl_per_chunk = kl_per_elem.sum(dim=-1) * chunk_mask # [B, C] + kl_common = kl_per_chunk.sum() / chunk_mask.sum().float().clamp( + min=1.0 + ) + else: + kl_common = torch.zeros( + (), device=device, dtype=zero_dtype, requires_grad=True + ) + student_common = None + teacher_common = None + + # -------------------- L1 on uncommon ---------------------- + uncommon_topk = cfg["uncommon_topk"] + if uncommon_s.numel() > 0 or uncommon_t.numel() > 0: + student_unc = student_chunks[:, :, uncommon_s][chunk_mask] # [N_valid, N_u_s] + teacher_unc = teacher_chunks[:, :, uncommon_t][chunk_mask] # [N_valid, N_u_t] + n_valid = student_unc.shape[0] + max_uncommon = min( + student_unc.shape[-1], + teacher_unc.shape[-1], + uncommon_topk, + ) + if n_valid > 0 and max_uncommon > 0: + student_unc_probs = student_unc.exp() + teacher_unc_probs = teacher_unc.exp() + if student_unc_probs.shape[-1] > max_uncommon: + student_sorted = torch.topk( + student_unc_probs, k=max_uncommon, dim=-1, largest=True + ).values + else: + student_sorted = student_unc_probs.sort( + dim=-1, descending=True + ).values + if teacher_unc_probs.shape[-1] > max_uncommon: + teacher_sorted = torch.topk( + teacher_unc_probs, k=max_uncommon, dim=-1, largest=True + ).values + else: + teacher_sorted = teacher_unc_probs.sort( + dim=-1, descending=True + ).values + min_len = min( + student_sorted.shape[-1], teacher_sorted.shape[-1] + ) + student_sorted = student_sorted[:, :min_len] + teacher_sorted = teacher_sorted[:, :min_len] + l1_per_chunk = torch.nn.functional.l1_loss( + student_sorted, teacher_sorted, reduction="none" + ).sum(dim=-1) + l1_uncommon = l1_per_chunk.mean() + else: + l1_uncommon = torch.zeros( + (), device=device, dtype=zero_dtype, requires_grad=True + ) + else: + l1_uncommon = torch.zeros( + (), device=device, dtype=zero_dtype, requires_grad=True + ) + + # -------------------- Top-1 accuracy ---------------------- + with torch.no_grad(): + if student_common is not None: + s_common_valid = student_common[chunk_mask] + t_common_valid = teacher_common[chunk_mask] + matches = ( + s_common_valid.argmax(dim=-1) + == t_common_valid.argmax(dim=-1) + ).sum().float() + top1_acc = matches / chunk_mask.sum().float().clamp(min=1.0) + else: + top1_acc = torch.zeros((), device=device, dtype=zero_dtype) + + loss = (kl_common + l1_uncommon) * (T * T) + return ( + loss, + kl_common.detach(), + l1_uncommon.detach(), + chunk_mask.sum().detach(), + top1_acc.detach(), + ) + def _compute_ce( self, logits: torch.Tensor, diff --git a/nemo_rl/algorithms/xtoken_distillation.py b/nemo_rl/algorithms/xtoken_distillation.py index 6bb7eb6de5..022a832cb2 100644 --- a/nemo_rl/algorithms/xtoken_distillation.py +++ b/nemo_rl/algorithms/xtoken_distillation.py @@ -63,17 +63,17 @@ # fn can index them per-microbatch, but the worker's `check_sequence_dim` # pre-flight (which assumes [B, student_seq, ...] for every 2+D tensor) must # skip them. Sources: -# - teacher_topk_*: produced by GlobalTopkLogitsPostProcessor in -# dtensor_policy_worker_v2.get_global_topk_logits. +# - teacher_full_logits_ipc: list[B] of CUDA IPC handle dicts produced by +# FullLogitsPostProcessor in dtensor_policy_worker_v2.get_full_logits_ipc. +# Not a tensor at all โ€” list of dicts โ€” but listed here so the worker's +# dict-level dim check skips it. # - teacher_input_ids/teacher_token_mask + alignment_*: produced by # CrossTokenizerCollator (in DataLoader workers). # alignment_student_chunk_id and alignment_student_exact_partition_mask are # [B, T_s] and DO follow the student-seq invariant, so they are NOT listed. XTOKEN_NON_STUDENT_SEQ_KEYS: frozenset[str] = frozenset( { - "teacher_topk_logits", - "teacher_topk_logits_ipc", - "teacher_topk_indices", + "teacher_full_logits_ipc", "teacher_input_ids", "teacher_token_mask", "alignment_student_spans", @@ -97,8 +97,6 @@ class XTokenDistillationConfig(TypedDict): num_prompts_per_step: Global batch size at the dataloader level. max_num_steps: Max training steps before early stop. max_num_epochs: Max passes over the training dataset. - topk_logits_k: ``k`` passed to ``teacher_policy.get_topk_logits``. - Should equal ``loss_fn.vocab_topk``. seed: RNG seed. val_period: Validation cadence in steps. ``0`` disables validation. val_at_start: Run validation before training begins. @@ -108,7 +106,6 @@ class XTokenDistillationConfig(TypedDict): num_prompts_per_step: int max_num_steps: int max_num_epochs: int - topk_logits_k: int seed: int val_period: int val_at_start: bool @@ -176,14 +173,6 @@ def setup( logger_config = master_config["logger"] cluster_config = master_config["cluster"] - # Parity check that catches misconfigured topk values early. - assert loss_config["vocab_topk"] == distill_config["topk_logits_k"], ( - f"loss_fn.vocab_topk ({loss_config['vocab_topk']}) must equal " - f"distillation.topk_logits_k ({distill_config['topk_logits_k']}) โ€” " - f"the teacher returns top-k in teacher vocab and the loss fn must " - f"use the same k." - ) - # Backend gate: this code path is DTensor V2 only by design. assert policy_config["dtensor_cfg"]["enabled"] and policy_config["dtensor_cfg"].get( "_v2", False @@ -321,8 +310,9 @@ def setup( # Loss # ========================== # Inject the teacher's full vocab size so the projection matrix's V_t - # axis covers every teacher id GlobalTopkLogitsPostProcessor can pick. - # `len(tokenizer)` is what HF treats as the embedding/lm_head dim. + # axis covers every teacher id the loss fn's exact-token map / P-KL + # global top-k can pick. `len(tokenizer)` is what HF treats as the + # embedding/lm_head dim. loss_config = {**loss_config, "teacher_vocab_size": len(teacher_tokenizer)} loss_fn = CrossTokenizerDistillationLossFn(loss_config) @@ -378,7 +368,6 @@ def xtoken_distillation_train( val_at_end = distill_cfg["val_at_end"] max_epochs = distill_cfg["max_num_epochs"] max_steps = distill_cfg["max_num_steps"] - topk_logits_k = distill_cfg["topk_logits_k"] if val_at_start and total_steps == 0 and val_dataloader is not None: val_metrics, val_timings = validate( @@ -416,29 +405,27 @@ def xtoken_distillation_train( token_mask=batch["teacher_token_mask"], sample_mask=batch["sample_mask"], ) - # IPC transport: per-sample [T_t, k] logit views are - # exported as CUDA IPC handles and consumed by the - # student in-process. Indices are tiny and go through - # the standard CPU/Ray path. - teacher_handles, teacher_indices = ( - teacher_policy.get_global_topk_logits_ipc( - teacher_data, k=topk_logits_k, timer=timer - ) + # IPC transport: per-sample [T_t, V_t] full-vocab logit + # views are exported as CUDA IPC handles and consumed + # by the student in-process. The loss fn either uses + # full vocab (gold path) or derives a microbatch-global + # top-k from this inline (P-KL path). + teacher_handles = teacher_policy.get_full_logits_ipc( + teacher_data, timer=timer ) # Model offload frees the teacher's PARAMS to CPU; the # IPC-stashed logit tensors live in worker Python state # and survive this call. teacher_policy.offload_after_refit() - # Pack student-side training data with teacher topk and the - # alignment payload the loss fn will index into. + # Pack student-side training data with teacher logits and + # the alignment payload the loss fn will index into. train_data: BatchedDataDict[Any] = BatchedDataDict( input_ids=batch["input_ids"], input_lengths=batch["input_lengths"], token_mask=batch["token_mask"], sample_mask=batch["sample_mask"], - teacher_topk_logits_ipc=teacher_handles, - teacher_topk_indices=teacher_indices, + teacher_full_logits_ipc=teacher_handles, alignment_student_spans=batch["alignment_student_spans"], alignment_teacher_spans=batch["alignment_teacher_spans"], alignment_pair_valid=batch["alignment_pair_valid"], @@ -454,7 +441,7 @@ def xtoken_distillation_train( alignment_num_chunks=batch["alignment_num_chunks"], ) # `.to("cpu")` is a no-op on the IPC handle list (lists are - # not tensors) and on the CPU indices tensor. + # not tensors). train_data.to("cpu") with timer.time("training_prep"): @@ -502,7 +489,10 @@ def xtoken_distillation_train( metrics: dict[str, Any] = {} metrics.update(train_results["all_mb_metrics"]) - # Reduce per-microbatch metrics to per-step scalars. + # Reduce per-microbatch metrics to per-step scalars. The + # P-KL path emits kl_loss/ce_loss/kl_loss_scale/proj_accuracy; + # the gold-loss path emits kl_common/l1_uncommon. Either set + # may be present โ€” reduce both via the same rules. for k, v in metrics.items(): if k in { "lr", @@ -512,6 +502,8 @@ def xtoken_distillation_train( "accuracy", "proj_accuracy", "kl_loss_scale", + "kl_common", + "l1_uncommon", }: metrics[k] = float(np.mean(v)) else: @@ -581,32 +573,47 @@ def xtoken_distillation_train( timing_metrics: dict[str, float] = timer.get_timing_metrics( reduction_op="sum" ) # type: ignore - # `metrics["loss"]/kl_loss/ce_loss` are SUM across all DP ranks - # AND microbatches (= dp_size * local_mbs values summed). PT - # logs rank-0 per-microbatch raw, so to compare apples-to-apples - # to PT, also print per-MB-mean. n_mb = len of the flat list of - # per-MB metrics across all ranks. + # `metrics["loss"]` and the SUM-reduced terms (kl_loss, ce_loss + # for the P-KL path) are SUM across all DP ranks AND microbatches + # (= dp_size * local_mbs values summed). PT logs rank-0 + # per-microbatch raw, so for apples-to-apples we also print + # per-MB-mean. n_mb = len of the flat list of per-MB metrics. n_mb = max(len(train_results["all_mb_metrics"].get("loss", [])), 1) print( f" โ€ข Loss: {metrics['loss']:.4f} " f"(per-MB-mean: {metrics['loss'] / n_mb:.4f})", flush=True, ) - kl_sum = float(metrics.get("kl_loss", 0.0)) - ce_sum = float(metrics.get("ce_loss", 0.0)) - print( - f" โ€ข KL: {kl_sum:.4f} " - f"(per-MB-mean: {kl_sum / n_mb:.4f})", - flush=True, - ) - print( - f" โ€ข CE: {ce_sum:.4f} " - f"(per-MB-mean: {ce_sum / n_mb:.4f})", - flush=True, - ) - # Next-token accuracy is already a mean-across-MB-ranks (we - # put it in the np.mean branch above), directly comparable to - # PT's `Acc:` log column. + # P-KL path metrics โ€” only printed when they're present. + if "kl_loss" in metrics: + kl_sum = float(metrics["kl_loss"]) + print( + f" โ€ข KL: {kl_sum:.4f} " + f"(per-MB-mean: {kl_sum / n_mb:.4f})", + flush=True, + ) + if "ce_loss" in metrics: + ce_sum = float(metrics["ce_loss"]) + print( + f" โ€ข CE: {ce_sum:.4f} " + f"(per-MB-mean: {ce_sum / n_mb:.4f})", + flush=True, + ) + # Gold-loss path metrics โ€” kl_common/l1_uncommon are already + # per-MB means (np.mean branch above), so no /n_mb division. + if "kl_common" in metrics: + print( + f" โ€ข KL(common): {metrics['kl_common']:.4f}", + flush=True, + ) + if "l1_uncommon" in metrics: + print( + f" โ€ข L1(uncommon): {metrics['l1_uncommon']:.4f}", + flush=True, + ) + # Accuracy: P-KL emits next-token student accuracy + projection + # top-1; gold emits top-1 common-vocab accuracy. Both arrive + # under "accuracy" so the same line works. if "accuracy" in metrics: print( f" โ€ข Acc: {metrics['accuracy'] * 100:.2f}%", @@ -661,12 +668,16 @@ def validate( backward / optimizer step runs. Returns mean train-style metrics. """ distill_cfg = master_config["distillation"] - topk_logits_k = distill_cfg["topk_logits_k"] timer = timer if timer is not None else Timer() losses: list[float] = [] + # The P-KL path emits kl_loss/ce_loss; the gold path emits + # kl_common/l1_uncommon. Track both, only the ones the active loss + # populates will end up in the returned metrics. kl_losses: list[float] = [] ce_losses: list[float] = [] + kl_common_losses: list[float] = [] + l1_uncommon_losses: list[float] = [] with timer.time("validation_total"): teacher_policy.prepare_for_lp_inference() @@ -677,19 +688,14 @@ def validate( token_mask=batch["teacher_token_mask"], sample_mask=batch["sample_mask"], ) - teacher_handles, teacher_indices = ( - teacher_policy.get_global_topk_logits_ipc( - teacher_data, k=topk_logits_k - ) - ) + teacher_handles = teacher_policy.get_full_logits_ipc(teacher_data) train_data: BatchedDataDict[Any] = BatchedDataDict( input_ids=batch["input_ids"], input_lengths=batch["input_lengths"], token_mask=batch["token_mask"], sample_mask=batch["sample_mask"], - teacher_topk_logits_ipc=teacher_handles, - teacher_topk_indices=teacher_indices, + teacher_full_logits_ipc=teacher_handles, alignment_student_spans=batch["alignment_student_spans"], alignment_teacher_spans=batch["alignment_teacher_spans"], alignment_pair_valid=batch["alignment_pair_valid"], @@ -721,6 +727,12 @@ def validate( kl_losses.append(float(np.mean(mb_metrics["kl_loss"]))) if "ce_loss" in mb_metrics: ce_losses.append(float(np.mean(mb_metrics["ce_loss"]))) + if "kl_common" in mb_metrics: + kl_common_losses.append(float(np.mean(mb_metrics["kl_common"]))) + if "l1_uncommon" in mb_metrics: + l1_uncommon_losses.append( + float(np.mean(mb_metrics["l1_uncommon"])) + ) teacher_policy.offload_after_refit() metrics: dict[str, Any] = { @@ -730,5 +742,9 @@ def validate( metrics["kl_loss"] = float(np.mean(kl_losses)) if ce_losses: metrics["ce_loss"] = float(np.mean(ce_losses)) + if kl_common_losses: + metrics["kl_common"] = float(np.mean(kl_common_losses)) + if l1_uncommon_losses: + metrics["l1_uncommon"] = float(np.mean(l1_uncommon_losses)) return metrics, timer.get_timing_metrics(reduction_op="sum") # type: ignore diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index 5aedd9b8bd..23ec86cae6 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -58,7 +58,7 @@ "LossPostProcessor", "LogprobsPostProcessor", "TopkLogitsPostProcessor", - "GlobalTopkLogitsPostProcessor", + "FullLogitsPostProcessor", "ScorePostProcessor", ] @@ -328,7 +328,7 @@ def forward_with_post_processing_fn( LossPostProcessor, LogprobsPostProcessor, TopkLogitsPostProcessor, - GlobalTopkLogitsPostProcessor, + FullLogitsPostProcessor, ), ): # Temperature scaling is element-wise, directly applying it here. @@ -348,7 +348,7 @@ def forward_with_post_processing_fn( ) elif isinstance( post_processing_fn, - (LogprobsPostProcessor, TopkLogitsPostProcessor, GlobalTopkLogitsPostProcessor), + (LogprobsPostProcessor, TopkLogitsPostProcessor), ): result = post_processing_fn( logits=logits, @@ -363,6 +363,16 @@ def forward_with_post_processing_fn( else: vals, idx = result metrics = {"topk_logits": vals, "topk_indices": idx} + elif isinstance(post_processing_fn, FullLogitsPostProcessor): + result = post_processing_fn( + logits=logits, + data_dict=data_dict, + processed_inputs=processed_inputs, + original_batch_size=processed_mb.original_batch_size, + original_seq_len=processed_mb.original_seq_len, + sequence_dim=sequence_dim, + ) + metrics = {"full_logits": result} elif isinstance(post_processing_fn, ScorePostProcessor): result = post_processing_fn(logits=logits) metrics = {"scores": result} @@ -936,23 +946,23 @@ def __call__( return vals, idx -class GlobalTopkLogitsPostProcessor: - """Post-processor for global top-k logits (single vocab subset per sample). +class FullLogitsPostProcessor: + """Post-processor that returns the full teacher vocab logits unchanged. - Used by cross-tokenizer distillation: the teacher's vocab axis is reduced - to one ``vocab_topk`` subset *per sample* (max over positions, then top-k - over vocab) so the same vocab columns are kept across every position. - This makes chunk-averaged KL well-defined when student and teacher chunks - span multiple positions โ€” every position carries the same vocab axis. + Used by cross-tokenizer distillation: the loss fn needs the entire + ``[B, S, V_t]`` teacher logits tensor โ€” no vocab reduction is done at + the worker. The loss fn either (a) derives a microbatch-global top-k + subset internally (``gold_loss=False`` path, matching PT + ``global_top_indices`` math) or (b) operates on full vocab directly + (``gold_loss=True`` path, matching PT gold). Doing the reduction in + the loss fn (not here) keeps transport faithful to the PT reference. Output: - vals: ``[B, S, k]`` teacher logits at the selected ``k`` columns. - idx: ``[B, k]`` teacher vocab ids selected per sample. + logits: ``[B, S, V_t]`` raw teacher logits cast to ``float32``. - v0 limitation: only the no-TP, no-CP, no-seq-packing path is implemented. - The post-processor asserts on the unsupported configurations because - distributed global top-k requires TP-aware reduction that isn't on the - smoke-test path. + v0 limitation: only the no-TP, no-CP, no-seq-packing path is + implemented. Asserts on the unsupported configurations โ€” distributed + full-vocab gather requires TP-aware reduction not on the smoke path. """ def __init__( @@ -962,7 +972,6 @@ def __init__( cp_mesh: Any, tp_mesh: Any, cp_size: int, - k: int, enable_seq_packing: bool = False, ): self.cfg = cfg @@ -970,7 +979,6 @@ def __init__( self.cp_mesh = cp_mesh self.tp_mesh = tp_mesh self.cp_size = cp_size - self.k = k self.enable_seq_packing = enable_seq_packing def __call__( @@ -981,15 +989,15 @@ def __call__( original_batch_size: int, original_seq_len: int, sequence_dim: int = 1, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: if self.cp_size > 1: raise NotImplementedError( - "GlobalTopkLogitsPostProcessor: context_parallel_size > 1 is " + "FullLogitsPostProcessor: context_parallel_size > 1 is " "not supported in v0." ) if self.enable_seq_packing: raise NotImplementedError( - "GlobalTopkLogitsPostProcessor: sequence packing is not " + "FullLogitsPostProcessor: sequence packing is not " "supported in v0." ) if isinstance(logits, DTensor): @@ -1001,21 +1009,12 @@ def __call__( ) if tp_size > 1: raise NotImplementedError( - "GlobalTopkLogitsPostProcessor: tensor_parallel_size > 1 " + "FullLogitsPostProcessor: tensor_parallel_size > 1 " "is not supported in v0." ) logits = logits.to_local() - full_logits = logits.to(torch.float32) # [B, S, V] - # Per-sample max over positions, then top-k over vocab. - per_vocab_max = full_logits.max(dim=1).values # [B, V] - _, idx = torch.topk(per_vocab_max, k=self.k, dim=-1) # [B, k] - # Gather the selected k columns at every position for each sample. - # full_logits: [B, S, V], idx: [B, k] -> vals: [B, S, k]. - b, s, _ = full_logits.shape - idx_expanded = idx.unsqueeze(1).expand(-1, s, -1) # [B, S, k] - vals = torch.gather(full_logits, dim=-1, index=idx_expanded) - return vals, idx + return logits.to(torch.float32) # [B, S, V_t] class ScorePostProcessor: diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 9f802d227a..78ac0d8376 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -591,89 +591,35 @@ def get_topk_logits( return stacked - def get_global_topk_logits( + def get_full_logits_ipc( self, data: BatchedDataDict[GenerationDatumSpec], - k: int, micro_batch_size: Optional[int] = None, timer: Optional[Timer] = None, - ) -> BatchedDataDict[Any]: - """Dispatch get_global_topk_logits to workers. + ) -> list[dict[str, Any]]: + """Ship the teacher's full-vocab logits to the student via CUDA IPC. - Returns a BatchedDataDict with: - - ``topk_logits``: ``[B, S, k]`` - - ``topk_indices``: ``[B, k]`` per-sample global vocab subset. - - Used by cross-tokenizer distillation; intentionally doesn't support - dynamic batching or sequence packing in v0. - """ - if self.use_dynamic_batches or self.use_sequence_packing: - raise NotImplementedError( - "get_global_topk_logits does not support dynamic batching or " - "sequence packing in v0." - ) - dp_size = self.sharding_annotations.get_axis_size("data_parallel") - with timer.time("get_global_topk_logits/shard_data") if timer else nullcontext(): - sharded_data = data.shard_by_batch_size( # type: ignore - dp_size, - batch_size=None, - ) - with ( - timer.time("get_global_topk_logits/submit") - if timer - else nullcontext() - ): - futures = self.worker_group.run_all_workers_sharded_data( - "get_global_topk_logits", - data=sharded_data, - in_sharded_axes=["data_parallel"], - replicate_on_axes=[ - "context_parallel", - "tensor_parallel", - "pipeline_parallel", - ], - output_is_replicated=[ - "context_parallel", - "tensor_parallel", - "pipeline_parallel", - ], - common_kwargs={"k": k, "micro_batch_size": micro_batch_size}, - ) - worker_batches = self.worker_group.get_all_worker_results(futures) - all_vals = [wb["topk_logits"] for wb in worker_batches] - all_idx = [wb["topk_indices"] for wb in worker_batches] - stacked: BatchedDataDict[Any] = BatchedDataDict() - stacked["topk_logits"] = torch.cat(all_vals, dim=0) - stacked["topk_indices"] = torch.cat(all_idx, dim=0) - return stacked - - def get_global_topk_logits_ipc( - self, - data: BatchedDataDict[GenerationDatumSpec], - k: int, - micro_batch_size: Optional[int] = None, - timer: Optional[Timer] = None, - ) -> tuple[list[dict[str, Any]], torch.Tensor]: - """Like get_global_topk_logits but ships the logits via CUDA IPC. - - Logits are returned as a flat ``list[B]`` of per-sample IPC handle - dicts. Each entry's ``logits_ipc`` reconstructs a ``[T_t, k]`` - CUDA tensor on the consumer device. Top-k *indices* are tiny (a - few MB at k=8192) so they're shipped via the standard CPU/Ray - path as a single ``[B, k]`` tensor. + Used by cross-tokenizer distillation. Returns a flat ``list[B]`` of + per-sample IPC handle dicts; each entry's ``logits_ipc`` + reconstructs a ``[T_t, V_t]`` CUDA tensor on the consumer device. + The loss fn then either (a) derives a microbatch-global top-k + subset internally to match the PT non-gold path, or (b) uses the + full vocab end-to-end to match the PT gold path. Caller must invoke :meth:`release_ipc_buffer` after the student finishes consuming the handles โ€” otherwise the producer-side CUDA tensors leak. + + v0 limitation: no dynamic batching, no sequence packing. """ if self.use_dynamic_batches or self.use_sequence_packing: raise NotImplementedError( - "get_global_topk_logits_ipc does not support dynamic batching " + "get_full_logits_ipc does not support dynamic batching " "or sequence packing in v0." ) dp_size = self.sharding_annotations.get_axis_size("data_parallel") with ( - timer.time("get_global_topk_logits_ipc/shard_data") + timer.time("get_full_logits_ipc/shard_data") if timer else nullcontext() ): @@ -682,12 +628,12 @@ def get_global_topk_logits_ipc( batch_size=None, ) with ( - timer.time("get_global_topk_logits_ipc/submit") + timer.time("get_full_logits_ipc/submit") if timer else nullcontext() ): futures = self.worker_group.run_all_workers_sharded_data( - "get_global_topk_logits_ipc", + "get_full_logits_ipc", data=sharded_data, in_sharded_axes=["data_parallel"], replicate_on_axes=[ @@ -700,21 +646,13 @@ def get_global_topk_logits_ipc( "tensor_parallel", "pipeline_parallel", ], - common_kwargs={"k": k, "micro_batch_size": micro_batch_size}, + common_kwargs={"micro_batch_size": micro_batch_size}, ) worker_results = self.worker_group.get_all_worker_results(futures) all_handles: list[dict[str, Any]] = [] - all_idx: list[torch.Tensor] = [] for wr in worker_results: all_handles.extend(wr["per_sample_handles"]) - all_idx.append(wr["topk_indices"]) - indices = torch.cat(all_idx, dim=0) - assert len(all_handles) == indices.shape[0], ( - f"IPC handle list length ({len(all_handles)}) must match the " - f"top-k indices batch dim ({indices.shape[0]}). One of the two " - f"transports lost a sample โ€” sharding pairing is broken." - ) - return all_handles, indices + return all_handles def release_ipc_buffer(self) -> None: """Tell all workers to drop their stashed IPC tensors.""" diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 02ada237ce..ad8dc3762a 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -51,10 +51,10 @@ validate_and_prepare_config, ) from nemo_rl.models.automodel.train import ( + FullLogitsPostProcessor, LogprobsPostProcessor, LossPostProcessor, ScorePostProcessor, - GlobalTopkLogitsPostProcessor, TopkLogitsPostProcessor, aggregate_training_statistics, automodel_forward_backward, @@ -789,134 +789,35 @@ def get_topk_logits( ).cpu() return ret - def get_global_topk_logits( + def get_full_logits_ipc( self, data: BatchedDataDict[Any], - k: int, - micro_batch_size: Optional[int] = None, - ) -> BatchedDataDict[Any]: - """Return per-sample global top-k teacher logits. - - Computes one ``vocab_topk`` subset per sample (max over positions, - then top-k over vocab) and returns the full sequence's logits at - those columns. Used by cross-tokenizer distillation to enable - chunk-averaged KL with a stable vocab axis. - - Returns: - BatchedDataDict with: - - ``topk_logits``: [B, S, k] - - ``topk_indices``: [B, k] (per-sample, position-independent) - - v0: TP=1, CP=1, no sequence packing. - """ - topk_batch_size = ( - micro_batch_size - if micro_batch_size is not None - else self.cfg["logprob_batch_size"] - ) - sequence_dim, seq_dim_size = check_sequence_dim(data) - - out_topk_vals: list[torch.Tensor] = [] - out_topk_idx: list[torch.Tensor] = [] - self.model.eval() - - post_processor = GlobalTopkLogitsPostProcessor( - cfg=self.cfg, - device_mesh=self.device_mesh, - cp_mesh=self.cp_mesh, - tp_mesh=self.tp_mesh, - cp_size=self.cp_size, - k=k, - enable_seq_packing=self.enable_seq_packing, - ) - - with torch.no_grad(): - data.to("cuda") - processed_iterator, iterator_len = get_microbatch_iterator( - data, - self.cfg, - topk_batch_size, - self.dp_mesh, - tokenizer=self.tokenizer, - cp_size=self.cp_size, - ) - for batch_idx, processed_mb in enumerate(processed_iterator): - processed_inputs = processed_mb.processed_inputs - with get_train_context( - cp_size=self.cp_size, - cp_mesh=self.cp_mesh, - cp_buffers=processed_inputs.cp_buffers, - sequence_dim=sequence_dim, - dtype=self.dtype, - autocast_enabled=self.autocast_enabled, - ): - (vals, idx), _metrics, _ = forward_with_post_processing_fn( - model=self.model, - post_processing_fn=post_processor, - processed_mb=processed_mb, - is_reward_model=False, - allow_flash_attn_args=self.allow_flash_attn_args, - sampling_params=self.sampling_params, - sequence_dim=sequence_dim, - ) - if batch_idx >= iterator_len: - continue - out_topk_vals.append(vals.cpu()) - out_topk_idx.append(idx.cpu()) - - ret = BatchedDataDict[Any]() - # Pad each microbatch's vals along the sequence dim to the common - # length so concatenation on dim 0 matches the input batch shape. - # idx has no sequence dim and concatenates directly. - all_vals_padded = [] - target_seq_len = seq_dim_size - for vals in out_topk_vals: - pad_needed = target_seq_len - vals.shape[1] - if pad_needed > 0: - vals = torch.nn.functional.pad( - vals, (0, 0, 0, pad_needed, 0, 0), mode="constant", value=0.0 - ) - all_vals_padded.append(vals) - ret["topk_logits"] = ( - torch.cat(all_vals_padded, dim=0) - if len(all_vals_padded) > 1 - else all_vals_padded[0] - ).cpu() - ret["topk_indices"] = ( - torch.cat(out_topk_idx, dim=0) - if len(out_topk_idx) > 1 - else out_topk_idx[0] - ).cpu() - return ret - - def get_global_topk_logits_ipc( - self, - data: BatchedDataDict[Any], - k: int, micro_batch_size: Optional[int] = None, ) -> dict[str, Any]: - """Cross-tokenizer teacher forward; logits leave via CUDA IPC. - - Same per-sample global top-k math as :meth:`get_global_topk_logits`, - but the per-sample ``[T_t, k]`` logits views are exported as CUDA IPC - handles instead of moved to CPU. Top-k *indices* are tiny - (``[B_r, k]`` int64 ~= 6 MB for k=8192, B_r=96/dp_size) and go back - through Ray on CPU like before โ€” cheap, simpler, and bookkeeping - between handles and indices is preserved by returning both from this - single call. - - Lifetime: the source ``[B_r, T_t, k]`` CUDA tensor is stashed in - ``self._teacher_ipc_buffer`` so it outlives the consumer's - ``rebuild_cuda_tensor_from_ipc`` call. The driver releases it via - :meth:`release_ipc_buffer` after the student training step finishes. + """Cross-tokenizer teacher forward; full-vocab logits leave via CUDA IPC. + + Used by cross-tokenizer distillation. Returns the full teacher + vocab logits ``[T_t, V_t]`` per sample as a CUDA IPC handle so the + student worker (a separate Ray actor on the same node) can rebuild + the tensor without a CPU round-trip. Both gold-loss and projection- + KL paths consume full vocab: PT gold operates on full vocab end to + end; PT non-gold computes its ``global_top_indices`` reduction + *inside the loss*, not at the worker. + + Lifetime: the source ``[B_r, T_t, V_t]`` CUDA tensor is stashed in + ``self._teacher_ipc_buffer`` so per-sample views remain valid + across the consumer's :func:`rebuild_cuda_tensor_from_ipc` call. + The driver releases it via :meth:`release_ipc_buffer` after the + student training step finishes. Returns: dict with: - ``per_sample_handles``: ``list[B_r]`` of dicts each carrying a single ``logits_ipc`` handle tuple plus shape/dtype. - - ``topk_indices``: CPU tensor ``[B_r, k]``. + + v0 limitation: TP=1, CP=1, no sequence packing. """ - topk_batch_size = ( + forward_batch_size = ( micro_batch_size if micro_batch_size is not None else self.cfg["logprob_batch_size"] @@ -924,16 +825,14 @@ def get_global_topk_logits_ipc( sequence_dim, seq_dim_size = check_sequence_dim(data) out_vals: list[torch.Tensor] = [] - out_idx: list[torch.Tensor] = [] self.model.eval() - post_processor = GlobalTopkLogitsPostProcessor( + post_processor = FullLogitsPostProcessor( cfg=self.cfg, device_mesh=self.device_mesh, cp_mesh=self.cp_mesh, tp_mesh=self.tp_mesh, cp_size=self.cp_size, - k=k, enable_seq_packing=self.enable_seq_packing, ) @@ -942,7 +841,7 @@ def get_global_topk_logits_ipc( processed_iterator, iterator_len = get_microbatch_iterator( data, self.cfg, - topk_batch_size, + forward_batch_size, self.dp_mesh, tokenizer=self.tokenizer, cp_size=self.cp_size, @@ -957,7 +856,7 @@ def get_global_topk_logits_ipc( dtype=self.dtype, autocast_enabled=self.autocast_enabled, ): - (vals, idx), _metrics, _ = forward_with_post_processing_fn( + vals, _metrics, _ = forward_with_post_processing_fn( model=self.model, post_processing_fn=post_processor, processed_mb=processed_mb, @@ -969,21 +868,17 @@ def get_global_topk_logits_ipc( if batch_idx >= iterator_len: continue # Keep vals on CUDA for IPC; pad seq dim now so the stash - # tensor matches the canonical [B_r, T_t, k] shape. + # tensor matches the canonical [B_r, T_t, V_t] shape. pad_needed = seq_dim_size - vals.shape[1] if pad_needed > 0: vals = torch.nn.functional.pad( vals, (0, 0, 0, pad_needed, 0, 0), mode="constant", value=0.0 ) out_vals.append(vals.contiguous()) - out_idx.append(idx.cpu()) final_vals = ( torch.cat(out_vals, dim=0) if len(out_vals) > 1 else out_vals[0] - ) # CUDA [B_r, T_t, k] - final_idx = ( - torch.cat(out_idx, dim=0) if len(out_idx) > 1 else out_idx[0] - ) # CPU [B_r, k] + ) # CUDA [B_r, T_t, V_t] # Stash the full tensor so per-sample views remain valid across the # student's train call. Cleared by release_ipc_buffer(). @@ -993,7 +888,7 @@ def get_global_topk_logits_ipc( per_sample_handles: list[dict[str, Any]] = [] for i in range(final_vals.shape[0]): - view_i = final_vals[i] # [T_t, k] view; dim-0 slice of row-major + view_i = final_vals[i] # [T_t, V_t] view; dim-0 slice of row-major per_sample_handles.append( { "logits_ipc": get_handle_from_tensor(view_i), @@ -1001,10 +896,7 @@ def get_global_topk_logits_ipc( "dtype": view_i.dtype, } ) - return { - "per_sample_handles": per_sample_handles, - "topk_indices": final_idx, - } + return {"per_sample_handles": per_sample_handles} def release_ipc_buffer(self) -> None: """Drop the stashed teacher logits and reclaim GPU memory. From 7bfad34d8bf884e57b767d98d2ad9b3797c84d06 Mon Sep 17 00:00:00 2001 From: Adithya Hanasoge Date: Thu, 14 May 2026 11:36:43 -0700 Subject: [PATCH 4/6] feat(xtoken): add projection-prep CLI utilities MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds nemo_rl/utils/x_token/ โ€” the four offline CLI tools that produce the projection matrix consumed at training time by CrossTokenizerDistillationLossFn: - minimal_projection_generator.py: embedding-similarity top-k between student and teacher vocabularies. - minimal_projection_via_multitoken.py: augments the projection with multi-token mappings (e.g. "12" -> "1", "2") in both directions. - reapply_exact_map.py: pins literally-identical tokens to 1-to-1 mappings with weight 1.0. - sort_and_cut_projection_matrix.py: trims to runtime top_k; reads enable_scale_trick metadata to auto-enable --preserve_last. All four scripts take --student-model and --teacher-model as required arguments; the projection direction follows the CLI args exactly (no alphabetical-order swap). Common helpers (add_multitoken_mappings, print_projection_map_examples, print_projection_statistics, clean_model_name_for_filename) live at module scope. Cross-script dict keys renamed model_A_id -> student_model_id (legacy keys still read). Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Adithya Hanasoge --- nemo_rl/utils/x_token/__init__.py | 13 + .../x_token/minimal_projection_generator.py | 553 +++++++++++ .../minimal_projection_via_multitoken.py | 936 ++++++++++++++++++ nemo_rl/utils/x_token/reapply_exact_map.py | 228 +++++ .../x_token/sort_and_cut_projection_matrix.py | 462 +++++++++ 5 files changed, 2192 insertions(+) create mode 100644 nemo_rl/utils/x_token/__init__.py create mode 100644 nemo_rl/utils/x_token/minimal_projection_generator.py create mode 100644 nemo_rl/utils/x_token/minimal_projection_via_multitoken.py create mode 100644 nemo_rl/utils/x_token/reapply_exact_map.py create mode 100644 nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py diff --git a/nemo_rl/utils/x_token/__init__.py b/nemo_rl/utils/x_token/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/nemo_rl/utils/x_token/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_rl/utils/x_token/minimal_projection_generator.py b/nemo_rl/utils/x_token/minimal_projection_generator.py new file mode 100644 index 0000000000..877a1669c4 --- /dev/null +++ b/nemo_rl/utils/x_token/minimal_projection_generator.py @@ -0,0 +1,553 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +import re + +import torch +from tqdm.auto import tqdm +from transformers import AutoConfig, AutoModel, AutoTokenizer + + +EXACT_MATCH_ONLY = False + + +def parse_arguments() -> argparse.Namespace: + """Parse CLI arguments for projection-matrix generation.""" + parser = argparse.ArgumentParser( + description="Generate a sparse projection map between a student and a teacher tokenizer.", + ) + parser.add_argument( + "--student-model", + type=str, + required=True, + help="HuggingFace model name for the student tokenizer (source vocabulary).", + ) + parser.add_argument( + "--teacher-model", + type=str, + required=True, + help="HuggingFace model name for the teacher tokenizer (target vocabulary).", + ) + parser.add_argument( + "--keep_top_tokens", + type=int, + default=-1, + help="Number of top tokens to keep for each vocabulary. -1 means all.", + ) + parser.add_argument( + "--data_dir", + type=str, + default="cross_tokenizer_data/", + help="Directory for importance scores and cached data.", + ) + parser.add_argument( + "--top_k", + type=int, + default=10, + help="Number of top projections to keep for each token.", + ) + parser.add_argument( + "--weight_threshold", + type=float, + default=0.0, + help="Minimum weight threshold to keep a projection. Values below this will be filtered out.", + ) + parser.add_argument( + "--force_recompute", + action="store_true", + help="Force recomputation of embeddings even if cached files exist.", + ) + parser.add_argument( + "--use_canonicalization", + action="store_true", + help=( + "Apply token canonicalization before generating embeddings to normalize " + "different tokenizer representations (e.g. ฤ  vs โ– prefixes, ฤŠ vs \\n)." + ), + ) + return parser.parse_args() + + +args = parse_arguments() + +if args.student_model == args.teacher_model: + raise ValueError( + f"Cannot use the same model for both student and teacher: {args.student_model}" + ) + +EMBEDDING_MODEL_CHOICES = [ + {"name": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", "type": "sbert"}, + {"name": "sentence-transformers/all-mpnet-base-v2", "type": "sbert"}, + {"name": "sentence-transformers/all-MiniLM-L6-v2", "type": "sbert"}, + {"name": "Qwen/Qwen3-Embedding-4B", "type": "llm_first_layer"}, + {"name": "Qwen/Qwen3-Embedding-0.6B", "type": "llm_first_layer"}, +] + +MAX_SEQ_LENGTH_EMBEDDING = 64 +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def sinkhorn(A, n_iters=10): + for _ in range(n_iters): + if _ % 2 == 0: + # A = A / (A.sum(dim=0, keepdim=True) + 1e-6) + col_sums = A.sum(dim=0, keepdim=True) + safe_col_sums = torch.where(col_sums == 0, torch.ones_like(col_sums), col_sums) + A = A / safe_col_sums + else: + #0, 2, 4, 6 + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + + return A + +def sinkhorn_one_dim(A, n_iters=1): + for _ in range(n_iters): + + + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + + return A + +# --- Helper Functions --- + +def clean_model_name_for_filename(name: str) -> str: + """Removes parameter counts and common suffixes from model names for cleaner filenames.""" + # Removes patterns like -8B, -1.5B, -4b, -125m etc. + cleaned_name = re.sub(r'-?[0-9\.]+[bBmB]', '', name, flags=re.IGNORECASE) + # Remove common suffixes + cleaned_name = cleaned_name.replace('-Base', '').replace('-it', '').replace('-Instruct', '') + # Clean up any leading/trailing hyphens that might result + cleaned_name = cleaned_name.strip('-_') + if 'mini' in name: + cleaned_name += "_mini" + return cleaned_name + +def load_tokenizer(model_id_or_path): + """Loads a HuggingFace tokenizer, setting a pad token if necessary.""" + tok = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok + + +def save_data(data, filename): + """Saves data to a torch file.""" + os.makedirs(os.path.dirname(filename), exist_ok=True) + torch.save(data.cpu(), filename) + print(f"Data saved to {filename}") + +def load_data(filename): + """Loads data from a torch file.""" + return torch.load(filename) + +def get_llm_first_layer_embeddings(decoded_tokens_list, llm_embedding_tokenizer, llm_embedding_model, max_seq_length_embedding, device, batch_size=32): + """Generates embeddings using the first layer of a given LLM.""" + all_embeddings = [] + llm_embedding_model.eval() + embedding_dim = llm_embedding_model.config.hidden_size + + for i in tqdm(range(0, len(decoded_tokens_list), batch_size), desc="Encoding tokens with LLM"): + batch_tokens = decoded_tokens_list[i:i + batch_size] + inputs = llm_embedding_tokenizer( + batch_tokens, return_tensors="pt", padding=True, truncation=True, + max_length=max_seq_length_embedding, add_special_tokens=False, + ).to(device) + + with torch.no_grad(): + outputs = llm_embedding_model(**inputs, output_hidden_states=True) + first_layer_output = outputs.hidden_states[0] + + for k in range(first_layer_output.shape[0]): + valid_token_mask = inputs['attention_mask'][k] == 1 + if valid_token_mask.sum() > 0: + pooled_embedding = first_layer_output[k, valid_token_mask].mean(dim=0) + all_embeddings.append(pooled_embedding) + else: + all_embeddings.append(torch.zeros(embedding_dim, device=device)) + + return torch.stack(all_embeddings).to(device) + + +def compute_chunked_projection_map(embeddings_query, embeddings_corpus, args, device, chunk_size=1000): + """Computes projection map in chunks to save memory.""" + num_queries = embeddings_query.shape[0] + target_vocab_size = embeddings_corpus.shape[0] + + # Pre-allocate result tensors + all_top_k_indices = torch.zeros((num_queries, args.top_k), dtype=torch.long) + all_top_k_likelihoods = torch.zeros((num_queries, args.top_k), dtype=torch.float32) + + # Normalize corpus embeddings once + embeddings_corpus_norm = torch.nn.functional.normalize(embeddings_corpus.to(device).float(), p=2, dim=1) + + for chunk_start in tqdm(range(0, num_queries, chunk_size), desc="Processing chunks"): + chunk_end = min(chunk_start + chunk_size, num_queries) + chunk_query = embeddings_query[chunk_start:chunk_end].to(device).float() + + with torch.no_grad(): + # Compute similarities for this chunk + chunk_query_norm = torch.nn.functional.normalize(chunk_query, p=2, dim=1) + similarities = torch.matmul(chunk_query_norm, embeddings_corpus_norm.t()) + + # Generate projection map for this chunk + chunk_top_k_indices, chunk_top_k_likelihoods = generate_projection_map_chunk(similarities, args) + + # Store results + all_top_k_indices[chunk_start:chunk_end] = chunk_top_k_indices.cpu() + all_top_k_likelihoods[chunk_start:chunk_end] = chunk_top_k_likelihoods.cpu() + + # Clear GPU memory + del similarities, chunk_query_norm, chunk_top_k_indices, chunk_top_k_likelihoods + torch.cuda.empty_cache() + + return all_top_k_indices, all_top_k_likelihoods + +def generate_projection_map_chunk(similarities, args): + """Calculates the sparse likelihood map from a similarity matrix chunk.""" + similarities = similarities.abs() + similarities[similarities > 0.999999999] = 1.0 + max_similarities = torch.max(similarities, dim=1, keepdim=True)[0] + sharpness = 10.0 * max_similarities + likelihood = similarities ** sharpness + + # Normalize rows + likelihood = sinkhorn_one_dim(likelihood) + + # Extract final top-k values from the normalized sparse likelihood matrix + top_k_likelihood, top_k_indices = likelihood.topk(args.top_k, dim=1) + + # Apply weight threshold filtering if specified + if args.weight_threshold > 0.0: + threshold_mask = top_k_likelihood >= args.weight_threshold + top_k_indices = top_k_indices.where(threshold_mask, torch.full_like(top_k_indices, -1)) + + return top_k_indices, top_k_likelihood + +def project_token_likelihoods(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device): + """Projects token likelihoods from a source to a target vocabulary using a sparse map.""" + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + if source_vocab_size != projection_map_indices.shape[0]: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches projection map size ({projection_map_indices.shape[0]})") + + top_k = projection_map_indices.shape[1] + input_likelihoods = input_likelihoods.to(device) + projection_map_indices = projection_map_indices.to(device) + projection_map_values = projection_map_values.to(device) + + crow_indices = torch.arange(0, (source_vocab_size + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = projection_map_indices.flatten() + values = projection_map_values.flatten() + + sparse_projection_matrix = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(source_vocab_size, target_vocab_size), device=device + ) + + reshaped_input = input_likelihoods.reshape(batch_size * seq_len, source_vocab_size) + projected_likelihoods_reshaped = torch.matmul(reshaped_input, sparse_projection_matrix) + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, target_vocab_size) + +def debug_projection_map(top_k_indices, top_k_likelihood, source_tokenizer, target_tokenizer, direction="", N=2000): + """Debug function to show first N rows with decoded tokens and weights.""" + N = min(N, top_k_indices.shape[0]) # Show first N rows or less + print(f"\n--- Debugging projection map {direction} (first {N} rows) ---") + + for row_idx in range(N): + # for row_idx in range(-N,-1): + # Decode source token + try: + token_id = row_idx if row_idx >= 0 else top_k_indices.shape[0] + row_idx + source_token = source_tokenizer.decode([token_id]) + # source_token = source_tokenizer.convert_ids_to_tokens([token_id])[0] + source_token_str = repr(source_token) # Use repr to show special chars + except Exception: + source_token_str = f"" + + # Build the target tokens with weights string + row_indices = top_k_indices[row_idx].cpu().numpy() + row_weights = top_k_likelihood[row_idx].float().cpu().numpy() + + weight_total = 0 + target_parts = [] + + if row_weights.max() != row_weights[-1]: + continue + + + for target_idx, weight in zip(row_indices, row_weights): + try: + target_token = target_tokenizer.decode([target_idx]) + target_token_str = repr(target_token) + except Exception: + target_token_str = f"" + + + target_parts.append(f"{target_token_str}({weight:.4f})") + weight_total += weight + + target_string = " ".join(target_parts) + # print(f"Weight total: {weight_total:.4f}") + print(f"{source_token_str} -> {target_string}") + +def generate_projection_map(similarities, args): + """Calculates the sparse likelihood map from a similarity matrix.""" + similarities = similarities.abs() + similarities[similarities > 0.999999999] = 1.0 + max_similarities = torch.max(similarities, dim=1, keepdim=True)[0] + sharpness = 10.0 * max_similarities + likelihood = similarities ** sharpness + + # Create a sparse representation by keeping only top-k values + # top_k_likelihood_pre_norm, _ = likelihood.topk(args.top_k, dim=1) + # likelihood = likelihood.where(likelihood >= top_k_likelihood_pre_norm[:, -1:], torch.zeros_like(likelihood)) + + # Normalize the row to sum to 1, handling rows that are all zero + # row_sums = likelihood.sum(dim=1, keepdim=True) + # safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + # likelihood = likelihood / safe_row_sums + # likelihood = sinkhorn_one_dim(likelihood) + + # Get the final top-k values and their indices from the sparse, normalized likelihood matrix + top_k_likelihood, top_k_indices = likelihood.topk(args.top_k, dim=1) + + # Store top-k values before zeroing (to avoid losing them) + row_indices = torch.arange(likelihood.shape[0]).unsqueeze(1).expand(-1, args.top_k) + top_k_values = likelihood[row_indices, top_k_indices].clone() + + # Zero out entire likelihood matrix in-place, then restore only top-k elements + likelihood.zero_() + likelihood[row_indices, top_k_indices] = top_k_values + + # likelihood = sinkhorn(likelihood, n_iters=1) + # likelihood = sinkhorn(likelihood, n_iters=1) works the best + + + likelihood = sinkhorn_one_dim(likelihood) + + # Extract final top-k values from the normalized sparse likelihood matrix + top_k_likelihood, top_k_indices = likelihood.topk(args.top_k, dim=1) + + # Apply weight threshold filtering if specified + if args.weight_threshold > 0.0: + print(f"Applying weight threshold filter: {args.weight_threshold}") + # Create mask for values above threshold + threshold_mask = top_k_likelihood >= args.weight_threshold + + #set indices to -1 where threshold is not met + top_k_indices = top_k_indices.where(threshold_mask, torch.full_like(top_k_indices, -1)) + + # # Count how many values per row are above threshold + # valid_counts = threshold_mask.sum(dim=1) + # total_filtered = (valid_counts == 0).sum().item() + # total_kept = threshold_mask.sum().item() + # total_possible = top_k_likelihood.numel() + + # print(f"Kept {total_kept}/{total_possible} ({100*total_kept/total_possible:.1f}%) projections above threshold") + + # if total_filtered > 0: + # print(f"Warning: {total_filtered} tokens have no projections above threshold {args.weight_threshold}") + + # # Zero out values below threshold + # filtered_likelihood = top_k_likelihood * threshold_mask.to(top_k_likelihood.dtype) + # filtered_indices = top_k_indices.clone() + + # # For rows with no values above threshold, keep the top value to avoid empty rows + # empty_rows = valid_counts == 0 + # if empty_rows.any(): + # print(f"Keeping top projection for {empty_rows.sum().item()} tokens with no values above threshold") + # filtered_likelihood[empty_rows, 0] = top_k_likelihood[empty_rows, 0] + + # top_k_likelihood = filtered_likelihood + # top_k_indices = filtered_indices + + + return top_k_indices, top_k_likelihood + +# --- Main Execution --- +if __name__ == "__main__": + # 1. Load student and teacher tokenizers directly from --student-model / --teacher-model. + # No alphabetical swap โ€” the projection direction follows the CLI args. + student = {"id": args.student_model} + student["name"] = student["id"].split("/")[-1] + print(f"Loading student tokenizer: {student['name']}") + student["tokenizer"] = load_tokenizer(student["id"]) + + teacher = {"id": args.teacher_model} + teacher["name"] = teacher["id"].split("/")[-1] + print(f"Loading teacher tokenizer: {teacher['name']}") + teacher["tokenizer"] = load_tokenizer(teacher["id"]) + + print(f"\nSource (student): {student['name']}") + print(f"Target (teacher): {teacher['name']}") + + student_config = AutoConfig.from_pretrained( + student["id"], trust_remote_code="nvidia" in student["id"] + ) + teacher_config = AutoConfig.from_pretrained( + teacher["id"], trust_remote_code="nvidia" in teacher["id"] + ) + source_vocab_size = student_config.vocab_size + if "gemma" in teacher["id"]: + target_vocab_size = teacher_config.text_config.vocab_size + else: + target_vocab_size = teacher_config.vocab_size + + print(f"Source vocab size (full): {source_vocab_size}") + print(f"Target vocab size (full): {target_vocab_size}") + + # 2. Select and Load Embedding Model + embedding_model_index = 3 # Default to a good LLM embedder + selected_model_info = EMBEDDING_MODEL_CHOICES[embedding_model_index] + embedding_model_name = selected_model_info["name"] + embedding_model_type = selected_model_info["type"] + print(f"\nUsing embedding model: {embedding_model_name} ({embedding_model_type})") + + # 3. Generate or Load Embeddings + canonicalization_suffix = "_canonical" if args.use_canonicalization else "_raw" + embeddings_path_student = os.path.join( + args.data_dir, + f"embeddings_{student['name']}_{embedding_model_name.replace('/', '_')}_full{canonicalization_suffix}.pt", + ) + embeddings_path_teacher = os.path.join( + args.data_dir, + f"embeddings_{teacher['name']}_{embedding_model_name.replace('/', '_')}_full{canonicalization_suffix}.pt", + ) + + if ( + not args.force_recompute + and os.path.exists(embeddings_path_student) + and os.path.exists(embeddings_path_teacher) + ): + print("Loading cached embeddings...") + student["embeddings"] = load_data(embeddings_path_student).to(DEVICE) + teacher["embeddings"] = load_data(embeddings_path_teacher).to(DEVICE) + else: + print("Generating new embeddings...") + + # Generate raw decoded tokens + raw_tokens_student = [ + student["tokenizer"].decode([idx]) + for idx in range(student["tokenizer"].vocab_size) + ] + raw_tokens_teacher = [ + teacher["tokenizer"].decode([idx]) + for idx in range(teacher["tokenizer"].vocab_size) + ] + + # Apply canonicalization if requested + if args.use_canonicalization: + from nemo_rl.algorithms.x_token.tokenalign import TokenAligner + + print("Applying token canonicalization before embedding generation...") + decoded_tokens_student = [TokenAligner._canonical_token(token) for token in raw_tokens_student] + decoded_tokens_teacher = [TokenAligner._canonical_token(token) for token in raw_tokens_teacher] + + # Show some examples of canonicalization + print("Canonicalization examples:") + for i in range(min(10, len(raw_tokens_student))): + if raw_tokens_student[i] != decoded_tokens_student[i]: + print(f" student: '{raw_tokens_student[i]}' -> '{decoded_tokens_student[i]}'") + for i in range(min(10, len(raw_tokens_teacher))): + if raw_tokens_teacher[i] != decoded_tokens_teacher[i]: + print(f" teacher: '{raw_tokens_teacher[i]}' -> '{decoded_tokens_teacher[i]}'") + + print( + f"Applied canonicalization to {len(decoded_tokens_student)} student tokens " + f"and {len(decoded_tokens_teacher)} teacher tokens" + ) + else: + print("Using raw decoded tokens without canonicalization") + decoded_tokens_student = raw_tokens_student + decoded_tokens_teacher = raw_tokens_teacher + + if embedding_model_type == "sbert": + try: + from sentence_transformers import SentenceTransformer + except ImportError as e: + raise ImportError( + "The sbert embedding path requires `sentence-transformers` to be installed. " + "Install it with `uv pip install sentence-transformers`, " + "or pick an embedding model with type `llm_first_layer` instead." + ) from e + sbert_model = SentenceTransformer(embedding_model_name, device=DEVICE) + student["embeddings"] = sbert_model.encode(decoded_tokens_student, convert_to_tensor=True, show_progress_bar=True) + teacher["embeddings"] = sbert_model.encode(decoded_tokens_teacher, convert_to_tensor=True, show_progress_bar=True) + elif embedding_model_type == "llm_first_layer": + llm_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, trust_remote_code=True) + if llm_tokenizer.pad_token is None: + llm_tokenizer.pad_token = llm_tokenizer.eos_token + llm_model = AutoModel.from_pretrained( + embedding_model_name, torch_dtype=torch.bfloat16, trust_remote_code=True + ).to(DEVICE) + student["embeddings"] = get_llm_first_layer_embeddings(decoded_tokens_student, llm_tokenizer, llm_model, MAX_SEQ_LENGTH_EMBEDDING, DEVICE) + teacher["embeddings"] = get_llm_first_layer_embeddings(decoded_tokens_teacher, llm_tokenizer, llm_model, MAX_SEQ_LENGTH_EMBEDDING, DEVICE) + + save_data(student["embeddings"], embeddings_path_student) + save_data(teacher["embeddings"], embeddings_path_teacher) + + # 4. Compute Similarity and Generate Projection Maps (chunked to save memory) + print("\nComputing projection map in chunks to save memory...") + chunk_size = 500 # Process 500 tokens at a time to avoid OOM + top_k_indices_student_to_teacher, top_k_likelihood_student_to_teacher = compute_chunked_projection_map( + student["embeddings"], teacher["embeddings"], args, DEVICE, chunk_size=chunk_size + ) + + # Note: Exact match enforcement is skipped in chunked mode for simplicity. + # The chunked approach processes similarities in small batches to avoid OOM. + + # 5. Save the Combined Projection Map + print("\nSaving combined projection map...") + student_clean_name = clean_model_name_for_filename(student["name"]) + teacher_clean_name = clean_model_name_for_filename(teacher["name"]) + output_filename = f"temp_projection_map_{student_clean_name}_to_{teacher_clean_name}_top_{args.top_k}.pt" + if args.weight_threshold > 0.0: + output_filename = output_filename.replace(".pt", f"_thresh_{args.weight_threshold:.3f}.pt") + output_path = os.path.join(args.data_dir, output_filename) + + torch.save( + { + "indices": top_k_indices_student_to_teacher.cpu(), + "likelihoods": top_k_likelihood_student_to_teacher.cpu(), + "student_model_id": student["id"], + "teacher_model_id": teacher["id"], + }, + output_path, + ) + + print(f"Saved combined projection map to: {output_path}") + + # 6. Example Usage of the Projection Function + print("\n--- Testing projection function (student -> teacher) ---") + source_vocab_size_student = student["embeddings"].shape[0] + target_vocab_size_teacher = teacher["embeddings"].shape[0] + dummy_tensor = torch.randn( + 1, 4096, source_vocab_size_student, device=DEVICE, dtype=torch.bfloat16 + ) + + # Transform this tensor using the projection map (convert to float32 for compatibility) + projected_tensor = project_token_likelihoods( + dummy_tensor.float(), + top_k_indices_student_to_teacher, + top_k_likelihood_student_to_teacher, + target_vocab_size_teacher, + DEVICE, + ) + print(f"Input tensor shape: {dummy_tensor.shape}") + print(f"Projected tensor shape: {projected_tensor.shape}") + print("Projection test successful.") diff --git a/nemo_rl/utils/x_token/minimal_projection_via_multitoken.py b/nemo_rl/utils/x_token/minimal_projection_via_multitoken.py new file mode 100644 index 0000000000..3e3c3a1c72 --- /dev/null +++ b/nemo_rl/utils/x_token/minimal_projection_via_multitoken.py @@ -0,0 +1,936 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import difflib +import os +import re +from collections import defaultdict + +import torch +import tqdm +from transformers import AutoConfig, AutoTokenizer + +from nemo_rl.algorithms.x_token.tokenalign import TokenAligner + + +#remove all special tokens that start with <| and end with |> + + +# compare 3 ways to estimate likelihood matrix: +# 1. using embeddings from another model, like was done in minimal_projection_generator.py +# 2. using text analysis like in tokenalign_likelihood_estimate.py +# 3. use one token to multiple and assign those as transformation matrix + +# this file implements 3rd way + +def sinkhorn_one_dim(A, n_iters=1): + for _ in range(n_iters): + + + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + + return A + +def apply_canonicalization_if_enabled(token_str, use_canonicalization): + """Apply canonicalization to token string if enabled.""" + if use_canonicalization: + return TokenAligner._canonical_token(token_str) + return token_str + +def create_weight_distribution(num_tokens): + """Create weight distribution for multi-token mappings using exponential decay.""" + weights = [] + base = 0.9 + for i in range(num_tokens): + if i == 0: + weights.append(base) + else: + weights.append(base * (0.1 ** i)) + + # Normalize to sum to 1 + total = sum(weights) + weights = [w / total for w in weights] + return weights + + +def clean_model_name_for_filename(name: str) -> str: + """Removes parameter counts and common suffixes from model names for cleaner filenames.""" + # Removes patterns like -8B, -1.5B, -4b, -125m etc. + cleaned_name = re.sub(r"-?[0-9\.]+[bBmB]", "", name, flags=re.IGNORECASE) + # Remove common suffixes + cleaned_name = cleaned_name.replace("-Base", "").replace("-it", "").replace("-Instruct", "") + # Clean up any leading/trailing hyphens that might result + cleaned_name = cleaned_name.strip("-_") + return cleaned_name + + +def print_projection_map_examples( + transformation_counts, + source_tokenizer, + target_tokenizer, + *, + direction: str = "", + num_examples: int = 50, + use_raw_tokens: bool = False, + use_canonicalization: bool = False, +) -> None: + """Print a sample of projection mappings with decoded tokens and weights. + + Args: + transformation_counts: Dict mapping (source_id, target_id) -> weight. + source_tokenizer: HuggingFace tokenizer for source-side decoding. + target_tokenizer: HuggingFace tokenizer for target-side decoding. + direction: Human-readable direction label, e.g. "student->teacher". + num_examples: Number of source tokens to print (highest-id first). + use_raw_tokens: If True, use ``convert_ids_to_tokens`` instead of ``decode``. + use_canonicalization: If True, apply ``TokenAligner._canonical_token`` after decoding. + """ + print( + f"\n--- Projection map examples {direction} (showing {num_examples} examples) ---" + ) + + # Group transformation_counts by source token (student token) + source_to_targets = defaultdict(list) + for (source_id, target_id), weight in transformation_counts.items(): + source_to_targets[source_id].append((target_id, weight)) + + # Take the highest-id `num_examples` source tokens for inspection. + sorted_sources = sorted(source_to_targets.keys())[-num_examples:] + + for source_id in sorted_sources: + try: + if use_raw_tokens: + source_token = source_tokenizer.convert_ids_to_tokens([source_id])[0] + else: + source_token = source_tokenizer.decode([source_id]) + source_token = apply_canonicalization_if_enabled(source_token, use_canonicalization) + source_token_str = repr(source_token) + except Exception: + source_token_str = f"" + + targets_weights = sorted(source_to_targets[source_id], key=lambda x: x[1], reverse=True) + + target_parts = [] + for target_id, weight in targets_weights: + try: + if use_raw_tokens: + target_token = target_tokenizer.convert_ids_to_tokens([target_id])[0] + else: + target_token = target_tokenizer.decode([target_id]) + target_token = apply_canonicalization_if_enabled(target_token, use_canonicalization) + target_token_str = repr(target_token) + except Exception: + target_token_str = f"" + target_parts.append(f"{target_token_str}({weight:.4f})") + + print(f"{source_token_str} -> {' '.join(target_parts)}") + + +def add_multitoken_mappings( + *, + source_tokenizer, + target_tokenizer, + source_total_vocab_size: int, + source_ignore_ids, + target_ignore_ids, + source_role: str, + transformation_counts, + tokens_to_cut: int, + use_raw_tokens: bool, + use_canonicalization: bool, +) -> tuple[dict, list]: + """Re-tokenize every source-vocab token with the target tokenizer and accumulate weighted mappings. + + The two passes of the multi-token projection (student->teacher and + teacher->student) share this same logic โ€” only the source/target + tokenizer pair and the role labels change. + + Args: + source_tokenizer: Tokenizer to decode source-vocab tokens. + target_tokenizer: Tokenizer used to re-encode each decoded source token. + source_total_vocab_size: Total source-vocab size to iterate. + source_ignore_ids: Source token ids to skip entirely. + target_ignore_ids: Target token ids that, if present in any encoding, drop the whole mapping. + source_role: Either ``"student"`` or ``"teacher"``. Determines which + position of the ``transformation_counts`` key the source id fills + (the key is always ``(student_id, teacher_id)``). + transformation_counts: Mutable mapping ``(student_id, teacher_id) -> float``; + weights are accumulated in place. + tokens_to_cut: Cap the re-encoding at this many target tokens. + use_raw_tokens: If True, decode via ``convert_ids_to_tokens`` instead of ``decode``. + use_canonicalization: If True, apply ``TokenAligner._canonical_token`` after decoding. + + Returns: + ``(decoded_source_tokens, examples)`` where ``decoded_source_tokens`` is a + ``{source_id: decoded_str}`` dict and ``examples`` is a list of multi-token + examples (only those with ``>= 2`` target tokens), keyed by + ``f"{source_role}_token"`` / ``f"{target_role}_tokens"`` etc. + + Raises: + ValueError: If ``source_role`` is not "student" or "teacher". + """ + if source_role not in ("student", "teacher"): + raise ValueError(f"source_role must be 'student' or 'teacher', got {source_role!r}") + target_role = "teacher" if source_role == "student" else "student" + + decoded_source: dict[int, str] = {} + print(f"Decoding {source_role} tokens...") + for token_id in tqdm.tqdm( + range(source_total_vocab_size), desc=f"Decoding {source_role} tokens" + ): + if token_id in source_ignore_ids: + continue + try: + if use_raw_tokens: + decoded = source_tokenizer.convert_ids_to_tokens([token_id])[0] + else: + decoded = source_tokenizer.decode([token_id]) + if decoded.startswith("<|") and decoded.endswith("|>"): + print(f"Skipping special token: {decoded}") + continue + decoded = apply_canonicalization_if_enabled(decoded, use_canonicalization) + decoded_source[token_id] = decoded + except Exception: + continue + print(f"Successfully decoded {len(decoded_source)} {source_role} tokens") + + examples: list[dict] = [] + print(f"Finding {source_role}->{target_role} multi-token mappings...") + for source_token_id, source_token_str in tqdm.tqdm( + decoded_source.items(), desc=f"Processing {source_role} tokens" + ): + target_encoding = target_tokenizer( + source_token_str, add_special_tokens=False, return_attention_mask=False + ) + target_token_ids = target_encoding["input_ids"] + if any(tid in target_ignore_ids for tid in target_token_ids): + continue + + target_token_ids = target_token_ids[:tokens_to_cut] + weights = create_weight_distribution(len(target_token_ids)) + + for target_token_id, weight in zip(target_token_ids, weights): + if source_role == "student": + key = (source_token_id, target_token_id) + else: + key = (target_token_id, source_token_id) + transformation_counts[key] += weight + + if len(target_token_ids) >= 2: + decoded_targets = [target_tokenizer.decode([tid]) for tid in target_token_ids] + examples.append( + { + f"{source_role}_token": source_token_str, + f"{source_role}_id": source_token_id, + f"{target_role}_tokens": decoded_targets, + f"{target_role}_ids": target_token_ids, + "weights": weights, + } + ) + + return decoded_source, examples + + +def print_projection_statistics( + *, + transformation_counts, + student_to_teacher_examples: list, + teacher_to_student_examples: list, + enable_reverse_pass: bool, + max_examples_shown: int = 10, +) -> None: + """Print a summary of the multi-token mappings collected so far.""" + print("\n=== SUMMARY ===") + print( + f"Found {len(student_to_teacher_examples)} student tokens that map " + f"to multiple teacher tokens" + ) + + if student_to_teacher_examples: + print("\nExamples of student->teacher multi-token mappings:") + for example in student_to_teacher_examples[:max_examples_shown]: + print( + f" Student '{example['student_token']}' -> Teacher " + f"{example['teacher_tokens']} (weights: {example['weights']})" + ) + if len(student_to_teacher_examples) > max_examples_shown: + print( + f" ... and {len(student_to_teacher_examples) - max_examples_shown} more." + ) + + if enable_reverse_pass: + print( + f"\nReverse pass enabled โ€” added " + f"{len(teacher_to_student_examples)} teacher->student bidirectional examples" + ) + + print(f"\nTotal transformation entries: {len(transformation_counts)}") + + +def find_similar_special_tokens(tokenizer_a, tokenizer_b, similarity_threshold=0.4, top_k_matches=3): + """Find similar special tokens between two tokenizers using string similarity.""" + + def is_special_token(token_str): + """Check if a token looks like a special token""" + return (token_str.startswith('<|') and token_str.endswith('|>')) or \ + (token_str.startswith('<') and token_str.endswith('>')) or \ + token_str in ['', '', '', '', '', ''] + + def extract_special_tokens(tokenizer): + """Extract all special tokens from a tokenizer with their IDs""" + special_tokens = {} + vocab = tokenizer.get_vocab() + for token_str, token_id in vocab.items(): + if is_special_token(token_str): + special_tokens[token_id] = token_str + return special_tokens + + def calculate_similarity(token_a, token_b): + """Calculate similarity between two token strings""" + # Use difflib for sequence similarity + seq_similarity = difflib.SequenceMatcher(None, token_a, token_b).ratio() + + # Extract key words from special tokens for semantic matching + def extract_keywords(token): + # Remove special token markers and split by common separators + cleaned = re.sub(r'[<>|_]', ' ', token.lower()) + words = [w for w in cleaned.split() if len(w) > 2] # Filter short words + return set(words) + + keywords_a = extract_keywords(token_a) + keywords_b = extract_keywords(token_b) + + # Jaccard similarity for keywords + if keywords_a or keywords_b: + keyword_similarity = len(keywords_a.intersection(keywords_b)) / len(keywords_a.union(keywords_b)) + else: + keyword_similarity = 0.0 + + # Combined similarity (weighted average) + return 0.6 * seq_similarity + 0.4 * keyword_similarity + + print("Extracting special tokens...") + special_tokens_a = extract_special_tokens(tokenizer_a) # student + special_tokens_b = extract_special_tokens(tokenizer_b) # teacher + + print(f"Found {len(special_tokens_a)} special tokens in student tokenizer") + print(f"Found {len(special_tokens_b)} special tokens in teacher tokenizer") + + # Find matches + special_token_mappings = [] + + print("Finding similar special tokens...") + for token_id_a, token_str_a in special_tokens_a.items(): + similarities = [] + for token_id_b, token_str_b in special_tokens_b.items(): + similarity = calculate_similarity(token_str_a, token_str_b) + if similarity >= similarity_threshold: + similarities.append((token_id_b, token_str_b, similarity)) + + # Sort by similarity and take top-k + similarities.sort(key=lambda x: x[2], reverse=True) + for token_id_b, token_str_b, similarity in similarities[:top_k_matches]: + special_token_mappings.append({ + 'student_id': token_id_a, + 'student_token': token_str_a, + 'teacher_id': token_id_b, + 'teacher_token': token_str_b, + 'similarity': similarity + }) + + return special_token_mappings + + +def parse_arguments(): + """Parse command line arguments for the multi-token projection script.""" + parser = argparse.ArgumentParser( + description="Generate multi-token projection mappings between tokenizers", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + # Model selection arguments + parser.add_argument( + "--student-model", + type=str, + required=True, + help="Student model name or path", + ) + parser.add_argument( + "--teacher-model", + type=str, + required=True, + help="Teacher model name or path", + ) + parser.add_argument( + "--num-examples", + type=int, + default=0, + help=( + "Print this many projection-map examples after building the " + "transformation matrix (0 disables; useful for spot-checking)." + ), + ) + + # Boolean flags + parser.add_argument( + "--enable-scale-trick", + action="store_true", + default=True, + help="Enable scale trick (set last column likelihood to 0.2)" + ) + parser.add_argument( + "--disable-scale-trick", + action="store_false", + dest="enable_scale_trick", + help="Disable scale trick" + ) + parser.add_argument( + "--enable-reverse-pass", + action="store_true", + default=True, + help="Enable second pass: student tokens -> teacher tokens" + ) + parser.add_argument( + "--disable-reverse-pass", + action="store_false", + dest="enable_reverse_pass", + help="Disable reverse pass" + ) + parser.add_argument( + "--enable-exact-match", + action="store_true", + default=False, + help="Enable exact match enforcement for identical tokens" + ) + parser.add_argument( + "--use-raw-tokens", + action="store_true", + default=False, + help="Use convert_ids_to_tokens instead of decode, should be False" + ) + parser.add_argument( + "--enable-special-token-mapping", + action="store_true", + default=True, + help="Enable mapping of similar special tokens" + ) + parser.add_argument( + "--disable-special-token-mapping", + action="store_false", + dest="enable_special_token_mapping", + help="Disable special token mapping" + ) + parser.add_argument( + "--use-canonicalization", + action="store_true", + default=False, + help="Apply token canonicalization before processing to normalize different tokenizer representations (e.g., ฤ  vs โ– prefixes, ฤŠ vs \\n)" + ) + + # Numeric parameters + parser.add_argument( + "--tokens-to-cut", + type=int, + default=4, + help="Maximum number of tokens to consider for multi-token mappings" + ) + parser.add_argument( + "--top-k", + type=int, + default=32, + help="Number of top projections to keep for each token" + ) + parser.add_argument( + "--special-token-similarity-threshold", + type=float, + default=0.3, + help="Minimum similarity threshold for special token matching" + ) + parser.add_argument( + "--special-token-top-k", + type=int, + default=None, + help="Top K matches for each special token (defaults to --top-k value)" + ) + + # File paths + parser.add_argument( + "--initial-projection-path", + type=str, + default=None, + help="Path to initial projection map to load and extend" + ) + parser.add_argument( + "--output-dir", + type=str, + default="cross_tokenizer_data", + help="Output directory for saving projection maps" + ) + + return parser.parse_args() + + +if __name__ == "__main__": + # Parse command line arguments + args = parse_arguments() + + # Configuration from arguments + ENABLE_SCALE_TRICK = args.enable_scale_trick + ENABLE_REVERSE_PASS = args.enable_reverse_pass + ENABLE_EXACT_MATCH = args.enable_exact_match + + TOKENS_TO_CUT = args.tokens_to_cut + TOP_K = args.top_k + USE_RAW_TOKENS = args.use_raw_tokens + INITIAL_PROJECTION_PATH = args.initial_projection_path + ENABLE_SPECIAL_TOKEN_MAPPING = args.enable_special_token_mapping + SPECIAL_TOKEN_SIMILARITY_THRESHOLD = args.special_token_similarity_threshold + SPECIAL_TOKEN_TOP_K = args.special_token_top_k if args.special_token_top_k is not None else TOP_K + USE_CANONICALIZATION = args.use_canonicalization + + # Model names from arguments + teacher_model_name = args.teacher_model + student_model_name = args.student_model + + # Print configuration + print("=== Configuration ===") + print(f"Student model: {student_model_name}") + print(f"Teacher model: {teacher_model_name}") + print(f"Enable scale trick: {ENABLE_SCALE_TRICK}") + print(f"Enable reverse pass: {ENABLE_REVERSE_PASS}") + print(f"Enable exact match: {ENABLE_EXACT_MATCH}") + print(f"Use raw tokens: {USE_RAW_TOKENS}") + print(f"Use canonicalization: {USE_CANONICALIZATION}") + print(f"Tokens to cut: {TOKENS_TO_CUT}") + print(f"Top K: {TOP_K}") + print(f"Enable special token mapping: {ENABLE_SPECIAL_TOKEN_MAPPING}") + if ENABLE_SPECIAL_TOKEN_MAPPING: + print(f"Special token similarity threshold: {SPECIAL_TOKEN_SIMILARITY_THRESHOLD}") + print(f"Special token top K: {SPECIAL_TOKEN_TOP_K}") + print(f"Initial projection path: {INITIAL_PROJECTION_PATH}") + print(f"Output directory: {args.output_dir}") + print("=" * 25) + + tokenizer_student = AutoTokenizer.from_pretrained(student_model_name) + tokenizer_teacher = AutoTokenizer.from_pretrained(teacher_model_name) + + tokenizer_student_total_vocab_size = len(tokenizer_student) + tokenizer_teacher_total_vocab_size = len(tokenizer_teacher) + model_A_config = AutoConfig.from_pretrained(student_model_name) + model_B_config = AutoConfig.from_pretrained(teacher_model_name) + if "gemma" not in student_model_name.lower(): + source_vocab_size = model_A_config.vocab_size + else: + source_vocab_size = model_A_config.text_config.vocab_size + + if "gemma" not in teacher_model_name.lower(): + target_vocab_size = model_B_config.vocab_size + else: + target_vocab_size = model_B_config.text_config.vocab_size + + tokenizer_student_total_vocab_size = source_vocab_size + tokenizer_teacher_total_vocab_size = target_vocab_size + # print(f"Source top k tokens: {model_A_top_k_tokens}") + # print(f"Target top k tokens: {model_B_top_k_tokens}") + + print(f"Student tokenizer total vocab size: {tokenizer_student_total_vocab_size}") + print(f"Teacher tokenizer total vocab size: {tokenizer_teacher_total_vocab_size}") + + # Print token processing mode + if USE_RAW_TOKENS: + print("Using raw token representation (convert_ids_to_tokens)") + else: + print("Using decoded token representation (decode)") + + + transformation_counts = defaultdict(float) + import os + if INITIAL_PROJECTION_PATH and os.path.exists(INITIAL_PROJECTION_PATH): + print(f"Loading initial projection from: {INITIAL_PROJECTION_PATH}") + initial_projection_map = torch.load(INITIAL_PROJECTION_PATH, map_location='cpu') + + if isinstance(initial_projection_map, dict) and 'indices' in initial_projection_map and 'likelihoods' in initial_projection_map: + print("Loading from sparse top-k format and converting to transformation_counts.") + indices = initial_projection_map['indices'] + likelihoods = initial_projection_map['likelihoods'] + + # Accept the new `student_model_id`/`teacher_model_id` keys and fall back to the + # legacy `model_A_id`/`model_B_id` keys produced by older generator runs. + loaded_student_model = initial_projection_map.get( + "student_model_id", initial_projection_map.get("model_A_id") + ) + loaded_teacher_model = initial_projection_map.get( + "teacher_model_id", initial_projection_map.get("model_B_id") + ) + + if loaded_student_model and loaded_student_model != student_model_name: + print(f"Warning: Student model mismatch. Loaded: {loaded_student_model}, Current: {student_model_name}") + if loaded_teacher_model and loaded_teacher_model != teacher_model_name: + print(f"Warning: Teacher model mismatch. Loaded: {loaded_teacher_model}, Current: {teacher_model_name}") + + num_student_tokens = indices.shape[0] + top_k = indices.shape[1] + + for student_id in tqdm.tqdm(range(num_student_tokens), desc="Converting initial projection to counts"): + for k in range(top_k): + teacher_id = indices[student_id, k].item() + if teacher_id != -1: + likelihood = likelihoods[student_id, k].item() + if likelihood > 0: + transformation_counts[(student_id, teacher_id)] = likelihood + + elif torch.is_tensor(initial_projection_map): + if initial_projection_map.is_sparse: + print("Loading from sparse tensor and converting to transformation_counts.") + sparse_matrix = initial_projection_map.coalesce() + map_indices = sparse_matrix.indices() + map_values = sparse_matrix.values() + for i in tqdm.tqdm(range(map_indices.shape[1]), desc="Converting sparse tensor to counts"): + student_id = map_indices[0, i].item() + teacher_id = map_indices[1, i].item() + weight = map_values[i].item() + if weight > 0: + transformation_counts[(student_id, teacher_id)] = weight + else: + print("Loading from dense matrix and converting to transformation_counts.") + dense_matrix = initial_projection_map + non_zero_indices = torch.nonzero(dense_matrix, as_tuple=False) + for idx in tqdm.tqdm(range(non_zero_indices.shape[0]), desc="Converting dense projection to counts"): + student_id = non_zero_indices[idx, 0].item() + teacher_id = non_zero_indices[idx, 1].item() + weight = dense_matrix[student_id, teacher_id].item() + if weight > 0: + transformation_counts[(student_id, teacher_id)] = weight + else: + print(f"Warning: Unrecognized format for initial projection map at {INITIAL_PROJECTION_PATH}. Skipping.") + + print(f"Initialized transformation_counts with {len(transformation_counts)} entries.") + + + ignore_tokens = ["<|endoftext|>", ""] + ignore_student_ids = { + tokenizer_student.convert_tokens_to_ids(token) + for token in ignore_tokens + if token in tokenizer_student.get_vocab() + } + ignore_teacher_ids = { + tokenizer_teacher.convert_tokens_to_ids(token) + for token in ignore_tokens + if token in tokenizer_teacher.get_vocab() + } + + # First pass: student tokens -> teacher tokens. + print("\n=== FIRST PASS: Student tokens -> Teacher tokens ===") + _, student_to_teacher_examples = add_multitoken_mappings( + source_tokenizer=tokenizer_student, + target_tokenizer=tokenizer_teacher, + source_total_vocab_size=tokenizer_student_total_vocab_size, + source_ignore_ids=ignore_student_ids, + target_ignore_ids=ignore_teacher_ids, + source_role="student", + transformation_counts=transformation_counts, + tokens_to_cut=TOKENS_TO_CUT, + use_raw_tokens=USE_RAW_TOKENS, + use_canonicalization=USE_CANONICALIZATION, + ) + + # Second pass: teacher tokens -> student tokens (opposite direction). + teacher_to_student_examples: list[dict] = [] + if ENABLE_REVERSE_PASS: + print("\n=== SECOND PASS: Teacher tokens -> Student tokens ===") + _, teacher_to_student_examples = add_multitoken_mappings( + source_tokenizer=tokenizer_teacher, + target_tokenizer=tokenizer_student, + source_total_vocab_size=tokenizer_teacher_total_vocab_size, + source_ignore_ids=ignore_teacher_ids, + target_ignore_ids=ignore_student_ids, + source_role="teacher", + transformation_counts=transformation_counts, + tokens_to_cut=TOKENS_TO_CUT, + use_raw_tokens=USE_RAW_TOKENS, + use_canonicalization=USE_CANONICALIZATION, + ) + + print("\n=== ADDING SPECIAL TOKEN MAPPINGS ===") + + # Find and add special token mappings (if enabled) + special_token_mappings = [] + if ENABLE_SPECIAL_TOKEN_MAPPING: + special_token_mappings = find_similar_special_tokens( + tokenizer_student, + tokenizer_teacher, + similarity_threshold=SPECIAL_TOKEN_SIMILARITY_THRESHOLD, + top_k_matches=SPECIAL_TOKEN_TOP_K + ) + else: + print("Special token mapping disabled") + + if special_token_mappings: + print(f"\nFound {len(special_token_mappings)} special token mappings:") + initial_transformation_count = len(transformation_counts) + + # Add ALL mappings to transformation matrix + for mapping in special_token_mappings: + student_id = mapping['student_id'] + teacher_id = mapping['teacher_id'] + similarity = mapping['similarity'] + + # Add mapping with weight based on similarity + weight = similarity * 0.8 # Scale similarity to reasonable weight + transformation_counts[(student_id, teacher_id)] += weight + + # Group mappings by student token and show top 2 matches per student token + from collections import defaultdict + student_mappings = defaultdict(list) + for mapping in special_token_mappings: + student_mappings[mapping['student_id']].append(mapping) + + # Sort each student's mappings by similarity and show top 2 + print("Top 2 matches per student special token:") + shown_count = 0 + for student_id, mappings in student_mappings.items(): + # Sort by similarity (highest first) + sorted_mappings = sorted(mappings, key=lambda x: x['similarity'], reverse=True) + + # Show top 2 for this student token + student_token = sorted_mappings[0]['student_token'] # Get student token name + print(f" {student_token}:") + + for mapping in sorted_mappings[:2]: + similarity = mapping['similarity'] + weight = similarity * 0.8 + print(f" -> '{mapping['teacher_token']}' (similarity: {similarity:.3f}, weight: {weight:.3f})") + shown_count += 1 + + if len(sorted_mappings) > 2: + print(f" ... and {len(sorted_mappings) - 2} more matches") + + total_hidden = len(special_token_mappings) - shown_count + if total_hidden > 0: + print(f"Total mappings not shown: {total_hidden}") + + added_count = len(transformation_counts) - initial_transformation_count + print(f"Added {added_count} new special token transformation entries") + else: + print("No similar special tokens found") + + print_projection_statistics( + transformation_counts=transformation_counts, + student_to_teacher_examples=student_to_teacher_examples, + teacher_to_student_examples=teacher_to_student_examples, + enable_reverse_pass=ENABLE_REVERSE_PASS, + ) + + if ENABLE_EXACT_MATCH: + + print("Checking for exact token matches and setting exact mappings...") + # check exact match between student and teacher tokens and set those as perfect 1-to-1 mappings + # Convert all tokens to strings at once for vectorized comparison + tokens_student = [apply_canonicalization_if_enabled(tokenizer_student.convert_ids_to_tokens([i])[0], USE_CANONICALIZATION) for i in range(tokenizer_student_total_vocab_size)] + tokens_teacher = [apply_canonicalization_if_enabled(tokenizer_teacher.convert_ids_to_tokens([j])[0], USE_CANONICALIZATION) for j in range(tokenizer_teacher_total_vocab_size)] + + map_teacher_token_to_idx = {token: j for j, token in enumerate(tokens_teacher)} + + # Find indices in student and teacher where the tokens are identical + match_indices_student = [] + match_indices_teacher = [] + for i, token_student in enumerate(tokens_student): + if token_student in map_teacher_token_to_idx: + j = map_teacher_token_to_idx[token_student] + match_indices_student.append(i) + match_indices_teacher.append(j) + + if match_indices_student: + print(f"Found {len(match_indices_student)} exact matches. Setting perfect 1-to-1 mappings.") + + # For tokens that match exactly, we want their mapping to be 1.0 + # and they should not be mapped to any other token. + # First, remove all existing mappings for these student tokens + match_indices_student_set = set(match_indices_student) + keys_to_remove = [] + for key in transformation_counts.keys(): + student_id, teacher_id = key + if student_id in match_indices_student_set: + keys_to_remove.append(key) + for key in keys_to_remove: + del transformation_counts[key] + # Then, set the perfect 1-to-1 mappings for exact matches + for student_id, teacher_id in zip(match_indices_student, match_indices_teacher): + transformation_counts[(student_id, teacher_id)] = 1.0 + + + # Create transformation matrix (student -> teacher projection) + indices = list(transformation_counts.keys()) + values = list(transformation_counts.values()) + + teacher_indices = [idx[1] for idx in indices] + student_indices = [idx[0] for idx in indices] + + # Create sparse tensor with student tokens as rows, teacher tokens as columns + # This creates a student -> teacher projection matrix + indices_tensor = torch.LongTensor([student_indices, teacher_indices]) + values_tensor = torch.FloatTensor(values) + + transformation_matrix_sparse = torch.sparse_coo_tensor( + indices_tensor, + values_tensor, + (tokenizer_student_total_vocab_size, tokenizer_teacher_total_vocab_size), + device="cuda" if torch.cuda.is_available() else "cpu", + dtype=torch.bfloat16 + ) + + # indices, values = torch.topk(transformation_matrix_sparse, k=1000, dim=1) + + print(f"Created sparse student->teacher projection matrix with shape: {transformation_matrix_sparse.shape}") + print(f"Non-zero elements: {transformation_matrix_sparse._nnz()}") + + # Convert sparse matrix to same format as minimal_projection_generator.py + os.makedirs(args.output_dir, exist_ok=True) + + # Convert defaultdict to regular dict for saving + transformation_counts_dict = dict(transformation_counts) + + # Show some examples of the projection mappings + if args.num_examples > 0: + print_projection_map_examples( + transformation_counts_dict, + tokenizer_student, + tokenizer_teacher, + direction="student->teacher", + num_examples=args.num_examples, + use_raw_tokens=USE_RAW_TOKENS, + use_canonicalization=USE_CANONICALIZATION, + ) + + print(f"\nConverting sparse matrix to top-{TOP_K} dense format...") + + # Convert sparse matrix to dense and get top-k values per row + print("Converting to dense matrix on CPU to avoid memory issues...") + dense_matrix = transformation_matrix_sparse.cpu().to_dense() # Move to CPU to handle memory + print(f"Dense matrix shape: {dense_matrix.shape}") + + # Get top-k values and indices for each row (each source token) + print(f"Extracting top-{TOP_K} values per token...") + + # Apply sinkhorn normalization on CPU + print("Applying Sinkhorn normalization on CPU...") + dense_matrix = sinkhorn_one_dim(dense_matrix, n_iters=1) + + # Extract top-k on CPU + top_k_likelihoods, top_k_indices = torch.topk(dense_matrix, k=min(TOP_K, dense_matrix.shape[1]), dim=1) + # exit() + # Handle case where vocabulary has fewer tokens than TOP_K + actual_k = top_k_indices.shape[1] + if actual_k < TOP_K: + print(f"Warning: Target vocabulary size ({dense_matrix.shape[1]}) is smaller than TOP_K ({TOP_K}). Using k={actual_k}") + # Pad with -1 indices and 0.0 likelihoods to maintain consistent shape + pad_size = TOP_K - actual_k + top_k_indices = torch.cat([top_k_indices, torch.full((top_k_indices.shape[0], pad_size), -1, dtype=top_k_indices.dtype)], dim=1) + top_k_likelihoods = torch.cat([top_k_likelihoods, torch.zeros((top_k_likelihoods.shape[0], pad_size), dtype=top_k_likelihoods.dtype)], dim=1) + + # Apply SCALE_TRICK: set last column to -4 if enabled + if ENABLE_SCALE_TRICK: + print("ENABLE_SCALE_TRICK is True: Setting last column of likelihoods to -4.0") + top_k_likelihoods[:, -1] = 0.2 + if ENABLE_EXACT_MATCH: + for indices in match_indices_student: + top_k_likelihoods[indices, -1] = 0.0 + print(f"Set last column of likelihoods to 0.0 for {len(match_indices_student)} exact matches as exact match is enabled") + # Apply sinkhorn normalization on CPU + print("Applying Sinkhorn normalization on CPU...") + top_k_likelihoods = sinkhorn_one_dim(top_k_likelihoods, n_iters=1) + + #set indices to -1 where likelihood is 0 + + # Create filename in same format as minimal_projection_generator.py + student_clean_name = clean_model_name_for_filename(student_model_name.split("/")[-1]) + teacher_clean_name = clean_model_name_for_filename(teacher_model_name.split("/")[-1]) + + output_filename = f"projection_map_{student_clean_name}_to_{teacher_clean_name}_multitoken_top_{TOP_K}_double" + # if USE_RAW_TOKENS: + # output_filename += "_raw_tokens" + if ENABLE_SPECIAL_TOKEN_MAPPING: + output_filename += "_special" + output_filename += ".pt" + # if ENABLE_REVERSE_PASS: + # output_filename = output_filename.replace(".pt", "_bidirectional.pt") + output_path = os.path.join(args.output_dir, output_filename) + + # Save in same format as minimal_projection_generator.py. + # `enable_scale_trick` is persisted so downstream tools (e.g. + # `sort_and_cut_projection_matrix.py`) can decide whether the last + # column carries a tunable scale slot without the user re-specifying. + torch.save({ + "indices": top_k_indices, + "likelihoods": top_k_likelihoods, + "student_model_id": student_model_name, + "teacher_model_id": teacher_model_name, + "enable_scale_trick": ENABLE_SCALE_TRICK, + }, output_path) + + print(f"Saved projection map to: {output_path}") + print(f"Format: indices shape {top_k_indices.shape}, likelihoods shape {top_k_likelihoods.shape}") + print(f"Compatible with minimal_projection_generator.py format") + print(f"Token processing mode: {'Raw tokens (convert_ids_to_tokens)' if USE_RAW_TOKENS else 'Decoded tokens (decode)'}") + if ENABLE_REVERSE_PASS: + print("File includes bidirectional mappings (teacher->student and student->teacher)") + if ENABLE_SPECIAL_TOKEN_MAPPING: + print(f"File includes special token mappings (similarity_threshold={SPECIAL_TOKEN_SIMILARITY_THRESHOLD}, top_k={SPECIAL_TOKEN_TOP_K})") + # exit() + + + # Test projection function compatibility (same as minimal_projection_generator.py) + print("\n--- Testing projection function compatibility ---") + + def project_token_likelihoods(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device): + """Projects token likelihoods from a source to a target vocabulary using a sparse map.""" + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + if source_vocab_size != projection_map_indices.shape[0]: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches projection map size ({projection_map_indices.shape[0]})") + + top_k = projection_map_indices.shape[1] + input_likelihoods = input_likelihoods.to(device) + projection_map_indices = projection_map_indices.to(device) + projection_map_values = projection_map_values.to(device) + + crow_indices = torch.arange(0, (source_vocab_size + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = projection_map_indices.flatten() + values = projection_map_values.flatten() + + sparse_projection_matrix = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(source_vocab_size, target_vocab_size), device=device + ) + + reshaped_input = input_likelihoods.reshape(batch_size * seq_len, source_vocab_size) + projected_likelihoods_reshaped = torch.matmul(reshaped_input, sparse_projection_matrix) + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, target_vocab_size) + + # Create a dummy likelihood tensor: [BATCH, SEQ, source_vocab_size] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dummy_tensor = torch.randn(1, 4096, tokenizer_student_total_vocab_size, device=device, dtype=torch.bfloat16) + + # Transform this tensor using the projection map + projected_tensor = project_token_likelihoods( + dummy_tensor, + top_k_indices.to(device), + top_k_likelihoods.to(device), + tokenizer_teacher_total_vocab_size, + device + ) + print(f"Input tensor shape: {dummy_tensor.shape}") + print(f"Projected tensor shape: {projected_tensor.shape}") + print("Projection test successful - format is fully compatible!") diff --git a/nemo_rl/utils/x_token/reapply_exact_map.py b/nemo_rl/utils/x_token/reapply_exact_map.py new file mode 100644 index 0000000000..bd48103d3a --- /dev/null +++ b/nemo_rl/utils/x_token/reapply_exact_map.py @@ -0,0 +1,228 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os + +import torch +from transformers import AutoConfig, AutoTokenizer + +from nemo_rl.algorithms.x_token.tokenalign import TokenAligner + +def apply_canonicalization_if_enabled(token_str, use_canonicalization): + """Apply canonicalization to token string if enabled.""" + if use_canonicalization: + return TokenAligner._canonical_token(token_str) + return token_str + +def parse_arguments(): + """Parse command line arguments for the multi-token projection script.""" + parser = argparse.ArgumentParser( + description="Generate multi-token projection mappings between tokenizers", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + # Model selection arguments + parser.add_argument( + "--student-model", + type=str, + default="meta-llama/Llama-3.2-1B", + help="Student model name or path" + ) + parser.add_argument( + "--teacher-model", + type=str, + default="microsoft/phi-4", + help="Teacher model name or path" + ) + + # Boolean flags + parser.add_argument( + "--enable-scale-trick", + action="store_true", + default=True, + help="Enable scale trick (set last column likelihood to 0.2)" + ) + parser.add_argument( + "--disable-scale-trick", + action="store_false", + dest="enable_scale_trick", + help="Disable scale trick" + ) + parser.add_argument( + "--enable-reverse-pass", + action="store_true", + default=True, + help="Enable second pass: student tokens -> teacher tokens" + ) + parser.add_argument( + "--disable-reverse-pass", + action="store_false", + dest="enable_reverse_pass", + help="Disable reverse pass" + ) + parser.add_argument( + "--enable-exact-match", + action="store_true", + default=False, + help="Enable exact match enforcement for identical tokens" + ) + parser.add_argument( + "--use-raw-tokens", + action="store_true", + default=False, + help="Use convert_ids_to_tokens instead of decode, should be False" + ) + parser.add_argument( + "--enable-special-token-mapping", + action="store_true", + default=True, + help="Enable mapping of similar special tokens" + ) + parser.add_argument( + "--disable-special-token-mapping", + action="store_false", + dest="enable_special_token_mapping", + help="Disable special token mapping" + ) + parser.add_argument( + "--use-canonicalization", + action="store_true", + default=False, + help="Apply token canonicalization before processing to normalize different tokenizer representations (e.g., ฤ  vs โ– prefixes, ฤŠ vs \\n)" + ) + + # Numeric parameters + parser.add_argument( + "--tokens-to-cut", + type=int, + default=4, + help="Maximum number of tokens to consider for multi-token mappings" + ) + parser.add_argument( + "--top-k", + type=int, + default=32, + help="Number of top projections to keep for each token" + ) + parser.add_argument( + "--special-token-similarity-threshold", + type=float, + default=0.3, + help="Minimum similarity threshold for special token matching" + ) + parser.add_argument( + "--special-token-top-k", + type=int, + default=None, + help="Top K matches for each special token (defaults to --top-k value)" + ) + + # File paths + parser.add_argument( + "--initial-projection-path", + type=str, + required=True, + help="Path to initial projection map to load and extend", + ) + parser.add_argument( + "--output-dir", + type=str, + default="cross_tokenizer_data", + help="Output directory for saving projection maps" + ) + + return parser.parse_args() + + +if __name__ == "__main__": + # Parse command line arguments + args = parse_arguments() + # Model names from arguments + teacher_model_name = args.teacher_model + student_model_name = args.student_model + USE_CANONICALIZATION = args.use_canonicalization + + + tokenizer_student = AutoTokenizer.from_pretrained(student_model_name) + tokenizer_teacher = AutoTokenizer.from_pretrained(teacher_model_name) + + tokenizer_student_total_vocab_size = len(tokenizer_student) + tokenizer_teacher_total_vocab_size = len(tokenizer_teacher) + model_A_config = AutoConfig.from_pretrained(student_model_name) + model_B_config = AutoConfig.from_pretrained(teacher_model_name) + + tokens_student = [apply_canonicalization_if_enabled(tokenizer_student.convert_ids_to_tokens([i])[0], USE_CANONICALIZATION) for i in range(tokenizer_student_total_vocab_size)] + tokens_teacher = [apply_canonicalization_if_enabled(tokenizer_teacher.convert_ids_to_tokens([j])[0], USE_CANONICALIZATION) for j in range(tokenizer_teacher_total_vocab_size)] + + map_teacher_token_to_idx = {token: j for j, token in enumerate(tokens_teacher)} + + # Find indices in student and teacher where the tokens are identical + match_indices_student = [] + match_indices_teacher = [] + for i, token_student in enumerate(tokens_student): + if token_student in map_teacher_token_to_idx: + j = map_teacher_token_to_idx[token_student] + match_indices_student.append(i) + match_indices_teacher.append(j) + + if match_indices_student: + print(f"Found {len(match_indices_student)} exact matches. Setting perfect 1-to-1 mappings.") + + # load initial projection map + initial_projection_path = args.initial_projection_path + initial_projection_map = torch.load(initial_projection_path) + + # go through token in projection map. For each token present in match_indices_student, set it's likelihoods and incices to 1.0 and the exact match teacher token + non_exact_map_tokens = list(range(len(initial_projection_map["likelihoods"]))) + all_student_token_ids = list(range(len(initial_projection_map["likelihoods"]))) + + show_remapping = 5 + if show_remapping > 0: + print(f"Showing remapping for the last {show_remapping} exact matches.") + else: + print(f"Not showing remapping.") + + for i, exact_token_student in enumerate(match_indices_student): + exact_token_teacher = match_indices_teacher[i] + + index_ = all_student_token_ids.index(exact_token_student) + likelihoods = initial_projection_map["likelihoods"][index_] + indices = initial_projection_map["indices"][index_] + + if len(match_indices_student) - i <= show_remapping: + print(f"prior to remapping: likelihoods {likelihoods} indices {indices}") + + topk = indices.shape[0] + + remapped_indices = torch.ones_like(indices) * -1 + remapped_likelihoods = torch.zeros_like(likelihoods) + + remapped_likelihoods[0] = 1.0 + remapped_indices[0] = exact_token_teacher + + + initial_projection_map["likelihoods"][index_] = remapped_likelihoods + initial_projection_map["indices"][index_] = remapped_indices + + + if len(match_indices_student) - i <= show_remapping: + print(f'after remapping {tokens_student[exact_token_student]}:{exact_token_student} -> {tokens_teacher[exact_token_teacher]}:{exact_token_teacher}: likelihoods {initial_projection_map["likelihoods"][index_]} indices {initial_projection_map["indices"][index_]}') + non_exact_map_tokens.remove(index_) + + + base, ext = os.path.splitext(args.initial_projection_path) + save_path = base + "_exact_map_remapped" + (ext or ".pt") + torch.save(initial_projection_map, save_path) + print(f"Saved remapped projection map to: {save_path}") + print(f"remapped {len(match_indices_student)} tokens. Retained remaining {len(non_exact_map_tokens)} tokens as is.") diff --git a/nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py b/nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py new file mode 100644 index 0000000000..b925732a18 --- /dev/null +++ b/nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py @@ -0,0 +1,462 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os + +import torch +import tqdm + +def sinkhorn_one_dim(A, n_iters=1): + """Apply Sinkhorn normalization to make each row sum to 1.""" + for _ in range(n_iters): + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + return A + +def sort_and_cut_projection_matrix(input_path, output_path, new_top_k, preserve_last=False, verbose=True): + """ + Load a projection matrix, sort each row by weight values, and save with new top_k cutoff. + + Args: + input_path: Path to input projection matrix file + output_path: Path to save the new projection matrix + new_top_k: New top_k value for cutoff + preserve_last: If True, always preserve the last column as the final element + verbose: Whether to print progress information + """ + if verbose: + print(f"Loading projection matrix from: {input_path}") + + # Load the projection matrix + projection_data = torch.load(input_path, map_location='cpu', weights_only=False) + + if not isinstance(projection_data, dict) or 'indices' not in projection_data or 'likelihoods' not in projection_data: + raise ValueError("Input file must contain a dictionary with 'indices' and 'likelihoods' keys") + + original_indices = projection_data['indices'] # Shape: [vocab_size, original_top_k] + original_likelihoods = projection_data['likelihoods'] # Shape: [vocab_size, original_top_k] + + vocab_size, original_top_k = original_indices.shape + + if verbose: + print(f"Original matrix shape: {original_indices.shape}") + print(f"Original top_k: {original_top_k}") + print(f"New top_k: {new_top_k}") + print(f"Preserve last column: {preserve_last}") + + if new_top_k > original_top_k: + print(f"Warning: New top_k ({new_top_k}) is larger than original top_k ({original_top_k})") + print(f"Will pad with -1 indices and 0.0 likelihoods") + effective_top_k = original_top_k + else: + effective_top_k = new_top_k + + # Initialize new tensors + new_indices = torch.full((vocab_size, new_top_k), -1, dtype=original_indices.dtype) + new_likelihoods = torch.zeros((vocab_size, new_top_k), dtype=original_likelihoods.dtype) + + # Statistics tracking + rows_with_order_change = 0 + significant_components_count = [0] * min(new_top_k, 10) # Track up to 10 components + threshold_for_significance = 0.2 # Threshold for considering a component "significant" + # Track position of maximum element in original ordering + max_element_positions = {} # position -> count + # Track preserve_last statistics + rows_with_preserved_last = 0 + # Track specifically when max element is in the last column + rows_with_max_in_last_column = 0 + # Track position of maximum element in final sorted and trimmed matrix + final_max_element_positions = {} # position -> count + + # threshold_for_significance = 0.05 # Threshold for considering a component "significant" + # threshold_for_significance = 0.05 # Threshold for considering a component "significant" + + if verbose: + print("Sorting and cutting each row...") + + # Process each row (each source token) + last_element_trick_count = 0 + for row_idx in tqdm.tqdm(range(vocab_size), desc="Processing rows", disable=not verbose): + row_indices = original_indices[row_idx] # [original_top_k] + row_likelihoods = original_likelihoods[row_idx] # [original_top_k] + + # Filter out invalid indices (-1) and zero likelihoods + valid_mask = (row_indices != -1) & (row_likelihoods > 0) + + if valid_mask.any(): + valid_indices = row_indices[valid_mask] + valid_likelihoods = row_likelihoods[valid_mask] + + # Track position of maximum element in original ordering + max_pos = torch.argmax(valid_likelihoods).item() + if max_pos not in max_element_positions: + max_element_positions[max_pos] = 0 + max_element_positions[max_pos] += 1 + + # Check if max element is specifically in the last column + # Find the actual maximum value in the original row (including invalid entries) + original_max_pos = torch.argmax(row_likelihoods).item() + if original_max_pos == original_top_k - 1: + # Only count if the last position actually has valid data + last_index = row_indices[original_top_k - 1] + last_likelihood = row_likelihoods[original_top_k - 1] + if last_index != -1 and last_likelihood > 0: + rows_with_max_in_last_column += 1 + + if preserve_last and new_top_k >= 1: + # Handle preserve_last case + last_index = original_indices[row_idx, original_top_k - 1] + last_likelihood = original_likelihoods[row_idx, original_top_k - 1] + + if new_top_k == 1: + # Special case: only keep the last element + if last_index != -1 and last_likelihood > 0: + new_indices[row_idx, 0] = last_index + new_likelihoods[row_idx, 0] = last_likelihood + rows_with_preserved_last += 1 + + # Count significant components + if last_likelihood >= threshold_for_significance: + significant_components_count[0] += 1 + else: + # General case: sort first (original_top_k-1) elements, then add last element + elements_to_sort = min(len(valid_likelihoods), original_top_k - 1) + if elements_to_sort > 0: + # Get elements excluding the last position in original matrix + sort_mask = torch.arange(len(valid_likelihoods)) < elements_to_sort + if sort_mask.any(): + sortable_indices = valid_indices[sort_mask] + sortable_likelihoods = valid_likelihoods[sort_mask] + + # Sort the non-last elements + sorted_likelihoods, sort_order = torch.sort(sortable_likelihoods, descending=True) + sorted_indices = sortable_indices[sort_order] + + # Check if order changed in the sortable portion + original_order = torch.arange(len(sortable_likelihoods)) + if not torch.equal(sort_order, original_order): + rows_with_order_change += 1 + + # Take top (new_top_k - 1) elements from sorted portion + num_from_sorted = min(len(sorted_indices), new_top_k - 1) + + new_indices[row_idx, :num_from_sorted] = sorted_indices[:num_from_sorted] + new_likelihoods[row_idx, :num_from_sorted] = sorted_likelihoods[:num_from_sorted] + + # Count significant components from sorted portion + for comp_idx in range(min(num_from_sorted, len(significant_components_count) - 1)): + if sorted_likelihoods[comp_idx] >= threshold_for_significance: + significant_components_count[comp_idx] += 1 + + # Always put the last element at the end (if valid) + + if last_index != -1 and last_likelihood > 0: + last_element_trick_count += 1 + new_indices[row_idx, new_top_k - 1] = last_index + new_likelihoods[row_idx, new_top_k - 1] = last_likelihood + rows_with_preserved_last += 1 + + # Count significant component for the preserved last element + if new_top_k - 1 < len(significant_components_count): + if last_likelihood >= threshold_for_significance: + significant_components_count[new_top_k - 1] += 1 + + else: + # Original logic: sort all elements normally + # Check if order changed by comparing original vs sorted order + original_order = torch.arange(len(valid_likelihoods)) + sorted_likelihoods, sort_order = torch.sort(valid_likelihoods, descending=True) + + # Check if the order changed (not just sorted, but actually different) + if not torch.equal(sort_order, original_order): + rows_with_order_change += 1 + + sorted_indices = valid_indices[sort_order] + + # Take top effective_top_k elements + num_to_take = min(len(sorted_indices), effective_top_k) + + new_indices[row_idx, :num_to_take] = sorted_indices[:num_to_take] + new_likelihoods[row_idx, :num_to_take] = sorted_likelihoods[:num_to_take] + # Count significant components (components above threshold) + for comp_idx in range(min(num_to_take, len(significant_components_count))): + if sorted_likelihoods[comp_idx] >= threshold_for_significance: + significant_components_count[comp_idx] += 1 + # if significant_components_count[1] > 0.0: + + # If new_top_k > original_top_k, the tensors are already padded with -1 and 0.0 + + # Apply Sinkhorn normalization to the final matrix + if verbose: + print(f"last element trick count: {last_element_trick_count}") + print("Applying Sinkhorn normalization...") + + # Apply normalization only to non-zero values to preserve sparsity structure + normalized_likelihoods = sinkhorn_one_dim(new_likelihoods.clone(), n_iters=1) + + # Calculate final maximum element position statistics after sorting and normalization + if verbose: + print("Calculating final maximum element position statistics...") + + for row_idx in range(vocab_size): + row_likelihoods = normalized_likelihoods[row_idx] + # Filter out zero likelihoods + valid_mask = row_likelihoods > 0 + if valid_mask.any(): + valid_likelihoods = row_likelihoods[valid_mask] + # Find position of maximum element in the final matrix + max_pos_in_valid = torch.argmax(valid_likelihoods).item() + # Convert back to original position in the row + valid_positions = torch.nonzero(valid_mask).squeeze(-1) + actual_max_pos = valid_positions[max_pos_in_valid].item() + + if actual_max_pos not in final_max_element_positions: + final_max_element_positions[actual_max_pos] = 0 + final_max_element_positions[actual_max_pos] += 1 + + # Create output dictionary with same format as input + output_data = { + 'indices': new_indices, + 'likelihoods': normalized_likelihoods, + } + + # Copy over any additional metadata + for key in projection_data: + if key not in ['indices', 'likelihoods']: + output_data[key] = projection_data[key] + + # Save the new projection matrix + torch.save(output_data, output_path) + + if verbose: + print(f"Saved sorted and cut projection matrix to: {output_path}") + print(f"New matrix shape: {new_indices.shape}") + + # Show basic statistics + non_zero_counts = (new_likelihoods > 0).sum(dim=1) + avg_non_zero = non_zero_counts.float().mean().item() + print(f"Average non-zero entries per row: {avg_non_zero:.2f}") + print(f"Rows with max entries ({new_top_k}): {(non_zero_counts == new_top_k).sum().item()}") + + # Show ordering statistics + print(f"\n=== Ordering Statistics ===") + print(f"Rows with changed order after sorting: {rows_with_order_change:,} / {vocab_size:,} ({100*rows_with_order_change/vocab_size:.1f}%)") + if preserve_last: + print(f"Rows with preserved last element: {rows_with_preserved_last:,} / {vocab_size:,} ({100*rows_with_preserved_last/vocab_size:.1f}%)") + + # Show last column maximum element statistics + print(f"\n=== Last Column Maximum Element Statistics ===") + total_rows_with_data = sum(max_element_positions.values()) + if total_rows_with_data > 0: + percentage_last_max = 100 * rows_with_max_in_last_column / total_rows_with_data + print(f"Rows with maximum element in LAST column: {rows_with_max_in_last_column:,} / {total_rows_with_data:,} ({percentage_last_max:.1f}%)") + print(f"Rows with maximum element in NON-LAST columns: {total_rows_with_data - rows_with_max_in_last_column:,} / {total_rows_with_data:,} ({100 - percentage_last_max:.1f}%)") + else: + print(f"No valid data found to analyze last column statistics") + + # Show maximum element position distribution + print(f"\n=== Maximum Element Position Distribution (Original Ordering) ===") + total_rows_with_data = sum(max_element_positions.values()) + print(f"Total rows with valid data: {total_rows_with_data:,}") + + # Sort positions for ordered display + sorted_positions = sorted(max_element_positions.keys()) + for pos in sorted_positions[:20]: # Show up to first 20 positions + count = max_element_positions[pos] + percentage = 100 * count / total_rows_with_data if total_rows_with_data > 0 else 0 + ordinal = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"][pos] if pos < 10 else f"{pos+1}th" + print(f"Rows with max element in {ordinal} position: {count:,} / {total_rows_with_data:,} ({percentage:.1f}%)") + + if len(sorted_positions) > 20: + remaining_count = sum(max_element_positions[pos] for pos in sorted_positions[20:]) + remaining_percentage = 100 * remaining_count / total_rows_with_data if total_rows_with_data > 0 else 0 + print(f"Rows with max element in positions 21+: {remaining_count:,} / {total_rows_with_data:,} ({remaining_percentage:.1f}%)") + + # Show final maximum element position distribution (after sorting and normalization) + print(f"\n=== Maximum Element Position Distribution (Final Sorted & Normalized Matrix) ===") + total_final_rows_with_data = sum(final_max_element_positions.values()) + print(f"Total rows with valid data: {total_final_rows_with_data:,}") + + if total_final_rows_with_data > 0: + # Sort positions for ordered display + sorted_final_positions = sorted(final_max_element_positions.keys()) + for pos in sorted_final_positions[:min(new_top_k, 20)]: # Show up to new_top_k or 20 positions + count = final_max_element_positions[pos] + percentage = 100 * count / total_final_rows_with_data + ordinal = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"][pos] if pos < 10 else f"{pos+1}th" + print(f"Rows with max element in {ordinal} position: {count:,} / {total_final_rows_with_data:,} ({percentage:.1f}%)") + + if len(sorted_final_positions) > min(new_top_k, 20): + remaining_count = sum(final_max_element_positions[pos] for pos in sorted_final_positions[min(new_top_k, 20):]) + remaining_percentage = 100 * remaining_count / total_final_rows_with_data + print(f"Rows with max element in positions {min(new_top_k, 20)+1}+: {remaining_count:,} / {total_final_rows_with_data:,} ({remaining_percentage:.1f}%)") + + # Show significant components statistics + print(f"\n=== Significant Components Statistics (threshold >= {threshold_for_significance}) ===") + component_names = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"] + for i, count in enumerate(significant_components_count): + percentage = 100 * count / vocab_size if vocab_size > 0 else 0 + print(f"Rows with significant {component_names[i]} component: {count:,} / {vocab_size:,} ({percentage:.1f}%)") + + # Additional analysis: distribution of likelihood values (after normalization) + all_likelihoods = normalized_likelihoods[normalized_likelihoods > 0] + if len(all_likelihoods) > 0: + print(f"\n=== Likelihood Distribution ===") + print(f"Total non-zero likelihoods: {len(all_likelihoods):,}") + print(f"Mean likelihood: {all_likelihoods.mean().item():.4f}") + print(f"Median likelihood: {all_likelihoods.median().item():.4f}") + print(f"Min likelihood: {all_likelihoods.min().item():.4f}") + print(f"Max likelihood: {all_likelihoods.max().item():.4f}") + + # Show percentiles - convert to float for quantile calculation + percentiles = [90, 95, 99] + all_likelihoods_float = all_likelihoods.float() + for p in percentiles: + val = torch.quantile(all_likelihoods_float, p/100.0).item() + print(f"{p}th percentile: {val:.4f}") + + # Show how many rows have multiple significant components + print(f"\n=== Multi-Component Analysis ===") + rows_with_multiple_significant = 0 + for row_idx in range(vocab_size): + significant_in_row = (normalized_likelihoods[row_idx] >= threshold_for_significance).sum().item() + if significant_in_row >= 2: + rows_with_multiple_significant += 1 + + percentage_multi = 100 * rows_with_multiple_significant / vocab_size if vocab_size > 0 else 0 + print(f"Rows with 2+ significant components: {rows_with_multiple_significant:,} / {vocab_size:,} ({percentage_multi:.1f}%)") + + # Show normalization effect + print(f"\n=== Normalization Effect ===") + # Calculate row sums for ALL rows (including zero rows) + all_row_sums = normalized_likelihoods.sum(dim=1) + non_zero_rows = (normalized_likelihoods > 0).any(dim=1) + zero_rows = ~non_zero_rows + + print(f"Total rows: {vocab_size:,}") + print(f"Rows with non-zero entries: {non_zero_rows.sum().item():,}") + print(f"Rows with all zeros: {zero_rows.sum().item():,}") + + if non_zero_rows.any(): + row_sums_nonzero = all_row_sums[non_zero_rows] + print(f"\nNon-zero rows statistics:") + print(f" Mean sum: {row_sums_nonzero.mean().item():.6f}") + print(f" Std sum: {row_sums_nonzero.std().item():.6f}") + print(f" Min sum: {row_sums_nonzero.min().item():.6f}") + print(f" Max sum: {row_sums_nonzero.max().item():.6f}") + + # Check how many rows don't sum to 1 (with different tolerance levels) + tolerances = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2] + for tol in tolerances: + perfect_rows = (torch.abs(row_sums_nonzero - 1.0) < tol).sum().item() + imperfect_rows = len(row_sums_nonzero) - perfect_rows + percentage_imperfect = 100 * imperfect_rows / len(row_sums_nonzero) + print(f" Rows NOT summing to 1.0 (tol={tol}): {imperfect_rows:,}/{len(row_sums_nonzero):,} ({percentage_imperfect:.2f}%)") + + # Show distribution of row sums that deviate from 1.0 + if non_zero_rows.any(): + row_sums_nonzero = all_row_sums[non_zero_rows] + deviations = torch.abs(row_sums_nonzero - 1.0) + significant_deviations = deviations > 1e-3 + + if significant_deviations.any(): + print(f"\nRows with significant deviations from 1.0 (>0.001): {significant_deviations.sum().item():,}") + worst_deviations = deviations[significant_deviations] + print(f" Mean deviation: {worst_deviations.mean().item():.6f}") + print(f" Max deviation: {worst_deviations.max().item():.6f}") + + # Show some examples of problematic rows + worst_indices = torch.topk(deviations, k=min(5, len(deviations)))[1] + print(f" Worst {min(5, len(worst_indices))} row examples:") + for i, idx in enumerate(worst_indices): + actual_row_idx = torch.nonzero(non_zero_rows)[idx].item() + sum_val = row_sums_nonzero[idx].item() + deviation = deviations[idx].item() + non_zero_count = (normalized_likelihoods[actual_row_idx] > 0).sum().item() + print(f" Row {actual_row_idx}: sum={sum_val:.6f}, deviation={deviation:.6f}, non_zeros={non_zero_count}") + else: + print(f"\nAll non-zero rows sum very close to 1.0 (deviation < 0.001)") + +def main(): + parser = argparse.ArgumentParser(description="Sort and cut projection matrix by top_k") + parser.add_argument("input_path", help="Path to input projection matrix file") + parser.add_argument("--top_k", type=int, required=True, help="New top_k value for cutoff") + parser.add_argument("--output_path", help="Output path (auto-generated if not specified)") + parser.add_argument( + "--preserve_last", + action=argparse.BooleanOptionalAction, + default=None, + help=( + "Force-enable or force-disable preserving the last column. " + "If unspecified, the value is read from the input projection map's " + "`enable_scale_trick` metadata (and defaults to False if absent)." + ), + ) + parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output") + args = parser.parse_args() + + # Resolve preserve_last from CLI override -> projection-map metadata -> default False. + if args.preserve_last is None: + try: + meta = torch.load(args.input_path, map_location="cpu", weights_only=False) + except (FileNotFoundError, RuntimeError): + meta = {} + preserve_last = bool(meta.get("enable_scale_trick", False)) if isinstance(meta, dict) else False + if preserve_last and not args.quiet: + print( + "Auto-enabling --preserve_last because projection map was generated with " + "enable_scale_trick=True." + ) + else: + preserve_last = args.preserve_last + + # Auto-generate output path if not specified + if args.output_path is None: + input_dir = os.path.dirname(args.input_path) + input_filename = os.path.basename(args.input_path) + + # Extract base name and extension + base_name, ext = os.path.splitext(input_filename) + + # Remove old top_k info if present + import re + base_name = re.sub(r"_top_\d+", "", base_name) + + # Add new top_k info and preserve_last flag + suffix = "_sorted" + if preserve_last: + suffix += "_preservelast" + output_filename = f"{base_name}_top_{args.top_k}{suffix}{ext}" + args.output_path = os.path.join(input_dir, output_filename) + + # Ensure output directory exists (skip when output_path has no directory component) + output_dir = os.path.dirname(args.output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + # Process the matrix + sort_and_cut_projection_matrix( + args.input_path, + args.output_path, + args.top_k, + preserve_last=preserve_last, + verbose=not args.quiet, + ) + if not args.quiet: + print(f"Output written to: {args.output_path}") + +if __name__ == "__main__": + main() From 441fb80cf48e4d224a694b01918b3665930b74f8 Mon Sep 17 00:00:00 2001 From: Adithya Hanasoge Date: Thu, 14 May 2026 11:37:02 -0700 Subject: [PATCH 5/6] docs(xtoken): add cross-tokenizer distillation tutorial Adds docs/guides/xtoken-distillation.md walking through the four-step offline projection-prep pipeline (generate -> add multi-token -> reapply exact map -> sort and trim) followed by the distillation training run via examples/run_xtoken_distillation.py. Covers backend constraints (DTensor V2 only, CUDA-IPC teacher logits, arrow_text corpus), the three loss modes (P-KL, gold, gold+xtoken), and worked examples for Qwen, Gemma, and Phi-3/Phi-4 teacher pairs. Registered in docs/index.md under Guides. Also nulls out the hardcoded /lustre user-specific projection_matrix_path in examples/configs/xtoken_distillation.yaml and replaces it with an inline comment pointing at the new tutorial; the path is now expected to be supplied via Hydra CLI per run, matching the data.train.arrow_files pattern already used in the same file. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Adithya Hanasoge --- docs/guides/xtoken-distillation.md | 218 ++++++++++++++++++++++ docs/index.md | 8 + examples/configs/xtoken_distillation.yaml | 6 +- 3 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 docs/guides/xtoken-distillation.md diff --git a/docs/guides/xtoken-distillation.md b/docs/guides/xtoken-distillation.md new file mode 100644 index 0000000000..fc03cc7725 --- /dev/null +++ b/docs/guides/xtoken-distillation.md @@ -0,0 +1,218 @@ +# Cross-Tokenizer (X-Token) Off-Policy Distillation + +NeMo RL supports off-policy distillation between a student and a teacher that +**do not share a tokenizer** โ€” for example, distilling a Qwen3-4B teacher into +a Llama-3.2-1B student. Cross-tokenizer ("x-token") distillation handles the +vocabulary mismatch by routing teacher logits through a precomputed +**projection matrix** that maps each student token to the teacher tokens it +most plausibly corresponds to. + +This guide explains how to: + +1. Produce the projection matrix from a (student, teacher) tokenizer pair +2. Launch a distillation run that consumes it + +## How it works + +A full run has two phases. The first three steps are *offline data prep* โ€” +small CLI tools you run once per (student, teacher) pair โ€” and the result is a +single `.pt` file. The fourth step is the actual distillation training loop. + +``` + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + โ”‚ Offline projection-matrix preparation โ”‚ + โ”‚ โ”‚ + (student, teacher) โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ + tokenizers + base โ”€โ”€โ”€โ–ถโ”‚ โ”‚ 1. minimal_projection_generator.py โ”‚ โ”‚ + embedding model โ”‚ โ”‚ โ€” embedding-similarity top-k โ”‚ โ”‚ + โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ + โ”‚ โ”‚ โ”‚ + โ”‚ โ–ผ โ”‚ + โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ + โ”‚ โ”‚ 2. minimal_projection_via_ โ”‚ โ”‚ + โ”‚ โ”‚ multitoken.py โ”‚ โ”‚ + โ”‚ โ”‚ โ€” add multi-token mappings โ”‚ โ”‚ + โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ + โ”‚ โ”‚ โ”‚ + โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ–ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ + โ”‚ โ”‚ 3. (optional) reapply_exact_map.py โ”‚ โ”‚ + โ”‚ โ”‚ โ€” pin exact 1-to-1 matches โ”‚ โ”‚ + โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ + โ”‚ โ”‚ โ”‚ + โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ–ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ + โ”‚ โ”‚ 4. sort_and_cut_projection_matrix โ”‚ โ”‚ + โ”‚ โ”‚ .py โ€” trim to runtime top_k โ”‚ โ”‚ + โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”‚โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ + โ–ผ projection_matrix.pt + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + โ”‚ 5. examples/run_xtoken_distillation.py โ”‚ + โ”‚ โ€” student forward + teacher forward โ”‚ + โ”‚ (via CUDA-IPC), x-token KD loss โ”‚ + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +``` + +The projection matrix is a sparse `[V_student, top_k]` tensor that the +training-time loss multiplies against the student logits to project them into +the teacher's vocab space (or vice versa, depending on the loss mode). + +## Backend and scope + +- **DTensor V2 only.** Set `policy.dtensor_cfg.enabled=true` and + `policy.dtensor_cfg._v2=true`. The Megatron path is intentionally stubbed + with `NotImplementedError`. +- **Teacher logits travel via CUDA IPC**, so student and teacher policies must + be colocated on the same node. No remote-Ray transport for x-token logits. +- **No sequence packing or dynamic batching for the teacher forward** in v0. +- The corpus must be served via the `arrow_text` dataset (no chat template, + loss on every token โ€” see `examples/configs/xtoken_distillation.yaml`). + +## Step 1 โ€” Generate the base projection matrix + +`minimal_projection_generator.py` walks both vocabularies, embeds every token +with a small embedding LLM (or a sentence-transformers model), and stores the +top-`k` teacher tokens by cosine similarity for each student token. + +```bash +uv run python -m nemo_rl.utils.x_token.minimal_projection_generator \ + --student-model "meta-llama/Llama-3.2-1B" \ + --teacher-model "Qwen/Qwen3-4B" \ + --top_k 32 \ + --force_recompute \ + --data_dir cross_tokenizer_data/ +``` + +Both `--student-model` and `--teacher-model` are required and **not swapped** +โ€” the projection direction follows the CLI args exactly. Output lands at +`cross_tokenizer_data/temp_projection_map_Llama-3.2_to_Qwen3_top_32.pt`. + +If you pick an `embedding_model_type == "sbert"` choice from +`EMBEDDING_MODEL_CHOICES`, install `sentence-transformers` first; the script +falls back to a clear `ImportError` otherwise. The default +`embedding_model_index = 3` uses `Qwen/Qwen3-Embedding-4B` and does not need +`sentence-transformers`. + +## Step 2 โ€” Add multi-token mappings + +Many student tokens (e.g., `"12"`) tokenize into multiple teacher tokens +(e.g., `"1"`, `"2"`). `minimal_projection_via_multitoken.py` walks the +student vocab, re-tokenizes each token with the teacher tokenizer, and adds +weighted entries to the projection. With `--enable-reverse-pass` it also +does the symmetric teacher โ†’ student walk. + +```bash +uv run python -m nemo_rl.utils.x_token.minimal_projection_via_multitoken \ + --student-model "meta-llama/Llama-3.2-1B" \ + --teacher-model "Qwen/Qwen3-4B" \ + --initial-projection-path cross_tokenizer_data/temp_projection_map_Llama-3.2_to_Qwen3_top_32.pt \ + --top-k 32 \ + --enable-scale-trick \ + --enable-reverse-pass \ + --enable-special-token-mapping +``` + +Output: `cross_tokenizer_data/projection_map_Llama-3.2_to_Qwen3_multitoken_top_32_double_special.pt`. + +Pass `--num-examples 50` to print a sample of studentโ†’teacher mappings after +the matrix is built โ€” useful for spot-checking that special tokens, numerals, +and punctuation map to sensible teacher tokens. + +When `--enable-scale-trick` is set, the script records `enable_scale_trick=True` +in the saved `.pt` so Step 4 can auto-enable `--preserve_last`. + +## Step 3 (optional) โ€” Reapply exact-token map + +Some token pairs are *literally identical* (e.g., common punctuation, single +ASCII characters). `reapply_exact_map.py` pins those to 1-to-1 mappings with +weight 1.0, overwriting whatever Steps 1โ€“2 produced for them. + +```bash +uv run python -m nemo_rl.utils.x_token.reapply_exact_map \ + --student-model "meta-llama/Llama-3.2-1B" \ + --teacher-model "Qwen/Qwen3-4B" \ + --initial-projection-path cross_tokenizer_data/projection_map_Llama-3.2_to_Qwen3_multitoken_top_32_double_special.pt +``` + +Output is written next to the input as `_exact_map_remapped.pt`. + +## Step 4 โ€” Sort and trim to runtime `top_k` + +The training loss only needs a small `top_k` per row (typical: 4โ€“8). This +step sorts each row by weight and trims to the chosen runtime cap. + +```bash +uv run python -m nemo_rl.utils.x_token.sort_and_cut_projection_matrix \ + cross_tokenizer_data/projection_map_Llama-3.2_to_Qwen3_multitoken_top_32_double_special_exact_map_remapped.pt \ + --top_k 4 \ + --output_path cross_tokenizer_data/projection_matrix_llama_qwen_top4.pt +``` + +`--preserve_last` is `argparse.BooleanOptionalAction` with default `None`. When +unspecified, the script reads `enable_scale_trick` from the input matrix's +metadata (set in Step 2) and auto-enables preservation of the last column +slot. Pass `--preserve_last` or `--no-preserve_last` to override. + +## Step 5 โ€” Launch x-token distillation + +The training entrypoint is `examples/run_xtoken_distillation.py` with the +exemplar config at `examples/configs/xtoken_distillation.yaml`. The exemplar +defaults to Llama-3.2-1B (student) โ† Qwen3-4B (teacher), an arrow-text +corpus, and the P-KL loss mode. Override paths via Hydra CLI: + +```bash +uv run python examples/run_xtoken_distillation.py \ + --config examples/configs/xtoken_distillation.yaml \ + loss_fn.projection_matrix_path=cross_tokenizer_data/projection_matrix_llama_qwen_top4.pt \ + data.train.arrow_files=/path/to/corpus/*.arrow \ + cluster.gpus_per_node=8 \ + cluster.num_nodes=1 +``` + +The exemplar config keeps `loss_fn.projection_matrix_path` and +`data.train.arrow_files` as `null` so they must be supplied at the CLI โ€” this +makes the config reusable across (student, teacher) pairs. + +### Loss-mode knobs + +`loss_fn` has two flags that pick between three behaviors: + +| `gold_loss` | `xtoken_loss` | Behavior | +|---|---|---| +| `false` | (inert) | **P-KL** โ€” full-vocab teacher logits via CUDA IPC; the loss derives a microbatch-global top-k inside, projects the student into teacher vocab via the projection matrix, and chunk-averages KL on the top-k subset. CE term is added. | +| `true` | `false` | **Gold loss** (PT-faithful) โ€” split the vocab into an *exact-token-mapped* common set (KL) and an *uncommon* tail (sorted L1). | +| `true` | `true` | **Gold + x-token loss** โ€” same as gold, but relax the exact-map threshold to `>= 0.6` and allow multi-token projections to count as exact maps via a collision-replacement rule. | + +Other relevant fields: + +- `loss_fn.temperature` โ€” softmax temperature applied symmetrically to student and teacher logits before KL. +- `loss_fn.vocab_topk` โ€” microbatch-global top-k size for the P-KL path (inert when `gold_loss=true`). +- `loss_fn.uncommon_topk` โ€” cap on the L1 uncommon-tail sort in the gold path (defaults to PT's hardcoded 8192). +- `loss_fn.reverse_kl` โ€” compute `KL(student || teacher)` instead of `KL(teacher || student)`. + +## Other (student, teacher) pairs + +The same pipeline works for any HuggingFace tokenizer pair. Two worked +examples โ€” Llama โ†’ Gemma and Llama โ†’ Phi โ€” only differ in the +`--student-model` / `--teacher-model` arguments to Steps 1 and 2. + +For Phi-3 / Phi-4 family teachers, also export +`NRL_TRUST_REMOTE_CODE=false` and `NRL_SKIP_PHI_ROPE_FIX=1` in the +training environment so the in-tree HuggingFace implementation is used. + +## Where files live + +| Stage | Tool | Default output | +|---|---|---| +| Generate base | `nemo_rl/utils/x_token/minimal_projection_generator.py` | `/temp_projection_map__to__top_.pt` | +| Add multi-token | `nemo_rl/utils/x_token/minimal_projection_via_multitoken.py` | `/projection_map__to__multitoken_top__double[_special].pt` | +| Reapply exact map | `nemo_rl/utils/x_token/reapply_exact_map.py` | `_exact_map_remapped.pt` | +| Sort and trim | `nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py` | `/_top__sorted[_preservelast].pt` (or `--output_path`) | +| Train | `examples/run_xtoken_distillation.py` | per the run's `logger.log_dir` and `checkpointing.checkpoint_dir` | + +## Related + +- Config exemplar: [`examples/configs/xtoken_distillation.yaml`](../../examples/configs/xtoken_distillation.yaml) +- Loss implementation: `nemo_rl/algorithms/loss/loss_functions.py::CrossTokenizerDistillationLossFn` +- Token alignment: `nemo_rl/algorithms/x_token/tokenalign.py::TokenAligner` +- Same-tokenizer distillation: [Quantization-Aware RL](quantization-aware-rl.md) (the QA-Distillation workflow uses the same training entrypoint with a same-tokenizer teacher). diff --git a/docs/index.md b/docs/index.md index e20b1745f5..d07b4ec50e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -135,6 +135,13 @@ Learn how to add support for new model architectures in NeMo RL. Extend a model's context window with YaRN RoPE scaling on the Megatron backend for SFT, GRPO, and other workflows. ::: +:::{grid-item-card} {octicon}`git-compare` Cross-Tokenizer Distillation +:link: guides/xtoken-distillation +:link-type: doc + +Off-policy distillation across mismatched tokenizers โ€” build a (student, teacher) projection matrix and run x-token KD via CUDA-IPC teacher logits. +::: + :::: ## Advanced Topics @@ -251,6 +258,7 @@ guides/async-grpo.md guides/quantization-aware-rl.md guides/eagle3-speculative-decoding.md guides/yarn-long-context.md +guides/xtoken-distillation.md guides/muon-optimizer.md guides/dtensor-tp-accuracy.md guides/ft-launcher-guide.md diff --git a/examples/configs/xtoken_distillation.yaml b/examples/configs/xtoken_distillation.yaml index ea043202d2..24869f9a38 100644 --- a/examples/configs/xtoken_distillation.yaml +++ b/examples/configs/xtoken_distillation.yaml @@ -14,7 +14,11 @@ distillation: val_at_end: false loss_fn: - projection_matrix_path: "/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_genai/users/avenkateshha/nemo_rl/RL/cross_tokenizer_data/llama_qwen_best_special_exact_map_remapped.pt" + # Path to the (student, teacher) projection matrix .pt file. Override via + # Hydra CLI per run, e.g. + # loss_fn.projection_matrix_path=/path/to/projection_matrix.pt + # See docs/guides/xtoken-distillation.md for how to produce this file. + projection_matrix_path: null # Loss-mode selection: # gold_loss=false -> P-KL: full-vocab teacher logits # transported via IPC; the loss derives a microbatch-global top-k From 755fb8e4afc0319472201bd84c0030495a66c51b Mon Sep 17 00:00:00 2001 From: Adithya Hanasoge Date: Fri, 15 May 2026 18:37:13 -0700 Subject: [PATCH 6/6] refactor(xtoken): address PR review on projection-prep CLI tools Apply the post-review changes for PR #2347 on the cross-tokenizer projection-prep utilities under nemo_rl/utils/x_token/: minimal_projection_via_multitoken.py - Re-add optional --output-filename stem (default None falls back to the auto-derived name) so recipe-driven runs can pin the filename, matching the contract used by the reference recipes. - Extend the gemma vocab-size branch to also fire for qwen3.5, which nests vocab_size under config.text_config on both student and teacher sides. - Document the .pt save-key schema near torch.save (student_model_id / teacher_model_id / enable_scale_trick) with a pointer to the legacy-key fallback in the load path. minimal_projection_generator.py - Add the same schema annotation near torch.save. reapply_exact_map.py - Validate the loaded projection map is a dict containing 'indices' and 'likelihoods'; raise ValueError with the file path on mismatch instead of surfacing a confusing KeyError mid-loop. sort_and_cut_projection_matrix.py - Lift the argparse block out of main() into a module-level parse_arguments() helper, matching the shape used by the other x_token CLI scripts. - Lift the verbose stats block out of sort_and_cut_projection_matrix() into a print_projection_statistics() helper. - Replace the positional input_path with a named --initial-projection-path flag for naming consistency with the rest of the x_token CLI tools. docs/guides/xtoken-distillation.md - Update Step 4 usage example to the new --initial-projection-path flag. Behaviour preserved: the Llama-3.2-3B / Qwen3-4B-Base recipe still produces bitwise-equal indices and likelihoods against the canonical llama_qwen_best_special_exact_map_remapped.pt artifact. Signed-off-by: Adithya Hanasoge --- docs/guides/xtoken-distillation.md | 2 +- .../x_token/minimal_projection_generator.py | 3 + .../minimal_projection_via_multitoken.py | 57 ++- nemo_rl/utils/x_token/reapply_exact_map.py | 10 + .../x_token/sort_and_cut_projection_matrix.py | 405 ++++++++++-------- 5 files changed, 289 insertions(+), 188 deletions(-) diff --git a/docs/guides/xtoken-distillation.md b/docs/guides/xtoken-distillation.md index fc03cc7725..6d100f6629 100644 --- a/docs/guides/xtoken-distillation.md +++ b/docs/guides/xtoken-distillation.md @@ -143,7 +143,7 @@ step sorts each row by weight and trims to the chosen runtime cap. ```bash uv run python -m nemo_rl.utils.x_token.sort_and_cut_projection_matrix \ - cross_tokenizer_data/projection_map_Llama-3.2_to_Qwen3_multitoken_top_32_double_special_exact_map_remapped.pt \ + --initial-projection-path cross_tokenizer_data/projection_map_Llama-3.2_to_Qwen3_multitoken_top_32_double_special_exact_map_remapped.pt \ --top_k 4 \ --output_path cross_tokenizer_data/projection_matrix_llama_qwen_top4.pt ``` diff --git a/nemo_rl/utils/x_token/minimal_projection_generator.py b/nemo_rl/utils/x_token/minimal_projection_generator.py index 877a1669c4..4ea6baf6ed 100644 --- a/nemo_rl/utils/x_token/minimal_projection_generator.py +++ b/nemo_rl/utils/x_token/minimal_projection_generator.py @@ -520,6 +520,9 @@ def generate_projection_map(similarities, args): output_filename = output_filename.replace(".pt", f"_thresh_{args.weight_threshold:.3f}.pt") output_path = os.path.join(args.data_dir, output_filename) + # Metadata keys use the student/teacher framing; legacy `model_A_id`/ + # `model_B_id` from older PT-reference artifacts are accepted on the + # load side in minimal_projection_via_multitoken.py. torch.save( { "indices": top_k_indices_student_to_teacher.cpu(), diff --git a/nemo_rl/utils/x_token/minimal_projection_via_multitoken.py b/nemo_rl/utils/x_token/minimal_projection_via_multitoken.py index 3e3c3a1c72..0af22d9155 100644 --- a/nemo_rl/utils/x_token/minimal_projection_via_multitoken.py +++ b/nemo_rl/utils/x_token/minimal_projection_via_multitoken.py @@ -480,6 +480,17 @@ def parse_arguments(): default="cross_tokenizer_data", help="Output directory for saving projection maps" ) + parser.add_argument( + "--output-filename", + type=str, + default=None, + help=( + "Optional output filename stem (without extension). When unset, " + "the stem is auto-derived from cleaned student and teacher model " + "names; recipe-driven runs (e.g. tokenalign/commands.txt) pass " + "an explicit stem like 'llama_qwen_best' to lock the filename." + ), + ) return parser.parse_args() @@ -532,15 +543,20 @@ def parse_arguments(): tokenizer_teacher_total_vocab_size = len(tokenizer_teacher) model_A_config = AutoConfig.from_pretrained(student_model_name) model_B_config = AutoConfig.from_pretrained(teacher_model_name) - if "gemma" not in student_model_name.lower(): - source_vocab_size = model_A_config.vocab_size - else: + # Gemma and Qwen3.5 nest `vocab_size` under `config.text_config`; the + # rest of the supported families expose it directly on the top-level + # config. Mirrors the PT reference. + student_name_lower = student_model_name.lower() + if "gemma" in student_name_lower or "qwen3.5" in student_name_lower: source_vocab_size = model_A_config.text_config.vocab_size - - if "gemma" not in teacher_model_name.lower(): - target_vocab_size = model_B_config.vocab_size else: + source_vocab_size = model_A_config.vocab_size + + teacher_name_lower = teacher_model_name.lower() + if "gemma" in teacher_name_lower or "qwen3.5" in teacher_name_lower: target_vocab_size = model_B_config.text_config.vocab_size + else: + target_vocab_size = model_B_config.vocab_size tokenizer_student_total_vocab_size = source_vocab_size tokenizer_teacher_total_vocab_size = target_vocab_size @@ -856,24 +872,33 @@ def parse_arguments(): #set indices to -1 where likelihood is 0 - # Create filename in same format as minimal_projection_generator.py - student_clean_name = clean_model_name_for_filename(student_model_name.split("/")[-1]) - teacher_clean_name = clean_model_name_for_filename(teacher_model_name.split("/")[-1]) - - output_filename = f"projection_map_{student_clean_name}_to_{teacher_clean_name}_multitoken_top_{TOP_K}_double" - # if USE_RAW_TOKENS: - # output_filename += "_raw_tokens" + # Build the output filename. If the caller provided an explicit + # `--output-filename` stem, honor it (recipe-driven runs lock the name); + # otherwise auto-derive from cleaned student/teacher names so ad-hoc + # runs produce a self-describing default that matches the format used by + # minimal_projection_generator.py. + if args.output_filename is not None: + output_filename = args.output_filename + else: + student_clean_name = clean_model_name_for_filename(student_model_name.split("/")[-1]) + teacher_clean_name = clean_model_name_for_filename(teacher_model_name.split("/")[-1]) + output_filename = ( + f"projection_map_{student_clean_name}_to_{teacher_clean_name}" + f"_multitoken_top_{TOP_K}_double" + ) if ENABLE_SPECIAL_TOKEN_MAPPING: output_filename += "_special" - output_filename += ".pt" - # if ENABLE_REVERSE_PASS: - # output_filename = output_filename.replace(".pt", "_bidirectional.pt") + if not output_filename.endswith(".pt"): + output_filename += ".pt" output_path = os.path.join(args.output_dir, output_filename) # Save in same format as minimal_projection_generator.py. # `enable_scale_trick` is persisted so downstream tools (e.g. # `sort_and_cut_projection_matrix.py`) can decide whether the last # column carries a tunable scale slot without the user re-specifying. + # Metadata keys use the student/teacher framing; legacy `model_A_id`/ + # `model_B_id` produced by older PT-reference artifacts are accepted on + # the load side (see the fallback near line 574). torch.save({ "indices": top_k_indices, "likelihoods": top_k_likelihoods, diff --git a/nemo_rl/utils/x_token/reapply_exact_map.py b/nemo_rl/utils/x_token/reapply_exact_map.py index bd48103d3a..310dd7d2f8 100644 --- a/nemo_rl/utils/x_token/reapply_exact_map.py +++ b/nemo_rl/utils/x_token/reapply_exact_map.py @@ -182,6 +182,16 @@ def parse_arguments(): # load initial projection map initial_projection_path = args.initial_projection_path initial_projection_map = torch.load(initial_projection_path) + if not ( + isinstance(initial_projection_map, dict) + and "indices" in initial_projection_map + and "likelihoods" in initial_projection_map + ): + raise ValueError( + f"Projection map at {initial_projection_path} is not a dict with " + f"'indices' and 'likelihoods' tensors; got " + f"{type(initial_projection_map).__name__}." + ) # go through token in projection map. For each token present in match_indices_student, set it's likelihoods and incices to 1.0 and the exact match teacher token non_exact_map_tokens = list(range(len(initial_projection_map["likelihoods"]))) diff --git a/nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py b/nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py index b925732a18..9d08579a4b 100644 --- a/nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py +++ b/nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py @@ -13,10 +13,37 @@ # limitations under the License. import argparse import os +import re import torch import tqdm + +def parse_arguments() -> argparse.Namespace: + """Parse CLI arguments for the sort-and-cut script.""" + parser = argparse.ArgumentParser(description="Sort and cut projection matrix by top_k") + parser.add_argument( + "--initial-projection-path", + type=str, + required=True, + help="Path to the input projection matrix .pt file.", + ) + parser.add_argument("--top_k", type=int, required=True, help="New top_k value for cutoff") + parser.add_argument("--output_path", type=str, default=None, help="Output path (auto-generated if not specified)") + parser.add_argument( + "--preserve_last", + action=argparse.BooleanOptionalAction, + default=None, + help=( + "Force-enable or force-disable preserving the last column. " + "If unspecified, the value is read from the input projection map's " + "`enable_scale_trick` metadata (and defaults to False if absent)." + ), + ) + parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output") + return parser.parse_args() + + def sinkhorn_one_dim(A, n_iters=1): """Apply Sinkhorn normalization to make each row sum to 1.""" for _ in range(n_iters): @@ -26,6 +53,191 @@ def sinkhorn_one_dim(A, n_iters=1): A = A / safe_row_sums return A + +def print_projection_statistics( + *, + output_path, + new_indices, + new_likelihoods, + normalized_likelihoods, + vocab_size, + new_top_k, + preserve_last, + rows_with_order_change, + rows_with_preserved_last, + rows_with_max_in_last_column, + max_element_positions, + final_max_element_positions, + significant_components_count, + threshold_for_significance, +): + """Print verbose statistics about the trimmed projection matrix. + + Args: + output_path: Path the new projection matrix was saved to. + new_indices: Trimmed [vocab_size, new_top_k] indices tensor. + new_likelihoods: Trimmed [vocab_size, new_top_k] likelihoods (pre-normalization). + normalized_likelihoods: Sinkhorn-normalized likelihoods tensor. + vocab_size: Number of rows in the projection matrix. + new_top_k: Top-k cutoff applied during trimming. + preserve_last: Whether the last column was preserved during trimming. + rows_with_order_change: Count of rows whose top-k order changed after sorting. + rows_with_preserved_last: Count of rows that received a preserved last element. + rows_with_max_in_last_column: Count of rows whose original max sat in the last column. + max_element_positions: Histogram of max-element position in the original matrix. + final_max_element_positions: Histogram of max-element position after sorting. + significant_components_count: Per-component counts above `threshold_for_significance`. + threshold_for_significance: Likelihood threshold for the significance histogram. + """ + print(f"Saved sorted and cut projection matrix to: {output_path}") + print(f"New matrix shape: {new_indices.shape}") + + # Show basic statistics + non_zero_counts = (new_likelihoods > 0).sum(dim=1) + avg_non_zero = non_zero_counts.float().mean().item() + print(f"Average non-zero entries per row: {avg_non_zero:.2f}") + print(f"Rows with max entries ({new_top_k}): {(non_zero_counts == new_top_k).sum().item()}") + + # Show ordering statistics + print(f"\n=== Ordering Statistics ===") + print(f"Rows with changed order after sorting: {rows_with_order_change:,} / {vocab_size:,} ({100*rows_with_order_change/vocab_size:.1f}%)") + if preserve_last: + print(f"Rows with preserved last element: {rows_with_preserved_last:,} / {vocab_size:,} ({100*rows_with_preserved_last/vocab_size:.1f}%)") + + # Show last column maximum element statistics + print(f"\n=== Last Column Maximum Element Statistics ===") + total_rows_with_data = sum(max_element_positions.values()) + if total_rows_with_data > 0: + percentage_last_max = 100 * rows_with_max_in_last_column / total_rows_with_data + print(f"Rows with maximum element in LAST column: {rows_with_max_in_last_column:,} / {total_rows_with_data:,} ({percentage_last_max:.1f}%)") + print(f"Rows with maximum element in NON-LAST columns: {total_rows_with_data - rows_with_max_in_last_column:,} / {total_rows_with_data:,} ({100 - percentage_last_max:.1f}%)") + else: + print(f"No valid data found to analyze last column statistics") + + # Show maximum element position distribution + print(f"\n=== Maximum Element Position Distribution (Original Ordering) ===") + total_rows_with_data = sum(max_element_positions.values()) + print(f"Total rows with valid data: {total_rows_with_data:,}") + + # Sort positions for ordered display + sorted_positions = sorted(max_element_positions.keys()) + for pos in sorted_positions[:20]: # Show up to first 20 positions + count = max_element_positions[pos] + percentage = 100 * count / total_rows_with_data if total_rows_with_data > 0 else 0 + ordinal = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"][pos] if pos < 10 else f"{pos+1}th" + print(f"Rows with max element in {ordinal} position: {count:,} / {total_rows_with_data:,} ({percentage:.1f}%)") + + if len(sorted_positions) > 20: + remaining_count = sum(max_element_positions[pos] for pos in sorted_positions[20:]) + remaining_percentage = 100 * remaining_count / total_rows_with_data if total_rows_with_data > 0 else 0 + print(f"Rows with max element in positions 21+: {remaining_count:,} / {total_rows_with_data:,} ({remaining_percentage:.1f}%)") + + # Show final maximum element position distribution (after sorting and normalization) + print(f"\n=== Maximum Element Position Distribution (Final Sorted & Normalized Matrix) ===") + total_final_rows_with_data = sum(final_max_element_positions.values()) + print(f"Total rows with valid data: {total_final_rows_with_data:,}") + + if total_final_rows_with_data > 0: + # Sort positions for ordered display + sorted_final_positions = sorted(final_max_element_positions.keys()) + for pos in sorted_final_positions[:min(new_top_k, 20)]: # Show up to new_top_k or 20 positions + count = final_max_element_positions[pos] + percentage = 100 * count / total_final_rows_with_data + ordinal = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"][pos] if pos < 10 else f"{pos+1}th" + print(f"Rows with max element in {ordinal} position: {count:,} / {total_final_rows_with_data:,} ({percentage:.1f}%)") + + if len(sorted_final_positions) > min(new_top_k, 20): + remaining_count = sum(final_max_element_positions[pos] for pos in sorted_final_positions[min(new_top_k, 20):]) + remaining_percentage = 100 * remaining_count / total_final_rows_with_data + print(f"Rows with max element in positions {min(new_top_k, 20)+1}+: {remaining_count:,} / {total_final_rows_with_data:,} ({remaining_percentage:.1f}%)") + + # Show significant components statistics + print(f"\n=== Significant Components Statistics (threshold >= {threshold_for_significance}) ===") + component_names = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"] + for i, count in enumerate(significant_components_count): + percentage = 100 * count / vocab_size if vocab_size > 0 else 0 + print(f"Rows with significant {component_names[i]} component: {count:,} / {vocab_size:,} ({percentage:.1f}%)") + + # Additional analysis: distribution of likelihood values (after normalization) + all_likelihoods = normalized_likelihoods[normalized_likelihoods > 0] + if len(all_likelihoods) > 0: + print(f"\n=== Likelihood Distribution ===") + print(f"Total non-zero likelihoods: {len(all_likelihoods):,}") + print(f"Mean likelihood: {all_likelihoods.mean().item():.4f}") + print(f"Median likelihood: {all_likelihoods.median().item():.4f}") + print(f"Min likelihood: {all_likelihoods.min().item():.4f}") + print(f"Max likelihood: {all_likelihoods.max().item():.4f}") + + # Show percentiles - convert to float for quantile calculation + percentiles = [90, 95, 99] + all_likelihoods_float = all_likelihoods.float() + for p in percentiles: + val = torch.quantile(all_likelihoods_float, p/100.0).item() + print(f"{p}th percentile: {val:.4f}") + + # Show how many rows have multiple significant components + print(f"\n=== Multi-Component Analysis ===") + rows_with_multiple_significant = 0 + for row_idx in range(vocab_size): + significant_in_row = (normalized_likelihoods[row_idx] >= threshold_for_significance).sum().item() + if significant_in_row >= 2: + rows_with_multiple_significant += 1 + + percentage_multi = 100 * rows_with_multiple_significant / vocab_size if vocab_size > 0 else 0 + print(f"Rows with 2+ significant components: {rows_with_multiple_significant:,} / {vocab_size:,} ({percentage_multi:.1f}%)") + + # Show normalization effect + print(f"\n=== Normalization Effect ===") + # Calculate row sums for ALL rows (including zero rows) + all_row_sums = normalized_likelihoods.sum(dim=1) + non_zero_rows = (normalized_likelihoods > 0).any(dim=1) + zero_rows = ~non_zero_rows + + print(f"Total rows: {vocab_size:,}") + print(f"Rows with non-zero entries: {non_zero_rows.sum().item():,}") + print(f"Rows with all zeros: {zero_rows.sum().item():,}") + + if non_zero_rows.any(): + row_sums_nonzero = all_row_sums[non_zero_rows] + print(f"\nNon-zero rows statistics:") + print(f" Mean sum: {row_sums_nonzero.mean().item():.6f}") + print(f" Std sum: {row_sums_nonzero.std().item():.6f}") + print(f" Min sum: {row_sums_nonzero.min().item():.6f}") + print(f" Max sum: {row_sums_nonzero.max().item():.6f}") + + # Check how many rows don't sum to 1 (with different tolerance levels) + tolerances = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2] + for tol in tolerances: + perfect_rows = (torch.abs(row_sums_nonzero - 1.0) < tol).sum().item() + imperfect_rows = len(row_sums_nonzero) - perfect_rows + percentage_imperfect = 100 * imperfect_rows / len(row_sums_nonzero) + print(f" Rows NOT summing to 1.0 (tol={tol}): {imperfect_rows:,}/{len(row_sums_nonzero):,} ({percentage_imperfect:.2f}%)") + + # Show distribution of row sums that deviate from 1.0 + if non_zero_rows.any(): + row_sums_nonzero = all_row_sums[non_zero_rows] + deviations = torch.abs(row_sums_nonzero - 1.0) + significant_deviations = deviations > 1e-3 + + if significant_deviations.any(): + print(f"\nRows with significant deviations from 1.0 (>0.001): {significant_deviations.sum().item():,}") + worst_deviations = deviations[significant_deviations] + print(f" Mean deviation: {worst_deviations.mean().item():.6f}") + print(f" Max deviation: {worst_deviations.max().item():.6f}") + + # Show some examples of problematic rows + worst_indices = torch.topk(deviations, k=min(5, len(deviations)))[1] + print(f" Worst {min(5, len(worst_indices))} row examples:") + for i, idx in enumerate(worst_indices): + actual_row_idx = torch.nonzero(non_zero_rows)[idx].item() + sum_val = row_sums_nonzero[idx].item() + deviation = deviations[idx].item() + non_zero_count = (normalized_likelihoods[actual_row_idx] > 0).sum().item() + print(f" Row {actual_row_idx}: sum={sum_val:.6f}, deviation={deviation:.6f}, non_zeros={non_zero_count}") + else: + print(f"\nAll non-zero rows sum very close to 1.0 (deviation < 0.001)") + + def sort_and_cut_projection_matrix(input_path, output_path, new_top_k, preserve_last=False, verbose=True): """ Load a projection matrix, sort each row by weight values, and save with new top_k cutoff. @@ -81,9 +293,6 @@ def sort_and_cut_projection_matrix(input_path, output_path, new_top_k, preserve_ # Track position of maximum element in final sorted and trimmed matrix final_max_element_positions = {} # position -> count - # threshold_for_significance = 0.05 # Threshold for considering a component "significant" - # threshold_for_significance = 0.05 # Threshold for considering a component "significant" - if verbose: print("Sorting and cutting each row...") @@ -195,7 +404,6 @@ def sort_and_cut_projection_matrix(input_path, output_path, new_top_k, preserve_ for comp_idx in range(min(num_to_take, len(significant_components_count))): if sorted_likelihoods[comp_idx] >= threshold_for_significance: significant_components_count[comp_idx] += 1 - # if significant_components_count[1] > 0.0: # If new_top_k > original_top_k, the tensors are already padded with -1 and 0.0 @@ -242,176 +450,31 @@ def sort_and_cut_projection_matrix(input_path, output_path, new_top_k, preserve_ torch.save(output_data, output_path) if verbose: - print(f"Saved sorted and cut projection matrix to: {output_path}") - print(f"New matrix shape: {new_indices.shape}") - - # Show basic statistics - non_zero_counts = (new_likelihoods > 0).sum(dim=1) - avg_non_zero = non_zero_counts.float().mean().item() - print(f"Average non-zero entries per row: {avg_non_zero:.2f}") - print(f"Rows with max entries ({new_top_k}): {(non_zero_counts == new_top_k).sum().item()}") - - # Show ordering statistics - print(f"\n=== Ordering Statistics ===") - print(f"Rows with changed order after sorting: {rows_with_order_change:,} / {vocab_size:,} ({100*rows_with_order_change/vocab_size:.1f}%)") - if preserve_last: - print(f"Rows with preserved last element: {rows_with_preserved_last:,} / {vocab_size:,} ({100*rows_with_preserved_last/vocab_size:.1f}%)") - - # Show last column maximum element statistics - print(f"\n=== Last Column Maximum Element Statistics ===") - total_rows_with_data = sum(max_element_positions.values()) - if total_rows_with_data > 0: - percentage_last_max = 100 * rows_with_max_in_last_column / total_rows_with_data - print(f"Rows with maximum element in LAST column: {rows_with_max_in_last_column:,} / {total_rows_with_data:,} ({percentage_last_max:.1f}%)") - print(f"Rows with maximum element in NON-LAST columns: {total_rows_with_data - rows_with_max_in_last_column:,} / {total_rows_with_data:,} ({100 - percentage_last_max:.1f}%)") - else: - print(f"No valid data found to analyze last column statistics") + print_projection_statistics( + output_path=output_path, + new_indices=new_indices, + new_likelihoods=new_likelihoods, + normalized_likelihoods=normalized_likelihoods, + vocab_size=vocab_size, + new_top_k=new_top_k, + preserve_last=preserve_last, + rows_with_order_change=rows_with_order_change, + rows_with_preserved_last=rows_with_preserved_last, + rows_with_max_in_last_column=rows_with_max_in_last_column, + max_element_positions=max_element_positions, + final_max_element_positions=final_max_element_positions, + significant_components_count=significant_components_count, + threshold_for_significance=threshold_for_significance, + ) - # Show maximum element position distribution - print(f"\n=== Maximum Element Position Distribution (Original Ordering) ===") - total_rows_with_data = sum(max_element_positions.values()) - print(f"Total rows with valid data: {total_rows_with_data:,}") - - # Sort positions for ordered display - sorted_positions = sorted(max_element_positions.keys()) - for pos in sorted_positions[:20]: # Show up to first 20 positions - count = max_element_positions[pos] - percentage = 100 * count / total_rows_with_data if total_rows_with_data > 0 else 0 - ordinal = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"][pos] if pos < 10 else f"{pos+1}th" - print(f"Rows with max element in {ordinal} position: {count:,} / {total_rows_with_data:,} ({percentage:.1f}%)") - - if len(sorted_positions) > 20: - remaining_count = sum(max_element_positions[pos] for pos in sorted_positions[20:]) - remaining_percentage = 100 * remaining_count / total_rows_with_data if total_rows_with_data > 0 else 0 - print(f"Rows with max element in positions 21+: {remaining_count:,} / {total_rows_with_data:,} ({remaining_percentage:.1f}%)") - - # Show final maximum element position distribution (after sorting and normalization) - print(f"\n=== Maximum Element Position Distribution (Final Sorted & Normalized Matrix) ===") - total_final_rows_with_data = sum(final_max_element_positions.values()) - print(f"Total rows with valid data: {total_final_rows_with_data:,}") - - if total_final_rows_with_data > 0: - # Sort positions for ordered display - sorted_final_positions = sorted(final_max_element_positions.keys()) - for pos in sorted_final_positions[:min(new_top_k, 20)]: # Show up to new_top_k or 20 positions - count = final_max_element_positions[pos] - percentage = 100 * count / total_final_rows_with_data - ordinal = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"][pos] if pos < 10 else f"{pos+1}th" - print(f"Rows with max element in {ordinal} position: {count:,} / {total_final_rows_with_data:,} ({percentage:.1f}%)") - - if len(sorted_final_positions) > min(new_top_k, 20): - remaining_count = sum(final_max_element_positions[pos] for pos in sorted_final_positions[min(new_top_k, 20):]) - remaining_percentage = 100 * remaining_count / total_final_rows_with_data - print(f"Rows with max element in positions {min(new_top_k, 20)+1}+: {remaining_count:,} / {total_final_rows_with_data:,} ({remaining_percentage:.1f}%)") - - # Show significant components statistics - print(f"\n=== Significant Components Statistics (threshold >= {threshold_for_significance}) ===") - component_names = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"] - for i, count in enumerate(significant_components_count): - percentage = 100 * count / vocab_size if vocab_size > 0 else 0 - print(f"Rows with significant {component_names[i]} component: {count:,} / {vocab_size:,} ({percentage:.1f}%)") - - # Additional analysis: distribution of likelihood values (after normalization) - all_likelihoods = normalized_likelihoods[normalized_likelihoods > 0] - if len(all_likelihoods) > 0: - print(f"\n=== Likelihood Distribution ===") - print(f"Total non-zero likelihoods: {len(all_likelihoods):,}") - print(f"Mean likelihood: {all_likelihoods.mean().item():.4f}") - print(f"Median likelihood: {all_likelihoods.median().item():.4f}") - print(f"Min likelihood: {all_likelihoods.min().item():.4f}") - print(f"Max likelihood: {all_likelihoods.max().item():.4f}") - - # Show percentiles - convert to float for quantile calculation - percentiles = [90, 95, 99] - all_likelihoods_float = all_likelihoods.float() - for p in percentiles: - val = torch.quantile(all_likelihoods_float, p/100.0).item() - print(f"{p}th percentile: {val:.4f}") - - # Show how many rows have multiple significant components - print(f"\n=== Multi-Component Analysis ===") - rows_with_multiple_significant = 0 - for row_idx in range(vocab_size): - significant_in_row = (normalized_likelihoods[row_idx] >= threshold_for_significance).sum().item() - if significant_in_row >= 2: - rows_with_multiple_significant += 1 - - percentage_multi = 100 * rows_with_multiple_significant / vocab_size if vocab_size > 0 else 0 - print(f"Rows with 2+ significant components: {rows_with_multiple_significant:,} / {vocab_size:,} ({percentage_multi:.1f}%)") - - # Show normalization effect - print(f"\n=== Normalization Effect ===") - # Calculate row sums for ALL rows (including zero rows) - all_row_sums = normalized_likelihoods.sum(dim=1) - non_zero_rows = (normalized_likelihoods > 0).any(dim=1) - zero_rows = ~non_zero_rows - - print(f"Total rows: {vocab_size:,}") - print(f"Rows with non-zero entries: {non_zero_rows.sum().item():,}") - print(f"Rows with all zeros: {zero_rows.sum().item():,}") - - if non_zero_rows.any(): - row_sums_nonzero = all_row_sums[non_zero_rows] - print(f"\nNon-zero rows statistics:") - print(f" Mean sum: {row_sums_nonzero.mean().item():.6f}") - print(f" Std sum: {row_sums_nonzero.std().item():.6f}") - print(f" Min sum: {row_sums_nonzero.min().item():.6f}") - print(f" Max sum: {row_sums_nonzero.max().item():.6f}") - - # Check how many rows don't sum to 1 (with different tolerance levels) - tolerances = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2] - for tol in tolerances: - perfect_rows = (torch.abs(row_sums_nonzero - 1.0) < tol).sum().item() - imperfect_rows = len(row_sums_nonzero) - perfect_rows - percentage_imperfect = 100 * imperfect_rows / len(row_sums_nonzero) - print(f" Rows NOT summing to 1.0 (tol={tol}): {imperfect_rows:,}/{len(row_sums_nonzero):,} ({percentage_imperfect:.2f}%)") - - # Show distribution of row sums that deviate from 1.0 - if non_zero_rows.any(): - row_sums_nonzero = all_row_sums[non_zero_rows] - deviations = torch.abs(row_sums_nonzero - 1.0) - significant_deviations = deviations > 1e-3 - - if significant_deviations.any(): - print(f"\nRows with significant deviations from 1.0 (>0.001): {significant_deviations.sum().item():,}") - worst_deviations = deviations[significant_deviations] - print(f" Mean deviation: {worst_deviations.mean().item():.6f}") - print(f" Max deviation: {worst_deviations.max().item():.6f}") - - # Show some examples of problematic rows - worst_indices = torch.topk(deviations, k=min(5, len(deviations)))[1] - print(f" Worst {min(5, len(worst_indices))} row examples:") - for i, idx in enumerate(worst_indices): - actual_row_idx = torch.nonzero(non_zero_rows)[idx].item() - sum_val = row_sums_nonzero[idx].item() - deviation = deviations[idx].item() - non_zero_count = (normalized_likelihoods[actual_row_idx] > 0).sum().item() - print(f" Row {actual_row_idx}: sum={sum_val:.6f}, deviation={deviation:.6f}, non_zeros={non_zero_count}") - else: - print(f"\nAll non-zero rows sum very close to 1.0 (deviation < 0.001)") def main(): - parser = argparse.ArgumentParser(description="Sort and cut projection matrix by top_k") - parser.add_argument("input_path", help="Path to input projection matrix file") - parser.add_argument("--top_k", type=int, required=True, help="New top_k value for cutoff") - parser.add_argument("--output_path", help="Output path (auto-generated if not specified)") - parser.add_argument( - "--preserve_last", - action=argparse.BooleanOptionalAction, - default=None, - help=( - "Force-enable or force-disable preserving the last column. " - "If unspecified, the value is read from the input projection map's " - "`enable_scale_trick` metadata (and defaults to False if absent)." - ), - ) - parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output") - args = parser.parse_args() + args = parse_arguments() # Resolve preserve_last from CLI override -> projection-map metadata -> default False. if args.preserve_last is None: try: - meta = torch.load(args.input_path, map_location="cpu", weights_only=False) + meta = torch.load(args.initial_projection_path, map_location="cpu", weights_only=False) except (FileNotFoundError, RuntimeError): meta = {} preserve_last = bool(meta.get("enable_scale_trick", False)) if isinstance(meta, dict) else False @@ -425,14 +488,13 @@ def main(): # Auto-generate output path if not specified if args.output_path is None: - input_dir = os.path.dirname(args.input_path) - input_filename = os.path.basename(args.input_path) + input_dir = os.path.dirname(args.initial_projection_path) + input_filename = os.path.basename(args.initial_projection_path) # Extract base name and extension base_name, ext = os.path.splitext(input_filename) # Remove old top_k info if present - import re base_name = re.sub(r"_top_\d+", "", base_name) # Add new top_k info and preserve_last flag @@ -449,7 +511,7 @@ def main(): # Process the matrix sort_and_cut_projection_matrix( - args.input_path, + args.initial_projection_path, args.output_path, args.top_k, preserve_last=preserve_last, @@ -458,5 +520,6 @@ def main(): if not args.quiet: print(f"Output written to: {args.output_path}") + if __name__ == "__main__": main()