Skip to content

feat(xtoken): cross-tokenizer off-policy distillation#2508

Open
avenkateshha wants to merge 35 commits into
NVIDIA-NeMo:mainfrom
avenkateshha:avenkateshha/xtoken-off-policy-distillation
Open

feat(xtoken): cross-tokenizer off-policy distillation#2508
avenkateshha wants to merge 35 commits into
NVIDIA-NeMo:mainfrom
avenkateshha:avenkateshha/xtoken-off-policy-distillation

Conversation

@avenkateshha
Copy link
Copy Markdown

Summary

Add cross-tokenizer off-policy distillation to NeMo-RL.

  • TokenAligner — DP alignment over canonicalized tokens; projection-matrix loader.
  • Loss modes: P-KL (full-vocab teacher logits via IPC + microbatch-global top-k), gold-loss (exact-map common + L1 uncommon tail), xtoken-loss (relaxed exact-map collisions).
  • DTensor V2 worker support; Megatron path stubbed with NotImplementedError.
  • Projection-prep CLI utilities under nemo_rl/utils/x_token/ (minimal_projection_generator, minimal_projection_via_multitoken, reapply_exact_map, sort_and_cut_projection_matrix).
  • Docs guide at docs/guides/xtoken-distillation.md.

This PR supersedes #2347 (which was force-closed when its head branch was renamed). RayenTian's review feedback on #2347 is addressed in commit 755fb8e4:

  • minimal_projection_via_multitoken.py: restored --output-filename override; extended gemma vocab-size branch to also fire for qwen3.5; documented .pt save-key schema.
  • minimal_projection_generator.py: matching schema annotation near torch.save.
  • reapply_exact_map.py: validate loaded projection map is a dict with indices/likelihoods; raise ValueError with the file path on mismatch.
  • sort_and_cut_projection_matrix.py: factored argparse into parse_arguments(); factored verbose stats into print_projection_statistics(); replaced positional input_path with --initial-projection-path to match the other tools.
  • docs/guides/xtoken-distillation.md: synced the Step 4 example with the new CLI.

Test plan

  • CI green
  • Projection-prep reproducibility: re-running the CLI tools on the Llama-3.2-3B ↔ Qwen3-4B-Base pair reproduces the canonical llama_qwen_best_special_exact_map_remapped.pt bitwise (torch.equal on indices and likelihoods).
  • Smoke run: 2-node Llama-3.2-1B ← Qwen3-4B, 100-step P-KL on climb data.

@avenkateshha avenkateshha requested review from a team as code owners May 16, 2026 02:05
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 16, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@github-actions github-actions Bot added Documentation Improvements or additions to documentation community-request labels May 16, 2026
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-maintainers Waiting on maintainers to respond label May 18, 2026
Comment thread nemo_rl/algorithms/loss/loss_functions.py
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
@svcnvidia-nemo-ci svcnvidia-nemo-ci added waiting-on-customer Waiting on the original author to respond and removed waiting-on-maintainers Waiting on maintainers to respond labels May 20, 2026
Comment thread nemo_rl/algorithms/loss/loss_functions.py
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
self.ctx_length_teacher,
self.make_seq_div_by_teacher,
)
alignment = self.aligner.align(student_input_ids, teacher_input_ids)
Copy link
Copy Markdown
Contributor

@RayenTian RayenTian May 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make token alignment an explicit algorithm step instead of doing it inside CrossTokenizerCollator? The collator currently tokenizes and immediately calls aligner.align(...), so alignment is hidden as a dataloader side effect. That makes it harder to reuse the x-token path from other algorithms such as MOPD, where the batch may be produced by generation rather than by this raw-text collator.

A way split would be:

  • collator / batch builder: produce student ids and teacher ids
  • algorithm loop: call aligner.align(student_input_ids, teacher_input_ids) explicitly, under a separate x_token_alignment timer.

And the align step can be put here.

CC: @yuki-97

# Rebuild full teacher logits from the IPC handles. Same transport
# as the gold path consumes; here we additionally compute a
# microbatch-global top-k inline to match PT.
teacher_full_logits = self._rebuild_teacher_full_logits(data) # [B, T_t, V_t_model]
Copy link
Copy Markdown
Contributor

@RayenTian RayenTian May 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to move the cross-tokenizer loss input preparation into prepare_loss_input() as well? Right now the regular distillation path prepares teacher/student log-prob inputs in utils.py, while CrossTokenizerDistillationLossFn keeps input-side work such as rebuilding teacher_full_logits from IPC handles inside the loss implementation. Since that rebuild is more about materializing the loss inputs than the loss formula itself, a dedicated input-prep might make the boundary more consistent and keep the loss function focused on the actual reductions.

CC: @yuki-97

Copy link
Copy Markdown
Author

@avenkateshha avenkateshha May 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RayenTian - Do you suggest extending the LOGIT branch or creating a new enum LossInputType.CROSS_TOKENIZER?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My preference is creating a new enum LossInputType.CROSS_TOKENIZER. What do you think @yuki-97 ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a new enum should be better, and wdyt this name LossInputType.DISTILLATION_CROSS_TOKENIZER?

@RayenTian
Copy link
Copy Markdown
Contributor

Hi, @avenkateshha! Thanks for putting this together, and thanks @yuki-97 for the review help. I have only partly reviewed the PR so far, but I left a few comments. I’ll continue reviewing the remaining files separately.

avenkateshha added a commit to avenkateshha/RL that referenced this pull request May 20, 2026
Comments addressed: #3, #5, NVIDIA-NeMo#7, NVIDIA-NeMo#8, NVIDIA-NeMo#9, NVIDIA-NeMo#10, NVIDIA-NeMo#11.

- Rename _load_M -> _get_sparse_projection_matrix and
  _load_dense_projection -> _get_topk_projection (later removed in
  favor of module-level cache helpers below).
- Drop unused alignment_student_spans / alignment_teacher_spans
  from the cross-tokenizer batch payload.
- Remove NRL_XTOKEN_LOSS_DUMP_DIR debug-dump side effect.
- Move Fp32SparseMM, chunk_average_log_probs, valid_chunk_mask to a
  new shared module nemo_rl/algorithms/x_token/utils.py.
- Extract projection-file parsing into utils.parse_projection_file;
  tokenalign.py and loss_functions.py both go through it.
- Move per-instance projection-matrix caches to process-local caches
  in utils.get_sparse_projection_matrix / get_topk_projection. The
  driver no longer holds large CUDA tensors; each Ray worker fills
  its own cache on first loss call.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
avenkateshha and others added 6 commits May 28, 2026 00:49
Remove the optional embedding-similarity seed (formerly Step 1,
minimal_projection_generator.py) from the cross-tokenizer
off-policy distillation guide, leaving the three documented prep
steps (2 multi-token mappings, 3 reapply exact map, 4 sort+trim)
and the training step (5). Add a "Which prep steps are essential?"
subsection noting Steps 2 and 4 are required, Step 3 is optional,
and best results on this branch came from running 2 -> 3 -> 4.

The minimal_projection_generator.py tool and the
build_projection_matrix.sh wrapper are untouched; this change is
docs-only.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
After removing the embedding-similarity seed in the previous
commit, renumber the remaining prep steps to a clean 1-4 sequence
(1 multi-token mappings, 2 optional reapply exact map, 3 sort+trim,
4 train) across the arch diagram, section headers, the "Which
prep steps are essential?" subsection (best results path is now
Steps 1 -> 2 -> 3), the Quickstart paragraph, and all in-text
cross-references.

Also audit the rest of the guide against the current code and
drop stale lines:

- Drop the "Megatron path is intentionally stubbed with
  NotImplementedError" sentence in Backend and scope (no such
  stub on the Megatron worker; the actual enforcement is an
  assertion at xtoken setup) and the "No sequence packing or
  dynamic batching for the teacher forward in v0" bullet.
- Drop the entire "Other (student, teacher) pairs" section.
  The Llama -> Gemma claim could not be substantiated against
  the codebase (no CT exemplar for a Gemma teacher), and the
  Phi-3 / Phi-4 NRL_TRUST_REMOTE_CODE + NRL_SKIP_PHI_ROPE_FIX
  paragraph it carried referenced env vars that are not consumed
  by any code on this branch.
- Drop the "Related" bullet pointing to the Quantization-Aware
  RL guide; the QA-Distillation workflow invokes
  examples/run_distillation.py, not the new
  run_xtoken_off_policy_distillation.py entrypoint added on this
  branch, so the claim that they share a training entrypoint was
  incorrect.
- Drop two "PT" / "PT-faithful" references from the loss-mode
  table and the uncommon_topk default; NeMo-RL users have no
  context for the upstream PyTorch reference run.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
…int claims

- "(or vice versa, depending on the loss mode)" implied a teacher
  -> student projection exists. It does not: every loss mode
  projects student logits into teacher vocab via M (see
  Fp32SparseMM and the P-KL / gold paths in loss_functions.py).
  reverse_kl flips the KL direction, not the projection direction.
- "The corpus must be served via the arrow_text dataset" overstated
  the constraint. The xtoken entrypoint calls the generic
  setup_response_data, which dispatches through DATASET_REGISTRY;
  any registered dataset can be used. The "no chat template, loss
  on every token" property comes from kd_data_processor, not from
  arrow_text. The exemplar config still demonstrates the typical
  arrow_text + kd_data_processor pairing.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Add the trainer module, the cross-tokenizer collator, and the KD
data processor to the Related section of the cross-tokenizer
distillation guide. These are the next files a reader is likely
to want to open after the loss function and token aligner, and
they cover the rest of the runtime path: setup + train loop
(xtoken_off_policy_distillation.py), per-microbatch teacher
tokenization and alignment field construction
(CrossTokenizerCollator), and the no-chat-template /
loss-on-every-token data processor the exemplar config wires up
(kd_data_processor).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
The intro said cross-tokenizer distillation works by "routing
teacher logits through a precomputed projection matrix" — the
projection direction is the other way around. Every loss mode
projects student logits through M into the teacher vocab space
(see Fp32SparseMM and the docstring at
loss_functions.py:1303 "Project full-vocab student probs through
M to teacher vocab"); the teacher logits stay in teacher vocab.

Rewrite the sentence to make the direction explicit: student
logits are projected into the teacher's vocab space via M so the
two distributions can be compared.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Reorder the Related section so the file pointers follow the
chronology of a training step: config -> trainer setup ->
per-sample KD processing -> per-batch cross-tokenizer collation
(which internally invokes the token aligner) -> per-microbatch
loss. Add a one-line lead-in noting the ordering.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-customer Waiting on the original author to respond label May 29, 2026
avenkateshha and others added 7 commits May 29, 2026 17:29
Mirrors the TypedDict -> BaseModel migration applied to the other
trainers in NVIDIA-NeMo#2325 (MasterConfig & ClippedPGLossConfig). Strict
pydantic validation now runs in the runner before setup(), so the
exemplar YAML had to be cleaned of NotRequired fields that were
set to null (chat_template, generation) — those keys are typed
NotRequired[T] without `| None`, so the validator rejects them.
Match the convention used by tests/unit/reference_configs/
distillation_math.yaml: omit the key entirely when there's no
value. The two runtime-injected vocab-size fields on
CrossTokenizerDistillationLossConfig (already documented as "not a
user knob in YAML") are flipped to NotRequired[int] so they pass
schema validation before setup() fills them from
len(student_tokenizer) / len(teacher_tokenizer).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
…logits

Three coupled changes to the cross-tokenizer teacher-logits transport:

1. Producer (dtensor_policy_worker_v2): replace per-step stash of
   per-sample handles with a single persistent ``[B_r, T_t, V_t]``
   IPC buffer captured once on first call. Each step ``.copy_()``-s
   fresh logits into the same backing memory behind the same stable
   handle. ``release_ipc_buffer`` becomes a no-op kept for the driver
   contract.

2. Consumer (loss_functions): rebuild the single rank-level handle
   once and slice ``[mb_start:mb_end]`` for the microbatch — drops
   the per-sample ``rebuild_cuda_tensor_from_ipc`` loop + ``torch.stack``
   and avoids an extra allocation. Asserts contiguous monotonic
   ``sample_idx_within_rank`` to fail loudly if any future change
   reorders microbatch samples.

3. Producer (automodel.FullLogitsPostProcessor): drop the
   ``.to(torch.float32)`` cast — teacher is frozen, downstream
   log_softmax/KL kernels upcast internally where they need fp32, and
   shipping native compute dtype (bf16 under autocast) halves the
   persistent IPC buffer footprint.

Together these eliminate the per-step alloc/free cycle the
consumer-must-copy-out rule was working around, halve the buffer's
memory, and remove the consumer-side stack entirely.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Lift the script body of tools/x_token/reapply_exact_map.py into a
reapply_exact_map(args) function so unit tests can drive it without
runpy / sys.modules import-time gymnastics. The __main__ block becomes
a thin shim over the helper; CLI behavior is unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Policy.train forwards skip_keys through the worker dispatch payload
regardless of which worker implementation is selected. The v2 worker
already accepts the kwarg, but v1 (DTensorPolicyWorkerImpl) did not,
so any caller path that landed on v1 raised
"TypeError: DTensorPolicyWorkerImpl.train() got an unexpected keyword
argument 'skip_keys'" - surfaced by
test_vllm_generation_with_hf_training_colocated.

Add skip_keys: Optional[Iterable[str]] = None to v1's train() for
parity with v2 and honor the filter in the inline sequence-dim check.
v1 does not run cross-tokenizer, so skip_keys is typically None on
this path; the kwarg is accepted purely so the unified Policy.train
call site does not break.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Adds hermetic CPU unit tests for the xtoken module the PR-2508 reviewer
flagged as needing coverage. Style mirrors tests/unit/algorithms/
test_distillation.py: top-level test_* functions plus a single
mock_xtoken_components fixture; no Ray init, no CUDA, no HF downloads.

- tests/unit/algorithms/x_token/test_xtoken_off_policy_distillation.py
  setup() dtensor-v2 asserts, vocab-size injection, val_dataloader
  gating, exit-on-{max_steps, max_epochs, timeout}, validate() loss-
  path branches (P-KL vs gold), IPC buffer release on train failure,
  examples/run_xtoken_off_policy_distillation.py scope assert.

- tests/unit/algorithms/x_token/test_loss_utils.py
  alignment_from_flat_batch schema-drift guard, chunk_average_log_probs
  / valid_chunk_mask numerical math, parse_projection_file (dense +
  sparse + error cases), get_{sparse,topk}_projection cache semantics,
  build_exact_token_map cache keying.

- tests/unit/tools/x_token/test_projection_tools.py
  per-tool parse_arguments contracts, save_data/load_data round-trip,
  generate_projection_map_chunk top-k math, add_multitoken_mappings
  with mocked tokenizers, sort_and_cut_projection_matrix function +
  main() auto-preserve-from-metadata, _shared.sinkhorn_one_dim /
  clean_model_name_for_filename, reapply_exact_map helper roundtrip,
  build_projection_matrix.sh --help dry-run.

- tests/unit/data/test_cross_tokenizer_collate.py
  CrossTokenizerCollator output keys / shapes / truncation /
  seq-divisibility / pad-token fallback / custom text-key.

- tests/unit/data/test_data_processor.py
  TestKdDataProcessor: output contract, optional task_name forwarding,
  drift-detector against input_ids / token_mask / loss_mask emission,
  tokenizer-not-called guard, long-text-not-truncated-in-processor.

- tests/unit/data/datasets/test_arrow_text_dataset.py
  TestPackGenerator: characters_per_sample semantics, pack-may-exceed-
  threshold, no row truncation, trailing partial pack, zero-row input,
  all-empty-text still emits, schema_version cache-key only.

- tests/unit/algorithms/x_token/conftest.py
  make_loss_cfg fixture fixup: drop removed project_teacher_to_student,
  add now-required student_vocab_size.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
The chunk-KL reductions in `CrossTokenizerDistillationLossFn` divided
each rank's KL sum by its own local valid-chunk count, producing
`mean(rank_local_means)` instead of
`sum(global_valid_chunk_kl) / sum(global_valid_chunks)` once gradients
are summed across DP ranks. With heterogeneous cross-tokenizer
alignment, per-rank chunk counts differ and the effective objective
drifts away from the intended one.

Add `_dp_all_reduce_sum` on the loss class and use it to compute a
DP-global valid-chunk count at three sites:

- `_compute_p_kl` chunk-KL denominator.
- `_compute_gold` KL-on-common denominator.
- `_compute_gold` L1-on-uncommon `.mean()` (same per-rank-mean bug).

The all_reduce uses the default process group, which equals the DP
group under CT's `cp_size=tp_size=1` setup assert.

To keep the collective consistent across ranks, the existing
local-only `chunk_mask.any()` / `valid_chunk.any()` early returns are
moved to fire after the all_reduce and gated on the global count;
otherwise a rank with no local valid chunks would skip the collective
and deadlock against ranks that called it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Move `loss` and `grad_norm` into the initial metrics dict literal so
the existing per-step reduction loop collapses them via np.sum, matching
the upstream SFT / distillation pattern. The previous code assigned
both keys after the loop via float(train_results[k].numpy()), which
relied on numpy<2's implicit length-1 array-to-scalar coercion. Under
the v0.6 container's newer numpy this raises:
    TypeError: only 0-dimensional arrays can be converted to Python scalars
crashing the driver before the first logged step.

Logged values are unchanged: train_results["loss"] is shape (1,) for
off-policy distillation (one global batch per train() call), so
np.sum(arr).item() returns the same scalar the old code unwrapped.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-customer Waiting on the original author to respond label May 31, 2026
@RayenTian
Copy link
Copy Markdown
Contributor

/ok to test 9e49cfe

ArrowTextDataset now resolves `data_files` via `load_dataset_from_path`
(infers .arrow/.parquet/.json/.txt from the extension, or loads a HF
dataset by name) instead of hardcoding the "arrow" builder, and accepts
`subset`/`split`. The exemplar recipe defaults `data_files` to the
ungated, CC-BY-4.0 nvidia/Nemotron-Pretraining-Specialized-v1.1 parquet
shards (over hf://) so it runs with no auth and no extra setup; override
`data_files` with a path/glob to train on your own .arrow corpus.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-customer Waiting on the original author to respond label Jun 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version) community-request Documentation Improvements or additions to documentation waiting-on-customer Waiting on the original author to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants