[1/5] feat: add TokenAligner and cross-tokenizer projection utilities#2347
Conversation
Foundational library code for cross-tokenizer distillation. No algorithm
or training-loop integration yet — those follow in subsequent PRs.
- nemo_rl/algorithms/x_token/tokenalign.py: TokenAligner(nn.Module) with
Numba-accelerated DP alignment, projection-matrix loading
(dense and sparse COO), and the project_token_likelihoods_instance
forward path used by the cross-tokenizer loss.
- nemo_rl/algorithms/x_token/__init__.py: package init.
- nemo_rl/utils/x_token/{minimal_projection_generator,
minimal_projection_via_multitoken,reapply_exact_map,
sort_and_cut_projection_matrix}.py: standalone CLI scripts
(argparse-driven, __main__ entrypoints) for one-time projection-matrix
preparation. Not on the training import path.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
a556396 to
c492ac1
Compare
…gner
tokenalign.py: 2194 → 1877 lines (-317).
align_fast removal
- Committing to use_align_fast=false everywhere; the fast-lookup path is no
longer needed. Callers (CrossTokenizerCollator, off_policy_distillation
setup) will be updated on the dependent branches in this stack to take
the unconditional aligner.align(...) path.
- TokenAligner: dropped align_fast(), precompute_canonical_maps(), and the
_student_canon_map / _teacher_canon_map / _canon_id_to_str attrs that
backed them.
Behavior note: align_fast skipped _canonicalize_sequence's
_merge_encoding_artifacts / _merge_consecutive_bytes by reusing per-token
canonical maps. align() runs the full sequence-level canonicalization,
so alignment may differ slightly for sequences that contain
encoding-artifact or byte tokens. For pairs without those, the result is
identical. The cost is a small per-batch CPU bump in the collator
workers since we no longer cache canonical strings per id.
numba removal
- The training container ships without numba, so at runtime
_NUMBA_AVAILABLE was always False, _dp_core_numba was never defined,
and align_tokens_with_combinations_numpy_jit was a one-line forwarder
to align_tokens_with_combinations_numpy on its first if-not-numba
branch.
- Dropped the top-level numba try/except + _NUMBA_AVAILABLE flag and the
@njit-decorated _dp_core_numba kernel.
- Deleted align_tokens_with_combinations_numpy_jit and retargeted callers
at align_tokens_with_combinations_numpy directly:
- _perform_dp_alignment (chunk_size==0 branch)
- align_tokens_combinations_chunked (small-sequence base case +
divide-and-conquer fallback)
- Behavior: zero runtime change for this container; the numpy DP kernel
is what was always executing.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithyakrishna Hanasoge <avenkateshha@nvidia.com>
| tokenizer_teacher_total_vocab_size = len(tokenizer_teacher) | ||
| model_A_config = AutoConfig.from_pretrained(student_model_name) | ||
| model_B_config = AutoConfig.from_pretrained(teacher_model_name) | ||
| # pdb.set_trace() |
There was a problem hiding this comment.
Could we clean up some of the debug/development scaffolding before merging? I noticed import pdb, commented pdb.set_trace() /exit() calls, and constant conditionals like if 1: / if 0:. The script may be working as intended, but removing these would make the offline tool easier to maintain and more consistent with the rest of the repo.
| from nemo_rl.algorithms.x_token.tokenalign import TokenAligner | ||
| import gc | ||
| from collections import defaultdict | ||
| from datasets import load_dataset, get_dataset_config_names |
There was a problem hiding this comment.
Some imports appear to be unused in this file, such as AutoModelForCausalLM, torch.nn as nn, gc, load_dataset, get_dataset_config_names, random, time, numpy as np, and pdb. Could we remove the unused imports?
| decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION) | ||
| teacher_tokens_decoded[token_id] = decoded | ||
| except: | ||
| # Skip tokens that can't be processed |
There was a problem hiding this comment.
For this tokenizer scan, could we avoid the bare except and keep the skip behavior with a small summary instead? For example, tracking a count plus a few examples would make unexpected decode issues visible without spamming output.
decode_failure_count = 0
decode_failure_examples = []
...
except (IndexError, KeyError, ValueError) as exc:
decode_failure_count += 1
if len(decode_failure_examples) < 5:
decode_failure_examples.append((token_id, repr(exc)))
if decode_failure_count:
print(
f"Skipped {decode_failure_count} teacher tokens that could not be decoded. "
f"First examples: {decode_failure_examples}"
)| # Apply canonicalization if enabled | ||
| decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION) | ||
| student_tokens_decoded[token_id] = decoded | ||
| except: |
There was a problem hiding this comment.
Same comment here as above: could we replace the bare except with explicit failure tracking?
| decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION) | ||
| teacher_tokens_decoded[token_id] = decoded | ||
| except: | ||
| # Skip tokens that can't be processed |
There was a problem hiding this comment.
Same comment here as above: could we replace the bare except with explicit failure tracking?
| teacher_vocab = tokenizer_teacher.get_vocab() | ||
| teacher_tokens_decoded = {} | ||
|
|
||
| print("Decoding teacher tokens...") |
There was a problem hiding this comment.
These tokenizer decode loops are very similar in these 3 part:
- nemo_rl/utils/x_token/minimal_projection_via_multitoken.py:423
- nemo_rl/utils/x_token/minimal_projection_via_multitoken.py:460
- nemo_rl/utils/x_token/minimal_projection_via_multitoken.py:525
Could we factor them into a small helper that takes the tokenizer, vocab size, ignored IDs, raw-token/canonicalization flags, and whether to skip <|...|> special tokens?
That would also give us one place to handle decode failures with a counter plus a few examples instead of repeating bare except blocks in each loop.
| transformation_counts[(student_id, teacher_id)] = 1.0 | ||
|
|
||
|
|
||
| def debug_projection_map(transformation_counts, source_tokenizer, target_tokenizer, direction="", N=50): |
There was a problem hiding this comment.
debug_projection_map() seems to only format and print projection examples, and it is called as part of the normal script flow. Could we rename it to something like print_projection_map_examples() and consider gating it behind a CLI option such as --num-examples? That would make the intent clearer and avoid treating normal output as debug-only behavior.
| print(f"Created sparse student->teacher projection matrix with shape: {transformation_matrix_sparse.shape}") | ||
| print(f"Non-zero elements: {transformation_matrix_sparse._nnz()}") | ||
|
|
||
| if 0: |
There was a problem hiding this comment.
Could we remove the disabled/debug-only code before merging? I noticed blocks like if 0: and several commented-out pdb.set_trace() / exit() / debug snippets. Since these paths are not part of the active offline tool flow, deleting them would make the script easier to read and keep it consistent with the repo’s guidance around commented-out code.
| ) | ||
|
|
||
| # Model selection arguments | ||
| parser.add_argument( |
There was a problem hiding this comment.
Could we make the required inputs explicit in the CLI instead of relying on model defaults? In particular, --student-model and --teacher-model currently default to specific HuggingFace models, which makes the script easy to run accidentally against the wrong tokenizer pair.
|
|
||
| reverse_multi_token_examples = [] | ||
| print("Finding reverse multi-token mappings...") | ||
| for student_token_id, student_token_str in tqdm.tqdm(student_tokens_decoded.items(), desc="Processing student tokens"): |
There was a problem hiding this comment.
The two multi-token mapping passes here look very similar with minimal_projection_via_multitoken.py:550, with the source/target tokenizer direction swapped. Could we factor this into a small helper, for example add_multitoken_mappings(...), that takes the source decoded tokens, target tokenizer, ignore IDs, and output direction? That would reduce duplication and make the weighting, truncation, ignored-token handling, and example collection easier to keep consistent across both passes.
| print("Projection test successful - format is fully compatible!") | ||
|
|
||
| # pdb.set_trace() | ||
|
No newline at end of file |
There was a problem hiding this comment.
NIT: Could we remove the extra blank/trailing whitespace at the end of the file?
| #set indices to -1 where likelihood is 0 | ||
|
|
||
| # Create filename in same format as minimal_projection_generator.py | ||
| def clean_model_name_for_filename(name: str) -> str: |
There was a problem hiding this comment.
Could we move the nested helper functions out to module scope? Functions like clean_model_name_for_filename() and project_token_likelihoods() do not appear to depend on local state from the main script body, so defining them inside if __name__ == "__main__": makes the file harder to scan and harder to test. A main() entrypoint plus module-level private helpers would make the offline tool easier to maintain.
| # Test projection function compatibility (same as minimal_projection_generator.py) | ||
| print("\n--- Testing projection function compatibility ---") | ||
|
|
||
| def project_token_likelihoods(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device): |
There was a problem hiding this comment.
| import torch | ||
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | ||
| import torch.nn as nn | ||
| from tokenalign import TokenAligner |
There was a problem hiding this comment.
from nemo_rl.algorithms.x_token.tokenalign import TokenAligner
| ) | ||
|
|
||
| # File paths | ||
| parser.add_argument( |
There was a problem hiding this comment.
--initial-projection-path is optional in argparse, but the script crashes immediately if it is omitted because initial_projection_map["likelihoods"] is used later. This should be required=True, and the loaded object should be validated to contain indices and likelihoods.
| A = A / safe_row_sums | ||
| return A | ||
|
|
||
| def clean_model_name_for_filename(name: str) -> str: |
There was a problem hiding this comment.
This function defined but never used.
| print(f"Original top_k: {original_top_k}") | ||
| print(f"New top_k: {new_top_k}") | ||
| print(f"Preserve last column: {preserve_last}") | ||
| # pdb.set_trace() |
There was a problem hiding this comment.
|
|
||
| new_indices[row_idx, :num_to_take] = sorted_indices[:num_to_take] | ||
| new_likelihoods[row_idx, :num_to_take] = sorted_likelihoods[:num_to_take] | ||
| # pdb.set_trace() |
There was a problem hiding this comment.
| parser.add_argument("--output_path", help="Output path (auto-generated if not specified)") | ||
| parser.add_argument("--preserve_last", action="store_true", help="Always preserve the last column as the final element") | ||
| parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output") | ||
| # python sort_and_cut_projection_matrix.py /lustre/fsw/portfolios/nvr/projects/nvr_lpr_llm/users/pmolchanov/xtoken/models/runs/s4_l1q4b_lr0_kl1_ce0_k1_emb_top10_3_learn_qa2_transformation_matrices/learned_projection_map_latest.pt --top_k 8 --output_path cross_tokenizer_data/projection_matrix_learned_llama_qwen_top8.pt --preserve_last |
There was a problem hiding this comment.
Could we remove the local absolute-path usage example from the script and move the x-token offline tool usage into a docs page instead? A good place would be a new docs/guides/cross-tokenizer-distillation.md, linked from docs/index.md under Guides, since these helpers are part of the pre-distillation projection-map workflow. It would also be useful to document the other related tools there as well, so users have one consistent place to find commands for generating projection maps, reapplying exact matches, and sorting/cutting projection maps before running distillation.
| # If new_top_k > original_top_k, the tensors are already padded with -1 and 0.0 | ||
|
|
||
| # Apply Sinkhorn normalization to the final matrix | ||
| print(f"last element trick count: {last_element_trick_count}") |
There was a problem hiding this comment.
--quiet currently still prints last element trick count because this line is outside the if verbose: guard. Could we move it under the same verbose check as the other progress/statistics output so --quiet consistently suppresses non-essential logging?
| # Save the new projection matrix | ||
| torch.save(output_data, output_path) | ||
|
|
||
| if verbose: |
There was a problem hiding this comment.
The verbose statistics block is quite long and separate from the core sort/cut transformation. Could we move the output/statistics reporting into a helper like print_projection_statistics(...)? That would make sort_and_cut_projection_matrix() easier to follow and keep the transformation logic separate from diagnostic printing.
| parser.add_argument("input_path", help="Path to input projection matrix file") | ||
| parser.add_argument("--top_k", type=int, required=True, help="New top_k value for cutoff") | ||
| parser.add_argument("--output_path", help="Output path (auto-generated if not specified)") | ||
| parser.add_argument("--preserve_last", action="store_true", help="Always preserve the last column as the final element") |
There was a problem hiding this comment.
Just curious, why we need a switch to control whether preserve last column as the final element.
Could you please confirm whether --preserve_last is intended to be used whenever the projection map was generated with --enable-scale-trick in minimal_projection_via_multitoken.py?
If that is the case, could we encode enable_scale_trick into the projection map metadata when the map is generated, and have sort_and_cut_projection_matrix.py read that metadata to decide whether the last column should be preserved automatically? That would avoid requiring users to manually remember the coupling between these two tools.
| EXACT_MATCH_ONLY = False | ||
|
|
||
| # --- Configuration and Setup --- | ||
| parser = argparse.ArgumentParser(description="Generate a sparse projection map between two tokenizers.") |
There was a problem hiding this comment.
Could we move the argument parser setup into a dedicated parse_arguments() helper, similar to the other x-token offline tools?
| parser = argparse.ArgumentParser(description="Generate a sparse projection map between two tokenizers.") | ||
| parser.add_argument("--model_a_index", type=int, default=1, help="Index of the source model (Model A / Student).") | ||
| parser.add_argument("--model_b_index", type=int, default=0, help="Index of the target model (Model B / Teacher).") | ||
| parser.add_argument("--model_a_name", type=str, default=None, help="HuggingFace model name for source model (Model A / Student). If provided, overrides model_a_index.") |
There was a problem hiding this comment.
Could we rename the CLI arguments and internal variables from model_a / model_b to student-model / teacher-model? This can avoid confusion around which tokenizer owns the projection rows versus columns.
| # --- Configuration and Setup --- | ||
| parser = argparse.ArgumentParser(description="Generate a sparse projection map between two tokenizers.") | ||
| parser.add_argument("--model_a_index", type=int, default=1, help="Index of the source model (Model A / Student).") | ||
| parser.add_argument("--model_b_index", type=int, default=0, help="Index of the target model (Model B / Teacher).") |
There was a problem hiding this comment.
Could we consider simplifying the model selection interface here? If I understand correctly, model_a_index / model_b_index are shortcuts into the local MODEL_LIST, while model_a_name / model_b_name override them. Could we remove the model_a_index / model_b_index shortcuts and require explicit model names instead? These indices only refer to the local hard-coded MODEL_LIST, which seems useful for local experimentation but can be confusing in a repo-level offline tool. Using explicit --student-model and --teacher-model arguments would make the projection direction much clearer.
| model_2['tokenizer'] = load_tokenizer(model_2['id']) | ||
|
|
||
| # Deterministically assign model_A and model_B based on alphabetical order of names | ||
| if model_1['name'] > model_2['name']: |
There was a problem hiding this comment.
Why model_A / model_B are reassigned based on alphabetical order here? The CLI help describes model_a as the source/student side and model_b as the target/teacher side, so I’m worried this swap could make the generated projection map go in the opposite direction from what the user requested.
| import torch | ||
| import os | ||
| import argparse | ||
| from transformers import AutoTokenizer, AutoModel, AutoConfig |
There was a problem hiding this comment.
from nemo_rl.algorithms.x_token.tokenalign import TokenAligner
| # from sentence_transformers import SentenceTransformer | ||
| from tqdm.auto import tqdm | ||
| import re | ||
| import pdb |
There was a problem hiding this comment.
Could we clean up some of the debug/development scaffolding before merging? Some with this comment.
|
|
||
| def sinkhorn_one_dim(A, n_iters=1): | ||
| for _ in range(n_iters): | ||
|
|
There was a problem hiding this comment.
Would you please the white lines like this?
| parser.add_argument("--use_canonicalization", action='store_true', help="Apply token canonicalization before generating embeddings to normalize different tokenizer representations (e.g., Ġ vs ▁ prefixes, Ċ vs \\n).") | ||
| args = parser.parse_args() | ||
|
|
||
| args.skip_exact_enforcement = True |
There was a problem hiding this comment.
Could we remove or wire up --skip_exact_enforcement here? It looks like args.skip_exact_enforcement is forced to True after parsing, and I don’t see the argument being used later to control any exact-match behavior. If exact-match enforcement is intentionally out of scope for this chunked generator, removing the flag may be clearer so users do not expect it to affect the output.
| #set indices to -1 where threshold is not met | ||
| top_k_indices = top_k_indices.where(threshold_mask, torch.full_like(top_k_indices, -1)) | ||
|
|
||
| # # Count how many values per row are above threshold |
There was a problem hiding this comment.
Can we remove the debug code?
| target_expanded = target_indices_k.view(1, 1, -1).expand(batch_size, seq_len, chunk_len) | ||
| projected_likelihoods.scatter_add_(2, target_expanded, weighted_k) | ||
|
|
||
| # else: # For larger top_k, use optimized sequential processing |
There was a problem hiding this comment.
Can we remove the debug code?
|
Hi, @avenkateshha, thanks for the hard work on this! I left some review comments. Besides the comments, could we do a first cleanup pass on this file?
|
| self.student_tokenizer = None | ||
|
|
||
| self.max_combination_len = max_comb_len | ||
| self.sparse_transformation_matrix = None |
There was a problem hiding this comment.
Could we avoid assigning self.sparse_transformation_matrix = None before later registering a buffer with the same name? My understanding is that nn.Module.register_buffer() expects the name not to already exist as a regular attribute, so the non-learnable sparse load path may fail when it tries to register sparse_transformation_matrix. Registering an initial None buffer, or only setting the attribute in the load path, would avoid that conflict.
|
|
||
| # Get tokenizer vocab sizes | ||
| teacher_vocab_size = len(self.teacher_tokenizer) if self.teacher_tokenizer else 151669 # fallback | ||
| student_vocab_size = len(self.student_tokenizer) if self.student_tokenizer else 128256 # fallback |
There was a problem hiding this comment.
Could we avoid the hard-coded fallback vocab sizes here? These look specific to one Qwen/Llama-style setup, and they could silently create a projection matrix with the wrong shape for other model pairs when tokenizers are not initialized.
| import torch.nn as nn | ||
| from transformers import AutoConfig, AutoTokenizer | ||
|
|
||
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
There was a problem hiding this comment.
Could we avoid overriding TOKENIZERS_PARALLELISM at module import time here? This changes a process-wide setting for any code that imports TokenAligner, and it can override a value the user or launcher already configured. If we still want a default to suppress tokenizer parallelism warnings, using os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") or moving this to the training entrypoint/config would be less surprising.
|
|
||
| return s1_canonical == s2_canonical | ||
|
|
||
| def align_tokens_with_combinations_numpy(seq1, seq2, |
There was a problem hiding this comment.
Could we make align_tokens_with_combinations_numpy() a @staticmethod as well? It does not appear to use instance state, and most of the similar helper methods in this class are already static. More importantly, calling it through self can bind the instance as the first positional argument and shift the user-provided arguments, so marking it static would make the call behavior consistent with the rest of the helpers.
Apply the post-review changes for PR NVIDIA-NeMo#2347 on the cross-tokenizer projection-prep utilities under nemo_rl/utils/x_token/: minimal_projection_via_multitoken.py - Re-add optional --output-filename stem (default None falls back to the auto-derived name) so recipe-driven runs can pin the filename, matching the contract used by the reference recipes. - Extend the gemma vocab-size branch to also fire for qwen3.5, which nests vocab_size under config.text_config on both student and teacher sides. - Document the .pt save-key schema near torch.save (student_model_id / teacher_model_id / enable_scale_trick) with a pointer to the legacy-key fallback in the load path. minimal_projection_generator.py - Add the same schema annotation near torch.save. reapply_exact_map.py - Validate the loaded projection map is a dict containing 'indices' and 'likelihoods'; raise ValueError with the file path on mismatch instead of surfacing a confusing KeyError mid-loop. sort_and_cut_projection_matrix.py - Lift the argparse block out of main() into a module-level parse_arguments() helper, matching the shape used by the other x_token CLI scripts. - Lift the verbose stats block out of sort_and_cut_projection_matrix() into a print_projection_statistics() helper. - Replace the positional input_path with a named --initial-projection-path flag for naming consistency with the rest of the x_token CLI tools. docs/guides/xtoken-distillation.md - Update Step 4 usage example to the new --initial-projection-path flag. Behaviour preserved: the Llama-3.2-3B / Qwen3-4B-Base recipe still produces bitwise-equal indices and likelihoods against the canonical llama_qwen_best_special_exact_map_remapped.pt artifact. Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
|
Work continues in #2508 (the head branch of this PR was renamed off the |
Apply the post-review changes for PR NVIDIA-NeMo#2347 on the cross-tokenizer projection-prep utilities under nemo_rl/utils/x_token/: minimal_projection_via_multitoken.py - Re-add optional --output-filename stem (default None falls back to the auto-derived name) so recipe-driven runs can pin the filename, matching the contract used by the reference recipes. - Extend the gemma vocab-size branch to also fire for qwen3.5, which nests vocab_size under config.text_config on both student and teacher sides. - Document the .pt save-key schema near torch.save (student_model_id / teacher_model_id / enable_scale_trick) with a pointer to the legacy-key fallback in the load path. minimal_projection_generator.py - Add the same schema annotation near torch.save. reapply_exact_map.py - Validate the loaded projection map is a dict containing 'indices' and 'likelihoods'; raise ValueError with the file path on mismatch instead of surfacing a confusing KeyError mid-loop. sort_and_cut_projection_matrix.py - Lift the argparse block out of main() into a module-level parse_arguments() helper, matching the shape used by the other x_token CLI scripts. - Lift the verbose stats block out of sort_and_cut_projection_matrix() into a print_projection_statistics() helper. - Replace the positional input_path with a named --initial-projection-path flag for naming consistency with the rest of the x_token CLI tools. docs/guides/xtoken-distillation.md - Update Step 4 usage example to the new --initial-projection-path flag. Behaviour preserved: the Llama-3.2-3B / Qwen3-4B-Base recipe still produces bitwise-equal indices and likelihoods against the canonical llama_qwen_best_special_exact_map_remapped.pt artifact. Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Foundational library code for cross-tokenizer distillation. No algorithm or training-loop integration yet — those follow in subsequent PRs.
nemo_rl/algorithms/x_token/tokenalign.py:TokenAligner(nn.Module)with Numba-accelerated DP alignment, projection-matrix loading (dense and sparse COO), and theproject_token_likelihoods_instanceforward path used by the cross-tokenizer loss.nemo_rl/algorithms/x_token/__init__.py: package init.nemo_rl/utils/x_token/{minimal_projection_generator, minimal_projection_via_multitoken, reapply_exact_map, sort_and_cut_projection_matrix}.py: standalone CLI scripts (argparse-driven,__main__entrypoints) for one-time projection-matrix preparation. Not on the training import path.What does this PR do?
Introduces the
TokenAlignerclass and projection-matrix preparation tools that subsequent PRs build on for cross-tokenizer off-policy distillation.Issues
None linked yet.
Usage
Before your PR is "Ready for review"
TokenAligneris exercised by the loss tests landing with PR 3 and by the off-policy distillation recipe in PR 5.py_compileconfirmed clean. CI on this PR will be the first full run.docs/index.mdwill be updated alongside the algorithm PR (PR 5).Additional Information
First in a 5-PR stack splitting the cross-tokenizer off-policy distillation feature for review: