-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmodels.py
More file actions
148 lines (130 loc) · 4.91 KB
/
models.py
File metadata and controls
148 lines (130 loc) · 4.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch as t
from torch import nn
from torch.distributions.log_normal import Normal
import funcs
import pdb
class RBFNetwork(nn.Module):
"""
Implements a 'normalized RBF network'
https://arxiv.org/abs/1503.03585 (Sohl-Dickstein, Ganguli 2015) App D.1.1
https://en.wikipedia.org/wiki/Radial_basis_function_network
top layer weights are specific to each timestep
all other weights shared across mu, sigma, and across all timesteps
"""
def __init__(self, D=2, T=40, Hsqrt=4):
super().__init__()
self.T = T
self.D = D
self.H = Hsqrt ** 2
ls = t.linspace(-2, 2, Hsqrt)
pts = t.cartesian_prod(ls, ls)
# 16, 2
self.basis_centers = nn.Parameter(pts)
self.basis_sigma_logits = nn.Parameter(t.full((self.H, self.D), -0.5))
self.mu_alphas = nn.Parameter(t.full((self.T, self.H, self.D), 0.0))
self.sigma_alphas = nn.Parameter(t.full((self.T, self.H, self.D), -2.0))
def xdim(self):
return self.D
@property
def basis_sigmas(self):
return t.sigmoid(self.basis_sigma_logits)
def forward(self, xcond, ts=None):
"""
Runs in two modes:
If ts is None:
xcond: B,T,D
returns mu, sigma: B,T,D
Else:
xcond: B,D
returns mu, sigma: B,D
"""
B = xcond.shape[0]
D, H = self.D, self.H
if ts is None:
if xcond.ndim != 3:
raise RuntimeError(f'If ts is None, xcond.ndim must be 3. Got {xcond.ndim}')
Tdim = self.T
mu_alphas = self.mu_alphas
sigma_alphas = self.sigma_alphas
else:
if xcond.ndim != 2:
raise RuntimeError(
f'If ts is not None, xcond.ndim must be 2. Got {xcond.ndim}')
xcond = xcond.unsqueeze(1)
Tdim = 1
mu_alphas = self.mu_alphas[ts].unsqueeze(0)
sigma_alphas = self.sigma_alphas[ts].unsqueeze(0)
# B, Tdim, H
basis_sigmas = t.sigmoid(self.basis_sigma_logits)
pdf = funcs.combo_isonormal_pdf(xcond, self.basis_centers, basis_sigmas)
# B, Tdim, 1
normalizer = pdf.sum(-1, True) ** -1
# B, Tdim, H, D -> B, Tdim, D
mu = (pdf.unsqueeze(-1) * mu_alphas).sum(dim=2) * normalizer
sigma_logits = (pdf.unsqueeze(-1) * sigma_alphas).sum(dim=2) * normalizer
sigma = t.sigmoid(sigma_logits)
if ts is not None:
mu = mu.squeeze(1)
sigma = sigma.squeeze(1)
if t.any(t.isnan(pdf)) or t.any(t.isnan(normalizer)) or t.any(t.isnan(mu)) or t.any(t.isnan(sigma)):
pdb.set_trace()
return mu, sigma
class QDist:
"""
Generate samples from q(x^{0:T}). In the context of diffusion models, q(x^{0:T})
represents the 'forward trajectory' (Section 2.1, Eq. 3 of
https://arxiv.org/pdf/1503.03585.pdf). The variable x^0 represents the data
domain, while x^{>0} are related to x^0 through the diffusion process.
A 'forward trajectory' is simply a single 'sample' from this joint distribution.
"""
def __init__(self, sample_size, betas):
# sample_size is the number of trajectories sampled per data point
self.sample_size = sample_size
self.betas = betas
self.alpha = (1.0 - self.betas) ** 0.5
self.cond = [ Normal(0, beta) for beta in self.betas]
def sample(self, x0):
"""
x0: P, D
returns: B * P, T+1, D (batch, num_points, num_timesteps+1)
"""
# sample `sample_size = B` different trajectories from each of the P x0 points
xi_pre = x0.repeat(self.sample_size, 1)
total_reps = xi_pre.shape[0]
x = [xi_pre]
# Appendix B.4, Table App.1 row 'Forward Diffusion Kernel' for Gaussian case
for cond, alpha in zip(self.cond, self.alpha):
xi = cond.sample(xi_pre.shape) + alpha * xi_pre
xi_pre = xi
x.append(xi)
res = t.stack(x, dim=1)
return res
class PDist(nn.Module):
"""
Accepts a mu_sigma_model, which is an nn.Module which
"""
def __init__(self, mu_sigma_model):
super().__init__()
self.model = mu_sigma_model
def forward(self, x):
"""
x: B, T, D
returns: B, T
"""
xvar, xcond = x[:,:-1,:], x[:,1:,:]
mu, sigma = self.model(xcond)
log_density = funcs.isonormal_logpdf(xvar, xcond + mu, sigma)
return log_density
def sample(self, n):
"""
returns: n, T, D
"""
xcond = Normal(0, 1).sample((n,self.model.xdim()))
steps = [xcond]
with t.no_grad():
for ts in reversed(range(self.model.T)):
mu, sigma = self.model(xcond, ts)
xcond = Normal(xcond + mu, sigma).sample()
steps.append(xcond)
res = t.stack(tuple(reversed(steps)), dim=1)
return res