Skip to content

feat(xtoken): cross-tokenizer off-policy distillation#2507

Closed
avenkateshha wants to merge 6 commits into
NVIDIA-NeMo:mainfrom
avenkateshha:avenkateshha/xtoken-off-policy-distillation-combined
Closed

feat(xtoken): cross-tokenizer off-policy distillation#2507
avenkateshha wants to merge 6 commits into
NVIDIA-NeMo:mainfrom
avenkateshha:avenkateshha/xtoken-off-policy-distillation-combined

Conversation

@avenkateshha
Copy link
Copy Markdown

Summary

Add cross-tokenizer off-policy distillation to NeMo-RL.

  • TokenAligner — DP alignment over canonicalized tokens; projection-matrix loader.
  • Loss modes: P-KL (full-vocab teacher logits via IPC + microbatch-global top-k), gold-loss (exact-map common + L1 uncommon tail), xtoken-loss (relaxed exact-map collisions).
  • DTensor V2 worker support; Megatron path stubbed with NotImplementedError.
  • Projection-prep CLI utilities under nemo_rl/utils/x_token/ (minimal_projection_generator, minimal_projection_via_multitoken, reapply_exact_map, sort_and_cut_projection_matrix).
  • Docs guide at docs/guides/xtoken-distillation.md.

This branch is the combined form of the cross-tokenizer feature stack (independent of the per-piece review on #2347).

Test plan

  • CI green
  • Projection-prep reproducibility: re-running the CLI tools on the Llama-3.2-3B ↔ Qwen3-4B-Base pair reproduces the canonical llama_qwen_best_special_exact_map_remapped.pt bitwise.
  • Smoke run: 2-node Llama-3.2-1B ← Qwen3-4B, 100-step P-KL on climb data.

avenkateshha and others added 6 commits May 11, 2026 10:44
Adds a single-teacher cross-tokenizer distillation pipeline running on
DTensor V2 with FSDP2. The student and teacher use different tokenizers;
a learned projection matrix maps student-vocab probabilities into the
teacher's vocab so chunk-averaged P-KL can be computed across the
alignment between the two token streams.

New entrypoints
- examples/run_xtoken_distillation.py and
  examples/configs/xtoken_distillation.yaml exemplar.
- nemo_rl/algorithms/xtoken_distillation.py training loop.

New data path
- nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py:
  ArrowTextDataset for raw arrow shards.
- nemo_rl/data/processors.py: kd_data_processor (raw-text path).
- nemo_rl/data/cross_tokenizer_collate.py: collator that tokenizes the
  same source twice and runs TokenAligner.align inside DataLoader
  workers.

New algorithm package
- nemo_rl/algorithms/x_token/{__init__.py, tokenalign.py}: TokenAligner
  with projection-matrix loader and chunk alignment.

Loss
- nemo_rl/algorithms/loss/loss_functions.py: CrossTokenizerDistillationLossFn
  implements chunk-mean P-KL with PT-matched dynamic loss scaling.
  Sparse projection wrapped in _Fp32SparseMM (custom autograd.Function
  with amp custom_fwd/custom_bwd) to bypass the missing CUDA BF16
  addmm_sparse_cuda kernel on both forward and backward. _load_M
  filters -1 sentinels and sizes V_t to the teacher's full vocab.
  Per-step accuracy / projection-accuracy diagnostics and an
  env-gated per-microbatch loss dump (NRL_XTOKEN_LOSS_DUMP_DIR) for
  PT-vs-NRL parity comparison.

Worker / dispatcher plumbing
- nemo_rl/models/automodel/data.py: check_sequence_dim(skip_keys=…) so
  CT auxiliary tensors that don't follow [B, student_seq, …] can ride
  on the same BatchedDataDict without tripping the seq-dim guard.
- nemo_rl/models/automodel/train.py: register GlobalTopkLogitsPostProcessor
  in forward_with_post_processing_fn and add
  XTokenTeacherIPC{Loss,Export}PostProcessor subclasses.
- nemo_rl/models/policy/lm_policy.py + workers/dtensor_policy_worker_v2.py:
  plumb skip_keys through Policy.train and the worker.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
- xtoken_distillation runner now calls teacher_policy.get_global_topk_logits_ipc
  to receive teacher per-sample top-k logits via CUDA IPC handles
  instead of materializing them on the driver. Driver passes the
  handles into student_policy.train via teacher_topk_logits_ipc, and
  workers rebuild the tensors locally with rebuild_cuda_tensor_from_ipc.
- lm_policy: new dispatchers (get_global_topk_logits_ipc and
  release_ipc_buffer) plus an ipc-handle-to-batch pairing assert. Fix
  release_ipc_buffer to use ray.get(futures) (matches every other
  run_all_workers_single_data caller).
- dtensor_policy_worker_v2: worker-side get_global_topk_logits_ipc,
  IPC tensor stash, release_ipc_buffer, and matching train-side
  unpack of teacher_topk_logits_ipc.
- Loss fn unpacks the IPC handle list lazily inside _compute_p_kl.
- YAML default: topk_logits_k / vocab_topk lifted to 8192 (was 64).

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Ports the PT reference `compute_KL_loss_optimized` gold path
(tokenalign.py:3494-3829) into CrossTokenizerDistillationLossFn:

- _build_exact_token_map: vectorized port of PT 3493-3594. Partitions
  the student/teacher vocabs into common (exact-1-1 mapped) and uncommon
  sets. Strict mode keeps the lowest-s_idx winner per teacher id;
  xtoken_loss=true relaxes the exact-map threshold from `== 1.0` to
  `>= 0.6` and elects the highest-projection-weight student per teacher
  id via scatter_reduce(amax) with a min-s_idx tiebreaker. Cached
  per-device for the run since the map depends only on the immutable
  projection file + xtoken_loss flag.
- _compute_gold: log_softmax + chunk-average on full teacher vocab both
  sides, KL on the common slice (forward or reverse), L1 on sorted
  uncommon probs (topk-capped at cfg["uncommon_topk"], default 8192,
  matching PT line 3727). Combined as (kl_common + l1_uncommon) * T**2.
  Top-1 accuracy on the common slice over valid chunks.
- Shared helpers _rebuild_teacher_full_logits / _chunk_average_log_probs
  / _valid_chunk_mask reused by both _compute_p_kl and _compute_gold.
- __init__ now requires gold_loss=True when xtoken_loss=True; the older
  "<=1 True" mutex was wrong (PT's canonical CT run sets both, and the
  flag is a modifier inside the gold path).

Unifies the teacher-logits transport. Replaces the per-sample top-k
GlobalTopkLogitsPostProcessor + get_global_topk_logits[_ipc] pair with
a single FullLogitsPostProcessor + get_full_logits_ipc that ships the
entire [T_t, V_t] tensor per sample via CUDA IPC. The legacy P-KL path
stays correct by computing its microbatch-global top-k inline (max over
flat (B*T_t) -> topk over V_t -- matches PT `global_top_indices` at
tokenalign.py:3877) instead of at the worker. Closer to PT shape
(full vocab in-process) and removes worker-side top-k plumbing.

Handles the common HF case where lm_head pads out_features beyond
len(tokenizer) (Qwen3-4B: 151936 vs 151669): both _compute_p_kl and
_compute_gold slice teacher_full_logits to v_t (projection's V_t =
real tokenizer vocab) before any downstream math. Padded columns are
non-real tokens; dropping them matches the de-facto NRL contract.

Config: adds CrossTokenizerDistillationLossConfig.uncommon_topk; drops
XTokenDistillationConfig.topk_logits_k (derived locally from
loss_fn.vocab_topk). examples/configs/xtoken_distillation.yaml updated.

Runner: XTOKEN_NON_STUDENT_SEQ_KEYS swaps the top-k keys for
teacher_full_logits_ipc. Both training and validate() switch transport
call sites. Per-MB metric reduction + logging handle both metric shapes
(kl_loss/ce_loss vs kl_common/l1_uncommon). _maybe_dump_loss records
whichever keys are present so the same dump file format serves both
paths -- preserves the NRL_XTOKEN_LOSS_DUMP_DIR parity-check protocol.

Verified end-to-end with 100-step 2-node interactive smokes:
- P-KL (gold_loss=false): job 11752927 COMPLETED in 11:43,
  train:loss=3.10, ckpt at step 100.
- Gold (gold_loss=true, xtoken_loss=true): job 11754078 COMPLETED in
  12:39, train:loss=45.69, KL(common)~0.47, L1(uncommon)~0.002, top-1
  accuracy on common ~76.7%, ckpt at step 100.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Adds nemo_rl/utils/x_token/ — the four offline CLI tools that produce
the projection matrix consumed at training time by
CrossTokenizerDistillationLossFn:

  - minimal_projection_generator.py: embedding-similarity top-k between
    student and teacher vocabularies.
  - minimal_projection_via_multitoken.py: augments the projection with
    multi-token mappings (e.g. "12" -> "1", "2") in both directions.
  - reapply_exact_map.py: pins literally-identical tokens to 1-to-1
    mappings with weight 1.0.
  - sort_and_cut_projection_matrix.py: trims to runtime top_k; reads
    enable_scale_trick metadata to auto-enable --preserve_last.

All four scripts take --student-model and --teacher-model as required
arguments; the projection direction follows the CLI args exactly (no
alphabetical-order swap). Common helpers (add_multitoken_mappings,
print_projection_map_examples, print_projection_statistics,
clean_model_name_for_filename) live at module scope. Cross-script dict
keys renamed model_A_id -> student_model_id (legacy keys still read).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Adds docs/guides/xtoken-distillation.md walking through the four-step
offline projection-prep pipeline (generate -> add multi-token -> reapply
exact map -> sort and trim) followed by the distillation training run
via examples/run_xtoken_distillation.py. Covers backend constraints
(DTensor V2 only, CUDA-IPC teacher logits, arrow_text corpus), the three
loss modes (P-KL, gold, gold+xtoken), and worked examples for Qwen,
Gemma, and Phi-3/Phi-4 teacher pairs. Registered in docs/index.md under
Guides.

Also nulls out the hardcoded /lustre user-specific projection_matrix_path
in examples/configs/xtoken_distillation.yaml and replaces it with an
inline comment pointing at the new tutorial; the path is now expected
to be supplied via Hydra CLI per run, matching the data.train.arrow_files
pattern already used in the same file.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Apply the post-review changes for PR NVIDIA-NeMo#2347 on the cross-tokenizer
projection-prep utilities under nemo_rl/utils/x_token/:

minimal_projection_via_multitoken.py
- Re-add optional --output-filename stem (default None falls back to the
  auto-derived name) so recipe-driven runs can pin the filename, matching
  the contract used by the reference recipes.
- Extend the gemma vocab-size branch to also fire for qwen3.5, which
  nests vocab_size under config.text_config on both student and teacher
  sides.
- Document the .pt save-key schema near torch.save (student_model_id /
  teacher_model_id / enable_scale_trick) with a pointer to the
  legacy-key fallback in the load path.

minimal_projection_generator.py
- Add the same schema annotation near torch.save.

reapply_exact_map.py
- Validate the loaded projection map is a dict containing 'indices' and
  'likelihoods'; raise ValueError with the file path on mismatch instead
  of surfacing a confusing KeyError mid-loop.

sort_and_cut_projection_matrix.py
- Lift the argparse block out of main() into a module-level
  parse_arguments() helper, matching the shape used by the other
  x_token CLI scripts.
- Lift the verbose stats block out of sort_and_cut_projection_matrix()
  into a print_projection_statistics() helper.
- Replace the positional input_path with a named
  --initial-projection-path flag for naming consistency with the rest
  of the x_token CLI tools.

docs/guides/xtoken-distillation.md
- Update Step 4 usage example to the new --initial-projection-path flag.

Behaviour preserved: the Llama-3.2-3B / Qwen3-4B-Base recipe still
produces bitwise-equal indices and likelihoods against the canonical
llama_qwen_best_special_exact_map_remapped.pt artifact.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
@avenkateshha avenkateshha requested review from a team as code owners May 16, 2026 01:40
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 16, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@github-actions github-actions Bot added Documentation Improvements or additions to documentation community-request labels May 16, 2026
@avenkateshha avenkateshha deleted the avenkateshha/xtoken-off-policy-distillation-combined branch May 16, 2026 01:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request Documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants