-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcodebook.py
More file actions
211 lines (181 loc) · 9.79 KB
/
codebook.py
File metadata and controls
211 lines (181 loc) · 9.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
"""
Code for class Codebook adapted from https://github.com/dome272/VQGAN-pytorch/blob/main/codebook.py with augmentations for no_vq option
"""
class Codebook(nn.Module):
def __init__(self, args):
super(Codebook, self).__init__()
self.num_codebook_vectors = args.c_num_codebook_vectors if args.is_c else args.num_codebook_vectors
self.latent_dim = args.c_latent_dim if args.is_c else args.latent_dim
self.beta = args.beta
self.no_vq = getattr(args, "no_vq", False)
self.embedding = nn.Embedding(self.num_codebook_vectors, self.latent_dim)
self.embedding.weight.data.uniform_(-1.0 / self.num_codebook_vectors, 1.0 / self.num_codebook_vectors)
def forward(self, z):
if self.no_vq:
# Skip quantization, return z directly
z_q = z
# Compute indices for logging only
with torch.no_grad():
z_flattened = z.permute(0, 2, 3, 1).contiguous().view(-1, self.latent_dim)
d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - \
2 * torch.matmul(z_flattened, self.embedding.weight.t())
min_encoding_indices = torch.argmin(d, dim=1)
# Return dummy loss for logging and compatibility
loss = torch.tensor(0.0, device=z.device)
else:
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.latent_dim)
d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - \
2*(torch.matmul(z_flattened, self.embedding.weight.t()))
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
z_q = z + (z_q - z).detach()
z_q = z_q.permute(0, 3, 1, 2)
return z_q, min_encoding_indices, loss
"""
Code for class Online_Codebook taken from https://github.com/lyndonzheng/CVQ-VAE/blob/main/quantise.py for reproducibility
From the paper "Online Clustered Codebook" https://arxiv.org/abs/2307.15139
"""
class Online_Codebook(nn.Module):
"""
Improved version over vector quantiser, with the dynamic initialisation
for these unoptimised "dead" points.
num_embed: number of codebook entry
embed_dim: dimensionality of codebook entry
beta: weight for the commitment loss
distance: distance for looking up the closest code
anchor: anchor sampled methods
first_batch: if true, the offline version of our model
contras_loss: if true, use the contras_loss to further improve the performance
"""
def __init__(self, args):
super().__init__()
self.num_embed = args.c_num_codebook_vectors if args.is_c else args.num_codebook_vectors
self.embed_dim = args.c_latent_dim if args.is_c else args.latent_dim
self.beta = getattr(args, "online_beta", 0.25)
self.distance = getattr(args, "online_distance", "cos")
self.anchor = getattr(args, "online_anchor", "probrandom")
self.first_batch = getattr(args, "online_first_batch", False)
self.contras_loss = getattr(args, "online_contras_loss", False)
self.decay = getattr(args, "online_decay", 0.99)
self.init = False
self.no_vq = getattr(args, "no_vq", False)
self.pool = FeaturePool(self.num_embed, self.embed_dim)
self.embedding = nn.Embedding(self.num_embed, self.embed_dim)
self.embedding.weight.data.uniform_(-1.0 / self.num_embed, 1.0 / self.num_embed)
self.register_buffer("embed_prob", torch.zeros(self.num_embed))
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
assert rescale_logits==False, "Only for interface compatible with Gumbel"
assert return_logits==False, "Only for interface compatible with Gumbel"
# reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, 'b c h w -> b h w c').contiguous()
z_flattened = z.view(-1, self.embed_dim)
# clculate the distance
if self.distance == 'l2':
# l2 distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = - torch.sum(z_flattened.detach() ** 2, dim=1, keepdim=True) - \
torch.sum(self.embedding.weight ** 2, dim=1) + \
2 * torch.einsum('bd, dn-> bn', z_flattened.detach(), rearrange(self.embedding.weight, 'n d-> d n'))
elif self.distance == 'cos':
# cosine distances from z to embeddings e_j
normed_z_flattened = F.normalize(z_flattened, dim=1).detach()
normed_codebook = F.normalize(self.embedding.weight, dim=1)
d = torch.einsum('bd,dn->bn', normed_z_flattened, rearrange(normed_codebook, 'n d -> d n'))
# encoding
sort_distance, indices = d.sort(dim=1)
# look up the closest point for the indices
encoding_indices = indices[:,-1]
encodings = torch.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=z.device)
encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
# quantise and unflatten
z_q = torch.matmul(encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding
loss = self.beta * torch.mean((z_q.detach()-z)**2) + torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
# count
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
min_encodings = encodings
# online clustered reinitialisation for unoptimized points
if self.training:
# calculate the average usage of code entries
self.embed_prob.mul_(self.decay).add_(avg_probs, alpha= 1 - self.decay)
# running average updates
if self.anchor in ['closest', 'random', 'probrandom'] and (not self.init):
# closest sampling
if self.anchor == 'closest':
sort_distance, indices = d.sort(dim=0)
random_feat = z_flattened.detach()[indices[-1,:]]
# feature pool based random sampling
elif self.anchor == 'random':
random_feat = self.pool.query(z_flattened.detach())
# probabilitical based random sampling
elif self.anchor == 'probrandom':
norm_distance = F.softmax(d.t(), dim=1)
prob = torch.multinomial(norm_distance, num_samples=1).view(-1)
random_feat = z_flattened.detach()[prob]
# decay parameter based on the average usage
decay = torch.exp(-(self.embed_prob*self.num_embed*10)/(1-self.decay)-1e-3).unsqueeze(1).repeat(1, self.embed_dim)
self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay
if self.first_batch:
self.init = True
# contrastive loss
if self.contras_loss:
sort_distance, indices = d.sort(dim=0)
dis_pos = sort_distance[-max(1, int(sort_distance.size(0)/self.num_embed)):,:].mean(dim=0, keepdim=True)
dis_neg = sort_distance[:int(sort_distance.size(0)*1/2),:]
dis = torch.cat([dis_pos, dis_neg], dim=0).t() / 0.07
contra_loss = F.cross_entropy(dis, torch.zeros((dis.size(0),), dtype=torch.long, device=dis.device))
loss += contra_loss
# return z_q, loss, (perplexity, min_encodings, encoding_indices)
return z_q, encoding_indices, loss
class FeaturePool():
"""
This class implements a feature buffer that stores previously encoded features
This buffer enables us to initialize the codebook using a history of generated features
rather than the ones produced by the latest encoders
"""
def __init__(self, pool_size, dim=64):
"""
Initialize the FeaturePool class
Parameters:
pool_size(int) -- the size of featue buffer
"""
self.pool_size = pool_size
if self.pool_size > 0:
self.nums_features = 0
self.features = (torch.rand((pool_size, dim)) * 2 - 1)/ pool_size
def query(self, features):
"""
return features from the pool
"""
self.features = self.features.to(features.device)
if self.nums_features < self.pool_size:
if features.size(0) > self.pool_size: # if the batch size is large enough, directly update the whole codebook
random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),))
self.features = features[random_feat_id]
self.nums_features = self.pool_size
else:
# if the mini-batch is not large nuough, just store it for the next update
num = self.nums_features + features.size(0)
self.features[self.nums_features:num] = features
self.nums_features = num
else:
if features.size(0) > int(self.pool_size):
random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),))
self.features = features[random_feat_id]
else:
random_id = torch.randperm(self.pool_size)
self.features[random_id[:features.size(0)]] = features
return self.features