-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsample_batch.py
More file actions
198 lines (176 loc) · 7.73 KB
/
Copy pathsample_batch.py
File metadata and controls
198 lines (176 loc) · 7.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""
Sample from a trained model
"""
import os
import pickle
import json
from contextlib import nullcontext
import torch
import tiktoken
from model import GPTConfig, GPT
import re as _re
def _enforce_five_sentences(text: str, prompt: str = "") -> str:
"""Enforce exactly 5 complete sentences in generated story text.
Rules:
- Strip echoed prompt from beginning if present
- Split on sentence-ending punctuation (. ! ?)
- Keep only COMPLETE sentences (ending with . ! ?)
- Trailing fragment (no ending punctuation) = model was cut mid-sentence
→ drop it entirely (a fragment is worse than 4 complete sentences)
- If >= 5 complete sentences: keep first 5
- If < 5 complete sentences: keep all complete ones
"""
# Strip the echoed prompt from the beginning if present
if prompt and text.startswith(prompt.strip()):
text = text[len(prompt.strip()):].strip()
if not text:
return text
# Split into candidate sentences on . ! ? boundaries
parts = _re.split(r'(?<=[.!?])\s+', text.strip())
parts = [p.strip() for p in parts if p.strip()]
if not parts:
return text
# Keep only COMPLETE sentences (ending with punctuation)
# Drop everything from the first fragment onward
complete = []
for p in parts:
if p[-1] in '.!?':
complete.append(p)
else:
break # fragment — stop here, drop this and everything after
if not complete:
# Nothing complete — return raw text as fallback
return text
# Trim to 5 if we have more
result = ' '.join(complete[:5])
return result
# -----------------------------------------------------------------------------
init_from = 'gpt2' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out' # ignored if init_from is not 'resume'
start = "FILE:data/rocstories/eval_prompts.txt" # Prompt. Can also specify a file, use as: "FILE:prompt.txt"
batch_prompts = True # if True, read multiple prompts from the file (one per line)
output_file = 'samples.jsonl' # file to save generated samples in JSONL format (set to None to disable)
num_samples = 1 # number of samples to generate for each prompt
max_new_tokens = 100 # 5-sentence ROCStory needs ~45-95 tokens to generate
seed = 1337
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
exec(open('configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
# model
if init_from == 'resume':
# init from a model saved in a specific directory
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
import inspect as _inspect
_valid_keys = set(_inspect.signature(GPTConfig.__init__).parameters) - {'self'}
_filtered_args = {k: v for k, v in checkpoint['model_args'].items() if k in _valid_keys}
gptconf = GPTConfig(**_filtered_args)
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
# Try out_dir/sample_params.json first, then repo-root sample_params.json
for _sp_path in [os.path.join(out_dir, 'sample_params.json'), 'sample_params.json']:
if os.path.exists(_sp_path):
with open(_sp_path, 'r') as f:
sample_params = json.load(f)
break
else:
sample_params = {'temperature': 0.8, 'top_k': 50}
elif init_from.startswith('gpt2'):
# init from a given GPT-2 model
model = GPT.from_pretrained(init_from, dict(dropout=0.0))
sample_params = {
'temperature': 0.8,
'top_k': 200
}
model.eval()
model.to(device)
if compile:
model = torch.compile(model) # requires PyTorch 2.0 (optional)
# look for the meta pickle in case it is available in the dataset folder
load_meta = False
if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
load_meta = os.path.exists(meta_path)
if load_meta:
print(f"Loading meta from {meta_path}...")
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
else:
print("No meta.pkl found, assuming GPT-2 encodings...")
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)
# encode the beginning of the prompt
if start.startswith('FILE:'):
with open(start[5:], 'r', encoding='utf-8') as f:
if batch_prompts:
# Read multiple prompts from file (one per line)
prompts = [line.rstrip() for line in f.readlines()]
else:
# Read single prompt from file
prompts = [f.read()]
else:
prompts = [start]
# Encode all prompts
start_ids_list = [encode(prompt) for prompt in prompts]
# Create tensor from all prompts
x_list = [torch.tensor(ids, dtype=torch.long, device=device) for ids in start_ids_list]
# Open output file if specified
output_f = None
if output_file:
output_f = open(output_file, 'w', encoding='utf-8')
# run generation
with torch.no_grad():
with ctx:
for prompt_idx, x_single in enumerate(x_list):
x = x_single[None, ...]
prompt_text = prompts[prompt_idx]
if batch_prompts and len(prompts) > 1:
prompt_header = f"\n=== Prompt {prompt_idx + 1}: {prompt_text} ==="
print(prompt_header)
for k in range(num_samples):
y = model.generate(x, max_new_tokens, stop_token=50256,
**sample_params)
tokens = y[0].tolist()
# Trim at EOT token (50256) — prevents second-story bleed
if 50256 in tokens:
tokens = tokens[:tokens.index(50256)]
sample_text = decode(tokens)
# Enforce exactly 5 complete sentences
sample_text = _enforce_five_sentences(sample_text, prompt_text)
print(sample_text)
print('---------------')
# Save to JSONL file if specified
if output_f:
record = {
# 'prompt_idx': prompt_idx,
'prompt': prompt_text,
# 'sample_idx': k,
'generated_text': sample_text,
'params': {
'max_new_tokens': max_new_tokens,
**sample_params
}
}
output_f.write(json.dumps(record) + '\n')
# Close output file if opened
if output_f:
output_f.close()
print(f"\nResults saved to {output_file}")