-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathVariationalAutoencoder.py
More file actions
165 lines (140 loc) · 7.23 KB
/
VariationalAutoencoder.py
File metadata and controls
165 lines (140 loc) · 7.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from gymnasium import spaces
from CustomCNN import CustomCNN
from stable_baselines3.common.callbacks import BaseCallback
from torch.utils.data import TensorDataset, DataLoader
from PIL import Image
device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
""" these classes were taken from https://avandekleut.github.io/vae/"""
class VariationalEncoder(nn.Module):
def __init__(self, input_dims, latent_dims):
super(VariationalEncoder, self).__init__()
product = 1
for dim in input_dims:
product *= dim
self.linear1 = nn.Linear(product, 2048)
self.linear2 = nn.Linear(2048, 1024)
self.linear3 = nn.Linear(1024, latent_dims)
self.linear4 = nn.Linear(1024, latent_dims)
self.N = th.distributions.Normal(0, 1)
self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
self.N.scale = self.N.scale.cuda()
self.kl = 0
def forward(self, x):
x = th.flatten(x, start_dim=1) # why is max sometimes a decimal less than 1?
x = F.sigmoid(self.linear1(x))
x = F.sigmoid(self.linear2(x))
mu = self.linear3(x)
sigma = th.exp(self.linear4(x))
z = mu + sigma*self.N.sample(mu.shape)
self.kl = (sigma**2 + mu**2 - th.log(sigma) - 1/2).sum()
return z
class Decoder(nn.Module):
def __init__(self, latent_dims, output_dims):
self.n_channels, self.height, self.width = output_dims
self.unrolled_dim = self.n_channels * self.height * self.width
super(Decoder, self).__init__()
self.linear1 = nn.Linear(latent_dims, 1024)
self.linear2 = nn.Linear(1024, 2048)
self.linear3 = nn.Linear(2048, self.unrolled_dim)
def forward(self, z):
z = F.relu(self.linear1(z))
z = F.relu(self.linear2(z))
z = th.sigmoid(self.linear3(z)) * 255
return z.reshape((-1, self.n_channels, self.height, self.width))
class VariationalAutoencoder(nn.Module):
def __init__(self, input_dims, latent_dims, kl_divergence_weight=1):
super(VariationalAutoencoder, self).__init__()
self.latent_dims = latent_dims
self.kl_divergence_weight = kl_divergence_weight
self.encoder = VariationalEncoder(input_dims, latent_dims)
self.decoder = Decoder(latent_dims, input_dims)
self.optimizer = th.optim.Adam(self.parameters())
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)
class VariationalAutoencoderFeaturesExtractor(CustomCNN):
DEFAULT_LATENT_DIMS = 128
@CustomCNN.model.setter
def model(self, encoder):
if encoder is None:
print("Defaulting to basic trainable autoencoder.")
print(f"Observation shape: {self._observation_space.shape}")
self.n_input_channels, self.height, self.width = self._observation_space.shape
self.variational_autoencoder = VariationalAutoencoder((self.n_input_channels, self.height, self.width), self._features_dim, self.kl_divergence_weight)
encoder = self.variational_autoencoder.encoder
self._model = encoder
for param in self._model.parameters(): # freeze weights when not training the full autoencoder
param.requires_grad = False
#@CustomCNN.preprocessing_function.setter
#def preprocessing_function(self, preprocessing_function):
#if preprocessing_function is None:
## unroll based on dimensions
@property
def variational_autoencoder(self):
return self._variational_autoencoder
@variational_autoencoder.setter
def variational_autoencoder(self, variational_autoencoder):
self._variational_autoencoder = variational_autoencoder
def __init__(self, observation_space: spaces.Box, features_dim: int, base_model = None, weights = None,
preprocessing_function = None, kl_divergence_weight = 1):
self.kl_divergence_weight = kl_divergence_weight
super().__init__(observation_space, features_dim)
self.training_buffer = []
def train(autoencoder, data, epochs=1):
for param in autoencoder.parameters():
param.requires_grad = True
training_loss = []
for _ in range(epochs):
for x in data:
x = x[0]
x = x.to(device, dtype=th.float32) # GPU
autoencoder.optimizer.zero_grad()
x_hat = autoencoder(x)
# this was the original loss, which I've swapped out for MSE
#loss = ((x - x_hat)**2).sum() + autoencoder.encoder.kl
loss = nn.functional.mse_loss(x_hat, x) + autoencoder.kl_divergence_weight * autoencoder.encoder.kl
loss.backward()
autoencoder.optimizer.step()
training_loss.append(loss.cpu().detach().numpy())
average_training_loss_over_epochs = np.mean(training_loss)
print("Training loss:", average_training_loss_over_epochs)
with open(f"vae_{autoencoder.latent_dims}_kl_weight_{autoencoder.kl_divergence_weight}_losses.txt", 'a') as f:
f.write(f"{average_training_loss_over_epochs}\n")
for param in autoencoder.parameters():
param.requires_grad = False
return autoencoder
class VAETrainingCallback(BaseCallback):
def _on_step(self):
# is this the right way to access and train the encoder?
features_extractor = self.model.policy.features_extractor
if len(features_extractor.training_buffer) >= 1024:
print("Training VAE...")
tensor_x = th.from_numpy(np.array(features_extractor.training_buffer))
my_dataset = TensorDataset(tensor_x)
my_dataloader = DataLoader(my_dataset, batch_size=32)
features_extractor.variational_autoencoder = train(features_extractor.variational_autoencoder, my_dataloader)
features_extractor.training_buffer = []
# is this the right way to get the observations? I tried get_images
# and that gave me raw data, not the preprocessed images
if self.num_timesteps % 10 == 0:
image, _, _, _ = self.training_env.step_wait()
num_concurrent = len(image)
for i in range(num_concurrent):
features_extractor.training_buffer.append(image[i])
def _on_rollout_end(self):
features_extractor = self.model.policy.features_extractor
if len(features_extractor.training_buffer) > 0:
input_image = features_extractor.training_buffer[0]
im = Image.fromarray(np.squeeze(input_image))
im.save(f"original_img{features_extractor._features_dim}_{features_extractor.kl_divergence_weight}.png")
input_image_shape = input_image.shape[1:]
input_tensor = th.tensor(input_image, dtype=th.float32).to(device)
reconstructed_image = features_extractor.variational_autoencoder(input_tensor).detach().cpu().numpy()
reconstructed_image = np.reshape(reconstructed_image, input_image_shape)
reconstructed_image = reconstructed_image.astype(np.uint8)
reconstructed_image = Image.fromarray(reconstructed_image)
reconstructed_image.save(f"reconstructed_img_{features_extractor._features_dim}_{features_extractor.kl_divergence_weight}.png")