[5/5] feat: off-policy distillation algorithm and worker integration#2351
Closed
avenkateshha wants to merge 11 commits into
Closed
Conversation
This was referenced Apr 27, 2026
This was referenced Apr 27, 2026
Foundational library code for cross-tokenizer distillation. No algorithm
or training-loop integration yet — those follow in subsequent PRs.
- nemo_rl/algorithms/x_token/tokenalign.py: TokenAligner(nn.Module) with
Numba-accelerated DP alignment, projection-matrix loading
(dense and sparse COO), and the project_token_likelihoods_instance
forward path used by the cross-tokenizer loss.
- nemo_rl/algorithms/x_token/__init__.py: package init.
- nemo_rl/utils/x_token/{minimal_projection_generator,
minimal_projection_via_multitoken,reapply_exact_map,
sort_and_cut_projection_matrix}.py: standalone CLI scripts
(argparse-driven, __main__ entrypoints) for one-time projection-matrix
preparation. Not on the training import path.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
9fe4b94 to
65d573f
Compare
Data-layer plumbing for cross-tokenizer off-policy distillation, plus
in-training eval datasets. Builds on the TokenAligner package from the
prior PR.
- nemo_rl/data/cross_tokenizer_collate.py: CrossTokenizerCollator and
TeacherCTSpec. Runs in StatefulDataLoader worker processes — does
per-teacher tokenize + DP alignment up front so the train loop only
consumes pre-built per_teacher_ct_data. Lazy-imports TokenAligner so
workers that don't need cross-tokenizer never touch x_token.
- nemo_rl/data/__init__.py: add NotRequired prefetch_factor to DataConfig.
- nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py:
ArrowTextDataset with lazy packing, registered as "arrow_text" in
DATASET_REGISTRY.
- nemo_rl/data/datasets/eval_datasets/{humaneval_plus,mbpp_plus,mmlu}.py
and registry entries: in-training eval datasets. mmlu.py adds an
optional num_few_shot argument with a static _build_few_shot_prefixes
helper; default of 0 preserves existing behavior.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
Adds the loss-fn layer for cross-tokenizer distillation. Builds on the TokenAligner package (PR 1). - CrossTokenizerDistillationLossFn: per-token KL/CE loss over 1:1 aligned positions, with optional gold-loss path. Holds a reference to a TokenAligner; teacher data (input_ids, aligned_pairs, optional chunked COO masks) is set per-step via set_cross_tokenizer_data. - CrossTokenizerDistillationLossConfig and CrossTokenizerDistillationLossDataDict TypedDicts. - MultiTeacherLossAggregator: wraps a list of optional CrossTokenizerDistillationLossFn instances with per-teacher weights. N=1 is a degenerate case used by the unified single-/multi-teacher worker path; the algorithm-layer multi-teacher orchestration comes in a later PR. - _scatter_chunk_mask_from_coo helper for the chunked-CE path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
Adds the IPC plumbing that lets a teacher policy worker hand its logits to the student worker without going through Ray's serialization path — required for cross-tokenizer distillation where teacher full-vocab logits are too big to pickle per step. - nemo_rl/distributed/ipc_utils.py: get_handle_from_tensor and rebuild_cuda_tensor_from_ipc helpers wrapping CUDA IPC handles. - nemo_rl/models/automodel/train.py: two new post-processors — XTokenTeacherIPCExportPostProcessor (teacher side, allocates a pre-sized CUDA buffer and exports the IPC handle per microbatch) and XTokenTeacherIPCLossPostProcessor (student side, rebuilds the tensor from the handle and feeds it to the loss fn). Existing post-processors are untouched. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
Adds the full cross-tokenizer off-policy distillation training loop with single- and multi-teacher support, plus the worker-level plumbing. Builds on TokenAligner (PR 1), data layer (PR 2), loss (PR 3), and IPC (PR 4). Algorithm and entrypoint mirror the on-policy distillation patterns in nemo_rl/algorithms/distillation.py and examples/run_distillation.py: - nemo_rl/algorithms/off_policy_distillation.py: setup, validate, off_policy_distillation_train. No env/rollout machinery; teacher logits arrive via CUDA IPC. Supports single-teacher (master_config ["teacher"]) and multi-teacher (master_config["teachers"] list with TeacherSpec entries and per-teacher weights). Multi-teacher requires use_ipc=true; same-tokenizer + cross-tokenizer teachers can be mixed in one run. - examples/run_off_policy_distillation.py: thin entrypoint. - examples/configs/off_policy_distillation.yaml: production multi-teacher config (Phi-4-mini + Llama-3.2-3B teachers, Llama-3.2-1B student). data.train.arrow_files left null; override via Hydra CLI. Worker integration is purely additive — existing train() / get_topk_logits / get_logprobs paths are not modified, so GRPO / SFT / on-policy distillation are unaffected: - nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py: new init_cross_tokenizer_loss_fn, update_cross_tokenizer_data, train_off_policy_distillation, and compute_teacher_logits_ipc methods on DTensorPolicyWorkerV2Impl. - nemo_rl/models/policy/workers/megatron_policy_worker.py: same four method names as NotImplementedError stubs that point users to DTensorPolicyWorkerV2 for off-policy distillation. - nemo_rl/models/policy/lm_policy.py: matching dispatcher methods on Policy that fan out to the worker group. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
caab0e3 to
f9d931b
Compare
Off-policy distillation needs two driver-side methods that fan out to all policy workers: * ``init_cross_tokenizer_loss_fn(loss_config, token_aligner_config)`` — each worker materializes its own ``MultiTeacherLossAggregator`` (with N=1 for single-teacher) from the shared filesystem and caches it as ``self._cached_loss_fn``. * ``update_cross_tokenizer_data(teacher_input_ids, aligned_pairs, teacher_idx, chunk_indices)`` — push per-step CT data to the cached loss fn on every worker, sharded by DP axis so each worker only receives its slice. Uses ``run_all_workers_multiple_data`` for the per-rank sharded dispatch. The corresponding worker-side methods already exist in ``DTensorPolicyWorkerV2``; what was missing was the driver-side wrapper. Without these, ``off_policy_distillation_train`` calling ``student_policy.init_cross_tokenizer_loss_fn(...)`` / ``student_policy.update_cross_tokenizer_data(...)`` would AttributeError. Pattern matches existing wrappers like ``init_collective`` and ``prepare_for_training`` in the same file. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
…spatch Two refactor regressions in the cross-tokenizer off-policy worker surfaced once the smoke run cleared earlier setup-side bugs: 1. Phi-4 / Phi-3 ``trust_remote_code`` revisions register ``inv_freq`` and ``original_inv_freq`` as buffers that get left on the ``meta`` device after FSDP/DTensor materialization. Their forward later calls ``.to(device)`` on those buffers and crashes with ``NotImplementedError: Cannot copy out of meta tensor``. Add ``_fix_phi_rope_meta_buffers()`` (called once after model setup in ``__init__``) which re-runs each module's ``rope_init_fn`` against CUDA and replaces the meta buffer in place. Scoped to Phi-style models via the model_name / architecture string so other model families short-circuit immediately. The except is narrowed to ``(TypeError, RuntimeError, AttributeError)`` per the error-handling skill (no bare ``except: pass``); a skipped module logs at rank 0. 2. The ``teacher_worker_result`` dispatch had grown a multi-arm condition that misclassified the single-teacher payload shape ``list[rank]`` of dicts as a multi-teacher list and passed all 16 entries to the loss postprocessor, which then KeyError'd on ``handle[rank]``. Off-policy distillation always pre-shards by rank before the worker sees the payload (single-teacher: ``list[rank]``; multi-teacher: ``dict[rank, list[T]]``), so plain ``teacher_logits[rank]`` gives the right per-rank entry in both cases. Revert to the simpler form. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
…hed loss fn fallback
Three off-policy distillation regressions surfaced by the 100-step
single-teacher Phi-4 cross-tokenizer smoke run:
1. ``setup()`` builds ``token_aligners`` and ``teacher_tokenizers``
per teacher but never returned them. The runner unpacked the
9-tuple, never received the aligners, and ``cross_tokenizer_enabled``
in ``off_policy_distillation_train`` always evaluated to False —
skipping ``init_cross_tokenizer_loss_fn`` /
``update_cross_tokenizer_data``. Extend the return tuple by two
entries and have the runner pass them through as keyword args.
2. The runner passed ``config.get("env", {})`` to
``setup_response_data``. ``setup_response_data`` treats empty
dict as "envs are present" (only ``None`` skips them) and then
crashes on ``cfg["env_name"]`` for arrow_text data. Pass
``None`` when the config has no ``env`` key and unpack the
2-tuple return path; mirrors the SFT runner.
3. During validation, the call passed the driver-side ``loss_fn``
instance directly. For cross-tokenizer runs that's a
``MultiTeacherLossAggregator`` whose ``set_cross_tokenizer_data``
was never invoked — the per-step CT data was pushed to each
worker's ``_cached_loss_fn`` via ``update_cross_tokenizer_data``,
not the driver-side fn. Validation crashed with
``Cross-tokenizer data not set``. Pass ``loss_fn=None`` for
``MultiTeacherLossAggregator`` so workers fall back to
``_cached_loss_fn`` (which has the alignment data set from the
preceding training step). Plain (non-CT) runs unchanged.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
In off-policy distillation with keep_models_resident=False, the
algorithm calls student_policy.offload_after_refit() to free GPU
memory before each step's teacher inference. That moves the optimizer
state to CPU. The next student training step calls
student_policy.prepare_for_training(), but its existing onload gate
only triggers for the logprob/colocated-generation offload paths
(self.offload_optimizer_for_logprob or self.is_generation_colocated).
Off-policy distillation is neither, so the optimizer state stays on
CPU and the next optimizer.step() crashes with:
RuntimeError: Expected all tensors to be on the same device,
but found at least two devices, cuda:0 and cpu!
Step 1 happens to work because Adam's exp_avg/exp_avg_sq are not yet
materialized — move_optimizer_to_device is a no-op for fresh state.
Step 2 always crashes once those buffers exist.
Add a new worker method DTensorPolicyWorkerV2Impl.move_optimizer_to_cuda
that wraps the existing move_optimizer_to_device("cuda"), with a
matching NotImplementedError stub on MegatronPolicyWorkerImpl and a
dispatcher on Policy. Off-policy distillation calls it after
prepare_for_training only when keep_models_resident=False; it is a
no-op if the optimizer is already on CUDA. No change to the shared
prepare_for_training contract used by GRPO / SFT / on-policy
distillation.
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
Automodel's load-before-shard path materializes the full unsharded model on every rank before FSDP sharding (its `_should_load_before_shard` returns True for tp=ep=1, no PP/PEFT, regardless of FSDP DP world size). For Phi-4 14B that's ~56 GB of FP32 master weights per GPU, plus a ~28 GB BF16 state-dict stage in `_distribute_state_dict`, which OOMs an 80 GB H100 with `78.35 GB allocated by PyTorch` before the load completes. `validate_and_prepare_config` always pins `model_config.torch_dtype = float32` because FSDP2 mixed-precision needs FP32 master weights for AdamW state and grad-norm reductions on trainable policies. Inference- only paths (`init_optimizer=False`, e.g. teacher policies in off-policy distillation) never touch the optimizer code paths, so the FP32 master weights are pure load-time waste. Add `_apply_phi_inference_dtype_override`, a pre-load hook that runs between `validate_and_prepare_config` and `setup_distributed`. When both gates hold (model is Phi-style AND `init_optimizer=False`) it rewrites `runtime_config.model_config.torch_dtype` from FP32 to `runtime_config.dtype` (BF16 for the Phi-4 teacher YAML). Halves the load-time per-rank footprint to ~28 GB and clears the OOM. HF's `rope_init_fn` produces `inv_freq` in FP32 regardless of model dtype, so the existing post-load `_fix_phi_rope_meta_buffers` is unaffected. Refactor `_fix_phi_rope_meta_buffers` to share a `_is_phi_style_model` gate with the new hook. The shared gate accepts an optional `architectures` list so callers running before `self.model_config` is populated (the new pre-load hook) can pass it from `runtime_config.model_config`. Scoped to Phi-style models so SFT / GRPO / on-policy distillation paths and other model families keep the existing FP32-master-weights behavior; the override is also a no-op when an optimizer is requested (student policies stay on FP32). Verified on the Phi-4 (teacher, 14B) → Llama-3.1-8B (student) smoke run that previously OOM'd at "GPU 0 has 79.11 GB total ... 350 MiB to allocate". Same 2-node config now completes 100 training steps with the breadcrumb "[Phi inference dtype] loading microsoft/Phi-4 in torch.bfloat16 (no optimizer requested)" firing across all 16 teacher workers. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
Follow-up to the TokenAligner refactor on 01-tokenaligner that dropped align_fast() / precompute_canonical_maps(). This commit picks up the remaining call sites in the off-policy distillation orchestration. - TokenAlignerConfig: dropped use_align_fast typed-dict field. - setup(): removed the use_align_fast= pass-through into TeacherCTSpec and the token_aligner.precompute_canonical_maps() call (the helper no longer exists on TokenAligner; align() does the full sequence-level canonicalization on each call). - examples/configs/off_policy_distillation.yaml: removed the `use_align_fast: true` knob from each teacher's token_aligner block. Behavior: see the TokenAligner refactor commit for the alignment-result note on encoding-artifact / byte tokens. Same trade-off applies here. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds the full cross-tokenizer off-policy distillation training loop with single- and multi-teacher support, plus the worker-level plumbing. Builds on
TokenAligner(PR 1), data layer (PR 2), loss (PR 3), and IPC (PR 4).Algorithm and entrypoint mirror the on-policy distillation patterns in
nemo_rl/algorithms/distillation.pyandexamples/run_distillation.py:nemo_rl/algorithms/off_policy_distillation.py:setup,validate,off_policy_distillation_train. No env/rollout machinery; teacher logits arrive via CUDA IPC. Supports single-teacher (master_config["teacher"]) and multi-teacher (master_config["teachers"]list withTeacherSpecentries and per-teacher weights). Multi-teacher requiresuse_ipc=true; same-tokenizer + cross-tokenizer teachers can be mixed in one run.examples/run_off_policy_distillation.py: thin entrypoint.examples/configs/off_policy_distillation.yaml: production multi-teacher config (Phi-4-mini + Llama-3.2-3B teachers, Llama-3.2-1B student).data.train.arrow_filesleft null; override via Hydra CLI.Worker integration is purely additive — existing
train()/get_topk_logits/get_logprobspaths are not modified, so GRPO / SFT / on-policy distillation are unaffected:nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py: newinit_cross_tokenizer_loss_fn,update_cross_tokenizer_data,train_off_policy_distillation, andcompute_teacher_logits_ipcmethods onDTensorPolicyWorkerV2Impl.nemo_rl/models/policy/workers/megatron_policy_worker.py: same four method names asNotImplementedErrorstubs that point users toDTensorPolicyWorkerV2for off-policy distillation.nemo_rl/models/policy/lm_policy.py: matching dispatcher methods onPolicythat fan out to the worker group.What does this PR do?
Adds the cross-tokenizer off-policy distillation algorithm (single- and multi-teacher), the example entrypoint and YAML, and the worker dispatcher methods that wire it through
DTensorPolicyWorkerV2. This is the user-facing capstone of the 5-PR stack.Issues
None linked yet.
Usage
uv run examples/run_off_policy_distillation.py \ --config examples/configs/off_policy_distillation.yaml \ data.train.arrow_files=/path/to/your_dataset.arrow \ cluster.num_nodes=16The shipped YAML uses two teachers (Phi-4-mini + Llama-3.2-3B) with a Llama-3.2-1B student. For a single-teacher run, replace the
teachers:list with a singleteacher:block plus an optionaltoken_aligner:block — the algorithm normalizes both shapes through_normalize_teacher_specs.Before your PR is "Ready for review"
py_compileclean across all five branches in the stack. In-container import smoke test passed for the cross-tokenizer stack (TokenAligner, off-policy distillation entry, loss fn, collator, Megatron stubs). Full functional/CI run on this PR.docs/index.mdentry yet for off-policy distillation — happy to add as a doc-only follow-up PR.Additional Information
Draft. Final PR in the 5-PR stack. Stacked on PR 4 — #2350.
Megatron worker stubs (
raise NotImplementedError(... use DTensorPolicyWorkerV2 ...)) match the convention inruit/xtoken_rafactorso users hittingpolicy.megatron_cfg.enabled=trueget a clear error rather than an opaqueAttributeErrorfrom Ray.Full chain: