From 34302988ff67b7b740502ed03e25be326b3804ad Mon Sep 17 00:00:00 2001 From: Austin Tarango Date: Sat, 21 Mar 2026 10:37:45 -0600 Subject: [PATCH] Add Memory Tokens submission (val_bpb 1.1659, 8xH100 SXM) --- .../2026-03-21_MemoryTokens/README.md | 80 + .../2026-03-21_MemoryTokens/submission.json | 11 + .../2026-03-21_MemoryTokens/train.log | 120 ++ .../2026-03-21_MemoryTokens/train_gpt.py | 1676 +++++++++++++++++ 4 files changed, 1887 insertions(+) create mode 100644 records/track_10min_16mb/2026-03-21_MemoryTokens/README.md create mode 100644 records/track_10min_16mb/2026-03-21_MemoryTokens/submission.json create mode 100644 records/track_10min_16mb/2026-03-21_MemoryTokens/train.log create mode 100644 records/track_10min_16mb/2026-03-21_MemoryTokens/train_gpt.py diff --git a/records/track_10min_16mb/2026-03-21_MemoryTokens/README.md b/records/track_10min_16mb/2026-03-21_MemoryTokens/README.md new file mode 100644 index 000000000..69c8eae4b --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_MemoryTokens/README.md @@ -0,0 +1,80 @@ +# Memory Tokens + Mixed Quantization + +**val_bpb: 1.1659** (sliding window, stride=128, post int5/int6+zstd quantization roundtrip) +**Artifact size: 15,070,662 bytes** | 8xH100 SXM, 600s + +## Novel Contribution: Memory Tokens + +64 learnable embedding vectors that overwrite the first K positions of every input sequence. All real tokens can attend to them via the causal mask, giving every position access to learned global context — a shared scratchpad that the model optimizes end-to-end. + +- **Cost:** 32,768 parameters (0.12% of model), zero compute overhead +- **A/B tested:** -0.014 BPB improvement vs identical config without memory tokens (1.2787 vs 1.2928 sliding, 1xH100) +- **Implementation:** Memory positions use `ignore_index=-100` so they contribute zero to loss. During sliding window eval, memory tokens are prepended (not overwritten) to preserve all real token context +- Memory tokens are exempt from weight decay — they're a learned scratchpad that needs to hold specific values, not be regularized toward zero + +## Architecture + +- 10 transformer layers, 512 dim, 8 heads, 4 KV heads (GQA) +- 3x MLP expansion (hidden=1536), relu^2 activation +- U-Net skip connections, tied embeddings +- **Memory tokens (64):** global context scratchpad prepended to every sequence +- **BigramHashEmbedding (10240):** hash consecutive token pairs for local context +- **SmearGate:** learned blend with previous token at embedding level +- **Partial RoPE (16/64 dims):** position encoding on 25% of head dims, rest is content-only +- **LN Scale:** RMSNorm output damped by 1/sqrt(layer+1) for stability + +## Training + +- Muon optimizer (matrix_lr=0.04, momentum=0.95) + AdamW (embed/scalar, WD=0.04) +- Muon weight decay (0.04), memory tokens exempt from WD +- MTP auxiliary heads (k=2, alpha=0.2, stripped before export) +- EMA (decay=0.997, on-device, every 10 steps) +- Late QAT: fake int6 quantization (STE) when lr_scale < 0.1 +- seq_len=2048, batch=524K tokens, warmdown=3000, grad_clip=0.3 +- 9,030 steps in 600s (64ms/step) + +## Quantization + +- **Int5** [-16,15] for MLP weights (most compressible) +- **Int6** [-32,31] for attention weights (precision-sensitive) +- **FP16** for tied embeddings and small tensors +- **zstd-22** compression (better ratio than zlib) + +## Evaluation + +- Sliding window eval with stride=128, seq_len=1024 +- Batched (256 windows) + torch.compiled forward_logits +- Memory tokens prepended during sliding window (not overwritten) + +## Results + +| Metric | Value | +|--------|-------| +| Pre-quant val_bpb | 1.1842 | +| Int6+zstd roundtrip val_bpb | 1.1820 | +| **Sliding window val_bpb (s128)** | **1.1659** | +| Steps completed (600s cap) | 9,030 | +| Step time | 64ms | +| Model params | 25,812,049 | +| Artifact size | 15,070,662 bytes | + +## Run Command + +```bash +NUM_MEMORY_TOKENS=64 \ +NUM_LAYERS=10 \ +MTP_NUM_HEADS=2 \ +MTP_ALPHA=0.2 \ +MTP_ALPHA_DECAY=1 \ +MTP_HEAD_LR=0.008 \ +TRAIN_SEQ_LEN=2048 \ +EVAL_SEQ_LEN=1024 \ +EVAL_STRIDE=128 \ +FP16_EMBED_EXPORT=1 \ +RUN_ID=submission_8xh100 \ +DATA_PATH=./data/datasets/fineweb10B_sp1024/ \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +VAL_LOSS_EVERY=1000 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` diff --git a/records/track_10min_16mb/2026-03-21_MemoryTokens/submission.json b/records/track_10min_16mb/2026-03-21_MemoryTokens/submission.json new file mode 100644 index 000000000..1a3193916 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_MemoryTokens/submission.json @@ -0,0 +1,11 @@ +{ + "author": "Austin Tarango", + "github_id": "sp00mm", + "name": "Memory Tokens + Mixed Quantization", + "blurb": "64 learnable memory tokens as global context scratchpad, combined with 10-layer 3x MLP, BigramHashEmbedding, SmearGate, partial RoPE, LN scale, EMA, late QAT, mixed int5/int6+zstd quantization, and sliding window eval (stride=128). Memory tokens provide a -0.014 BPB improvement over the same stack without them (A/B tested).", + "date": "2026-03-21T17:32:00Z", + "val_loss": 1.96862490, + "val_bpb": 1.16593150, + "bytes_total": 15070662, + "bytes_code": 72123 +} diff --git a/records/track_10min_16mb/2026-03-21_MemoryTokens/train.log b/records/track_10min_16mb/2026-03-21_MemoryTokens/train.log new file mode 100644 index 000000000..24b62f1a8 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_MemoryTokens/train.log @@ -0,0 +1,120 @@ +W0321 16:17:20.557000 57903 torch/distributed/run.py:803] +W0321 16:17:20.557000 57903 torch/distributed/run.py:803] ***************************************** +W0321 16:17:20.557000 57903 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0321 16:17:20.557000 57903 torch/distributed/run.py:803] ***************************************** +logs/submission_8xh100.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26598481 +memory_tokens:64 memory_params:32768 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9301 val_bpb:4.1044 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:8.3195 train_time:120ms step_avg:119.58ms +step:2/20000 train_loss:11.7980 train_time:167ms step_avg:83.74ms +step:3/20000 train_loss:9.7056 train_time:232ms step_avg:77.26ms +step:4/20000 train_loss:9.0286 train_time:295ms step_avg:73.70ms +step:5/20000 train_loss:8.2414 train_time:358ms step_avg:71.70ms +step:6/20000 train_loss:9.0866 train_time:423ms step_avg:70.44ms +step:7/20000 train_loss:7.7866 train_time:486ms step_avg:69.49ms +step:8/20000 train_loss:7.6803 train_time:550ms step_avg:68.77ms +step:9/20000 train_loss:7.4604 train_time:614ms step_avg:68.21ms +step:10/20000 train_loss:7.1326 train_time:690ms step_avg:69.04ms +step:200/20000 train_loss:3.8336 train_time:12916ms step_avg:64.58ms +step:400/20000 train_loss:3.2400 train_time:25793ms step_avg:64.48ms +step:600/20000 train_loss:3.4516 train_time:38668ms step_avg:64.45ms +step:800/20000 train_loss:3.1412 train_time:51580ms step_avg:64.48ms +step:1000/20000 train_loss:3.2528 train_time:64524ms step_avg:64.52ms +step:1000/20000 val_loss:2.3132 val_bpb:1.3700 train_time:64531ms step_avg:64.53ms +step:1200/20000 train_loss:3.2851 train_time:77464ms step_avg:64.55ms +step:1400/20000 train_loss:3.3197 train_time:90401ms step_avg:64.57ms +step:1600/20000 train_loss:2.9129 train_time:103334ms step_avg:64.58ms +step:1800/20000 train_loss:3.0714 train_time:116258ms step_avg:64.59ms +step:2000/20000 train_loss:3.0818 train_time:129169ms step_avg:64.58ms +step:2000/20000 val_loss:2.2352 val_bpb:1.3238 train_time:129174ms step_avg:64.59ms +step:2200/20000 train_loss:3.2205 train_time:142071ms step_avg:64.58ms +step:2400/20000 train_loss:3.2534 train_time:154964ms step_avg:64.57ms +step:2600/20000 train_loss:3.0882 train_time:167851ms step_avg:64.56ms +step:2800/20000 train_loss:3.0362 train_time:180730ms step_avg:64.55ms +step:3000/20000 train_loss:4.2770 train_time:193597ms step_avg:64.53ms +step:3000/20000 val_loss:2.2228 val_bpb:1.3165 train_time:193602ms step_avg:64.53ms +step:3200/20000 train_loss:3.1657 train_time:206462ms step_avg:64.52ms +step:3400/20000 train_loss:2.9456 train_time:219322ms step_avg:64.51ms +step:3600/20000 train_loss:3.1282 train_time:232174ms step_avg:64.49ms +step:3800/20000 train_loss:3.0457 train_time:245026ms step_avg:64.48ms +step:4000/20000 train_loss:3.1794 train_time:257875ms step_avg:64.47ms +step:4000/20000 val_loss:2.1946 val_bpb:1.2998 train_time:257880ms step_avg:64.47ms +step:4200/20000 train_loss:3.1233 train_time:270785ms step_avg:64.47ms +step:4400/20000 train_loss:3.0505 train_time:283626ms step_avg:64.46ms +step:4600/20000 train_loss:3.0924 train_time:296473ms step_avg:64.45ms +step:4800/20000 train_loss:3.0320 train_time:309318ms step_avg:64.44ms +step:5000/20000 train_loss:3.1377 train_time:322160ms step_avg:64.43ms +step:5000/20000 val_loss:2.1819 val_bpb:1.2923 train_time:322165ms step_avg:64.43ms +step:5200/20000 train_loss:3.1860 train_time:335035ms step_avg:64.43ms +step:5400/20000 train_loss:3.1538 train_time:347881ms step_avg:64.42ms +step:5600/20000 train_loss:3.0167 train_time:360713ms step_avg:64.41ms +step:5800/20000 train_loss:3.0814 train_time:373553ms step_avg:64.41ms +step:6000/20000 train_loss:2.9858 train_time:386390ms step_avg:64.40ms +step:6000/20000 val_loss:2.1732 val_bpb:1.2871 train_time:386395ms step_avg:64.40ms +step:6200/20000 train_loss:2.9651 train_time:399234ms step_avg:64.39ms +step:6400/20000 train_loss:2.6521 train_time:412065ms step_avg:64.39ms +step:6600/20000 train_loss:2.8799 train_time:424898ms step_avg:64.38ms +step:6800/20000 train_loss:2.8737 train_time:437739ms step_avg:64.37ms +step:7000/20000 train_loss:2.7599 train_time:450577ms step_avg:64.37ms +step:7000/20000 val_loss:2.1377 val_bpb:1.2661 train_time:450582ms step_avg:64.37ms +step:7200/20000 train_loss:2.5335 train_time:463407ms step_avg:64.36ms +step:7400/20000 train_loss:2.3985 train_time:476244ms step_avg:64.36ms +step:7600/20000 train_loss:2.6284 train_time:489068ms step_avg:64.35ms +step:7800/20000 train_loss:2.5300 train_time:501897ms step_avg:64.35ms +step:8000/20000 train_loss:2.3737 train_time:514731ms step_avg:64.34ms +step:8000/20000 val_loss:2.0806 val_bpb:1.2323 train_time:514736ms step_avg:64.34ms +step:8200/20000 train_loss:2.4815 train_time:527564ms step_avg:64.34ms +step:8400/20000 train_loss:2.3902 train_time:540456ms step_avg:64.34ms +step:8600/20000 train_loss:2.3443 train_time:553294ms step_avg:64.34ms +step:8800/20000 train_loss:2.1107 train_time:566121ms step_avg:64.33ms +step:9000/20000 train_loss:2.0911 train_time:578943ms step_avg:64.33ms +step:9000/20000 val_loss:2.0015 val_bpb:1.1854 train_time:578949ms step_avg:64.33ms +step:9030/20000 val_loss:1.9995 val_bpb:1.1842 train_time:610471ms step_avg:67.60ms +stopping_early: wallclock_cap train_time:610471ms step:9030/20000 +peak memory allocated: 13072 MiB reserved: 13524 MiB +ema: loading averaged weights for export +Serialized model: 101124717 bytes +Code size: 72123 bytes +Total submission size: 101196840 bytes +Serialized model int8+zlib: 17372032 bytes (payload:26466628 raw_torch:26518191 payload_ratio:3.82x) +Total submission size int8+zlib: 17444155 bytes +final_int8_zlib_roundtrip val_loss:1.9872 val_bpb:1.1769 eval_time:10476ms +final_int8_zlib_roundtrip_exact val_loss:1.98718950 val_bpb:1.17692555 +Serialized model int6+zstd: 14998539 bytes +Total submission size int6+zstd: 15070662 bytes +final_int6_zstd_roundtrip val_loss:1.9958 val_bpb:1.1820 eval_time:1995ms +final_int6_zstd_roundtrip_exact val_loss:1.99575404 val_bpb:1.18199796 +Compiling forward_logits for sliding window eval (stride=128, seq_len=1024)... +Compilation done, starting sliding window eval... +sliding_window_eval val_loss:1.9686 val_bpb:1.1659 stride:128 eval_time:23403ms +sliding_window_eval_exact val_loss:1.96862490 val_bpb:1.16593150 diff --git a/records/track_10min_16mb/2026-03-21_MemoryTokens/train_gpt.py b/records/track_10min_16mb/2026-03-21_MemoryTokens/train_gpt.py new file mode 100644 index 000000000..930935d46 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_MemoryTokens/train_gpt.py @@ -0,0 +1,1676 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.04)) + + # Memory tokens: learnable tokens prepended to every sequence. + num_memory_tokens = int(os.environ.get("NUM_MEMORY_TOKENS", 0)) + + # Bigram hash embedding: inject token-pair info into the residual stream. + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + # EMA: exponential moving average of weights for smoother final model. + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + + # Partial RoPE: apply rotary embeddings to only a fraction of head dims. + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) # out of head_dim (64) + + # LN Scale: damp deeper layers by 1/sqrt(layer_idx+1). + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + + # Late QAT: enable fake int6 quantization (STE) in the final phase of training. + late_qat = bool(int(os.environ.get("LATE_QAT", "1"))) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", 0.1)) # lr_scale below this triggers QAT + + # Sliding window evaluation. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", train_seq_len)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) # 0 = disabled + + # Multi-token prediction auxiliary heads (training only, stripped before export). + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_alpha = float(os.environ.get("MTP_ALPHA", 0.2)) + mtp_alpha_decay = bool(int(os.environ.get("MTP_ALPHA_DECAY", 1))) + mtp_head_lr = float(os.environ.get("MTP_HEAD_LR", 0.008)) + + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding( + logits_fn, + num_memory_tokens: int, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + seq_len: int, + stride: int, + eval_batch_seqs: int = 256, +) -> tuple[float, float]: + """Batched sliding window eval: each token scored with near-full context.""" + total = val_tokens.numel() - 1 + score_offset = seq_len - stride + K = num_memory_tokens + + # Build windows + starts: list[int] = [] + pos = 0 + while pos + seq_len <= total: + starts.append(pos) + pos += stride + + # Distribute across ranks + n = len(starts) + per_rank = (n + world_size - 1) // world_size + my_start = rank * per_rank + my_end = min(my_start + per_rank, n) + my_windows = starts[my_start:my_end] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + with torch.inference_mode(): + for i in range(0, len(my_windows), eval_batch_seqs): + batch_starts = my_windows[i : i + eval_batch_seqs] + bs = len(batch_starts) + + # Build batch, pad to eval_batch_seqs to avoid recompilation + x_list = [val_tokens[s : s + seq_len] for s in batch_starts] + y_list = [val_tokens[s + 1 : s + seq_len + 1] for s in batch_starts] + pad = eval_batch_seqs - bs + if pad > 0: + x_list.extend([x_list[-1]] * pad) + y_list.extend([y_list[-1]] * pad) + + x = torch.stack(x_list).to(device=device, dtype=torch.int64) + y = torch.stack(y_list).to(device=device, dtype=torch.int64) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = logits_fn(x) + + # forward_logits prepends K memory tokens, so logits are [B, K+seq_len, V]. + for b in range(bs): + tail_logits = logits[b, K + score_offset:, :].float() + tail_targets = y[b, score_offset:] + loss = F.cross_entropy(tail_logits, tail_targets, reduction="sum") + loss_sum += loss.to(torch.float64) + tok_count += float(stride) + + tail_prev = x[b, score_offset:] + tail_tgt = y[b, score_offset:] + tb = base_bytes_lut[tail_tgt].to(dtype=torch.int16) + tb += (has_leading_space_lut[tail_tgt] & ~is_boundary_token_lut[tail_prev]).to(dtype=torch.int16) + byte_count += tb.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(tok_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / tok_count).item() + bpb = val_loss / math.log(2.0) * (tok_count.item() / byte_count.item()) + return val_loss, bpb + + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +FP16_EMBED_EXPORT = bool(int(os.environ.get("FP16_EMBED_EXPORT", "1"))) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Keep tied embedding in fp16 — int8 errors compound through input + output paths. + if FP16_EMBED_EXPORT and "tok_emb" in name: + kept = t.to(dtype=torch.float16).contiguous() + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# MIXED-PRECISION QUANTIZATION (int5/int6/int8 per layer category) +# ----------------------------- +# MLP weights are most compressible → int5 [-16,15] saves ~1.8MB vs int6. +# Attention weights need more precision → int6 [-32,31]. +# Embeddings/small tensors stay fp16. + +# Per-category bit width: maps category → (clip_max, label). +QUANT_BITS = {"int5": (15, "int5"), "int6": (31, "int6"), "int8": (127, "int8")} + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name: + return "attn" + return "other" + +def _quantize_per_row(t: Tensor, clip_max: int) -> tuple[Tensor, Tensor]: + """Symmetric per-row quantization to [-clip_max, clip_max].""" + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_max).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -clip_max, clip_max).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_max, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_max, clip_max).to(torch.int8) + return q, scale + +def mixed_quantize(state_dict: dict[str, Tensor], cat_bits: dict[str, str]): + """Quantize state dict with per-category bit widths. + cat_bits maps category ('mlp','attn','other') to bit label ('int5','int6','int8').""" + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if FP16_EMBED_EXPORT and "tok_emb" in name: + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + bit_label = cat_bits.get(cat, "int8") + clip_max = QUANT_BITS[bit_label][0] + q, s = _quantize_per_row(t, clip_max) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": bit_label} + return result, meta + +def dequantize_mixed(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class _FakeQuantInt6STE(torch.autograd.Function): + """Fake int6 quantization with straight-through estimator for Late QAT.""" + @staticmethod + def forward(ctx, w: Tensor) -> Tensor: + w32 = w.float() + abs_max = w32.abs().amax(dim=1, keepdim=True).clamp_min(1e-12) + scale = abs_max / 31.0 + q = torch.clamp(torch.round(w32 / scale), -32, 31) + return (q * scale).to(w.dtype) + @staticmethod + def backward(ctx, grad_output: Tensor) -> Tensor: + return grad_output + +_fake_quant_int6 = _FakeQuantInt6STE.apply + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + qat: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if self.qat and self.training and w.ndim == 2: + w = _fake_quant_int6(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + # Partial RoPE: only apply rotary to first rope_dims of each head. + # Remaining dims use position-free attention (content-only). + self.rope_dims = rope_dims if rope_dims > 0 else self.head_dim + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.rope_dims, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + rd = self.rope_dims + if rd < self.head_dim: + # Apply RoPE only to first rd dims, leave the rest position-free. + q_rope = apply_rotary_emb(q[..., :rd], cos, sin) + k_rope = apply_rotary_emb(k[..., :rd], cos, sin) + q = torch.cat([q_rope, q[..., rd:]], dim=-1) + k = torch.cat([k_rope, k[..., rd:]], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's — cheap local context.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table. + Injects bigram-level info into the residual stream at negligible compute cost.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + ln_scale: float = 1.0, + ): + super().__init__() + self.ln_scale = ln_scale + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale + attn_out = self.attn(self.attn_norm(x) * s) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x + + +class MTPHeads(nn.Module): + """K auxiliary linear heads for multi-token prediction. Training only.""" + + def __init__(self, k: int, model_dim: int, vocab_size: int, logit_softcap: float): + super().__init__() + self.logit_softcap = logit_softcap + self.heads = nn.ModuleList([CastedLinear(model_dim, vocab_size, bias=False) for _ in range(k)]) + for h in self.heads: + nn.init.zeros_(h.weight) + + def forward(self, hidden: Tensor, shifted_targets: list[Tensor]) -> Tensor: + total_loss = torch.zeros((), device=hidden.device, dtype=torch.float32) + cap = self.logit_softcap + for head, targets in zip(self.heads, shifted_targets): + logits = cap * torch.tanh(head(hidden) / cap) + total_loss = total_loss + F.cross_entropy(logits.float(), targets, ignore_index=-100, reduction="mean") + return total_loss / len(self.heads) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_memory_tokens: int = 0, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + rope_dims: int = 0, + ln_scale: bool = False, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_memory_tokens = num_memory_tokens + # Learnable memory tokens prepended to every sequence as global context. + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(1, num_memory_tokens, model_dim) * 0.02) + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + rope_dims=rope_dims, + ln_scale=1.0 / math.sqrt(i + 1) if ln_scale else 1.0, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads: MTPHeads | None = None + self.register_buffer("_mtp_alpha", torch.zeros((), dtype=torch.float32), persistent=False) + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + bsz, seqlen = input_ids.shape + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + + # Overwrite first K positions with learnable memory tokens. + # All later real tokens attend to them via causal mask. + # Target positions for memory slots are set to -100 (ignored by cross_entropy). + K = self.num_memory_tokens + if K > 0: + mem = self.memory_tokens.expand(bsz, -1, -1).to(dtype=x.dtype) + mem = F.rms_norm(mem, (mem.size(-1),)) + x = x.clone() + x[:, :K, :] = mem + target_ids = target_ids.clone() + target_ids[:, :K] = -100 + + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x) + hidden_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(hidden_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(hidden_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + primary_loss = F.cross_entropy(logits.float(), targets, ignore_index=-100, reduction="mean") + + if self.mtp_heads is not None and self.training: + num_mtp = len(self.mtp_heads.heads) + trunc_len = seqlen - num_mtp + hidden_trunc = x[:, :trunc_len, :].reshape(-1, x.size(-1)) + shifted_targets = [] + for k in range(1, num_mtp + 1): + t = target_ids[:, k : trunc_len + k].reshape(-1) + shifted_targets.append(t) + aux_loss = self.mtp_heads(hidden_trunc, shifted_targets) + return primary_loss + self._mtp_alpha * aux_loss + + return primary_loss + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return per-token logits [B, K+T, V] with softcap. Used for sliding window eval. + When memory tokens are active, they are prepended (not overwritten) so that + all real tokens retain their context and can attend to the learned scratchpad.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + K = self.num_memory_tokens + if K > 0: + bsz = input_ids.shape[0] + mem = self.memory_tokens.expand(bsz, -1, -1).to(dtype=x.dtype) + mem = F.rms_norm(mem, (mem.size(-1),)) + x = torch.cat([mem, x], dim=1) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight.to(x.dtype)) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_memory_tokens=args.num_memory_tokens, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + ).to(device).bfloat16() + if args.mtp_num_heads > 0: + base_model.mtp_heads = MTPHeads( + k=args.mtp_num_heads, + model_dim=args.model_dim, + vocab_size=args.vocab_size, + logit_softcap=args.logit_softcap, + ).to(device).bfloat16() + base_model._mtp_alpha.fill_(args.mtp_alpha) + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params: list[dict] = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + # Memory tokens get their own optimizer with NO weight decay — + # they're a learned scratchpad that needs to hold specific values. + if args.num_memory_tokens > 0: + optimizer_mem = torch.optim.Adam( + [{"params": [base_model.memory_tokens], "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.append(optimizer_mem) + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + if base_model.mtp_heads is not None: + optimizer_mtp = torch.optim.Adam( + [{"params": list(base_model.mtp_heads.parameters()), "lr": args.mtp_head_lr, "base_lr": args.mtp_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.append(optimizer_mtp) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + if args.num_memory_tokens > 0: + mem_params = base_model.memory_tokens.numel() + log0(f"memory_tokens:{args.num_memory_tokens} memory_params:{mem_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + + # EMA: keep an exponential moving average of weights for a smoother final model. + # Stored on-device to avoid CPU copy overhead; updated every 10 steps. + ema_sd: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_sd = {k: v.detach().clone().float() for k, v in base_model.state_dict().items()} + + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.mtp_num_heads > 0 and args.mtp_alpha_decay: + base_model._mtp_alpha.fill_(args.mtp_alpha * scale) + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + # Decoupled weight decay for Muon-optimized matrix params (Muon doesn't have built-in WD) + if args.muon_weight_decay > 0: + with torch.no_grad(): + for p in matrix_params: + p.mul_(1.0 - args.muon_weight_decay * optimizer_muon.param_groups[0]["lr"]) + zero_grad_all() + + step += 1 + + # EMA: update running average every 10 steps (on-device, no CPU copy). + if ema_sd is not None and step % 10 == 0: + d = args.ema_decay ** 10 + with torch.no_grad(): + for k, v in base_model.state_dict().items(): + ema_sd[k].mul_(d).add_(v.detach().float(), alpha=1.0 - d) + + # Late QAT: enable fake int6 quantization when LR is low enough. + if args.late_qat and scale < args.qat_threshold: + for m in base_model.modules(): + if isinstance(m, CastedLinear): + m.qat = True + + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + # Apply EMA weights for export — smoother than final training weights. + if ema_sd is not None: + log0("ema: loading averaged weights for export") + avg_sd = {k: v.to(dtype=base_model.state_dict()[k].dtype, device=base_model.state_dict()[k].device) + for k, v in ema_sd.items()} + base_model.load_state_dict(avg_sd, strict=True) + + # Disable QAT for eval. + for m in base_model.modules(): + if isinstance(m, CastedLinear): + m.qat = False + + # Strip MTP auxiliary heads before export — training only. + if base_model.mtp_heads is not None: + base_model.mtp_heads = None + base_model._mtp_alpha.fill_(0.0) + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + # Allow recompilation for variable batch sizes during post-training eval. + torch._dynamo.config.cache_size_limit = 64 + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # --- MIXED INT5/INT6 QUANTIZATION + ZSTD EXPORT --- + # MLP → int5 (most compressible), attention → int6 (needs precision) + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result6, quant_meta6 = mixed_quantize( + sd_cpu, {"mlp": "int5", "attn": "int6", "other": "int8"} + ) + quant_buf6 = io.BytesIO() + torch.save({"w": quant_result6, "m": quant_meta6}, quant_buf6) + quant_raw6 = quant_buf6.getvalue() + if _COMPRESSOR == "zstd": + quant_blob6 = zstandard.ZstdCompressor(level=22).compress(quant_raw6) + else: + quant_blob6 = zlib.compress(quant_raw6, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob6) + quant_file_bytes6 = os.path.getsize("final_model.int6.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes6} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes6 + code_bytes} bytes") + if distributed: + dist.barrier() + # Roundtrip int6 model for sliding window eval + with open("final_model.int6.ptz", "rb") as f: + quant_blob6_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed6 = zstandard.ZstdDecompressor().decompress(quant_blob6_disk) + else: + decompressed6 = zlib.decompress(quant_blob6_disk) + quant_state6 = torch.load(io.BytesIO(decompressed6), map_location="cpu") + base_model.load_state_dict(dequantize_mixed(quant_state6["w"], quant_state6["m"], sd_cpu), strict=True) + torch.cuda.synchronize() + t_q6eval = time.perf_counter() + q6_val_loss, q6_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int6_{_COMPRESSOR}_roundtrip val_loss:{q6_val_loss:.4f} val_bpb:{q6_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_q6eval):.0f}ms" + ) + log0(f"final_int6_{_COMPRESSOR}_roundtrip_exact val_loss:{q6_val_loss:.8f} val_bpb:{q6_val_bpb:.8f}") + + if args.eval_stride > 0: + eval_batch_seqs = 256 + log0(f"Compiling forward_logits for sliding window eval (stride={args.eval_stride}, seq_len={args.eval_seq_len})...") + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False) + # Warmup compilation + base_model.eval() + warmup_x = torch.zeros(eval_batch_seqs, args.eval_seq_len, dtype=torch.int64, device=device) + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + _ = compiled_logits(warmup_x) + log0("Compilation done, starting sliding window eval...") + torch.cuda.synchronize() + t_slide = time.perf_counter() + slide_val_loss, slide_val_bpb = eval_val_sliding( + compiled_logits, + base_model.num_memory_tokens, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + args.eval_seq_len, + args.eval_stride, + eval_batch_seqs=eval_batch_seqs, + ) + base_model.train() + torch.cuda.synchronize() + log0( + f"sliding_window_eval val_loss:{slide_val_loss:.4f} val_bpb:{slide_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"sliding_window_eval_exact val_loss:{slide_val_loss:.8f} val_bpb:{slide_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()