Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
__pycache__/
*.pyc
checkpoints/
*.pt
.DS_Store
data/
data/raw/
data/tokenized/
*.cache
.pytest_cache/
76 changes: 59 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,68 @@ pip install -r requirements.txt

## Quick start

**Verify your setup** (no data download, runs in seconds):
**1. Verify your setup** (no data download, runs in seconds):
```bash
python scripts/smoke_test.py
```

This checks imports, model construction, a forward/backward pass on dummy data, and a
short generation loop. Run it after setup or whenever you change the architecture.
**2. Download and tokenize the dataset:**
```bash
python scripts/prepare_data.py
```

This downloads [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories)
(~500 MB) and tokenizes it. The raw stories are saved as JSONL files you can inspect,
and the tokenized data is saved as `.pt` tensors:

```
data/
├── raw/
│ ├── train.jsonl # 2.1M stories as readable JSON
│ └── validation.jsonl # 22K stories
└── tokenized/
├── train.pt # ~472M tokens as a flat tensor
└── validation.pt # ~4.7M tokens
```

You can inspect the raw data with:
```bash
head -5 data/raw/train.jsonl | python -m json.tool
```

Or the tokenized data in Python:
```python
import torch
from minillm.tokenizer import Tokenizer
data = torch.load("data/tokenized/train.pt", weights_only=True)
tok = Tokenizer(10_000)
print(tok.decode(data[:100].tolist())) # first 100 tokens
```

**3. Train the model:**
```bash
python scripts/train.py # full run (10k steps)
python scripts/train.py --max-steps 500 # train 500 steps then stop
```

Every 500 steps the script prints validation loss, perplexity, and a generated text
sample so you can watch the model improve in real time.

**Train the model:**
**4. Resume training from a checkpoint:**
```bash
python scripts/train.py
python scripts/train.py --resume checkpoints/step_000500.pt # resume to 10k
python scripts/train.py --resume checkpoints/step_000500.pt --max-steps 2000 # resume to 2k
```

Training downloads the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories)
dataset on first run (~500 MB). You'll see a progress bar with the current loss. Every 500
steps the script prints validation loss, perplexity, and a generated text sample so you can
watch the model improve in real time.
Checkpoints save the full model state, optimizer state, and training progress so
training picks up exactly where it left off.

**Generate text from a checkpoint:**
**5. Generate text from a checkpoint:**
```bash
python scripts/generate.py --checkpoint checkpoints/step_010000.pt --prompt "Once upon a time"
```

**Interactive mode:**
**6. Interactive mode:**
```bash
python scripts/generate.py --checkpoint checkpoints/step_010000.pt --interactive
```
Expand Down Expand Up @@ -96,8 +134,8 @@ This project is designed to be read bottom-up. Here's the suggested order:
6. **`minillm/tokenizer.py`** -- How text becomes numbers. Wraps tiktoken's BPE
tokenizer with a vocab cap.

7. **`minillm/dataset.py`** -- How training data is prepared: tokenize everything,
concatenate into one long tensor, serve random windows with input/target shifted by 1.
7. **`minillm/dataset.py`** -- How pre-tokenized data is loaded from disk and
served as random sliding windows with input/target shifted by 1.

8. **`minillm/generate.py`** -- Autoregressive generation. Follow how the model
predicts one token at a time, and how temperature, top-k, and top-p shape the
Expand Down Expand Up @@ -158,15 +196,19 @@ miniLLM/
│ ├── test_model.py # Full model: shapes, loss, backward, weight tying
│ ├── test_generate.py # Generation loop: length, determinism
│ └── test_utils.py # LR schedule, perplexity
├── data/ # Created by prepare_data.py (git-ignored)
│ ├── raw/ # Human-readable JSONL files
│ └── tokenized/ # Pre-tokenized .pt tensors
├── scripts/
│ ├── train.py # Training entry point
│ ├── generate.py # Generation entry point
│ └── smoke_test.py # Quick sanity check (no data needed)
│ ├── prepare_data.py # Download + tokenize data
│ ├── train.py # Training (supports --resume, --max-steps)
│ ├── generate.py # Generation entry point
│ └── smoke_test.py # Quick sanity check (no data needed)
└── minillm/
├── __init__.py
├── config.py # All hyperparameters in one dataclass
├── tokenizer.py # BPE tokenizer (wraps tiktoken)
├── dataset.py # Data loading and batching
├── dataset.py # Loads pre-tokenized data, serves sliding windows
├── generate.py # Generation logic (temperature, top-k, top-p)
├── utils.py # LR scheduling, checkpointing, evaluation
└── model/
Expand Down
40 changes: 19 additions & 21 deletions minillm/dataset.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,35 @@
from pathlib import Path

import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

from minillm.config import MiniLLMConfig
from minillm.tokenizer import Tokenizer

TOKENIZED_DIR = Path("data/tokenized")


class TinyStoriesDataset(Dataset):
class TokenizedDataset(Dataset):
"""
Loads TinyStories, tokenizes the full split, concatenates all tokens into
one flat tensor, then serves random (context_length)-sized windows.
Loads pre-tokenized data from a .pt file and serves sliding-window
(context_length)-sized chunks.

Input: tokens[i : i + context_length]
Target: tokens[i + 1 : i + context_length + 1] (shifted by 1)
"""

def __init__(self, split: str, config: MiniLLMConfig, tokenizer: Tokenizer):
def __init__(self, split: str, config: MiniLLMConfig):
super().__init__()
self.context_length = config.context_length

print(f"Loading {split} split of {config.dataset_name}...")
ds = load_dataset(config.dataset_name, split=split)

print(f"Tokenizing {len(ds)} examples...")
all_tokens: list[int] = []
for example in ds:
all_tokens.extend(tokenizer.encode(example["text"]))
path = TOKENIZED_DIR / f"{split}.pt"
if not path.exists():
raise FileNotFoundError(
f"{path} not found. Run 'python scripts/prepare_data.py' first."
)

self.data = torch.tensor(all_tokens, dtype=torch.long)
print(f"Total tokens ({split}): {len(self.data):,}")
print(f"Loading pre-tokenized {split} data from {path}...")
self.data = torch.load(path, weights_only=True)
print(f" {len(self.data):,} tokens loaded")

def __len__(self) -> int:
return len(self.data) - self.context_length
Expand All @@ -40,12 +41,9 @@ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
return x, y


def create_dataloaders(
config: MiniLLMConfig,
tokenizer: Tokenizer,
) -> tuple[DataLoader, DataLoader]:
train_ds = TinyStoriesDataset("train", config, tokenizer)
val_ds = TinyStoriesDataset("validation", config, tokenizer)
def create_dataloaders(config: MiniLLMConfig) -> tuple[DataLoader, DataLoader]:
train_ds = TokenizedDataset("train", config)
val_ds = TokenizedDataset("validation", config)

train_loader = DataLoader(
train_ds,
Expand Down
2 changes: 2 additions & 0 deletions minillm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def save_checkpoint(
step: int,
loss: float,
config: MiniLLMConfig,
best_val_loss: float = float("inf"),
):
Path(config.checkpoint_dir).mkdir(parents=True, exist_ok=True)
path = os.path.join(config.checkpoint_dir, f"step_{step:06d}.pt")
Expand All @@ -36,6 +37,7 @@ def save_checkpoint(
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
"best_val_loss": best_val_loss,
"config": config,
},
path,
Expand Down
128 changes: 128 additions & 0 deletions scripts/prepare_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
Data preparation pipeline -- run this before training.

Step 1 (--download): Download TinyStories to data/raw/ as JSONL files
Step 2 (--tokenize): Tokenize raw data and save to data/tokenized/ as .pt files
Step 3 (no flags): Run both steps

Usage:
python scripts/prepare_data.py # full pipeline
python scripts/prepare_data.py --download # download only
python scripts/prepare_data.py --tokenize # tokenize only (requires download first)
"""

import argparse
import json
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

import torch
from datasets import load_dataset

from minillm.config import MiniLLMConfig
from minillm.tokenizer import Tokenizer

RAW_DIR = Path("data/raw")
TOKENIZED_DIR = Path("data/tokenized")


def download(config: MiniLLMConfig):
"""Download TinyStories and save as human-readable JSONL files."""
RAW_DIR.mkdir(parents=True, exist_ok=True)

for split in ("train", "validation"):
out_path = RAW_DIR / f"{split}.jsonl"
if out_path.exists():
n_lines = sum(1 for _ in open(out_path))
print(f"[download] {out_path} already exists ({n_lines:,} examples), skipping")
continue

print(f"[download] Downloading {split} split of {config.dataset_name}...")
ds = load_dataset(config.dataset_name, split=split)

print(f"[download] Saving {len(ds):,} examples to {out_path}...")
with open(out_path, "w") as f:
for example in ds:
f.write(json.dumps(example) + "\n")

print(f"[download] Done: {out_path} ({len(ds):,} examples)")

print()
print("Raw data saved to data/raw/. You can inspect the stories with:")
print(" head -5 data/raw/train.jsonl | python -m json.tool")
print()


def tokenize(config: MiniLLMConfig):
"""Tokenize raw JSONL files and save as .pt tensors."""
TOKENIZED_DIR.mkdir(parents=True, exist_ok=True)
tokenizer = Tokenizer(config.vocab_size)

for split in ("train", "validation"):
raw_path = RAW_DIR / f"{split}.jsonl"
out_path = TOKENIZED_DIR / f"{split}.pt"

if out_path.exists():
data = torch.load(out_path, weights_only=True)
print(f"[tokenize] {out_path} already exists ({len(data):,} tokens), skipping")
continue

if not raw_path.exists():
print(f"[tokenize] ERROR: {raw_path} not found. Run with --download first.")
sys.exit(1)

print(f"[tokenize] Tokenizing {raw_path}...")
all_tokens: list[int] = []
n_examples = 0

with open(raw_path) as f:
for line in f:
example = json.loads(line)
all_tokens.extend(tokenizer.encode(example["text"]))
n_examples += 1
if n_examples % 200_000 == 0:
print(f" ...processed {n_examples:,} examples ({len(all_tokens):,} tokens)")

data = torch.tensor(all_tokens, dtype=torch.long)
torch.save(data, out_path)
print(f"[tokenize] Done: {out_path} ({len(data):,} tokens from {n_examples:,} examples)")

# Print token stats
print()
print("Tokenized data saved to data/tokenized/. Stats:")
for split in ("train", "validation"):
path = TOKENIZED_DIR / f"{split}.pt"
if path.exists():
data = torch.load(path, weights_only=True)
print(f" {split}: {len(data):,} tokens ({path.stat().st_size / 1e6:.1f} MB on disk)")

print()
print("You can inspect the tokenized data in Python:")
print(" import torch")
print(" from minillm.tokenizer import Tokenizer")
print(" data = torch.load('data/tokenized/train.pt', weights_only=True)")
print(" tok = Tokenizer(10_000)")
print(" print(tok.decode(data[:100].tolist())) # first 100 tokens")
print()


def main():
parser = argparse.ArgumentParser(description="Prepare data for miniLLM training")
parser.add_argument("--download", action="store_true", help="Download raw data only")
parser.add_argument("--tokenize", action="store_true", help="Tokenize raw data only")
args = parser.parse_args()

config = MiniLLMConfig()

run_both = not args.download and not args.tokenize

if args.download or run_both:
download(config)
if args.tokenize or run_both:
tokenize(config)


if __name__ == "__main__":
main()
Loading