forked from Hzfinfdu/Diffusion-BERT
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsample.py
More file actions
76 lines (66 loc) · 3.83 KB
/
sample.py
File metadata and controls
76 lines (66 loc) · 3.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import abc
class SampleClassBase(abc.ABC):
def sample(self, logits, x_0):
raise NotImplementedError
def post_process_sample_in_prediction(self, sample, x_0):
return sample
class Categorical(SampleClassBase):
def __init__(self, device=None):
self.device = device
def sample(self, logits, x_0):
logits = logits.to(self.device) if self.device is not None else logits
x_0 = x_0.to(self.device) if self.device is not None else x_0
# Validate logits
if torch.isnan(logits).any() or torch.isinf(logits).any():
print(f"Categorical.sample: WARNING: NaN or inf in logits, using uniform distribution")
logits = torch.ones_like(logits) / logits.shape[-1]
# Debug: Log logits stats
print(f"Categorical.sample: logits min={logits.min().item()}, max={logits.max().item()}")
try:
sample = torch.distributions.categorical.Categorical(logits=logits).sample()
except RuntimeError as e:
print(f"Categorical.sample: ERROR: {str(e)}, using uniform distribution")
sample = torch.randint(0, logits.shape[-1], logits.shape[:-1], device=logits.device)
# Debug: Log sample diversity
unique_samples = torch.unique(sample, dim=0).shape[0]
print(f"Categorical.sample: unique_samples={unique_samples}/{sample.shape[0]}")
return sample
class WholeWordMasking(SampleClassBase):
def __init__(self, tokenizer, device=None):
self.dim = tokenizer.vocab_size
self.mask_id = tokenizer.mask_token_id
self.device = device
self.post_tokens = torch.zeros(size=(tokenizer.vocab_size,), device=self.device, dtype=torch.long)
for token, id in tokenizer.vocab.items():
if token.startswith('##'):
self.post_tokens[id] = 1
def sample(self, logits, x_0):
logits = logits.to(self.device) if self.device is not None else logits
x_0 = x_0.to(self.device) if self.device is not None else x_0
is_post = (self.post_tokens.to(x_0.device) * x_0).sum(-1).nonzero()
# Validate logits
if torch.isnan(logits).any() or torch.isinf(logits).any():
print(f"WholeWordMasking.sample: WARNING: NaN or inf in logits, using uniform distribution")
logits = torch.ones_like(logits) / logits.shape[-1]
# Debug: Log logits stats
print(f"WholeWordMasking.sample: logits min={logits.min().item()}, max={logits.max().item()}")
try:
samp = torch.distributions.categorical.Categorical(logits=logits).sample()
except RuntimeError as e:
print(f"WholeWordMasking.sample: ERROR: {str(e)}, using uniform distribution")
samp = torch.randint(0, logits.shape[-1], logits.shape[:-1], device=logits.device)
for index in is_post:
samp[index[0], index[1]] = self.mask_id if samp[index[0], index[1] - 1] == self.mask_id else x_0[index[0], index[1]].argmax()
# Debug: Log sample diversity
unique_samples = torch.unique(samp, dim=0).shape[0]
print(f"WholeWordMasking.sample: unique_samples={unique_samples}/{samp.shape[0]}")
return samp.to(self.device) if self.device is not None else samp
def post_process_sample_in_prediction(self, sample, x_0):
sample = sample.to(self.device) if self.device is not None else sample
x_0 = x_0.to(self.device) if self.device is not None else x_0
x_0_one_hot = torch.nn.functional.one_hot(x_0, num_classes=self.dim).to(self.device)
is_post = (self.post_tokens.to(x_0.device) * x_0_one_hot).sum(-1).nonzero()
for index in is_post:
sample[index[0], index[1]] = self.mask_id if sample[index[0], index[1] - 1] == self.mask_id else x_0_one_hot[index[0], index[1]].argmax()
return sample.to(self.device) if self.device is not None else sample