Skip to content

[5/5] feat: off-policy distillation algorithm and worker integration#2351

Closed
avenkateshha wants to merge 11 commits into
NVIDIA-NeMo:mainfrom
avenkateshha:avenkateshha/xtoken-off-policy-distillation/05-off-policy-distillation
Closed

[5/5] feat: off-policy distillation algorithm and worker integration#2351
avenkateshha wants to merge 11 commits into
NVIDIA-NeMo:mainfrom
avenkateshha:avenkateshha/xtoken-off-policy-distillation/05-off-policy-distillation

Conversation

@avenkateshha
Copy link
Copy Markdown

@avenkateshha avenkateshha commented Apr 27, 2026

Adds the full cross-tokenizer off-policy distillation training loop with single- and multi-teacher support, plus the worker-level plumbing. Builds on TokenAligner (PR 1), data layer (PR 2), loss (PR 3), and IPC (PR 4).

Algorithm and entrypoint mirror the on-policy distillation patterns in nemo_rl/algorithms/distillation.py and examples/run_distillation.py:

  • nemo_rl/algorithms/off_policy_distillation.py: setup, validate, off_policy_distillation_train. No env/rollout machinery; teacher logits arrive via CUDA IPC. Supports single-teacher (master_config["teacher"]) and multi-teacher (master_config["teachers"] list with TeacherSpec entries and per-teacher weights). Multi-teacher requires use_ipc=true; same-tokenizer + cross-tokenizer teachers can be mixed in one run.
  • examples/run_off_policy_distillation.py: thin entrypoint.
  • examples/configs/off_policy_distillation.yaml: production multi-teacher config (Phi-4-mini + Llama-3.2-3B teachers, Llama-3.2-1B student). data.train.arrow_files left null; override via Hydra CLI.

Worker integration is purely additive — existing train() / get_topk_logits / get_logprobs paths are not modified, so GRPO / SFT / on-policy distillation are unaffected:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py: new init_cross_tokenizer_loss_fn, update_cross_tokenizer_data, train_off_policy_distillation, and compute_teacher_logits_ipc methods on DTensorPolicyWorkerV2Impl.
  • nemo_rl/models/policy/workers/megatron_policy_worker.py: same four method names as NotImplementedError stubs that point users to DTensorPolicyWorkerV2 for off-policy distillation.
  • nemo_rl/models/policy/lm_policy.py: matching dispatcher methods on Policy that fan out to the worker group.

What does this PR do?

Adds the cross-tokenizer off-policy distillation algorithm (single- and multi-teacher), the example entrypoint and YAML, and the worker dispatcher methods that wire it through DTensorPolicyWorkerV2. This is the user-facing capstone of the 5-PR stack.

Issues

None linked yet.

Usage

uv run examples/run_off_policy_distillation.py \
    --config examples/configs/off_policy_distillation.yaml \
    data.train.arrow_files=/path/to/your_dataset.arrow \
    cluster.num_nodes=16

The shipped YAML uses two teachers (Phi-4-mini + Llama-3.2-3B) with a Llama-3.2-1B student. For a single-teacher run, replace the teachers: list with a single teacher: block plus an optional token_aligner: block — the algorithm normalizes both shapes through _normalize_teacher_specs.

Before your PR is "Ready for review"

  • Read Contributor guidelines
  • No new unit tests added — algorithm-level coverage is a follow-up.
  • Static py_compile clean across all five branches in the stack. In-container import smoke test passed for the cross-tokenizer stack (TokenAligner, off-policy distillation entry, loss fn, collator, Megatron stubs). Full functional/CI run on this PR.
  • No docs/index.md entry yet for off-policy distillation — happy to add as a doc-only follow-up PR.

Additional Information

Draft. Final PR in the 5-PR stack. Stacked on PR 4 — #2350.

Megatron worker stubs (raise NotImplementedError(... use DTensorPolicyWorkerV2 ...)) match the convention in ruit/xtoken_rafactor so users hitting policy.megatron_cfg.enabled=true get a clear error rather than an opaque AttributeError from Ray.

Full chain:

  1. TokenAligner + projection utilities — [1/5] feat: add TokenAligner and cross-tokenizer projection utilities #2347
  2. Collator + Arrow dataset + eval datasets — [2/5] feat: cross-tokenizer collator, Arrow dataset, and eval datasets #2348
  3. CT distillation loss + multi-teacher aggregator — [3/5] feat: cross-tokenizer distillation loss and multi-teacher aggregator #2349
  4. CUDA IPC for teacher logits — [4/5] feat: CUDA IPC for teacher logits transfer #2350
  5. (this PR) Algorithm + worker integration

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 27, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@avenkateshha avenkateshha changed the title feat: off-policy distillation algorithm and worker integration [5/5] feat: off-policy distillation algorithm and worker integration Apr 27, 2026
Foundational library code for cross-tokenizer distillation. No algorithm
or training-loop integration yet — those follow in subsequent PRs.

- nemo_rl/algorithms/x_token/tokenalign.py: TokenAligner(nn.Module) with
  Numba-accelerated DP alignment, projection-matrix loading
  (dense and sparse COO), and the project_token_likelihoods_instance
  forward path used by the cross-tokenizer loss.
- nemo_rl/algorithms/x_token/__init__.py: package init.
- nemo_rl/utils/x_token/{minimal_projection_generator,
  minimal_projection_via_multitoken,reapply_exact_map,
  sort_and_cut_projection_matrix}.py: standalone CLI scripts
  (argparse-driven, __main__ entrypoints) for one-time projection-matrix
  preparation. Not on the training import path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
@avenkateshha avenkateshha force-pushed the avenkateshha/xtoken-off-policy-distillation/05-off-policy-distillation branch from 9fe4b94 to 65d573f Compare April 27, 2026 10:21
avenkateshha and others added 4 commits April 27, 2026 03:25
Data-layer plumbing for cross-tokenizer off-policy distillation, plus
in-training eval datasets. Builds on the TokenAligner package from the
prior PR.

- nemo_rl/data/cross_tokenizer_collate.py: CrossTokenizerCollator and
  TeacherCTSpec. Runs in StatefulDataLoader worker processes — does
  per-teacher tokenize + DP alignment up front so the train loop only
  consumes pre-built per_teacher_ct_data. Lazy-imports TokenAligner so
  workers that don't need cross-tokenizer never touch x_token.
- nemo_rl/data/__init__.py: add NotRequired prefetch_factor to DataConfig.
- nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py:
  ArrowTextDataset with lazy packing, registered as "arrow_text" in
  DATASET_REGISTRY.
- nemo_rl/data/datasets/eval_datasets/{humaneval_plus,mbpp_plus,mmlu}.py
  and registry entries: in-training eval datasets. mmlu.py adds an
  optional num_few_shot argument with a static _build_few_shot_prefixes
  helper; default of 0 preserves existing behavior.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
Adds the loss-fn layer for cross-tokenizer distillation. Builds on the
TokenAligner package (PR 1).

- CrossTokenizerDistillationLossFn: per-token KL/CE loss over 1:1 aligned
  positions, with optional gold-loss path. Holds a reference to a
  TokenAligner; teacher data (input_ids, aligned_pairs, optional chunked
  COO masks) is set per-step via set_cross_tokenizer_data.
- CrossTokenizerDistillationLossConfig and
  CrossTokenizerDistillationLossDataDict TypedDicts.
- MultiTeacherLossAggregator: wraps a list of optional
  CrossTokenizerDistillationLossFn instances with per-teacher weights.
  N=1 is a degenerate case used by the unified single-/multi-teacher
  worker path; the algorithm-layer multi-teacher orchestration comes in
  a later PR.
- _scatter_chunk_mask_from_coo helper for the chunked-CE path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
Adds the IPC plumbing that lets a teacher policy worker hand its logits
to the student worker without going through Ray's serialization path —
required for cross-tokenizer distillation where teacher full-vocab logits
are too big to pickle per step.

- nemo_rl/distributed/ipc_utils.py: get_handle_from_tensor and
  rebuild_cuda_tensor_from_ipc helpers wrapping CUDA IPC handles.
- nemo_rl/models/automodel/train.py: two new post-processors —
  XTokenTeacherIPCExportPostProcessor (teacher side, allocates a
  pre-sized CUDA buffer and exports the IPC handle per microbatch) and
  XTokenTeacherIPCLossPostProcessor (student side, rebuilds the tensor
  from the handle and feeds it to the loss fn). Existing post-processors
  are untouched.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
Adds the full cross-tokenizer off-policy distillation training loop with
single- and multi-teacher support, plus the worker-level plumbing.
Builds on TokenAligner (PR 1), data layer (PR 2), loss (PR 3), and IPC
(PR 4).

Algorithm and entrypoint mirror the on-policy distillation patterns in
nemo_rl/algorithms/distillation.py and examples/run_distillation.py:

- nemo_rl/algorithms/off_policy_distillation.py: setup, validate,
  off_policy_distillation_train. No env/rollout machinery; teacher
  logits arrive via CUDA IPC. Supports single-teacher (master_config
  ["teacher"]) and multi-teacher (master_config["teachers"] list with
  TeacherSpec entries and per-teacher weights). Multi-teacher requires
  use_ipc=true; same-tokenizer + cross-tokenizer teachers can be mixed
  in one run.
- examples/run_off_policy_distillation.py: thin entrypoint.
- examples/configs/off_policy_distillation.yaml: production
  multi-teacher config (Phi-4-mini + Llama-3.2-3B teachers, Llama-3.2-1B
  student). data.train.arrow_files left null; override via Hydra CLI.

Worker integration is purely additive — existing train() / get_topk_logits
/ get_logprobs paths are not modified, so GRPO / SFT / on-policy
distillation are unaffected:

- nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py: new
  init_cross_tokenizer_loss_fn, update_cross_tokenizer_data,
  train_off_policy_distillation, and compute_teacher_logits_ipc methods
  on DTensorPolicyWorkerV2Impl.
- nemo_rl/models/policy/workers/megatron_policy_worker.py: same four
  method names as NotImplementedError stubs that point users to
  DTensorPolicyWorkerV2 for off-policy distillation.
- nemo_rl/models/policy/lm_policy.py: matching dispatcher methods on
  Policy that fan out to the worker group.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
@avenkateshha avenkateshha force-pushed the avenkateshha/xtoken-off-policy-distillation/05-off-policy-distillation branch 2 times, most recently from caab0e3 to f9d931b Compare April 28, 2026 02:22
avenkateshha and others added 6 commits April 27, 2026 19:22
Off-policy distillation needs two driver-side methods that fan out
to all policy workers:

* ``init_cross_tokenizer_loss_fn(loss_config, token_aligner_config)``
  — each worker materializes its own ``MultiTeacherLossAggregator``
  (with N=1 for single-teacher) from the shared filesystem and
  caches it as ``self._cached_loss_fn``.
* ``update_cross_tokenizer_data(teacher_input_ids, aligned_pairs,
  teacher_idx, chunk_indices)`` — push per-step CT data to the
  cached loss fn on every worker, sharded by DP axis so each
  worker only receives its slice. Uses
  ``run_all_workers_multiple_data`` for the per-rank sharded
  dispatch.

The corresponding worker-side methods already exist in
``DTensorPolicyWorkerV2``; what was missing was the driver-side
wrapper. Without these, ``off_policy_distillation_train`` calling
``student_policy.init_cross_tokenizer_loss_fn(...)`` /
``student_policy.update_cross_tokenizer_data(...)`` would
AttributeError. Pattern matches existing wrappers like
``init_collective`` and ``prepare_for_training`` in the same file.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
…spatch

Two refactor regressions in the cross-tokenizer off-policy worker
surfaced once the smoke run cleared earlier setup-side bugs:

1. Phi-4 / Phi-3 ``trust_remote_code`` revisions register
   ``inv_freq`` and ``original_inv_freq`` as buffers that get left
   on the ``meta`` device after FSDP/DTensor materialization. Their
   forward later calls ``.to(device)`` on those buffers and crashes
   with ``NotImplementedError: Cannot copy out of meta tensor``. Add
   ``_fix_phi_rope_meta_buffers()`` (called once after model setup
   in ``__init__``) which re-runs each module's ``rope_init_fn``
   against CUDA and replaces the meta buffer in place.

   Scoped to Phi-style models via the model_name / architecture
   string so other model families short-circuit immediately. The
   except is narrowed to ``(TypeError, RuntimeError, AttributeError)``
   per the error-handling skill (no bare ``except: pass``); a skipped
   module logs at rank 0.

2. The ``teacher_worker_result`` dispatch had grown a multi-arm
   condition that misclassified the single-teacher payload shape
   ``list[rank]`` of dicts as a multi-teacher list and passed all
   16 entries to the loss postprocessor, which then KeyError'd on
   ``handle[rank]``. Off-policy distillation always pre-shards by
   rank before the worker sees the payload (single-teacher:
   ``list[rank]``; multi-teacher: ``dict[rank, list[T]]``), so plain
   ``teacher_logits[rank]`` gives the right per-rank entry in both
   cases. Revert to the simpler form.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
…hed loss fn fallback

Three off-policy distillation regressions surfaced by the 100-step
single-teacher Phi-4 cross-tokenizer smoke run:

1. ``setup()`` builds ``token_aligners`` and ``teacher_tokenizers``
   per teacher but never returned them. The runner unpacked the
   9-tuple, never received the aligners, and ``cross_tokenizer_enabled``
   in ``off_policy_distillation_train`` always evaluated to False —
   skipping ``init_cross_tokenizer_loss_fn`` /
   ``update_cross_tokenizer_data``. Extend the return tuple by two
   entries and have the runner pass them through as keyword args.

2. The runner passed ``config.get("env", {})`` to
   ``setup_response_data``. ``setup_response_data`` treats empty
   dict as "envs are present" (only ``None`` skips them) and then
   crashes on ``cfg["env_name"]`` for arrow_text data. Pass
   ``None`` when the config has no ``env`` key and unpack the
   2-tuple return path; mirrors the SFT runner.

3. During validation, the call passed the driver-side ``loss_fn``
   instance directly. For cross-tokenizer runs that's a
   ``MultiTeacherLossAggregator`` whose ``set_cross_tokenizer_data``
   was never invoked — the per-step CT data was pushed to each
   worker's ``_cached_loss_fn`` via ``update_cross_tokenizer_data``,
   not the driver-side fn. Validation crashed with
   ``Cross-tokenizer data not set``. Pass ``loss_fn=None`` for
   ``MultiTeacherLossAggregator`` so workers fall back to
   ``_cached_loss_fn`` (which has the alignment data set from the
   preceding training step). Plain (non-CT) runs unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
In off-policy distillation with keep_models_resident=False, the
algorithm calls student_policy.offload_after_refit() to free GPU
memory before each step's teacher inference. That moves the optimizer
state to CPU. The next student training step calls
student_policy.prepare_for_training(), but its existing onload gate
only triggers for the logprob/colocated-generation offload paths
(self.offload_optimizer_for_logprob or self.is_generation_colocated).
Off-policy distillation is neither, so the optimizer state stays on
CPU and the next optimizer.step() crashes with:

  RuntimeError: Expected all tensors to be on the same device,
                but found at least two devices, cuda:0 and cpu!

Step 1 happens to work because Adam's exp_avg/exp_avg_sq are not yet
materialized — move_optimizer_to_device is a no-op for fresh state.
Step 2 always crashes once those buffers exist.

Add a new worker method DTensorPolicyWorkerV2Impl.move_optimizer_to_cuda
that wraps the existing move_optimizer_to_device("cuda"), with a
matching NotImplementedError stub on MegatronPolicyWorkerImpl and a
dispatcher on Policy. Off-policy distillation calls it after
prepare_for_training only when keep_models_resident=False; it is a
no-op if the optimizer is already on CUDA. No change to the shared
prepare_for_training contract used by GRPO / SFT / on-policy
distillation.

Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
Automodel's load-before-shard path materializes the full unsharded
model on every rank before FSDP sharding (its
`_should_load_before_shard` returns True for tp=ep=1, no PP/PEFT,
regardless of FSDP DP world size). For Phi-4 14B that's ~56 GB of FP32
master weights per GPU, plus a ~28 GB BF16 state-dict stage in
`_distribute_state_dict`, which OOMs an 80 GB H100 with `78.35 GB
allocated by PyTorch` before the load completes.

`validate_and_prepare_config` always pins `model_config.torch_dtype =
float32` because FSDP2 mixed-precision needs FP32 master weights for
AdamW state and grad-norm reductions on trainable policies. Inference-
only paths (`init_optimizer=False`, e.g. teacher policies in off-policy
distillation) never touch the optimizer code paths, so the FP32 master
weights are pure load-time waste.

Add `_apply_phi_inference_dtype_override`, a pre-load hook that runs
between `validate_and_prepare_config` and `setup_distributed`. When
both gates hold (model is Phi-style AND `init_optimizer=False`) it
rewrites `runtime_config.model_config.torch_dtype` from FP32 to
`runtime_config.dtype` (BF16 for the Phi-4 teacher YAML). Halves the
load-time per-rank footprint to ~28 GB and clears the OOM. HF's
`rope_init_fn` produces `inv_freq` in FP32 regardless of model dtype,
so the existing post-load `_fix_phi_rope_meta_buffers` is unaffected.

Refactor `_fix_phi_rope_meta_buffers` to share a `_is_phi_style_model`
gate with the new hook. The shared gate accepts an optional
`architectures` list so callers running before `self.model_config` is
populated (the new pre-load hook) can pass it from
`runtime_config.model_config`.

Scoped to Phi-style models so SFT / GRPO / on-policy distillation
paths and other model families keep the existing FP32-master-weights
behavior; the override is also a no-op when an optimizer is requested
(student policies stay on FP32).

Verified on the Phi-4 (teacher, 14B) → Llama-3.1-8B (student) smoke
run that previously OOM'd at "GPU 0 has 79.11 GB total ... 350 MiB to
allocate". Same 2-node config now completes 100 training steps with
the breadcrumb "[Phi inference dtype] loading microsoft/Phi-4 in
torch.bfloat16 (no optimizer requested)" firing across all 16 teacher
workers.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
Follow-up to the TokenAligner refactor on 01-tokenaligner that dropped
align_fast() / precompute_canonical_maps(). This commit picks up the
remaining call sites in the off-policy distillation orchestration.

- TokenAlignerConfig: dropped use_align_fast typed-dict field.
- setup(): removed the use_align_fast= pass-through into TeacherCTSpec
  and the token_aligner.precompute_canonical_maps() call (the helper no
  longer exists on TokenAligner; align() does the full sequence-level
  canonicalization on each call).
- examples/configs/off_policy_distillation.yaml: removed the
  `use_align_fast: true` knob from each teacher's token_aligner block.

Behavior: see the TokenAligner refactor commit for the alignment-result
note on encoding-artifact / byte tokens. Same trade-off applies here.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
@avenkateshha avenkateshha deleted the avenkateshha/xtoken-off-policy-distillation/05-off-policy-distillation branch May 16, 2026 01:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants