Skip to content

[1/5] feat: add TokenAligner and cross-tokenizer projection utilities#2347

Closed
avenkateshha wants to merge 2 commits into
NVIDIA-NeMo:mainfrom
avenkateshha:avenkateshha/xtoken-off-policy-distillation/01-tokenaligner
Closed

[1/5] feat: add TokenAligner and cross-tokenizer projection utilities#2347
avenkateshha wants to merge 2 commits into
NVIDIA-NeMo:mainfrom
avenkateshha:avenkateshha/xtoken-off-policy-distillation/01-tokenaligner

Conversation

@avenkateshha
Copy link
Copy Markdown

@avenkateshha avenkateshha commented Apr 27, 2026

Foundational library code for cross-tokenizer distillation. No algorithm or training-loop integration yet — those follow in subsequent PRs.

  • nemo_rl/algorithms/x_token/tokenalign.py: TokenAligner(nn.Module) with Numba-accelerated DP alignment, projection-matrix loading (dense and sparse COO), and the project_token_likelihoods_instance forward path used by the cross-tokenizer loss.
  • nemo_rl/algorithms/x_token/__init__.py: package init.
  • nemo_rl/utils/x_token/{minimal_projection_generator, minimal_projection_via_multitoken, reapply_exact_map, sort_and_cut_projection_matrix}.py: standalone CLI scripts (argparse-driven, __main__ entrypoints) for one-time projection-matrix preparation. Not on the training import path.

What does this PR do?

Introduces the TokenAligner class and projection-matrix preparation tools that subsequent PRs build on for cross-tokenizer off-policy distillation.

Issues

None linked yet.

Usage

from nemo_rl.algorithms.x_token.tokenalign import TokenAligner

aligner = TokenAligner(
    teacher_tokenizer_name="microsoft/Phi-4-mini-instruct",
    student_tokenizer_name="meta-llama/Llama-3.2-1B",
    max_comb_len=4,
)
aligner._load_logits_projection_map(
    file_path="cross_tokenizer_data/llama_phi-mini_proj.pt",
    use_sparse_format=True,
    learnable=False,
    device="cpu",
)
aligner.precompute_canonical_maps()
# `aligner` now exposes `align_fast`, `align`, and
# `project_token_likelihoods_instance` for use by the loss fn (PR 3).

Before your PR is "Ready for review"

  • Read Contributor guidelines
  • No new tests in this PR. TokenAligner is exercised by the loss tests landing with PR 3 and by the off-policy distillation recipe in PR 5.
  • Functional/unit suite not run end-to-end against this PR in isolation; static py_compile confirmed clean. CI on this PR will be the first full run.
  • No docs entry yet — docs/index.md will be updated alongside the algorithm PR (PR 5).

Additional Information

First in a 5-PR stack splitting the cross-tokenizer off-policy distillation feature for review:

  1. (this PR) TokenAligner + projection utilities
  2. Collator + Arrow dataset + eval datasets — [2/5] feat: cross-tokenizer collator, Arrow dataset, and eval datasets #2348
  3. Cross-tokenizer distillation loss + multi-teacher aggregator — [3/5] feat: cross-tokenizer distillation loss and multi-teacher aggregator #2349
  4. CUDA IPC for teacher logits transfer — [4/5] feat: CUDA IPC for teacher logits transfer #2350
  5. Algorithm + worker integration (single + multi-teacher) — [5/5] feat: off-policy distillation algorithm and worker integration #2351

@avenkateshha avenkateshha requested review from a team as code owners April 27, 2026 03:56
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 27, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Foundational library code for cross-tokenizer distillation. No algorithm
or training-loop integration yet — those follow in subsequent PRs.

- nemo_rl/algorithms/x_token/tokenalign.py: TokenAligner(nn.Module) with
  Numba-accelerated DP alignment, projection-matrix loading
  (dense and sparse COO), and the project_token_likelihoods_instance
  forward path used by the cross-tokenizer loss.
- nemo_rl/algorithms/x_token/__init__.py: package init.
- nemo_rl/utils/x_token/{minimal_projection_generator,
  minimal_projection_via_multitoken,reapply_exact_map,
  sort_and_cut_projection_matrix}.py: standalone CLI scripts
  (argparse-driven, __main__ entrypoints) for one-time projection-matrix
  preparation. Not on the training import path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
@avenkateshha avenkateshha force-pushed the avenkateshha/xtoken-off-policy-distillation/01-tokenaligner branch from a556396 to c492ac1 Compare April 27, 2026 10:20
…gner

tokenalign.py: 2194 → 1877 lines (-317).

align_fast removal
- Committing to use_align_fast=false everywhere; the fast-lookup path is no
  longer needed. Callers (CrossTokenizerCollator, off_policy_distillation
  setup) will be updated on the dependent branches in this stack to take
  the unconditional aligner.align(...) path.
- TokenAligner: dropped align_fast(), precompute_canonical_maps(), and the
  _student_canon_map / _teacher_canon_map / _canon_id_to_str attrs that
  backed them.

Behavior note: align_fast skipped _canonicalize_sequence's
_merge_encoding_artifacts / _merge_consecutive_bytes by reusing per-token
canonical maps. align() runs the full sequence-level canonicalization,
so alignment may differ slightly for sequences that contain
encoding-artifact or byte tokens. For pairs without those, the result is
identical. The cost is a small per-batch CPU bump in the collator
workers since we no longer cache canonical strings per id.

numba removal
- The training container ships without numba, so at runtime
  _NUMBA_AVAILABLE was always False, _dp_core_numba was never defined,
  and align_tokens_with_combinations_numpy_jit was a one-line forwarder
  to align_tokens_with_combinations_numpy on its first if-not-numba
  branch.
- Dropped the top-level numba try/except + _NUMBA_AVAILABLE flag and the
  @njit-decorated _dp_core_numba kernel.
- Deleted align_tokens_with_combinations_numpy_jit and retargeted callers
  at align_tokens_with_combinations_numpy directly:
    - _perform_dp_alignment (chunk_size==0 branch)
    - align_tokens_combinations_chunked (small-sequence base case +
      divide-and-conquer fallback)
- Behavior: zero runtime change for this container; the numpy DP kernel
  is what was always executing.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-maintainers Waiting on maintainers to respond label Apr 29, 2026
tokenizer_teacher_total_vocab_size = len(tokenizer_teacher)
model_A_config = AutoConfig.from_pretrained(student_model_name)
model_B_config = AutoConfig.from_pretrained(teacher_model_name)
# pdb.set_trace()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we clean up some of the debug/development scaffolding before merging? I noticed import pdb, commented pdb.set_trace() /exit() calls, and constant conditionals like if 1: / if 0:. The script may be working as intended, but removing these would make the offline tool easier to maintain and more consistent with the rest of the repo.

from nemo_rl.algorithms.x_token.tokenalign import TokenAligner
import gc
from collections import defaultdict
from datasets import load_dataset, get_dataset_config_names
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some imports appear to be unused in this file, such as AutoModelForCausalLM, torch.nn as nn, gc, load_dataset, get_dataset_config_names, random, time, numpy as np, and pdb. Could we remove the unused imports?

decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION)
teacher_tokens_decoded[token_id] = decoded
except:
# Skip tokens that can't be processed
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this tokenizer scan, could we avoid the bare except and keep the skip behavior with a small summary instead? For example, tracking a count plus a few examples would make unexpected decode issues visible without spamming output.

decode_failure_count = 0
decode_failure_examples = []
  ...
except (IndexError, KeyError, ValueError) as exc:
      decode_failure_count += 1
      if len(decode_failure_examples) < 5:
          decode_failure_examples.append((token_id, repr(exc)))

if decode_failure_count:
      print(
          f"Skipped {decode_failure_count} teacher tokens that could not be decoded. "
          f"First examples: {decode_failure_examples}"
      )

# Apply canonicalization if enabled
decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION)
student_tokens_decoded[token_id] = decoded
except:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here as above: could we replace the bare except with explicit failure tracking?

decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION)
teacher_tokens_decoded[token_id] = decoded
except:
# Skip tokens that can't be processed
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here as above: could we replace the bare except with explicit failure tracking?

teacher_vocab = tokenizer_teacher.get_vocab()
teacher_tokens_decoded = {}

print("Decoding teacher tokens...")
Copy link
Copy Markdown
Contributor

@RayenTian RayenTian May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tokenizer decode loops are very similar in these 3 part:

Could we factor them into a small helper that takes the tokenizer, vocab size, ignored IDs, raw-token/canonicalization flags, and whether to skip <|...|> special tokens?
That would also give us one place to handle decode failures with a counter plus a few examples instead of repeating bare except blocks in each loop.

transformation_counts[(student_id, teacher_id)] = 1.0


def debug_projection_map(transformation_counts, source_tokenizer, target_tokenizer, direction="", N=50):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug_projection_map() seems to only format and print projection examples, and it is called as part of the normal script flow. Could we rename it to something like print_projection_map_examples() and consider gating it behind a CLI option such as --num-examples? That would make the intent clearer and avoid treating normal output as debug-only behavior.

print(f"Created sparse student->teacher projection matrix with shape: {transformation_matrix_sparse.shape}")
print(f"Non-zero elements: {transformation_matrix_sparse._nnz()}")

if 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove the disabled/debug-only code before merging? I noticed blocks like if 0: and several commented-out pdb.set_trace() / exit() / debug snippets. Since these paths are not part of the active offline tool flow, deleting them would make the script easier to read and keep it consistent with the repo’s guidance around commented-out code.

@svcnvidia-nemo-ci svcnvidia-nemo-ci added waiting-on-customer Waiting on the original author to respond and removed waiting-on-maintainers Waiting on maintainers to respond labels May 5, 2026
)

# Model selection arguments
parser.add_argument(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make the required inputs explicit in the CLI instead of relying on model defaults? In particular, --student-model and --teacher-model currently default to specific HuggingFace models, which makes the script easy to run accidentally against the wrong tokenizer pair.


reverse_multi_token_examples = []
print("Finding reverse multi-token mappings...")
for student_token_id, student_token_str in tqdm.tqdm(student_tokens_decoded.items(), desc="Processing student tokens"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two multi-token mapping passes here look very similar with minimal_projection_via_multitoken.py:550, with the source/target tokenizer direction swapped. Could we factor this into a small helper, for example add_multitoken_mappings(...), that takes the source decoded tokens, target tokenizer, ignore IDs, and output direction? That would reduce duplication and make the weighting, truncation, ignored-token handling, and example collection easier to keep consistent across both passes.

print("Projection test successful - format is fully compatible!")

# pdb.set_trace()

No newline at end of file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Could we remove the extra blank/trailing whitespace at the end of the file?

#set indices to -1 where likelihood is 0

# Create filename in same format as minimal_projection_generator.py
def clean_model_name_for_filename(name: str) -> str:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move the nested helper functions out to module scope? Functions like clean_model_name_for_filename() and project_token_likelihoods() do not appear to depend on local state from the main script body, so defining them inside if __name__ == "__main__": makes the file harder to scan and harder to test. A main() entrypoint plus module-level private helpers would make the offline tool easier to maintain.

# Test projection function compatibility (same as minimal_projection_generator.py)
print("\n--- Testing projection function compatibility ---")

def project_token_likelihoods(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import torch.nn as nn
from tokenalign import TokenAligner
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from nemo_rl.algorithms.x_token.tokenalign import TokenAligner

)

# File paths
parser.add_argument(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--initial-projection-path is optional in argparse, but the script crashes immediately if it is omitted because initial_projection_map["likelihoods"] is used later. This should be required=True, and the loaded object should be validated to contain indices and likelihoods.

A = A / safe_row_sums
return A

def clean_model_name_for_filename(name: str) -> str:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function defined but never used.

print(f"Original top_k: {original_top_k}")
print(f"New top_k: {new_top_k}")
print(f"Preserve last column: {preserve_last}")
# pdb.set_trace()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


new_indices[row_idx, :num_to_take] = sorted_indices[:num_to_take]
new_likelihoods[row_idx, :num_to_take] = sorted_likelihoods[:num_to_take]
# pdb.set_trace()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parser.add_argument("--output_path", help="Output path (auto-generated if not specified)")
parser.add_argument("--preserve_last", action="store_true", help="Always preserve the last column as the final element")
parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output")
# python sort_and_cut_projection_matrix.py /lustre/fsw/portfolios/nvr/projects/nvr_lpr_llm/users/pmolchanov/xtoken/models/runs/s4_l1q4b_lr0_kl1_ce0_k1_emb_top10_3_learn_qa2_transformation_matrices/learned_projection_map_latest.pt --top_k 8 --output_path cross_tokenizer_data/projection_matrix_learned_llama_qwen_top8.pt --preserve_last
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove the local absolute-path usage example from the script and move the x-token offline tool usage into a docs page instead? A good place would be a new docs/guides/cross-tokenizer-distillation.md, linked from docs/index.md under Guides, since these helpers are part of the pre-distillation projection-map workflow. It would also be useful to document the other related tools there as well, so users have one consistent place to find commands for generating projection maps, reapplying exact matches, and sorting/cutting projection maps before running distillation.

# If new_top_k > original_top_k, the tensors are already padded with -1 and 0.0

# Apply Sinkhorn normalization to the final matrix
print(f"last element trick count: {last_element_trick_count}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--quiet currently still prints last element trick count because this line is outside the if verbose: guard. Could we move it under the same verbose check as the other progress/statistics output so --quiet consistently suppresses non-essential logging?

# Save the new projection matrix
torch.save(output_data, output_path)

if verbose:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The verbose statistics block is quite long and separate from the core sort/cut transformation. Could we move the output/statistics reporting into a helper like print_projection_statistics(...)? That would make sort_and_cut_projection_matrix() easier to follow and keep the transformation logic separate from diagnostic printing.

parser.add_argument("input_path", help="Path to input projection matrix file")
parser.add_argument("--top_k", type=int, required=True, help="New top_k value for cutoff")
parser.add_argument("--output_path", help="Output path (auto-generated if not specified)")
parser.add_argument("--preserve_last", action="store_true", help="Always preserve the last column as the final element")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why we need a switch to control whether preserve last column as the final element.
Could you please confirm whether --preserve_last is intended to be used whenever the projection map was generated with --enable-scale-trick in minimal_projection_via_multitoken.py?

If that is the case, could we encode enable_scale_trick into the projection map metadata when the map is generated, and have sort_and_cut_projection_matrix.py read that metadata to decide whether the last column should be preserved automatically? That would avoid requiring users to manually remember the coupling between these two tools.

EXACT_MATCH_ONLY = False

# --- Configuration and Setup ---
parser = argparse.ArgumentParser(description="Generate a sparse projection map between two tokenizers.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move the argument parser setup into a dedicated parse_arguments() helper, similar to the other x-token offline tools?

parser = argparse.ArgumentParser(description="Generate a sparse projection map between two tokenizers.")
parser.add_argument("--model_a_index", type=int, default=1, help="Index of the source model (Model A / Student).")
parser.add_argument("--model_b_index", type=int, default=0, help="Index of the target model (Model B / Teacher).")
parser.add_argument("--model_a_name", type=str, default=None, help="HuggingFace model name for source model (Model A / Student). If provided, overrides model_a_index.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we rename the CLI arguments and internal variables from model_a / model_b to student-model / teacher-model? This can avoid confusion around which tokenizer owns the projection rows versus columns.

# --- Configuration and Setup ---
parser = argparse.ArgumentParser(description="Generate a sparse projection map between two tokenizers.")
parser.add_argument("--model_a_index", type=int, default=1, help="Index of the source model (Model A / Student).")
parser.add_argument("--model_b_index", type=int, default=0, help="Index of the target model (Model B / Teacher).")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we consider simplifying the model selection interface here? If I understand correctly, model_a_index / model_b_index are shortcuts into the local MODEL_LIST, while model_a_name / model_b_name override them. Could we remove the model_a_index / model_b_index shortcuts and require explicit model names instead? These indices only refer to the local hard-coded MODEL_LIST, which seems useful for local experimentation but can be confusing in a repo-level offline tool. Using explicit --student-model and --teacher-model arguments would make the projection direction much clearer.

model_2['tokenizer'] = load_tokenizer(model_2['id'])

# Deterministically assign model_A and model_B based on alphabetical order of names
if model_1['name'] > model_2['name']:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why model_A / model_B are reassigned based on alphabetical order here? The CLI help describes model_a as the source/student side and model_b as the target/teacher side, so I’m worried this swap could make the generated projection map go in the opposite direction from what the user requested.

import torch
import os
import argparse
from transformers import AutoTokenizer, AutoModel, AutoConfig
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from nemo_rl.algorithms.x_token.tokenalign import TokenAligner

# from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm
import re
import pdb
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we clean up some of the debug/development scaffolding before merging? Some with this comment.


def sinkhorn_one_dim(A, n_iters=1):
for _ in range(n_iters):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you please the white lines like this?

parser.add_argument("--use_canonicalization", action='store_true', help="Apply token canonicalization before generating embeddings to normalize different tokenizer representations (e.g., Ġ vs ▁ prefixes, Ċ vs \\n).")
args = parser.parse_args()

args.skip_exact_enforcement = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove or wire up --skip_exact_enforcement here? It looks like args.skip_exact_enforcement is forced to True after parsing, and I don’t see the argument being used later to control any exact-match behavior. If exact-match enforcement is intentionally out of scope for this chunked generator, removing the flag may be clearer so users do not expect it to affect the output.

#set indices to -1 where threshold is not met
top_k_indices = top_k_indices.where(threshold_mask, torch.full_like(top_k_indices, -1))

# # Count how many values per row are above threshold
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the debug code?

target_expanded = target_indices_k.view(1, 1, -1).expand(batch_size, seq_len, chunk_len)
projected_likelihoods.scatter_add_(2, target_expanded, weighted_k)

# else: # For larger top_k, use optimized sequential processing
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the debug code?

@RayenTian
Copy link
Copy Markdown
Contributor

Hi, @avenkateshha, thanks for the hard work on this! I left some review comments. Besides the comments, could we do a first cleanup pass on this file?

  • Would you mind running the repo pre-commit hooks on this file? You can refer to this doc to install pre-commit. Next time when you commit, the lint check will be triggered locally. Simply re-add the automatically fixed files and commit again. If any errors remain, the command output will show the specific files and rules that require manual adjustment.
  • Could you please add focused unit tests for the main TokenAligner behavior? In particular, tests for small projection maps, -1 padding entries, exact-map rows, enable_scale_trick, and a small canonicalization/alignment example would help cover the core runtime paths.
  • Could we also remove the remaining development scaffolding, such as if 0: / if 1: blocks, commented-out debug statements, and “will remove later” comments?

self.student_tokenizer = None

self.max_combination_len = max_comb_len
self.sparse_transformation_matrix = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we avoid assigning self.sparse_transformation_matrix = None before later registering a buffer with the same name? My understanding is that nn.Module.register_buffer() expects the name not to already exist as a regular attribute, so the non-learnable sparse load path may fail when it tries to register sparse_transformation_matrix. Registering an initial None buffer, or only setting the attribute in the load path, would avoid that conflict.


# Get tokenizer vocab sizes
teacher_vocab_size = len(self.teacher_tokenizer) if self.teacher_tokenizer else 151669 # fallback
student_vocab_size = len(self.student_tokenizer) if self.student_tokenizer else 128256 # fallback
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we avoid the hard-coded fallback vocab sizes here? These look specific to one Qwen/Llama-style setup, and they could silently create a projection matrix with the wrong shape for other model pairs when tokenizers are not initialized.

import torch.nn as nn
from transformers import AutoConfig, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we avoid overriding TOKENIZERS_PARALLELISM at module import time here? This changes a process-wide setting for any code that imports TokenAligner, and it can override a value the user or launcher already configured. If we still want a default to suppress tokenizer parallelism warnings, using os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") or moving this to the training entrypoint/config would be less surprising.


return s1_canonical == s2_canonical

def align_tokens_with_combinations_numpy(seq1, seq2,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make align_tokens_with_combinations_numpy() a @staticmethod as well? It does not appear to use instance state, and most of the similar helper methods in this class are already static. More importantly, calling it through self can bind the instance as the first positional argument and shift the user-provided arguments, so marking it static would make the call behavior consistent with the rest of the helpers.

avenkateshha added a commit to avenkateshha/RL that referenced this pull request May 16, 2026
Apply the post-review changes for PR NVIDIA-NeMo#2347 on the cross-tokenizer
projection-prep utilities under nemo_rl/utils/x_token/:

minimal_projection_via_multitoken.py
- Re-add optional --output-filename stem (default None falls back to the
  auto-derived name) so recipe-driven runs can pin the filename, matching
  the contract used by the reference recipes.
- Extend the gemma vocab-size branch to also fire for qwen3.5, which
  nests vocab_size under config.text_config on both student and teacher
  sides.
- Document the .pt save-key schema near torch.save (student_model_id /
  teacher_model_id / enable_scale_trick) with a pointer to the
  legacy-key fallback in the load path.

minimal_projection_generator.py
- Add the same schema annotation near torch.save.

reapply_exact_map.py
- Validate the loaded projection map is a dict containing 'indices' and
  'likelihoods'; raise ValueError with the file path on mismatch instead
  of surfacing a confusing KeyError mid-loop.

sort_and_cut_projection_matrix.py
- Lift the argparse block out of main() into a module-level
  parse_arguments() helper, matching the shape used by the other
  x_token CLI scripts.
- Lift the verbose stats block out of sort_and_cut_projection_matrix()
  into a print_projection_statistics() helper.
- Replace the positional input_path with a named
  --initial-projection-path flag for naming consistency with the rest
  of the x_token CLI tools.

docs/guides/xtoken-distillation.md
- Update Step 4 usage example to the new --initial-projection-path flag.

Behaviour preserved: the Llama-3.2-3B / Qwen3-4B-Base recipe still
produces bitwise-equal indices and likelihoods against the canonical
llama_qwen_best_special_exact_map_remapped.pt artifact.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
@avenkateshha avenkateshha deleted the avenkateshha/xtoken-off-policy-distillation/01-tokenaligner branch May 16, 2026 01:47
@avenkateshha
Copy link
Copy Markdown
Author

Work continues in #2508 (the head branch of this PR was renamed off the avenkateshha/xtoken-off-policy-distillation/ namespace, which auto-closed this cross-repo PR). All the review feedback from this thread is addressed in commit 755fb8e4 on the new PR.

@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-customer Waiting on the original author to respond label May 16, 2026
avenkateshha added a commit to avenkateshha/RL that referenced this pull request May 27, 2026
Apply the post-review changes for PR NVIDIA-NeMo#2347 on the cross-tokenizer
projection-prep utilities under nemo_rl/utils/x_token/:

minimal_projection_via_multitoken.py
- Re-add optional --output-filename stem (default None falls back to the
  auto-derived name) so recipe-driven runs can pin the filename, matching
  the contract used by the reference recipes.
- Extend the gemma vocab-size branch to also fire for qwen3.5, which
  nests vocab_size under config.text_config on both student and teacher
  sides.
- Document the .pt save-key schema near torch.save (student_model_id /
  teacher_model_id / enable_scale_trick) with a pointer to the
  legacy-key fallback in the load path.

minimal_projection_generator.py
- Add the same schema annotation near torch.save.

reapply_exact_map.py
- Validate the loaded projection map is a dict containing 'indices' and
  'likelihoods'; raise ValueError with the file path on mismatch instead
  of surfacing a confusing KeyError mid-loop.

sort_and_cut_projection_matrix.py
- Lift the argparse block out of main() into a module-level
  parse_arguments() helper, matching the shape used by the other
  x_token CLI scripts.
- Lift the verbose stats block out of sort_and_cut_projection_matrix()
  into a print_projection_statistics() helper.
- Replace the positional input_path with a named
  --initial-projection-path flag for naming consistency with the rest
  of the x_token CLI tools.

docs/guides/xtoken-distillation.md
- Update Step 4 usage example to the new --initial-projection-path flag.

Behaviour preserved: the Llama-3.2-3B / Qwen3-4B-Base recipe still
produces bitwise-equal indices and likelihoods against the canonical
llama_qwen_best_special_exact_map_remapped.pt artifact.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants