Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 310 additions & 0 deletions examples/configs/off_policy_distillation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
token_aligner:
enabled: false

teachers:
# Teacher 0: Phi-4-mini-instruct (weight 0.5)
# Mirrors tokenalign/teacher_configs/multi_teacher_config_phi-4B_llama-3.1-4b_best-proj_not_learned.json
- weight: 0.5
teacher:
model_name: "microsoft/Phi-4-mini-instruct"
tokenizer:
name: "microsoft/Phi-4-mini-instruct"
chat_template: null
precision: "bfloat16"
train_global_batch_size: 768
train_micro_batch_size: 1
logprob_batch_size: 1
max_total_sequence_length: 4096
max_grad_norm: 1.0
logprob_chunk_size: null
offload_optimizer_for_logprob: false
dtensor_cfg:
enabled: true
_v2: true
cpu_offload: false
sequence_parallel: false
activation_checkpointing: true
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null
dynamic_batching:
enabled: false
train_mb_tokens: 4096
logprob_mb_tokens: 4096
sequence_length_round: 64
sequence_packing:
enabled: false
train_mb_tokens: 4096
logprob_mb_tokens: 4096
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64
optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 5.0e-5
weight_decay: 0.1
betas: [0.9, 0.98]
eps: 1e-5
foreach: false
fused: false
generation: null
token_aligner:
enabled: true
projection_matrix_path: "cross_tokenizer_data/llama_phi-mini_best_special_exact_map_remapped.pt"
use_sparse_format: false
loss_type: "KL"
exact_token_match_only: false
temperature: 1.0
vocab_topk: 8192
reverse_kl: false
projection_matrix_multiplier: 1.0
max_comb_len: 4
learnable: false
project_teacher_to_student: false
use_char_offset: false
use_cuda_dp: false
dp_chunk_size: 128

# Teacher 1: Llama-3.2-3B (weight 0.5)
# Same tokenizer family as the Llama-3.2-1B student; uses the identity/exact-map
# remapped projection so the cross-tokenizer path reduces to exact-token alignment.
- weight: 0.5
teacher:
model_name: "meta-llama/Llama-3.2-3B"
tokenizer:
name: "meta-llama/Llama-3.2-3B"
chat_template: null
precision: "bfloat16"
train_global_batch_size: 768
train_micro_batch_size: 1
logprob_batch_size: 1
max_total_sequence_length: 4096
max_grad_norm: 1.0
logprob_chunk_size: null
offload_optimizer_for_logprob: false
dtensor_cfg:
enabled: true
_v2: true
cpu_offload: false
sequence_parallel: false
activation_checkpointing: true
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null
dynamic_batching:
enabled: false
train_mb_tokens: 4096
logprob_mb_tokens: 4096
sequence_length_round: 64
sequence_packing:
enabled: false
train_mb_tokens: 4096
logprob_mb_tokens: 4096
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64
optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 5.0e-5
weight_decay: 0.1
betas: [0.9, 0.98]
eps: 1e-5
foreach: false
fused: false
generation: null
# No token_aligner: same tokenizer as student → _compute_same_tokenizer_kl path

distillation:
num_prompts_per_step: 768
max_num_steps: 80000
max_num_epochs: 1
val_period: 1000
val_at_start: false
val_micro_batch_size: 1
topk_logits_k: 8192
use_ipc: true
seed: 42
# Skip per-step model/optimizer offloads in the off-policy distillation loop.
# Requires student + all teachers + student optimizer state to fit resident on
# each GPU. With Llama-3.2-1B student + Phi-4-mini + Llama-3.2-3B teachers on
# 80GB cards this fits comfortably (~35-45GB peak).
keep_models_resident: false

loss_fn:
loss_type: "KL"
temperature: 1.0
vocab_topk: 8192
exact_token_match_only: false
reverse_kl: false
project_teacher_to_student: false
gold_loss: true
xtoken_loss: true
ce_loss_scale: 0.1
dynamic_loss_scaling: true
normalize_by_vocab: true
teacher_aggregation_mode: "weighted"

checkpointing:
enabled: true
checkpoint_dir: "checkpoints/multi-teacher-distillation-llama1b"
metric_name: "train:loss"
higher_is_better: false
keep_top_k: 3
save_period: 10
save_optimizer: true
model_save_format: "safetensors"
save_consolidated: false

policy:
model_name: "meta-llama/Llama-3.2-1B"
tokenizer:
name: "meta-llama/Llama-3.2-1B"
chat_template: null
train_global_batch_size: 768
train_micro_batch_size: 1
max_total_sequence_length: 4096
precision: "bfloat16"
offload_optimizer_for_logprob: false
dtensor_cfg:
enabled: true
_v2: true
cpu_offload: false
sequence_parallel: false
activation_checkpointing: true
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null
max_grad_norm: 1.0
dynamic_batching:
enabled: false
train_mb_tokens: 4096
logprob_mb_tokens: 4096
sequence_length_round: 64
sequence_packing:
enabled: false
train_mb_tokens: 4096
logprob_mb_tokens: 4096
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64
optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 5.0e-5
weight_decay: 0.1
betas: [0.9, 0.98]
eps: 1e-5
foreach: false
fused: false
# Matches PyTorch reference: 5% linear warmup + cosine decay to min_lr=0.
# Tuned for a 1000-step production run (warmup=50, cosine_T_max=950).
# For different total steps, scale both counts ∝ distillation.max_num_steps.
scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
kwargs:
# Recent PyTorch enforces 0 < start_factor <= 1 (older versions allowed 0.0).
# 1e-8 is effectively zero-warmup but satisfies the constraint.
start_factor: 1.0e-8
end_factor: 1.0
total_iters: 50
- name: "torch.optim.lr_scheduler.CosineAnnealingLR"
kwargs:
T_max: 950
eta_min: 0.0
- milestones: [50]
sequence_packing:
enabled: false
train_mb_tokens: 4096
logprob_mb_tokens: 4096
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64
dynamic_batching:
enabled: false
train_mb_tokens: 4096
logprob_mb_tokens: 4096
sequence_length_round: 64
generation: null

teacher:
model_name: "microsoft/Phi-4-mini-instruct"

data:
max_input_seq_length: 4096
shuffle: true
# DataLoader workers that run teacher tokenize + DP alignment in parallel
# via CrossTokenizerCollator. 8 workers × 4 prefetch = up to 32 batches in
# flight, fully hiding CT behind teacher forward.
num_workers: 8
prefetch_factor: 4
train:
dataset_name: "arrow_text"
processor: "kd_data_processor"
arrow_files: null # Override via Hydra CLI: data.train.arrow_files=/path/to/file.arrow
prompt_file: null
characters_per_sample: 32768 # 4096 tokens × 8 chars/token (lazy packing)
default:
dataset_path: "allenai/c4"
hf_dataset_name: "allenai/c4"
hf_dataset_subset: "en"
hf_split: "train"
text_key: "text"

eval:
val_period: 50
val_at_start: false
max_val_samples: 512
val_batch_size: 64
max_rollout_turns: 1
benchmarks:
math:
dataset_name: "math"
prompt_file: "examples/prompts/cot.txt"
env:
num_workers: 8
mmlu:
dataset_name: "mmlu"
prompt_file: "examples/prompts/mmlu.txt"
env:
num_workers: 8
verifier_type: "multilingual_multichoice"
mmlu_5shot:
dataset_name: "mmlu"
prompt_file: "examples/prompts/mmlu.txt"
num_few_shot: 5
env:
num_workers: 8
verifier_type: "multilingual_multichoice"
mbpp_plus:
dataset_name: "mbpp_plus"
# Optional override:
# dataset_path: "evalplus/mbppplus"
split: "test"
env:
num_workers: 8
timeout_seconds: 5
humaneval_plus:
dataset_name: "humaneval_plus"
# Optional override:
# dataset_path: "evalplus/humanevalplus"
split: "test"
env:
num_workers: 8
timeout_seconds: 5

logger:
log_dir: "logs/multi-teacher-distillation-llama1b"
num_val_samples_to_print: 5
wandb_enabled: true
swanlab_enabled: false
mlflow_enabled: false
tensorboard_enabled: false
monitor_gpus: true
wandb:
project: "nemo-multi-teacher-distillation"
name: "multi-teacher-llama1b"
gpu_monitoring:
collection_interval: 10
flush_interval: 10

cluster:
gpus_per_node: 8
num_nodes: 16
Loading
Loading