diff --git a/causal_demon/Models/VAE/Funcs.py b/causal_demon/Models/VAE/Funcs.py new file mode 100644 index 0000000..34b3c28 --- /dev/null +++ b/causal_demon/Models/VAE/Funcs.py @@ -0,0 +1,53 @@ +from torch.autograd import Variable +import torch as th +from collections import OrderedDict +import torch.nn as nn + +def summary(input_size, model): + + def register_hook(module): + def hook(module, input, output): + class_name = str(module.__class__).split('.')[-1].split("'")[0] + module_idx = len(summary) + + m_key = '%s-%i' % (class_name, module_idx+1) + summary[m_key] = OrderedDict() + summary[m_key]['input_shape'] = list(input[0].size()) + summary[m_key]['input_shape'][0] = -1 + summary[m_key]['output_shape'] = list(output.size()) + summary[m_key]['output_shape'][0] = -1 + + params = 0 + if hasattr(module, 'weight'): + params += th.prod(th.LongTensor(list(module.weight.size()))) + if module.weight.requires_grad: + summary[m_key]['trainable'] = True + else: + summary[m_key]['trainable'] = False + if hasattr(module, 'bias'): + params += th.prod(th.LongTensor(list(module.bias.size()))) + summary[m_key]['nb_params'] = params + + if not isinstance(module, nn.Sequential) and \ + not isinstance(module, nn.ModuleList) and \ + not (module == model): + hooks.append(module.register_forward_hook(hook)) + + # check if there are multiple inputs to the network + if isinstance(input_size[0], (list, tuple)): + x = [Variable(th.rand(1,*in_size)) for in_size in input_size] + else: + x = Variable(th.rand(1,*input_size)) + + # create properties + summary = OrderedDict() + hooks = [] + # register hook + model.apply(register_hook) + # make a forward pass + model(x) + # remove these hooks + for h in hooks: + h.remove() + + return summary \ No newline at end of file diff --git a/causal_demon/Models/VAE/VAE_MNIST.py b/causal_demon/Models/VAE/VAE_MNIST.py new file mode 100644 index 0000000..00f1342 --- /dev/null +++ b/causal_demon/Models/VAE/VAE_MNIST.py @@ -0,0 +1,135 @@ +from __future__ import print_function +import argparse +import torch +import torch.utils.data +from torch import nn, optim +from torch.nn import functional as F +from torchvision import datasets, transforms +from torchvision.utils import save_image +from Funcs import summary + + +parser = argparse.ArgumentParser(description='VAE MNIST Example') +parser.add_argument('--batch-size', type=int, default=128, metavar='N', + help='input batch size for training (default: 128)') +parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 10)') +parser.add_argument('--no-cuda', action='store_true', default=False, + help='enables CUDA training') +parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') +parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') +args = parser.parse_args() +args.cuda = not args.no_cuda and torch.cuda.is_available() + +torch.manual_seed(args.seed) + +device = torch.device("cuda" if args.cuda else "cpu") + +kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} +train_loader = torch.utils.data.DataLoader( + datasets.MNIST('../data', train=True, download=True, + transform=transforms.ToTensor()), + batch_size=args.batch_size, shuffle=True, **kwargs) +test_loader = torch.utils.data.DataLoader( + datasets.MNIST('../data', train=False, transform=transforms.ToTensor()), + batch_size=args.batch_size, shuffle=True, **kwargs) + + +class VAE(nn.Module): + def __init__(self): + super(VAE, self).__init__() + + self.fc1 = nn.Linear(784, 400) + self.fc21 = nn.Linear(400, 20) + self.fc22 = nn.Linear(400, 20) + self.fc3 = nn.Linear(20, 400) + self.fc4 = nn.Linear(400, 784) + + def encode(self, x): + h1 = F.relu(self.fc1(x)) + return self.fc21(h1), self.fc22(h1) + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5*logvar) + eps = torch.randn_like(std) + return eps.mul(std).add_(mu) + + def decode(self, z): + h3 = F.relu(self.fc3(z)) + return torch.sigmoid(self.fc4(h3)) + + def forward(self, x): + mu, logvar = self.encode(x.view(-1, 784)) + z = self.reparameterize(mu, logvar) + return self.decode(z), mu, logvar + + +model = VAE().to(device) +o = summary([784],model) +print(dict(o)) +optimizer = optim.Adam(model.parameters(), lr=1e-3) + + +# Reconstruction + KL divergence losses summed over all elements and batch +def loss_function(recon_x, x, mu, logvar): + BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum') + + # see Appendix B from VAE paper: + # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 + # https://arxiv.org/abs/1312.6114 + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + + return BCE + KLD + + +def train(epoch): + model.train() + train_loss = 0 + for batch_idx, (data, _) in enumerate(train_loader): + data = data.to(device) + optimizer.zero_grad() + recon_batch, mu, logvar = model(data) + loss = loss_function(recon_batch, data, mu, logvar) + loss.backward() + train_loss += loss.item() + optimizer.step() + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), + loss.item() / len(data))) + + print('====> Epoch: {} Average loss: {:.4f}'.format( + epoch, train_loss / len(train_loader.dataset))) + + +def test(epoch): + model.eval() + test_loss = 0 + with torch.no_grad(): + for i, (data, _) in enumerate(test_loader): + data = data.to(device) + recon_batch, mu, logvar = model(data) + test_loss += loss_function(recon_batch, data, mu, logvar).item() + if i == 0: + n = min(data.size(0), 8) + comparison = torch.cat([data[:n], + recon_batch.view(args.batch_size, 1, 28, 28)[:n]]) + save_image(comparison.cpu(), + 'results/reconstruction_' + str(epoch) + '.png', nrow=n) + + test_loss /= len(test_loader.dataset) + print('====> Test set loss: {:.4f}'.format(test_loss)) + +if __name__ == "__main__": + for epoch in range(1, args.epochs + 1): + train(epoch) + test(epoch) + with torch.no_grad(): + sample = torch.randn(64, 20).to(device) + sample = model.decode(sample).cpu() + save_image(sample.view(64, 1, 28, 28), + 'results/sample_' + str(epoch) + '.png') diff --git a/causal_demon/Models/VAE/results/reconstruction_1.png b/causal_demon/Models/VAE/results/reconstruction_1.png new file mode 100644 index 0000000..4f5f61a Binary files /dev/null and b/causal_demon/Models/VAE/results/reconstruction_1.png differ diff --git a/causal_demon/Models/VAE/results/reconstruction_10.png b/causal_demon/Models/VAE/results/reconstruction_10.png new file mode 100644 index 0000000..4920c01 Binary files /dev/null and b/causal_demon/Models/VAE/results/reconstruction_10.png differ diff --git a/causal_demon/Models/VAE/results/reconstruction_2.png b/causal_demon/Models/VAE/results/reconstruction_2.png new file mode 100644 index 0000000..8580de7 Binary files /dev/null and b/causal_demon/Models/VAE/results/reconstruction_2.png differ diff --git a/causal_demon/Models/VAE/results/reconstruction_3.png b/causal_demon/Models/VAE/results/reconstruction_3.png new file mode 100644 index 0000000..183cb48 Binary files /dev/null and b/causal_demon/Models/VAE/results/reconstruction_3.png differ diff --git a/causal_demon/Models/VAE/results/reconstruction_4.png b/causal_demon/Models/VAE/results/reconstruction_4.png new file mode 100644 index 0000000..0698768 Binary files /dev/null and b/causal_demon/Models/VAE/results/reconstruction_4.png differ diff --git a/causal_demon/Models/VAE/results/reconstruction_5.png b/causal_demon/Models/VAE/results/reconstruction_5.png new file mode 100644 index 0000000..7385d87 Binary files /dev/null and b/causal_demon/Models/VAE/results/reconstruction_5.png differ diff --git a/causal_demon/Models/VAE/results/reconstruction_6.png b/causal_demon/Models/VAE/results/reconstruction_6.png new file mode 100644 index 0000000..df0e4b8 Binary files /dev/null and b/causal_demon/Models/VAE/results/reconstruction_6.png differ diff --git a/causal_demon/Models/VAE/results/reconstruction_7.png b/causal_demon/Models/VAE/results/reconstruction_7.png new file mode 100644 index 0000000..7e5157d Binary files /dev/null and b/causal_demon/Models/VAE/results/reconstruction_7.png differ diff --git a/causal_demon/Models/VAE/results/reconstruction_8.png b/causal_demon/Models/VAE/results/reconstruction_8.png new file mode 100644 index 0000000..d4b3bbb Binary files /dev/null and b/causal_demon/Models/VAE/results/reconstruction_8.png differ diff --git a/causal_demon/Models/VAE/results/reconstruction_9.png b/causal_demon/Models/VAE/results/reconstruction_9.png new file mode 100644 index 0000000..5761998 Binary files /dev/null and b/causal_demon/Models/VAE/results/reconstruction_9.png differ diff --git a/causal_demon/data/processed/test.pt b/causal_demon/data/processed/test.pt new file mode 100644 index 0000000..be74312 Binary files /dev/null and b/causal_demon/data/processed/test.pt differ diff --git a/causal_demon/data/processed/training.pt b/causal_demon/data/processed/training.pt new file mode 100644 index 0000000..4fcb128 Binary files /dev/null and b/causal_demon/data/processed/training.pt differ diff --git a/causal_demon/data/raw/t10k-images-idx3-ubyte b/causal_demon/data/raw/t10k-images-idx3-ubyte new file mode 100644 index 0000000..1170b2c Binary files /dev/null and b/causal_demon/data/raw/t10k-images-idx3-ubyte differ diff --git a/causal_demon/data/raw/t10k-labels-idx1-ubyte b/causal_demon/data/raw/t10k-labels-idx1-ubyte new file mode 100644 index 0000000..d1c3a97 Binary files /dev/null and b/causal_demon/data/raw/t10k-labels-idx1-ubyte differ diff --git a/causal_demon/data/raw/train-images-idx3-ubyte b/causal_demon/data/raw/train-images-idx3-ubyte new file mode 100644 index 0000000..bbce276 Binary files /dev/null and b/causal_demon/data/raw/train-images-idx3-ubyte differ diff --git a/causal_demon/data/raw/train-labels-idx1-ubyte b/causal_demon/data/raw/train-labels-idx1-ubyte new file mode 100644 index 0000000..d6b4c5d Binary files /dev/null and b/causal_demon/data/raw/train-labels-idx1-ubyte differ