Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 218 additions & 0 deletions docs/guides/xtoken-distillation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Cross-Tokenizer (X-Token) Off-Policy Distillation

NeMo RL supports off-policy distillation between a student and a teacher that
**do not share a tokenizer** — for example, distilling a Qwen3-4B teacher into
a Llama-3.2-1B student. Cross-tokenizer ("x-token") distillation handles the
vocabulary mismatch by routing teacher logits through a precomputed
**projection matrix** that maps each student token to the teacher tokens it
most plausibly corresponds to.

This guide explains how to:

1. Produce the projection matrix from a (student, teacher) tokenizer pair
2. Launch a distillation run that consumes it

## How it works

A full run has two phases. The first three steps are *offline data prep* —
small CLI tools you run once per (student, teacher) pair — and the result is a
single `.pt` file. The fourth step is the actual distillation training loop.

```
┌──────────────────────────────────────────────┐
│ Offline projection-matrix preparation │
│ │
(student, teacher) │ ┌────────────────────────────────────┐ │
tokenizers + base ───▶│ │ 1. minimal_projection_generator.py │ │
embedding model │ │ — embedding-similarity top-k │ │
│ └─────────────────┬──────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────┐ │
│ │ 2. minimal_projection_via_ │ │
│ │ multitoken.py │ │
│ │ — add multi-token mappings │ │
│ └─────────────────┬──────────────────┘ │
│ │ │
│ ┌─────────────────▼──────────────────┐ │
│ │ 3. (optional) reapply_exact_map.py │ │
│ │ — pin exact 1-to-1 matches │ │
│ └─────────────────┬──────────────────┘ │
│ │ │
│ ┌─────────────────▼──────────────────┐ │
│ │ 4. sort_and_cut_projection_matrix │ │
│ │ .py — trim to runtime top_k │ │
│ └─────────────────┬──────────────────┘ │
└────────────────────│─────────────────────────┘
▼ projection_matrix.pt
┌──────────────────────────────────────────────┐
│ 5. examples/run_xtoken_distillation.py │
│ — student forward + teacher forward │
│ (via CUDA-IPC), x-token KD loss │
└──────────────────────────────────────────────┘
```

The projection matrix is a sparse `[V_student, top_k]` tensor that the
training-time loss multiplies against the student logits to project them into
the teacher's vocab space (or vice versa, depending on the loss mode).

## Backend and scope

- **DTensor V2 only.** Set `policy.dtensor_cfg.enabled=true` and
`policy.dtensor_cfg._v2=true`. The Megatron path is intentionally stubbed
with `NotImplementedError`.
- **Teacher logits travel via CUDA IPC**, so student and teacher policies must
be colocated on the same node. No remote-Ray transport for x-token logits.
- **No sequence packing or dynamic batching for the teacher forward** in v0.
- The corpus must be served via the `arrow_text` dataset (no chat template,
loss on every token — see `examples/configs/xtoken_distillation.yaml`).

## Step 1 — Generate the base projection matrix

`minimal_projection_generator.py` walks both vocabularies, embeds every token
with a small embedding LLM (or a sentence-transformers model), and stores the
top-`k` teacher tokens by cosine similarity for each student token.

```bash
uv run python -m nemo_rl.utils.x_token.minimal_projection_generator \
--student-model "meta-llama/Llama-3.2-1B" \
--teacher-model "Qwen/Qwen3-4B" \
--top_k 32 \
--force_recompute \
--data_dir cross_tokenizer_data/
```

Both `--student-model` and `--teacher-model` are required and **not swapped**
— the projection direction follows the CLI args exactly. Output lands at
`cross_tokenizer_data/temp_projection_map_Llama-3.2_to_Qwen3_top_32.pt`.

If you pick an `embedding_model_type == "sbert"` choice from
`EMBEDDING_MODEL_CHOICES`, install `sentence-transformers` first; the script
falls back to a clear `ImportError` otherwise. The default
`embedding_model_index = 3` uses `Qwen/Qwen3-Embedding-4B` and does not need
`sentence-transformers`.

## Step 2 — Add multi-token mappings

Many student tokens (e.g., `"12"`) tokenize into multiple teacher tokens
(e.g., `"1"`, `"2"`). `minimal_projection_via_multitoken.py` walks the
student vocab, re-tokenizes each token with the teacher tokenizer, and adds
weighted entries to the projection. With `--enable-reverse-pass` it also
does the symmetric teacher → student walk.

```bash
uv run python -m nemo_rl.utils.x_token.minimal_projection_via_multitoken \
--student-model "meta-llama/Llama-3.2-1B" \
--teacher-model "Qwen/Qwen3-4B" \
--initial-projection-path cross_tokenizer_data/temp_projection_map_Llama-3.2_to_Qwen3_top_32.pt \
--top-k 32 \
--enable-scale-trick \
--enable-reverse-pass \
--enable-special-token-mapping
```

Output: `cross_tokenizer_data/projection_map_Llama-3.2_to_Qwen3_multitoken_top_32_double_special.pt`.

Pass `--num-examples 50` to print a sample of student→teacher mappings after
the matrix is built — useful for spot-checking that special tokens, numerals,
and punctuation map to sensible teacher tokens.

When `--enable-scale-trick` is set, the script records `enable_scale_trick=True`
in the saved `.pt` so Step 4 can auto-enable `--preserve_last`.

## Step 3 (optional) — Reapply exact-token map

Some token pairs are *literally identical* (e.g., common punctuation, single
ASCII characters). `reapply_exact_map.py` pins those to 1-to-1 mappings with
weight 1.0, overwriting whatever Steps 1–2 produced for them.

```bash
uv run python -m nemo_rl.utils.x_token.reapply_exact_map \
--student-model "meta-llama/Llama-3.2-1B" \
--teacher-model "Qwen/Qwen3-4B" \
--initial-projection-path cross_tokenizer_data/projection_map_Llama-3.2_to_Qwen3_multitoken_top_32_double_special.pt
```

Output is written next to the input as `<basename>_exact_map_remapped.pt`.

## Step 4 — Sort and trim to runtime `top_k`

The training loss only needs a small `top_k` per row (typical: 4–8). This
step sorts each row by weight and trims to the chosen runtime cap.

```bash
uv run python -m nemo_rl.utils.x_token.sort_and_cut_projection_matrix \
--initial-projection-path cross_tokenizer_data/projection_map_Llama-3.2_to_Qwen3_multitoken_top_32_double_special_exact_map_remapped.pt \
--top_k 4 \
--output_path cross_tokenizer_data/projection_matrix_llama_qwen_top4.pt
```

`--preserve_last` is `argparse.BooleanOptionalAction` with default `None`. When
unspecified, the script reads `enable_scale_trick` from the input matrix's
metadata (set in Step 2) and auto-enables preservation of the last column
slot. Pass `--preserve_last` or `--no-preserve_last` to override.

## Step 5 — Launch x-token distillation

The training entrypoint is `examples/run_xtoken_distillation.py` with the
exemplar config at `examples/configs/xtoken_distillation.yaml`. The exemplar
defaults to Llama-3.2-1B (student) ← Qwen3-4B (teacher), an arrow-text
corpus, and the P-KL loss mode. Override paths via Hydra CLI:

```bash
uv run python examples/run_xtoken_distillation.py \
--config examples/configs/xtoken_distillation.yaml \
loss_fn.projection_matrix_path=cross_tokenizer_data/projection_matrix_llama_qwen_top4.pt \
data.train.arrow_files=/path/to/corpus/*.arrow \
cluster.gpus_per_node=8 \
cluster.num_nodes=1
```

The exemplar config keeps `loss_fn.projection_matrix_path` and
`data.train.arrow_files` as `null` so they must be supplied at the CLI — this
makes the config reusable across (student, teacher) pairs.

### Loss-mode knobs

`loss_fn` has two flags that pick between three behaviors:

| `gold_loss` | `xtoken_loss` | Behavior |
|---|---|---|
| `false` | (inert) | **P-KL** — full-vocab teacher logits via CUDA IPC; the loss derives a microbatch-global top-k inside, projects the student into teacher vocab via the projection matrix, and chunk-averages KL on the top-k subset. CE term is added. |
| `true` | `false` | **Gold loss** (PT-faithful) — split the vocab into an *exact-token-mapped* common set (KL) and an *uncommon* tail (sorted L1). |
| `true` | `true` | **Gold + x-token loss** — same as gold, but relax the exact-map threshold to `>= 0.6` and allow multi-token projections to count as exact maps via a collision-replacement rule. |

Other relevant fields:

- `loss_fn.temperature` — softmax temperature applied symmetrically to student and teacher logits before KL.
- `loss_fn.vocab_topk` — microbatch-global top-k size for the P-KL path (inert when `gold_loss=true`).
- `loss_fn.uncommon_topk` — cap on the L1 uncommon-tail sort in the gold path (defaults to PT's hardcoded 8192).
- `loss_fn.reverse_kl` — compute `KL(student || teacher)` instead of `KL(teacher || student)`.

## Other (student, teacher) pairs

The same pipeline works for any HuggingFace tokenizer pair. Two worked
examples — Llama → Gemma and Llama → Phi — only differ in the
`--student-model` / `--teacher-model` arguments to Steps 1 and 2.

For Phi-3 / Phi-4 family teachers, also export
`NRL_TRUST_REMOTE_CODE=false` and `NRL_SKIP_PHI_ROPE_FIX=1` in the
training environment so the in-tree HuggingFace implementation is used.

## Where files live

| Stage | Tool | Default output |
|---|---|---|
| Generate base | `nemo_rl/utils/x_token/minimal_projection_generator.py` | `<data_dir>/temp_projection_map_<student>_to_<teacher>_top_<N>.pt` |
| Add multi-token | `nemo_rl/utils/x_token/minimal_projection_via_multitoken.py` | `<output_dir>/projection_map_<student>_to_<teacher>_multitoken_top_<N>_double[_special].pt` |
| Reapply exact map | `nemo_rl/utils/x_token/reapply_exact_map.py` | `<input>_exact_map_remapped.pt` |
| Sort and trim | `nemo_rl/utils/x_token/sort_and_cut_projection_matrix.py` | `<input_dir>/<basename>_top_<N>_sorted[_preservelast].pt` (or `--output_path`) |
| Train | `examples/run_xtoken_distillation.py` | per the run's `logger.log_dir` and `checkpointing.checkpoint_dir` |

## Related

- Config exemplar: [`examples/configs/xtoken_distillation.yaml`](../../examples/configs/xtoken_distillation.yaml)
- Loss implementation: `nemo_rl/algorithms/loss/loss_functions.py::CrossTokenizerDistillationLossFn`
- Token alignment: `nemo_rl/algorithms/x_token/tokenalign.py::TokenAligner`
- Same-tokenizer distillation: [Quantization-Aware RL](quantization-aware-rl.md) (the QA-Distillation workflow uses the same training entrypoint with a same-tokenizer teacher).
8 changes: 8 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ Learn how to add support for new model architectures in NeMo RL.
Extend a model's context window with YaRN RoPE scaling on the Megatron backend for SFT, GRPO, and other workflows.
:::

:::{grid-item-card} {octicon}`git-compare` Cross-Tokenizer Distillation
:link: guides/xtoken-distillation
:link-type: doc

Off-policy distillation across mismatched tokenizers — build a (student, teacher) projection matrix and run x-token KD via CUDA-IPC teacher logits.
:::

::::

## Advanced Topics
Expand Down Expand Up @@ -251,6 +258,7 @@ guides/async-grpo.md
guides/quantization-aware-rl.md
guides/eagle3-speculative-decoding.md
guides/yarn-long-context.md
guides/xtoken-distillation.md
guides/muon-optimizer.md
guides/dtensor-tp-accuracy.md
guides/ft-launcher-guide.md
Expand Down
Loading
Loading