feat(xtoken): cross-tokenizer off-policy distillation#2507
Closed
avenkateshha wants to merge 6 commits into
Closed
feat(xtoken): cross-tokenizer off-policy distillation#2507avenkateshha wants to merge 6 commits into
avenkateshha wants to merge 6 commits into
Conversation
Adds a single-teacher cross-tokenizer distillation pipeline running on
DTensor V2 with FSDP2. The student and teacher use different tokenizers;
a learned projection matrix maps student-vocab probabilities into the
teacher's vocab so chunk-averaged P-KL can be computed across the
alignment between the two token streams.
New entrypoints
- examples/run_xtoken_distillation.py and
examples/configs/xtoken_distillation.yaml exemplar.
- nemo_rl/algorithms/xtoken_distillation.py training loop.
New data path
- nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py:
ArrowTextDataset for raw arrow shards.
- nemo_rl/data/processors.py: kd_data_processor (raw-text path).
- nemo_rl/data/cross_tokenizer_collate.py: collator that tokenizes the
same source twice and runs TokenAligner.align inside DataLoader
workers.
New algorithm package
- nemo_rl/algorithms/x_token/{__init__.py, tokenalign.py}: TokenAligner
with projection-matrix loader and chunk alignment.
Loss
- nemo_rl/algorithms/loss/loss_functions.py: CrossTokenizerDistillationLossFn
implements chunk-mean P-KL with PT-matched dynamic loss scaling.
Sparse projection wrapped in _Fp32SparseMM (custom autograd.Function
with amp custom_fwd/custom_bwd) to bypass the missing CUDA BF16
addmm_sparse_cuda kernel on both forward and backward. _load_M
filters -1 sentinels and sizes V_t to the teacher's full vocab.
Per-step accuracy / projection-accuracy diagnostics and an
env-gated per-microbatch loss dump (NRL_XTOKEN_LOSS_DUMP_DIR) for
PT-vs-NRL parity comparison.
Worker / dispatcher plumbing
- nemo_rl/models/automodel/data.py: check_sequence_dim(skip_keys=…) so
CT auxiliary tensors that don't follow [B, student_seq, …] can ride
on the same BatchedDataDict without tripping the seq-dim guard.
- nemo_rl/models/automodel/train.py: register GlobalTopkLogitsPostProcessor
in forward_with_post_processing_fn and add
XTokenTeacherIPC{Loss,Export}PostProcessor subclasses.
- nemo_rl/models/policy/lm_policy.py + workers/dtensor_policy_worker_v2.py:
plumb skip_keys through Policy.train and the worker.
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
- xtoken_distillation runner now calls teacher_policy.get_global_topk_logits_ipc to receive teacher per-sample top-k logits via CUDA IPC handles instead of materializing them on the driver. Driver passes the handles into student_policy.train via teacher_topk_logits_ipc, and workers rebuild the tensors locally with rebuild_cuda_tensor_from_ipc. - lm_policy: new dispatchers (get_global_topk_logits_ipc and release_ipc_buffer) plus an ipc-handle-to-batch pairing assert. Fix release_ipc_buffer to use ray.get(futures) (matches every other run_all_workers_single_data caller). - dtensor_policy_worker_v2: worker-side get_global_topk_logits_ipc, IPC tensor stash, release_ipc_buffer, and matching train-side unpack of teacher_topk_logits_ipc. - Loss fn unpacks the IPC handle list lazily inside _compute_p_kl. - YAML default: topk_logits_k / vocab_topk lifted to 8192 (was 64). Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Ports the PT reference `compute_KL_loss_optimized` gold path (tokenalign.py:3494-3829) into CrossTokenizerDistillationLossFn: - _build_exact_token_map: vectorized port of PT 3493-3594. Partitions the student/teacher vocabs into common (exact-1-1 mapped) and uncommon sets. Strict mode keeps the lowest-s_idx winner per teacher id; xtoken_loss=true relaxes the exact-map threshold from `== 1.0` to `>= 0.6` and elects the highest-projection-weight student per teacher id via scatter_reduce(amax) with a min-s_idx tiebreaker. Cached per-device for the run since the map depends only on the immutable projection file + xtoken_loss flag. - _compute_gold: log_softmax + chunk-average on full teacher vocab both sides, KL on the common slice (forward or reverse), L1 on sorted uncommon probs (topk-capped at cfg["uncommon_topk"], default 8192, matching PT line 3727). Combined as (kl_common + l1_uncommon) * T**2. Top-1 accuracy on the common slice over valid chunks. - Shared helpers _rebuild_teacher_full_logits / _chunk_average_log_probs / _valid_chunk_mask reused by both _compute_p_kl and _compute_gold. - __init__ now requires gold_loss=True when xtoken_loss=True; the older "<=1 True" mutex was wrong (PT's canonical CT run sets both, and the flag is a modifier inside the gold path). Unifies the teacher-logits transport. Replaces the per-sample top-k GlobalTopkLogitsPostProcessor + get_global_topk_logits[_ipc] pair with a single FullLogitsPostProcessor + get_full_logits_ipc that ships the entire [T_t, V_t] tensor per sample via CUDA IPC. The legacy P-KL path stays correct by computing its microbatch-global top-k inline (max over flat (B*T_t) -> topk over V_t -- matches PT `global_top_indices` at tokenalign.py:3877) instead of at the worker. Closer to PT shape (full vocab in-process) and removes worker-side top-k plumbing. Handles the common HF case where lm_head pads out_features beyond len(tokenizer) (Qwen3-4B: 151936 vs 151669): both _compute_p_kl and _compute_gold slice teacher_full_logits to v_t (projection's V_t = real tokenizer vocab) before any downstream math. Padded columns are non-real tokens; dropping them matches the de-facto NRL contract. Config: adds CrossTokenizerDistillationLossConfig.uncommon_topk; drops XTokenDistillationConfig.topk_logits_k (derived locally from loss_fn.vocab_topk). examples/configs/xtoken_distillation.yaml updated. Runner: XTOKEN_NON_STUDENT_SEQ_KEYS swaps the top-k keys for teacher_full_logits_ipc. Both training and validate() switch transport call sites. Per-MB metric reduction + logging handle both metric shapes (kl_loss/ce_loss vs kl_common/l1_uncommon). _maybe_dump_loss records whichever keys are present so the same dump file format serves both paths -- preserves the NRL_XTOKEN_LOSS_DUMP_DIR parity-check protocol. Verified end-to-end with 100-step 2-node interactive smokes: - P-KL (gold_loss=false): job 11752927 COMPLETED in 11:43, train:loss=3.10, ckpt at step 100. - Gold (gold_loss=true, xtoken_loss=true): job 11754078 COMPLETED in 12:39, train:loss=45.69, KL(common)~0.47, L1(uncommon)~0.002, top-1 accuracy on common ~76.7%, ckpt at step 100. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Adds nemo_rl/utils/x_token/ — the four offline CLI tools that produce
the projection matrix consumed at training time by
CrossTokenizerDistillationLossFn:
- minimal_projection_generator.py: embedding-similarity top-k between
student and teacher vocabularies.
- minimal_projection_via_multitoken.py: augments the projection with
multi-token mappings (e.g. "12" -> "1", "2") in both directions.
- reapply_exact_map.py: pins literally-identical tokens to 1-to-1
mappings with weight 1.0.
- sort_and_cut_projection_matrix.py: trims to runtime top_k; reads
enable_scale_trick metadata to auto-enable --preserve_last.
All four scripts take --student-model and --teacher-model as required
arguments; the projection direction follows the CLI args exactly (no
alphabetical-order swap). Common helpers (add_multitoken_mappings,
print_projection_map_examples, print_projection_statistics,
clean_model_name_for_filename) live at module scope. Cross-script dict
keys renamed model_A_id -> student_model_id (legacy keys still read).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Adds docs/guides/xtoken-distillation.md walking through the four-step offline projection-prep pipeline (generate -> add multi-token -> reapply exact map -> sort and trim) followed by the distillation training run via examples/run_xtoken_distillation.py. Covers backend constraints (DTensor V2 only, CUDA-IPC teacher logits, arrow_text corpus), the three loss modes (P-KL, gold, gold+xtoken), and worked examples for Qwen, Gemma, and Phi-3/Phi-4 teacher pairs. Registered in docs/index.md under Guides. Also nulls out the hardcoded /lustre user-specific projection_matrix_path in examples/configs/xtoken_distillation.yaml and replaces it with an inline comment pointing at the new tutorial; the path is now expected to be supplied via Hydra CLI per run, matching the data.train.arrow_files pattern already used in the same file. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Apply the post-review changes for PR NVIDIA-NeMo#2347 on the cross-tokenizer projection-prep utilities under nemo_rl/utils/x_token/: minimal_projection_via_multitoken.py - Re-add optional --output-filename stem (default None falls back to the auto-derived name) so recipe-driven runs can pin the filename, matching the contract used by the reference recipes. - Extend the gemma vocab-size branch to also fire for qwen3.5, which nests vocab_size under config.text_config on both student and teacher sides. - Document the .pt save-key schema near torch.save (student_model_id / teacher_model_id / enable_scale_trick) with a pointer to the legacy-key fallback in the load path. minimal_projection_generator.py - Add the same schema annotation near torch.save. reapply_exact_map.py - Validate the loaded projection map is a dict containing 'indices' and 'likelihoods'; raise ValueError with the file path on mismatch instead of surfacing a confusing KeyError mid-loop. sort_and_cut_projection_matrix.py - Lift the argparse block out of main() into a module-level parse_arguments() helper, matching the shape used by the other x_token CLI scripts. - Lift the verbose stats block out of sort_and_cut_projection_matrix() into a print_projection_statistics() helper. - Replace the positional input_path with a named --initial-projection-path flag for naming consistency with the rest of the x_token CLI tools. docs/guides/xtoken-distillation.md - Update Step 4 usage example to the new --initial-projection-path flag. Behaviour preserved: the Llama-3.2-3B / Qwen3-4B-Base recipe still produces bitwise-equal indices and likelihoods against the canonical llama_qwen_best_special_exact_map_remapped.pt artifact. Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Add cross-tokenizer off-policy distillation to NeMo-RL.
TokenAligner— DP alignment over canonicalized tokens; projection-matrix loader.NotImplementedError.nemo_rl/utils/x_token/(minimal_projection_generator,minimal_projection_via_multitoken,reapply_exact_map,sort_and_cut_projection_matrix).docs/guides/xtoken-distillation.md.This branch is the combined form of the cross-tokenizer feature stack (independent of the per-piece review on #2347).
Test plan
llama_qwen_best_special_exact_map_remapped.ptbitwise.