This repository reproduces the grokking phenomenon described in Power et al., "Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets" (arXiv:2201.02177) using modular division modulo 97.
The repo includes:
- training and validation accuracy/loss curves
- long-horizon grokking dynamics
We study the task
Each example is represented as a token sequence:
[x, ÷, y, =, c]
The model is trained in a decoder-only, causal language modeling setup, with loss and accuracy computed only on the answer token.
The key phenomenon of interest is grokking:
- training accuracy reaches ~100% early,
- validation accuracy remains near chance for a long time,
- then suddenly rises in a sharp transition after many more optimization steps.
The implementation follows Appendix A.1 of the paper:
-
Architecture: decoder-only Transformer
- 2 layers
- width 128
- 4 attention heads
- causal self-attention
-
Objective: cross-entropy loss on the answer token only
-
Optimizer:
- Adam (no weight decay) for late-grokking runs
- AdamW (weight decay = 1) for faster, data-efficient variants
-
Learning rate: 1e-3
-
Warmup: 10 steps
-
Batch size:
min(512, |train| / 2)(as in the paper) -
Training budget:
- up to 1e6 steps for long-horizon grokking experiments
- Python 3.13
- PyTorch (CPU, CUDA, or MPS build as appropriate)
uv syncThis is the recommended setting if your goal is to reproduce the qualitative behavior shown in the paper (delayed generalization with a sharp transition).
uv run python main.py train \
--p 97 \
--train-frac 0.5 \
--steps 1000000 \
--eval-every 1000What to expect:
- Training accuracy reaches ~1.0 very early.
- Validation accuracy stays near chance (~1/97) for a long time.
- After tens or hundreds of thousands of steps (seed-dependent), validation accuracy rises rapidly.
- Embeddings gradually develop a strong periodic/Fourier-like structure.
Outputs are written to:
runs/division_mod_97/
``` (or custom --out-dir)
---
### Quick Start
```python
python3 main.py train \
--preset paper_late \
--p 97 \
--train-frac 0.5 \
--seed 0 \
--eval-every 1000 \
--out-dir runs/division_mod_97_pair_paper_late
After training, the following files are produced under runs/division_mod_97/ (or custom --out-dir):
metrics.jsonlTraining and validation loss/accuracy over time.
uv run python main.py plot --run-dir runs/division_mod_97Generates:
plots/accuracy.pngplots/loss.png
Power, A., Burda, Y., Edwards, H., Babuschkin, I., & Misra, V. Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets arXiv:2201.02177