-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDNA_module.py
More file actions
126 lines (95 loc) · 4.89 KB
/
DNA_module.py
File metadata and controls
126 lines (95 loc) · 4.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from models import MLPModel
from flow import expand_simplex, sample_conditional_path
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from collections import defaultdict
from dequantizer import Dequantizer
class DNAModule(pl.LightningModule):
def __init__(self, args, alphabet_size, num_cls, toy_data):
super().__init__()
self.save_hyperparameters()
self.args = args
self.alphabet_size = alphabet_size
self.num_cls = num_cls
self.toy_data = toy_data
self.load_model()
self.automatic_optimization = True
self.val_outputs = defaultdict(list)
if args.flow_type == 'argmax':
self.dequantizer = Dequantizer(K=args.toy_simplex_dim)
def step(self, batch):
seq, cls = batch
B = seq.size(0)
# 1. Sample conditional path
x0, xt, x1, t, r = sample_conditional_path(self.args,
seq,
self.alphabet_size,
self.device,
dequantizer=self.dequantizer if self.args.flow_type == 'argmax' else None)
#project the interpolated sample onto the simplex
if self.args.flow_type == 'dirichlet':
xt, prior_weights = expand_simplex(xt,t, self.args.prior_pseudocount)
else:
xt = torch.cat([xt, torch.zeros_like(xt)], dim=-1)
# Compute velocity field
v = x1 - x0
# 3. Wrap model for jvp
def model_fn(xt, t, r):
return self.model(xt, t, r)
primal = (xt, t, r)
tangent = (torch.zeros_like(xt), torch.ones_like(t), torch.zeros_like(r))
u, du_dt = torch.func.jvp(model_fn, primal, tangent)
u_target = v + (t - r)[:,None,None] * du_dt.detach()
loss = torch.nn.functional.mse_loss(u, u_target)
if self.stage == "val":
logits = self.mean_flow_inference(seq)
predicted_sequence = torch.argmax(logits, dim=-1)
self.val_outputs['seqs'].append(predicted_sequence.cpu())
# Compute KL divergence for THIS BATCH
batch_one_hot = torch.nn.functional.one_hot(predicted_sequence, num_classes=self.args.toy_simplex_dim)
batch_empirical_dist = batch_one_hot.float().mean(dim=0) # Average over batch
# Ensure no zeros for numerical stability
eps = 1e-10
batch_empirical_dist = batch_empirical_dist.clamp(min=eps)
true_probs = self.toy_data.probs[0].clamp(min=eps).to(batch_empirical_dist.device)
# KL(true || model) and KL(model || true)
kl = (batch_empirical_dist * (torch.log(batch_empirical_dist) - torch.log(true_probs))).sum()
rkl = (true_probs * (torch.log(true_probs) - torch.log(batch_empirical_dist))).sum()
sanity_self_kl = (batch_empirical_dist * (torch.log(batch_empirical_dist) - torch.log(batch_empirical_dist))).sum(-1).mean()
# Log batch-level metrics
self.log("self_rkl", sanity_self_kl, on_step=True, on_epoch=True, prog_bar=True)
self.log("val_kl", kl, on_step=True, on_epoch=True, prog_bar=True)
self.log("val_rkl", rkl, on_step=True, on_epoch=True)
if self.args.cls_ckpt is not None:
self.run_cls_model(predicted_sequence, cls, clean_data=False, postfix='_generated')
return loss
def training_step(self, batch, batch_idx):
self.stage = 'train'
loss = self.step(batch)
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
self.stage = 'val'
loss = self.step(batch)
self.log("val_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=5e-4)
def mean_flow_inference(self, seq):
B, L = seq.shape
K = self.model.alphabet_size
#sample random Dirichlet noise
zt = torch.distributions.Dirichlet(torch.ones(B, L, K)).sample().to(self.device)
t_span = torch.linspace(0, 1, self.args.num_integration_steps, device=self.device)
for n in range(self.args.num_integration_steps - 1):
r = t_span[n]
t = t_span[n+1]
zt_inp, _ = expand_simplex(zt, r.expand(B), self.args.prior_pseudocount)
u = self.model(zt_inp, r.expand(B), t.expand(B))
dt = t - r
zt = zt + dt * u
return zt
def load_model(self, checkpoint=None):
self.model = MLPModel(self.args, self.alphabet_size)