-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
85 lines (69 loc) · 3.27 KB
/
Copy pathmodel.py
File metadata and controls
85 lines (69 loc) · 3.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
# A simplified 'Baby' Transformer Block
class BabySLM(nn.Module):
def __init__(self, vocab_size, embed_dim, context_length):
super().__init__()
# 1. Embedding: Converts words (indexes) into vectors
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
self.position_embedding = nn.Embedding(context_length, embed_dim)
# 2. The "Brain": A single Transformer Block
# (Real LLMs have 30+ of these blocks stacked)
# Use batch_first=True so inputs can be (B, T, E)
self.transformer_block = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=4, # 4 "heads" to look at different parts of the sentence
batch_first=True,
)
# 3. The Output: Project back to vocabulary to predict next word
self.lm_head = nn.Linear(embed_dim, vocab_size)
def forward(self, idx):
B, T = idx.shape # Batch size, Time (sequence length)
# Create embeddings for tokens + their positions
tok_emb = self.token_embedding(idx) # (B, T, E)
pos_emb = self.position_embedding(
torch.arange(T, device=idx.device)
) # (T, E)
x = tok_emb + pos_emb # broadcast to (B, T, E)
# Run through transformer (batch_first=True accepts (B, T, E))
x = self.transformer_block(x)
# Predict next token
logits = self.lm_head(x)
return logits
def count_parameters(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
if __name__ == "__main__":
print("=" * 60)
print("Testing BabySLM with different configurations")
print("=" * 60)
# Config 1: Tiny model
print("\n[Config 1] Tiny model")
model1 = BabySLM(vocab_size=1000, embed_dim=32, context_length=16)
print(f"Parameters: {count_parameters(model1):,}")
idx1 = torch.randint(0, 1000, (2, 16))
logits1 = model1(idx1)
print(f"Input shape: {idx1.shape} -> Output shape: {logits1.shape}")
# Config 2: Larger vocabulary
print("\n[Config 2] Larger vocabulary (5000 words)")
model2 = BabySLM(vocab_size=5000, embed_dim=32, context_length=16)
print(f"Parameters: {count_parameters(model2):,}")
idx2 = torch.randint(0, 5000, (3, 16))
logits2 = model2(idx2)
print(f"Input shape: {idx2.shape} -> Output shape: {logits2.shape}")
# Config 3: Deeper embeddings
print("\n[Config 3] Deeper embeddings (128-dim)")
model3 = BabySLM(vocab_size=1000, embed_dim=128, context_length=16)
print(f"Parameters: {count_parameters(model3):,}")
idx3 = torch.randint(0, 1000, (2, 16))
logits3 = model3(idx3)
print(f"Input shape: {idx3.shape} -> Output shape: {logits3.shape}")
# Config 4: Longer context
print("\n[Config 4] Longer context (64 tokens)")
model4 = BabySLM(vocab_size=1000, embed_dim=32, context_length=64)
print(f"Parameters: {count_parameters(model4):,}")
idx4 = torch.randint(0, 1000, (1, 64))
logits4 = model4(idx4)
print(f"Input shape: {idx4.shape} -> Output shape: {logits4.shape}")
print("\n" + "=" * 60)
print("All configurations tested successfully!")
print("=" * 60)