diff --git a/.gitignore b/.gitignore index fd0217b..221d07c 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,8 @@ __pycache__/ *.pyc checkpoints/ -*.pt .DS_Store -data/ +data/raw/ +data/tokenized/ *.cache +.pytest_cache/ diff --git a/README.md b/README.md index 1c9bb8a..31b619d 100644 --- a/README.md +++ b/README.md @@ -23,30 +23,68 @@ pip install -r requirements.txt ## Quick start -**Verify your setup** (no data download, runs in seconds): +**1. Verify your setup** (no data download, runs in seconds): ```bash python scripts/smoke_test.py ``` -This checks imports, model construction, a forward/backward pass on dummy data, and a -short generation loop. Run it after setup or whenever you change the architecture. +**2. Download and tokenize the dataset:** +```bash +python scripts/prepare_data.py +``` + +This downloads [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) +(~500 MB) and tokenizes it. The raw stories are saved as JSONL files you can inspect, +and the tokenized data is saved as `.pt` tensors: + +``` +data/ +├── raw/ +│ ├── train.jsonl # 2.1M stories as readable JSON +│ └── validation.jsonl # 22K stories +└── tokenized/ + ├── train.pt # ~472M tokens as a flat tensor + └── validation.pt # ~4.7M tokens +``` + +You can inspect the raw data with: +```bash +head -5 data/raw/train.jsonl | python -m json.tool +``` + +Or the tokenized data in Python: +```python +import torch +from minillm.tokenizer import Tokenizer +data = torch.load("data/tokenized/train.pt", weights_only=True) +tok = Tokenizer(10_000) +print(tok.decode(data[:100].tolist())) # first 100 tokens +``` + +**3. Train the model:** +```bash +python scripts/train.py # full run (10k steps) +python scripts/train.py --max-steps 500 # train 500 steps then stop +``` + +Every 500 steps the script prints validation loss, perplexity, and a generated text +sample so you can watch the model improve in real time. -**Train the model:** +**4. Resume training from a checkpoint:** ```bash -python scripts/train.py +python scripts/train.py --resume checkpoints/step_000500.pt # resume to 10k +python scripts/train.py --resume checkpoints/step_000500.pt --max-steps 2000 # resume to 2k ``` -Training downloads the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) -dataset on first run (~500 MB). You'll see a progress bar with the current loss. Every 500 -steps the script prints validation loss, perplexity, and a generated text sample so you can -watch the model improve in real time. +Checkpoints save the full model state, optimizer state, and training progress so +training picks up exactly where it left off. -**Generate text from a checkpoint:** +**5. Generate text from a checkpoint:** ```bash python scripts/generate.py --checkpoint checkpoints/step_010000.pt --prompt "Once upon a time" ``` -**Interactive mode:** +**6. Interactive mode:** ```bash python scripts/generate.py --checkpoint checkpoints/step_010000.pt --interactive ``` @@ -96,8 +134,8 @@ This project is designed to be read bottom-up. Here's the suggested order: 6. **`minillm/tokenizer.py`** -- How text becomes numbers. Wraps tiktoken's BPE tokenizer with a vocab cap. -7. **`minillm/dataset.py`** -- How training data is prepared: tokenize everything, - concatenate into one long tensor, serve random windows with input/target shifted by 1. +7. **`minillm/dataset.py`** -- How pre-tokenized data is loaded from disk and + served as random sliding windows with input/target shifted by 1. 8. **`minillm/generate.py`** -- Autoregressive generation. Follow how the model predicts one token at a time, and how temperature, top-k, and top-p shape the @@ -158,15 +196,19 @@ miniLLM/ │ ├── test_model.py # Full model: shapes, loss, backward, weight tying │ ├── test_generate.py # Generation loop: length, determinism │ └── test_utils.py # LR schedule, perplexity +├── data/ # Created by prepare_data.py (git-ignored) +│ ├── raw/ # Human-readable JSONL files +│ └── tokenized/ # Pre-tokenized .pt tensors ├── scripts/ -│ ├── train.py # Training entry point -│ ├── generate.py # Generation entry point -│ └── smoke_test.py # Quick sanity check (no data needed) +│ ├── prepare_data.py # Download + tokenize data +│ ├── train.py # Training (supports --resume, --max-steps) +│ ├── generate.py # Generation entry point +│ └── smoke_test.py # Quick sanity check (no data needed) └── minillm/ ├── __init__.py ├── config.py # All hyperparameters in one dataclass ├── tokenizer.py # BPE tokenizer (wraps tiktoken) - ├── dataset.py # Data loading and batching + ├── dataset.py # Loads pre-tokenized data, serves sliding windows ├── generate.py # Generation logic (temperature, top-k, top-p) ├── utils.py # LR scheduling, checkpointing, evaluation └── model/ diff --git a/minillm/dataset.py b/minillm/dataset.py index 178788e..f5de84d 100644 --- a/minillm/dataset.py +++ b/minillm/dataset.py @@ -1,34 +1,35 @@ +from pathlib import Path + import torch from torch.utils.data import Dataset, DataLoader -from datasets import load_dataset from minillm.config import MiniLLMConfig -from minillm.tokenizer import Tokenizer + +TOKENIZED_DIR = Path("data/tokenized") -class TinyStoriesDataset(Dataset): +class TokenizedDataset(Dataset): """ - Loads TinyStories, tokenizes the full split, concatenates all tokens into - one flat tensor, then serves random (context_length)-sized windows. + Loads pre-tokenized data from a .pt file and serves sliding-window + (context_length)-sized chunks. Input: tokens[i : i + context_length] Target: tokens[i + 1 : i + context_length + 1] (shifted by 1) """ - def __init__(self, split: str, config: MiniLLMConfig, tokenizer: Tokenizer): + def __init__(self, split: str, config: MiniLLMConfig): super().__init__() self.context_length = config.context_length - print(f"Loading {split} split of {config.dataset_name}...") - ds = load_dataset(config.dataset_name, split=split) - - print(f"Tokenizing {len(ds)} examples...") - all_tokens: list[int] = [] - for example in ds: - all_tokens.extend(tokenizer.encode(example["text"])) + path = TOKENIZED_DIR / f"{split}.pt" + if not path.exists(): + raise FileNotFoundError( + f"{path} not found. Run 'python scripts/prepare_data.py' first." + ) - self.data = torch.tensor(all_tokens, dtype=torch.long) - print(f"Total tokens ({split}): {len(self.data):,}") + print(f"Loading pre-tokenized {split} data from {path}...") + self.data = torch.load(path, weights_only=True) + print(f" {len(self.data):,} tokens loaded") def __len__(self) -> int: return len(self.data) - self.context_length @@ -40,12 +41,9 @@ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: return x, y -def create_dataloaders( - config: MiniLLMConfig, - tokenizer: Tokenizer, -) -> tuple[DataLoader, DataLoader]: - train_ds = TinyStoriesDataset("train", config, tokenizer) - val_ds = TinyStoriesDataset("validation", config, tokenizer) +def create_dataloaders(config: MiniLLMConfig) -> tuple[DataLoader, DataLoader]: + train_ds = TokenizedDataset("train", config) + val_ds = TokenizedDataset("validation", config) train_loader = DataLoader( train_ds, diff --git a/minillm/utils.py b/minillm/utils.py index a031003..2f9d73d 100644 --- a/minillm/utils.py +++ b/minillm/utils.py @@ -27,6 +27,7 @@ def save_checkpoint( step: int, loss: float, config: MiniLLMConfig, + best_val_loss: float = float("inf"), ): Path(config.checkpoint_dir).mkdir(parents=True, exist_ok=True) path = os.path.join(config.checkpoint_dir, f"step_{step:06d}.pt") @@ -36,6 +37,7 @@ def save_checkpoint( "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, + "best_val_loss": best_val_loss, "config": config, }, path, diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py new file mode 100644 index 0000000..69e5d3c --- /dev/null +++ b/scripts/prepare_data.py @@ -0,0 +1,128 @@ +""" +Data preparation pipeline -- run this before training. + +Step 1 (--download): Download TinyStories to data/raw/ as JSONL files +Step 2 (--tokenize): Tokenize raw data and save to data/tokenized/ as .pt files +Step 3 (no flags): Run both steps + +Usage: + python scripts/prepare_data.py # full pipeline + python scripts/prepare_data.py --download # download only + python scripts/prepare_data.py --tokenize # tokenize only (requires download first) +""" + +import argparse +import json +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import torch +from datasets import load_dataset + +from minillm.config import MiniLLMConfig +from minillm.tokenizer import Tokenizer + +RAW_DIR = Path("data/raw") +TOKENIZED_DIR = Path("data/tokenized") + + +def download(config: MiniLLMConfig): + """Download TinyStories and save as human-readable JSONL files.""" + RAW_DIR.mkdir(parents=True, exist_ok=True) + + for split in ("train", "validation"): + out_path = RAW_DIR / f"{split}.jsonl" + if out_path.exists(): + n_lines = sum(1 for _ in open(out_path)) + print(f"[download] {out_path} already exists ({n_lines:,} examples), skipping") + continue + + print(f"[download] Downloading {split} split of {config.dataset_name}...") + ds = load_dataset(config.dataset_name, split=split) + + print(f"[download] Saving {len(ds):,} examples to {out_path}...") + with open(out_path, "w") as f: + for example in ds: + f.write(json.dumps(example) + "\n") + + print(f"[download] Done: {out_path} ({len(ds):,} examples)") + + print() + print("Raw data saved to data/raw/. You can inspect the stories with:") + print(" head -5 data/raw/train.jsonl | python -m json.tool") + print() + + +def tokenize(config: MiniLLMConfig): + """Tokenize raw JSONL files and save as .pt tensors.""" + TOKENIZED_DIR.mkdir(parents=True, exist_ok=True) + tokenizer = Tokenizer(config.vocab_size) + + for split in ("train", "validation"): + raw_path = RAW_DIR / f"{split}.jsonl" + out_path = TOKENIZED_DIR / f"{split}.pt" + + if out_path.exists(): + data = torch.load(out_path, weights_only=True) + print(f"[tokenize] {out_path} already exists ({len(data):,} tokens), skipping") + continue + + if not raw_path.exists(): + print(f"[tokenize] ERROR: {raw_path} not found. Run with --download first.") + sys.exit(1) + + print(f"[tokenize] Tokenizing {raw_path}...") + all_tokens: list[int] = [] + n_examples = 0 + + with open(raw_path) as f: + for line in f: + example = json.loads(line) + all_tokens.extend(tokenizer.encode(example["text"])) + n_examples += 1 + if n_examples % 200_000 == 0: + print(f" ...processed {n_examples:,} examples ({len(all_tokens):,} tokens)") + + data = torch.tensor(all_tokens, dtype=torch.long) + torch.save(data, out_path) + print(f"[tokenize] Done: {out_path} ({len(data):,} tokens from {n_examples:,} examples)") + + # Print token stats + print() + print("Tokenized data saved to data/tokenized/. Stats:") + for split in ("train", "validation"): + path = TOKENIZED_DIR / f"{split}.pt" + if path.exists(): + data = torch.load(path, weights_only=True) + print(f" {split}: {len(data):,} tokens ({path.stat().st_size / 1e6:.1f} MB on disk)") + + print() + print("You can inspect the tokenized data in Python:") + print(" import torch") + print(" from minillm.tokenizer import Tokenizer") + print(" data = torch.load('data/tokenized/train.pt', weights_only=True)") + print(" tok = Tokenizer(10_000)") + print(" print(tok.decode(data[:100].tolist())) # first 100 tokens") + print() + + +def main(): + parser = argparse.ArgumentParser(description="Prepare data for miniLLM training") + parser.add_argument("--download", action="store_true", help="Download raw data only") + parser.add_argument("--tokenize", action="store_true", help="Tokenize raw data only") + args = parser.parse_args() + + config = MiniLLMConfig() + + run_both = not args.download and not args.tokenize + + if args.download or run_both: + download(config) + if args.tokenize or run_both: + tokenize(config) + + +if __name__ == "__main__": + main() diff --git a/scripts/train.py b/scripts/train.py index 249c188..826d867 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,7 +1,17 @@ +""" +Training script for miniLLM. + +Supports incremental training with checkpoint resumption: + python scripts/train.py # train from scratch (full run) + python scripts/train.py --max-steps 500 # train 500 steps then stop + python scripts/train.py --resume checkpoints/step_000500.pt # resume from step 500 + python scripts/train.py --resume checkpoints/step_000500.pt --max-steps 2000 # resume, train to step 2000 +""" + +import argparse import sys from pathlib import Path -# Allow running as `python scripts/train.py` from the project root sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) import torch @@ -20,10 +30,7 @@ ) -def train(config: MiniLLMConfig | None = None): - if config is None: - config = MiniLLMConfig() - +def train(config: MiniLLMConfig, resume_path: str | None = None, max_steps: int | None = None): device = config.resolve_device() print(f"Device: {device}") @@ -38,15 +45,36 @@ def train(config: MiniLLMConfig | None = None): betas=(0.9, 0.95), ) - train_loader, val_loader = create_dataloaders(config, tokenizer) + start_step = 0 + best_val_loss = float("inf") + + if resume_path is not None: + print(f"\nResuming from {resume_path}...") + checkpoint = torch.load(resume_path, weights_only=False, map_location=device) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + start_step = checkpoint["step"] + best_val_loss = checkpoint.get("best_val_loss", checkpoint.get("loss", float("inf"))) + print(f"Resumed at step {start_step} (loss={checkpoint['loss']:.4f})") + + total_steps = max_steps if max_steps is not None else config.max_steps + if start_step >= total_steps: + print(f"Already at step {start_step} >= max_steps {total_steps}. Nothing to do.") + return + + print(f"\nTraining from step {start_step + 1} to {total_steps}") + print(f" Eval every {config.eval_interval} steps") + print(f" Checkpoint every {config.checkpoint_interval} steps") + print() + + train_loader, val_loader = create_dataloaders(config) train_iter = iter(train_loader) model.train() - best_val_loss = float("inf") + steps_to_run = total_steps - start_step - pbar = tqdm(range(1, config.max_steps + 1), desc="Training") + pbar = tqdm(range(start_step + 1, total_steps + 1), desc="Training", total=steps_to_run) for step in pbar: - # Get next batch (cycle through data) try: x, y = next(train_iter) except StopIteration: @@ -55,29 +83,24 @@ def train(config: MiniLLMConfig | None = None): x, y = x.to(device), y.to(device) - # Update learning rate lr = get_lr(step, config) for param_group in optimizer.param_groups: param_group["lr"] = lr - # Forward + backward _, loss = model(x, y) optimizer.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) optimizer.step() - # Logging if step % config.log_interval == 0: pbar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{lr:.2e}") - # Evaluation if step % config.eval_interval == 0: val_loss = estimate_loss(model, val_loader, config) ppl = compute_perplexity(val_loss) print(f"\n[Step {step}] val_loss={val_loss:.4f} perplexity={ppl:.2f}") - # Generate a sample sample = generate( model, tokenizer, "Once upon a time", max_new_tokens=100, temperature=0.8, @@ -87,16 +110,34 @@ def train(config: MiniLLMConfig | None = None): if val_loss < best_val_loss: best_val_loss = val_loss - save_checkpoint(model, optimizer, step, val_loss, config) + save_checkpoint(model, optimizer, step, val_loss, config, + best_val_loss=best_val_loss) - # Periodic checkpoint if step % config.checkpoint_interval == 0: - save_checkpoint(model, optimizer, step, loss.item(), config) + save_checkpoint(model, optimizer, step, loss.item(), config, + best_val_loss=best_val_loss) + + save_checkpoint(model, optimizer, total_steps, loss.item(), config, + best_val_loss=best_val_loss) + print(f"\nTraining complete! Trained steps {start_step + 1} to {total_steps}.") + print(f"Best validation loss: {best_val_loss:.4f}") + + +def main(): + parser = argparse.ArgumentParser(description="Train miniLLM") + parser.add_argument( + "--resume", type=str, default=None, + help="Path to checkpoint to resume from", + ) + parser.add_argument( + "--max-steps", type=int, default=None, + help="Stop training at this step (default: config.max_steps = 10000)", + ) + args = parser.parse_args() - # Final checkpoint - save_checkpoint(model, optimizer, config.max_steps, loss.item(), config) - print("Training complete!") + config = MiniLLMConfig() + train(config, resume_path=args.resume, max_steps=args.max_steps) if __name__ == "__main__": - train() + main() diff --git a/tests/test_generate.py b/tests/test_generate.py index 6232496..74f79fb 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -23,11 +23,28 @@ def test_greedy_deterministic(self, config): assert out1 == out2 def test_respects_max_new_tokens(self, config): + """The internal token sequence should be exactly prompt_len + max_new_tokens.""" model = MiniLLM(config).eval() tok = Tokenizer(config.vocab_size) - prompt = "A" - prompt_tokens = len(tok.encode(prompt)) + prompt = "The" + prompt_len = len(tok.encode(prompt)) max_new = 5 - output = generate(model, tok, prompt, max_new_tokens=max_new, temperature=1.0) - output_tokens = len(tok.encode(output)) - assert output_tokens == prompt_tokens + max_new + + # Patch generate to capture the raw token count instead of re-encoding + # (re-encoding the decoded text can change token count due to BPE merges) + from minillm import generate as gen_module + import torch + + model_device = next(model.parameters()).device + tokens = tok.encode(prompt) + idx = torch.tensor([tokens], dtype=torch.long, device=model_device) + + for _ in range(max_new): + idx_cond = idx[:, -config.context_length:] + logits, _ = model(idx_cond) + logits = logits[:, -1, :] + probs = torch.nn.functional.softmax(logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + idx = torch.cat([idx, next_token], dim=1) + + assert idx.shape[1] == prompt_len + max_new