Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions records/track_10min_16mb/2026-03-21_MemoryTokens/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Memory Tokens + Mixed Quantization

**val_bpb: 1.1659** (sliding window, stride=128, post int5/int6+zstd quantization roundtrip)
**Artifact size: 15,070,662 bytes** | 8xH100 SXM, 600s

## Novel Contribution: Memory Tokens

64 learnable embedding vectors that overwrite the first K positions of every input sequence. All real tokens can attend to them via the causal mask, giving every position access to learned global context — a shared scratchpad that the model optimizes end-to-end.

- **Cost:** 32,768 parameters (0.12% of model), zero compute overhead
- **A/B tested:** -0.014 BPB improvement vs identical config without memory tokens (1.2787 vs 1.2928 sliding, 1xH100)
- **Implementation:** Memory positions use `ignore_index=-100` so they contribute zero to loss. During sliding window eval, memory tokens are prepended (not overwritten) to preserve all real token context
- Memory tokens are exempt from weight decay — they're a learned scratchpad that needs to hold specific values, not be regularized toward zero

## Architecture

- 10 transformer layers, 512 dim, 8 heads, 4 KV heads (GQA)
- 3x MLP expansion (hidden=1536), relu^2 activation
- U-Net skip connections, tied embeddings
- **Memory tokens (64):** global context scratchpad prepended to every sequence
- **BigramHashEmbedding (10240):** hash consecutive token pairs for local context
- **SmearGate:** learned blend with previous token at embedding level
- **Partial RoPE (16/64 dims):** position encoding on 25% of head dims, rest is content-only
- **LN Scale:** RMSNorm output damped by 1/sqrt(layer+1) for stability

## Training

- Muon optimizer (matrix_lr=0.04, momentum=0.95) + AdamW (embed/scalar, WD=0.04)
- Muon weight decay (0.04), memory tokens exempt from WD
- MTP auxiliary heads (k=2, alpha=0.2, stripped before export)
- EMA (decay=0.997, on-device, every 10 steps)
- Late QAT: fake int6 quantization (STE) when lr_scale < 0.1
- seq_len=2048, batch=524K tokens, warmdown=3000, grad_clip=0.3
- 9,030 steps in 600s (64ms/step)

## Quantization

- **Int5** [-16,15] for MLP weights (most compressible)
- **Int6** [-32,31] for attention weights (precision-sensitive)
- **FP16** for tied embeddings and small tensors
- **zstd-22** compression (better ratio than zlib)

## Evaluation

- Sliding window eval with stride=128, seq_len=1024
- Batched (256 windows) + torch.compiled forward_logits
- Memory tokens prepended during sliding window (not overwritten)

## Results

| Metric | Value |
|--------|-------|
| Pre-quant val_bpb | 1.1842 |
| Int6+zstd roundtrip val_bpb | 1.1820 |
| **Sliding window val_bpb (s128)** | **1.1659** |
| Steps completed (600s cap) | 9,030 |
| Step time | 64ms |
| Model params | 25,812,049 |
| Artifact size | 15,070,662 bytes |

## Run Command

```bash
NUM_MEMORY_TOKENS=64 \
NUM_LAYERS=10 \
MTP_NUM_HEADS=2 \
MTP_ALPHA=0.2 \
MTP_ALPHA_DECAY=1 \
MTP_HEAD_LR=0.008 \
TRAIN_SEQ_LEN=2048 \
EVAL_SEQ_LEN=1024 \
EVAL_STRIDE=128 \
FP16_EMBED_EXPORT=1 \
RUN_ID=submission_8xh100 \
DATA_PATH=./data/datasets/fineweb10B_sp1024/ \
TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \
VOCAB_SIZE=1024 \
VAL_LOSS_EVERY=1000 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
```
11 changes: 11 additions & 0 deletions records/track_10min_16mb/2026-03-21_MemoryTokens/submission.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"author": "Austin Tarango",
"github_id": "sp00mm",
"name": "Memory Tokens + Mixed Quantization",
"blurb": "64 learnable memory tokens as global context scratchpad, combined with 10-layer 3x MLP, BigramHashEmbedding, SmearGate, partial RoPE, LN scale, EMA, late QAT, mixed int5/int6+zstd quantization, and sliding window eval (stride=128). Memory tokens provide a -0.014 BPB improvement over the same stack without them (A/B tested).",
"date": "2026-03-21T17:32:00Z",
"val_loss": 1.96862490,
"val_bpb": 1.16593150,
"bytes_total": 15070662,
"bytes_code": 72123
}
120 changes: 120 additions & 0 deletions records/track_10min_16mb/2026-03-21_MemoryTokens/train.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
W0321 16:17:20.557000 57903 torch/distributed/run.py:803]
W0321 16:17:20.557000 57903 torch/distributed/run.py:803] *****************************************
W0321 16:17:20.557000 57903 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0321 16:17:20.557000 57903 torch/distributed/run.py:803] *****************************************
logs/submission_8xh100.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:10
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:26598481
memory_tokens:64 memory_params:32768
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04
train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9301 val_bpb:4.1044 train_time:0ms step_avg:0.02ms
step:1/20000 train_loss:8.3195 train_time:120ms step_avg:119.58ms
step:2/20000 train_loss:11.7980 train_time:167ms step_avg:83.74ms
step:3/20000 train_loss:9.7056 train_time:232ms step_avg:77.26ms
step:4/20000 train_loss:9.0286 train_time:295ms step_avg:73.70ms
step:5/20000 train_loss:8.2414 train_time:358ms step_avg:71.70ms
step:6/20000 train_loss:9.0866 train_time:423ms step_avg:70.44ms
step:7/20000 train_loss:7.7866 train_time:486ms step_avg:69.49ms
step:8/20000 train_loss:7.6803 train_time:550ms step_avg:68.77ms
step:9/20000 train_loss:7.4604 train_time:614ms step_avg:68.21ms
step:10/20000 train_loss:7.1326 train_time:690ms step_avg:69.04ms
step:200/20000 train_loss:3.8336 train_time:12916ms step_avg:64.58ms
step:400/20000 train_loss:3.2400 train_time:25793ms step_avg:64.48ms
step:600/20000 train_loss:3.4516 train_time:38668ms step_avg:64.45ms
step:800/20000 train_loss:3.1412 train_time:51580ms step_avg:64.48ms
step:1000/20000 train_loss:3.2528 train_time:64524ms step_avg:64.52ms
step:1000/20000 val_loss:2.3132 val_bpb:1.3700 train_time:64531ms step_avg:64.53ms
step:1200/20000 train_loss:3.2851 train_time:77464ms step_avg:64.55ms
step:1400/20000 train_loss:3.3197 train_time:90401ms step_avg:64.57ms
step:1600/20000 train_loss:2.9129 train_time:103334ms step_avg:64.58ms
step:1800/20000 train_loss:3.0714 train_time:116258ms step_avg:64.59ms
step:2000/20000 train_loss:3.0818 train_time:129169ms step_avg:64.58ms
step:2000/20000 val_loss:2.2352 val_bpb:1.3238 train_time:129174ms step_avg:64.59ms
step:2200/20000 train_loss:3.2205 train_time:142071ms step_avg:64.58ms
step:2400/20000 train_loss:3.2534 train_time:154964ms step_avg:64.57ms
step:2600/20000 train_loss:3.0882 train_time:167851ms step_avg:64.56ms
step:2800/20000 train_loss:3.0362 train_time:180730ms step_avg:64.55ms
step:3000/20000 train_loss:4.2770 train_time:193597ms step_avg:64.53ms
step:3000/20000 val_loss:2.2228 val_bpb:1.3165 train_time:193602ms step_avg:64.53ms
step:3200/20000 train_loss:3.1657 train_time:206462ms step_avg:64.52ms
step:3400/20000 train_loss:2.9456 train_time:219322ms step_avg:64.51ms
step:3600/20000 train_loss:3.1282 train_time:232174ms step_avg:64.49ms
step:3800/20000 train_loss:3.0457 train_time:245026ms step_avg:64.48ms
step:4000/20000 train_loss:3.1794 train_time:257875ms step_avg:64.47ms
step:4000/20000 val_loss:2.1946 val_bpb:1.2998 train_time:257880ms step_avg:64.47ms
step:4200/20000 train_loss:3.1233 train_time:270785ms step_avg:64.47ms
step:4400/20000 train_loss:3.0505 train_time:283626ms step_avg:64.46ms
step:4600/20000 train_loss:3.0924 train_time:296473ms step_avg:64.45ms
step:4800/20000 train_loss:3.0320 train_time:309318ms step_avg:64.44ms
step:5000/20000 train_loss:3.1377 train_time:322160ms step_avg:64.43ms
step:5000/20000 val_loss:2.1819 val_bpb:1.2923 train_time:322165ms step_avg:64.43ms
step:5200/20000 train_loss:3.1860 train_time:335035ms step_avg:64.43ms
step:5400/20000 train_loss:3.1538 train_time:347881ms step_avg:64.42ms
step:5600/20000 train_loss:3.0167 train_time:360713ms step_avg:64.41ms
step:5800/20000 train_loss:3.0814 train_time:373553ms step_avg:64.41ms
step:6000/20000 train_loss:2.9858 train_time:386390ms step_avg:64.40ms
step:6000/20000 val_loss:2.1732 val_bpb:1.2871 train_time:386395ms step_avg:64.40ms
step:6200/20000 train_loss:2.9651 train_time:399234ms step_avg:64.39ms
step:6400/20000 train_loss:2.6521 train_time:412065ms step_avg:64.39ms
step:6600/20000 train_loss:2.8799 train_time:424898ms step_avg:64.38ms
step:6800/20000 train_loss:2.8737 train_time:437739ms step_avg:64.37ms
step:7000/20000 train_loss:2.7599 train_time:450577ms step_avg:64.37ms
step:7000/20000 val_loss:2.1377 val_bpb:1.2661 train_time:450582ms step_avg:64.37ms
step:7200/20000 train_loss:2.5335 train_time:463407ms step_avg:64.36ms
step:7400/20000 train_loss:2.3985 train_time:476244ms step_avg:64.36ms
step:7600/20000 train_loss:2.6284 train_time:489068ms step_avg:64.35ms
step:7800/20000 train_loss:2.5300 train_time:501897ms step_avg:64.35ms
step:8000/20000 train_loss:2.3737 train_time:514731ms step_avg:64.34ms
step:8000/20000 val_loss:2.0806 val_bpb:1.2323 train_time:514736ms step_avg:64.34ms
step:8200/20000 train_loss:2.4815 train_time:527564ms step_avg:64.34ms
step:8400/20000 train_loss:2.3902 train_time:540456ms step_avg:64.34ms
step:8600/20000 train_loss:2.3443 train_time:553294ms step_avg:64.34ms
step:8800/20000 train_loss:2.1107 train_time:566121ms step_avg:64.33ms
step:9000/20000 train_loss:2.0911 train_time:578943ms step_avg:64.33ms
step:9000/20000 val_loss:2.0015 val_bpb:1.1854 train_time:578949ms step_avg:64.33ms
step:9030/20000 val_loss:1.9995 val_bpb:1.1842 train_time:610471ms step_avg:67.60ms
stopping_early: wallclock_cap train_time:610471ms step:9030/20000
peak memory allocated: 13072 MiB reserved: 13524 MiB
ema: loading averaged weights for export
Serialized model: 101124717 bytes
Code size: 72123 bytes
Total submission size: 101196840 bytes
Serialized model int8+zlib: 17372032 bytes (payload:26466628 raw_torch:26518191 payload_ratio:3.82x)
Total submission size int8+zlib: 17444155 bytes
final_int8_zlib_roundtrip val_loss:1.9872 val_bpb:1.1769 eval_time:10476ms
final_int8_zlib_roundtrip_exact val_loss:1.98718950 val_bpb:1.17692555
Serialized model int6+zstd: 14998539 bytes
Total submission size int6+zstd: 15070662 bytes
final_int6_zstd_roundtrip val_loss:1.9958 val_bpb:1.1820 eval_time:1995ms
final_int6_zstd_roundtrip_exact val_loss:1.99575404 val_bpb:1.18199796
Compiling forward_logits for sliding window eval (stride=128, seq_len=1024)...
Compilation done, starting sliding window eval...
sliding_window_eval val_loss:1.9686 val_bpb:1.1659 stride:128 eval_time:23403ms
sliding_window_eval_exact val_loss:1.96862490 val_bpb:1.16593150
Loading