feat(xtoken): cross-tokenizer off-policy distillation#2508
feat(xtoken): cross-tokenizer off-policy distillation#2508avenkateshha wants to merge 35 commits into
Conversation
| self.ctx_length_teacher, | ||
| self.make_seq_div_by_teacher, | ||
| ) | ||
| alignment = self.aligner.align(student_input_ids, teacher_input_ids) |
There was a problem hiding this comment.
Can we make token alignment an explicit algorithm step instead of doing it inside CrossTokenizerCollator? The collator currently tokenizes and immediately calls aligner.align(...), so alignment is hidden as a dataloader side effect. That makes it harder to reuse the x-token path from other algorithms such as MOPD, where the batch may be produced by generation rather than by this raw-text collator.
A way split would be:
- collator / batch builder: produce student ids and teacher ids
- algorithm loop: call
aligner.align(student_input_ids, teacher_input_ids)explicitly, under a separatex_token_alignmenttimer.
And the align step can be put here.
CC: @yuki-97
| # Rebuild full teacher logits from the IPC handles. Same transport | ||
| # as the gold path consumes; here we additionally compute a | ||
| # microbatch-global top-k inline to match PT. | ||
| teacher_full_logits = self._rebuild_teacher_full_logits(data) # [B, T_t, V_t_model] |
There was a problem hiding this comment.
Would it make sense to move the cross-tokenizer loss input preparation into prepare_loss_input() as well? Right now the regular distillation path prepares teacher/student log-prob inputs in utils.py, while CrossTokenizerDistillationLossFn keeps input-side work such as rebuilding teacher_full_logits from IPC handles inside the loss implementation. Since that rebuild is more about materializing the loss inputs than the loss formula itself, a dedicated input-prep might make the boundary more consistent and keep the loss function focused on the actual reductions.
CC: @yuki-97
There was a problem hiding this comment.
@RayenTian - Do you suggest extending the LOGIT branch or creating a new enum LossInputType.CROSS_TOKENIZER?
There was a problem hiding this comment.
My preference is creating a new enum LossInputType.CROSS_TOKENIZER. What do you think @yuki-97 ?
There was a problem hiding this comment.
a new enum should be better, and wdyt this name LossInputType.DISTILLATION_CROSS_TOKENIZER?
|
Hi, @avenkateshha! Thanks for putting this together, and thanks @yuki-97 for the review help. I have only partly reviewed the PR so far, but I left a few comments. I’ll continue reviewing the remaining files separately. |
Comments addressed: #3, #5, NVIDIA-NeMo#7, NVIDIA-NeMo#8, NVIDIA-NeMo#9, NVIDIA-NeMo#10, NVIDIA-NeMo#11. - Rename _load_M -> _get_sparse_projection_matrix and _load_dense_projection -> _get_topk_projection (later removed in favor of module-level cache helpers below). - Drop unused alignment_student_spans / alignment_teacher_spans from the cross-tokenizer batch payload. - Remove NRL_XTOKEN_LOSS_DUMP_DIR debug-dump side effect. - Move Fp32SparseMM, chunk_average_log_probs, valid_chunk_mask to a new shared module nemo_rl/algorithms/x_token/utils.py. - Extract projection-file parsing into utils.parse_projection_file; tokenalign.py and loss_functions.py both go through it. - Move per-instance projection-matrix caches to process-local caches in utils.get_sparse_projection_matrix / get_topk_projection. The driver no longer holds large CUDA tensors; each Ray worker fills its own cache on first loss call. Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Remove the optional embedding-similarity seed (formerly Step 1, minimal_projection_generator.py) from the cross-tokenizer off-policy distillation guide, leaving the three documented prep steps (2 multi-token mappings, 3 reapply exact map, 4 sort+trim) and the training step (5). Add a "Which prep steps are essential?" subsection noting Steps 2 and 4 are required, Step 3 is optional, and best results on this branch came from running 2 -> 3 -> 4. The minimal_projection_generator.py tool and the build_projection_matrix.sh wrapper are untouched; this change is docs-only. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
After removing the embedding-similarity seed in the previous commit, renumber the remaining prep steps to a clean 1-4 sequence (1 multi-token mappings, 2 optional reapply exact map, 3 sort+trim, 4 train) across the arch diagram, section headers, the "Which prep steps are essential?" subsection (best results path is now Steps 1 -> 2 -> 3), the Quickstart paragraph, and all in-text cross-references. Also audit the rest of the guide against the current code and drop stale lines: - Drop the "Megatron path is intentionally stubbed with NotImplementedError" sentence in Backend and scope (no such stub on the Megatron worker; the actual enforcement is an assertion at xtoken setup) and the "No sequence packing or dynamic batching for the teacher forward in v0" bullet. - Drop the entire "Other (student, teacher) pairs" section. The Llama -> Gemma claim could not be substantiated against the codebase (no CT exemplar for a Gemma teacher), and the Phi-3 / Phi-4 NRL_TRUST_REMOTE_CODE + NRL_SKIP_PHI_ROPE_FIX paragraph it carried referenced env vars that are not consumed by any code on this branch. - Drop the "Related" bullet pointing to the Quantization-Aware RL guide; the QA-Distillation workflow invokes examples/run_distillation.py, not the new run_xtoken_off_policy_distillation.py entrypoint added on this branch, so the claim that they share a training entrypoint was incorrect. - Drop two "PT" / "PT-faithful" references from the loss-mode table and the uncommon_topk default; NeMo-RL users have no context for the upstream PyTorch reference run. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
…int claims - "(or vice versa, depending on the loss mode)" implied a teacher -> student projection exists. It does not: every loss mode projects student logits into teacher vocab via M (see Fp32SparseMM and the P-KL / gold paths in loss_functions.py). reverse_kl flips the KL direction, not the projection direction. - "The corpus must be served via the arrow_text dataset" overstated the constraint. The xtoken entrypoint calls the generic setup_response_data, which dispatches through DATASET_REGISTRY; any registered dataset can be used. The "no chat template, loss on every token" property comes from kd_data_processor, not from arrow_text. The exemplar config still demonstrates the typical arrow_text + kd_data_processor pairing. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Add the trainer module, the cross-tokenizer collator, and the KD data processor to the Related section of the cross-tokenizer distillation guide. These are the next files a reader is likely to want to open after the loss function and token aligner, and they cover the rest of the runtime path: setup + train loop (xtoken_off_policy_distillation.py), per-microbatch teacher tokenization and alignment field construction (CrossTokenizerCollator), and the no-chat-template / loss-on-every-token data processor the exemplar config wires up (kd_data_processor). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
The intro said cross-tokenizer distillation works by "routing teacher logits through a precomputed projection matrix" — the projection direction is the other way around. Every loss mode projects student logits through M into the teacher vocab space (see Fp32SparseMM and the docstring at loss_functions.py:1303 "Project full-vocab student probs through M to teacher vocab"); the teacher logits stay in teacher vocab. Rewrite the sentence to make the direction explicit: student logits are projected into the teacher's vocab space via M so the two distributions can be compared. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Reorder the Related section so the file pointers follow the chronology of a training step: config -> trainer setup -> per-sample KD processing -> per-batch cross-tokenizer collation (which internally invokes the token aligner) -> per-microbatch loss. Add a one-line lead-in noting the ordering. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Mirrors the TypedDict -> BaseModel migration applied to the other trainers in NVIDIA-NeMo#2325 (MasterConfig & ClippedPGLossConfig). Strict pydantic validation now runs in the runner before setup(), so the exemplar YAML had to be cleaned of NotRequired fields that were set to null (chat_template, generation) — those keys are typed NotRequired[T] without `| None`, so the validator rejects them. Match the convention used by tests/unit/reference_configs/ distillation_math.yaml: omit the key entirely when there's no value. The two runtime-injected vocab-size fields on CrossTokenizerDistillationLossConfig (already documented as "not a user knob in YAML") are flipped to NotRequired[int] so they pass schema validation before setup() fills them from len(student_tokenizer) / len(teacher_tokenizer). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
…logits Three coupled changes to the cross-tokenizer teacher-logits transport: 1. Producer (dtensor_policy_worker_v2): replace per-step stash of per-sample handles with a single persistent ``[B_r, T_t, V_t]`` IPC buffer captured once on first call. Each step ``.copy_()``-s fresh logits into the same backing memory behind the same stable handle. ``release_ipc_buffer`` becomes a no-op kept for the driver contract. 2. Consumer (loss_functions): rebuild the single rank-level handle once and slice ``[mb_start:mb_end]`` for the microbatch — drops the per-sample ``rebuild_cuda_tensor_from_ipc`` loop + ``torch.stack`` and avoids an extra allocation. Asserts contiguous monotonic ``sample_idx_within_rank`` to fail loudly if any future change reorders microbatch samples. 3. Producer (automodel.FullLogitsPostProcessor): drop the ``.to(torch.float32)`` cast — teacher is frozen, downstream log_softmax/KL kernels upcast internally where they need fp32, and shipping native compute dtype (bf16 under autocast) halves the persistent IPC buffer footprint. Together these eliminate the per-step alloc/free cycle the consumer-must-copy-out rule was working around, halve the buffer's memory, and remove the consumer-side stack entirely. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Lift the script body of tools/x_token/reapply_exact_map.py into a reapply_exact_map(args) function so unit tests can drive it without runpy / sys.modules import-time gymnastics. The __main__ block becomes a thin shim over the helper; CLI behavior is unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Policy.train forwards skip_keys through the worker dispatch payload regardless of which worker implementation is selected. The v2 worker already accepts the kwarg, but v1 (DTensorPolicyWorkerImpl) did not, so any caller path that landed on v1 raised "TypeError: DTensorPolicyWorkerImpl.train() got an unexpected keyword argument 'skip_keys'" - surfaced by test_vllm_generation_with_hf_training_colocated. Add skip_keys: Optional[Iterable[str]] = None to v1's train() for parity with v2 and honor the filter in the inline sequence-dim check. v1 does not run cross-tokenizer, so skip_keys is typically None on this path; the kwarg is accepted purely so the unified Policy.train call site does not break. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Adds hermetic CPU unit tests for the xtoken module the PR-2508 reviewer
flagged as needing coverage. Style mirrors tests/unit/algorithms/
test_distillation.py: top-level test_* functions plus a single
mock_xtoken_components fixture; no Ray init, no CUDA, no HF downloads.
- tests/unit/algorithms/x_token/test_xtoken_off_policy_distillation.py
setup() dtensor-v2 asserts, vocab-size injection, val_dataloader
gating, exit-on-{max_steps, max_epochs, timeout}, validate() loss-
path branches (P-KL vs gold), IPC buffer release on train failure,
examples/run_xtoken_off_policy_distillation.py scope assert.
- tests/unit/algorithms/x_token/test_loss_utils.py
alignment_from_flat_batch schema-drift guard, chunk_average_log_probs
/ valid_chunk_mask numerical math, parse_projection_file (dense +
sparse + error cases), get_{sparse,topk}_projection cache semantics,
build_exact_token_map cache keying.
- tests/unit/tools/x_token/test_projection_tools.py
per-tool parse_arguments contracts, save_data/load_data round-trip,
generate_projection_map_chunk top-k math, add_multitoken_mappings
with mocked tokenizers, sort_and_cut_projection_matrix function +
main() auto-preserve-from-metadata, _shared.sinkhorn_one_dim /
clean_model_name_for_filename, reapply_exact_map helper roundtrip,
build_projection_matrix.sh --help dry-run.
- tests/unit/data/test_cross_tokenizer_collate.py
CrossTokenizerCollator output keys / shapes / truncation /
seq-divisibility / pad-token fallback / custom text-key.
- tests/unit/data/test_data_processor.py
TestKdDataProcessor: output contract, optional task_name forwarding,
drift-detector against input_ids / token_mask / loss_mask emission,
tokenizer-not-called guard, long-text-not-truncated-in-processor.
- tests/unit/data/datasets/test_arrow_text_dataset.py
TestPackGenerator: characters_per_sample semantics, pack-may-exceed-
threshold, no row truncation, trailing partial pack, zero-row input,
all-empty-text still emits, schema_version cache-key only.
- tests/unit/algorithms/x_token/conftest.py
make_loss_cfg fixture fixup: drop removed project_teacher_to_student,
add now-required student_vocab_size.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
The chunk-KL reductions in `CrossTokenizerDistillationLossFn` divided each rank's KL sum by its own local valid-chunk count, producing `mean(rank_local_means)` instead of `sum(global_valid_chunk_kl) / sum(global_valid_chunks)` once gradients are summed across DP ranks. With heterogeneous cross-tokenizer alignment, per-rank chunk counts differ and the effective objective drifts away from the intended one. Add `_dp_all_reduce_sum` on the loss class and use it to compute a DP-global valid-chunk count at three sites: - `_compute_p_kl` chunk-KL denominator. - `_compute_gold` KL-on-common denominator. - `_compute_gold` L1-on-uncommon `.mean()` (same per-rank-mean bug). The all_reduce uses the default process group, which equals the DP group under CT's `cp_size=tp_size=1` setup assert. To keep the collective consistent across ranks, the existing local-only `chunk_mask.any()` / `valid_chunk.any()` early returns are moved to fire after the all_reduce and gated on the global count; otherwise a rank with no local valid chunks would skip the collective and deadlock against ranks that called it. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Move `loss` and `grad_norm` into the initial metrics dict literal so
the existing per-step reduction loop collapses them via np.sum, matching
the upstream SFT / distillation pattern. The previous code assigned
both keys after the loop via float(train_results[k].numpy()), which
relied on numpy<2's implicit length-1 array-to-scalar coercion. Under
the v0.6 container's newer numpy this raises:
TypeError: only 0-dimensional arrays can be converted to Python scalars
crashing the driver before the first logged step.
Logged values are unchanged: train_results["loss"] is shape (1,) for
off-policy distillation (one global batch per train() call), so
np.sum(arr).item() returns the same scalar the old code unwrapped.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
|
/ok to test 9e49cfe |
ArrowTextDataset now resolves `data_files` via `load_dataset_from_path` (infers .arrow/.parquet/.json/.txt from the extension, or loads a HF dataset by name) instead of hardcoding the "arrow" builder, and accepts `subset`/`split`. The exemplar recipe defaults `data_files` to the ungated, CC-BY-4.0 nvidia/Nemotron-Pretraining-Specialized-v1.1 parquet shards (over hf://) so it runs with no auth and no extra setup; override `data_files` with a path/glob to train on your own .arrow corpus. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Summary
Add cross-tokenizer off-policy distillation to NeMo-RL.
TokenAligner— DP alignment over canonicalized tokens; projection-matrix loader.NotImplementedError.nemo_rl/utils/x_token/(minimal_projection_generator,minimal_projection_via_multitoken,reapply_exact_map,sort_and_cut_projection_matrix).docs/guides/xtoken-distillation.md.This PR supersedes #2347 (which was force-closed when its head branch was renamed). RayenTian's review feedback on #2347 is addressed in commit
755fb8e4:minimal_projection_via_multitoken.py: restored--output-filenameoverride; extended gemma vocab-size branch to also fire for qwen3.5; documented.ptsave-key schema.minimal_projection_generator.py: matching schema annotation neartorch.save.reapply_exact_map.py: validate loaded projection map is a dict withindices/likelihoods; raiseValueErrorwith the file path on mismatch.sort_and_cut_projection_matrix.py: factored argparse intoparse_arguments(); factored verbose stats intoprint_projection_statistics(); replaced positionalinput_pathwith--initial-projection-pathto match the other tools.docs/guides/xtoken-distillation.md: synced the Step 4 example with the new CLI.Test plan
llama_qwen_best_special_exact_map_remapped.ptbitwise (torch.equalon indices and likelihoods).