Skip to content

[3/5] feat: cross-tokenizer distillation loss and multi-teacher aggregator#2349

Closed
avenkateshha wants to merge 4 commits into
NVIDIA-NeMo:mainfrom
avenkateshha:avenkateshha/xtoken-off-policy-distillation/03-loss
Closed

[3/5] feat: cross-tokenizer distillation loss and multi-teacher aggregator#2349
avenkateshha wants to merge 4 commits into
NVIDIA-NeMo:mainfrom
avenkateshha:avenkateshha/xtoken-off-policy-distillation/03-loss

Conversation

@avenkateshha
Copy link
Copy Markdown

@avenkateshha avenkateshha commented Apr 27, 2026

Adds the loss-fn layer for cross-tokenizer distillation. Builds on the TokenAligner package (PR 1).

  • CrossTokenizerDistillationLossFn: per-token KL/CE loss over 1:1 aligned positions, with optional gold-loss path. Holds a reference to a TokenAligner; teacher data (input ids, aligned pairs, optional chunked COO masks) is set per-step via set_cross_tokenizer_data.
  • CrossTokenizerDistillationLossConfig and CrossTokenizerDistillationLossDataDict TypedDicts.
  • MultiTeacherLossAggregator: wraps a list of optional CrossTokenizerDistillationLossFn instances with per-teacher weights. N=1 is a degenerate case used by the unified single-/multi-teacher worker path; the algorithm-layer multi-teacher orchestration comes in PR 5.
  • _scatter_chunk_mask_from_coo helper for the chunked-CE path.

What does this PR do?

Adds CrossTokenizerDistillationLossFn and MultiTeacherLossAggregator — the per-token KL loss and weighted aggregator used by cross-tokenizer off-policy distillation.

Issues

None linked yet.

Usage

from nemo_rl.algorithms.x_token.tokenalign import TokenAligner
from nemo_rl.algorithms.loss.loss_functions import (
    CrossTokenizerDistillationLossFn,
    MultiTeacherLossAggregator,
)

aligner = TokenAligner(...)  # see PR 1
loss_cfg = {"loss_type": "KL", "temperature": 1.0, "vocab_topk": 8192,
            "exact_token_match_only": False, "reverse_kl": False}
loss_fn = CrossTokenizerDistillationLossFn(loss_cfg, aligner)

# Per training step (after collator produces per-teacher CT data):
loss_fn.set_cross_tokenizer_data(teacher_input_ids, aligned_pairs)

# Or for multi-teacher:
agg = MultiTeacherLossAggregator(
    [loss_fn_t0, loss_fn_t1], weights=[0.5, 0.5], cfg=loss_cfg,
)
agg.set_cross_tokenizer_data(teacher_input_ids_t0, aligned_pairs_t0, teacher_idx=0)
agg.set_cross_tokenizer_data(teacher_input_ids_t1, aligned_pairs_t1, teacher_idx=1)

Before your PR is "Ready for review"

  • Read Contributor guidelines
  • No new unit tests in this PR (next-step follow-up). Existing tests/unit/algorithms/test_loss_functions.py continues to pass for stock loss fns.
  • Static py_compile confirmed clean. Functional smoke covered by PR 5.
  • No docs entry — added alongside PR 5.

Additional Information

Draft. Stacked on PR 2 (collator) — #2348. The aggregator class is included here as a building block; no algorithm-level multi-teacher orchestration ships in this PR.

Full chain:

  1. TokenAligner + projection utilities — [1/5] feat: add TokenAligner and cross-tokenizer projection utilities #2347
  2. Collator + Arrow dataset + eval datasets — [2/5] feat: cross-tokenizer collator, Arrow dataset, and eval datasets #2348
  3. (this PR) CT distillation loss + multi-teacher aggregator
  4. CUDA IPC for teacher logits — [4/5] feat: CUDA IPC for teacher logits transfer #2350
  5. Algorithm + worker integration — [5/5] feat: off-policy distillation algorithm and worker integration #2351

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 27, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@avenkateshha avenkateshha changed the title feat: cross-tokenizer distillation loss and multi-teacher aggregator [3/5] feat: cross-tokenizer distillation loss and multi-teacher aggregator Apr 27, 2026
Foundational library code for cross-tokenizer distillation. No algorithm
or training-loop integration yet — those follow in subsequent PRs.

- nemo_rl/algorithms/x_token/tokenalign.py: TokenAligner(nn.Module) with
  Numba-accelerated DP alignment, projection-matrix loading
  (dense and sparse COO), and the project_token_likelihoods_instance
  forward path used by the cross-tokenizer loss.
- nemo_rl/algorithms/x_token/__init__.py: package init.
- nemo_rl/utils/x_token/{minimal_projection_generator,
  minimal_projection_via_multitoken,reapply_exact_map,
  sort_and_cut_projection_matrix}.py: standalone CLI scripts
  (argparse-driven, __main__ entrypoints) for one-time projection-matrix
  preparation. Not on the training import path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
@avenkateshha avenkateshha force-pushed the avenkateshha/xtoken-off-policy-distillation/03-loss branch from 4d61ed1 to 775898c Compare April 27, 2026 10:21
avenkateshha and others added 2 commits April 27, 2026 03:25
Data-layer plumbing for cross-tokenizer off-policy distillation, plus
in-training eval datasets. Builds on the TokenAligner package from the
prior PR.

- nemo_rl/data/cross_tokenizer_collate.py: CrossTokenizerCollator and
  TeacherCTSpec. Runs in StatefulDataLoader worker processes — does
  per-teacher tokenize + DP alignment up front so the train loop only
  consumes pre-built per_teacher_ct_data. Lazy-imports TokenAligner so
  workers that don't need cross-tokenizer never touch x_token.
- nemo_rl/data/__init__.py: add NotRequired prefetch_factor to DataConfig.
- nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py:
  ArrowTextDataset with lazy packing, registered as "arrow_text" in
  DATASET_REGISTRY.
- nemo_rl/data/datasets/eval_datasets/{humaneval_plus,mbpp_plus,mmlu}.py
  and registry entries: in-training eval datasets. mmlu.py adds an
  optional num_few_shot argument with a static _build_few_shot_prefixes
  helper; default of 0 preserves existing behavior.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
Adds the loss-fn layer for cross-tokenizer distillation. Builds on the
TokenAligner package (PR 1).

- CrossTokenizerDistillationLossFn: per-token KL/CE loss over 1:1 aligned
  positions, with optional gold-loss path. Holds a reference to a
  TokenAligner; teacher data (input_ids, aligned_pairs, optional chunked
  COO masks) is set per-step via set_cross_tokenizer_data.
- CrossTokenizerDistillationLossConfig and
  CrossTokenizerDistillationLossDataDict TypedDicts.
- MultiTeacherLossAggregator: wraps a list of optional
  CrossTokenizerDistillationLossFn instances with per-teacher weights.
  N=1 is a degenerate case used by the unified single-/multi-teacher
  worker path; the algorithm-layer multi-teacher orchestration comes in
  a later PR.
- _scatter_chunk_mask_from_coo helper for the chunked-CE path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
@avenkateshha avenkateshha force-pushed the avenkateshha/xtoken-off-policy-distillation/03-loss branch from 775898c to 70a14eb Compare April 27, 2026 10:25
The cross-tokenizer loss path uses ``math.log`` (vocab-size
normalization), ``time.perf_counter`` (per-step timing), and
``typing.Union`` (type annotation) but none were imported. Module
load passed because all three are referenced inside function bodies;
the NameError only fires at training time.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
@avenkateshha avenkateshha deleted the avenkateshha/xtoken-off-policy-distillation/03-loss branch May 16, 2026 01:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants