diff --git a/beta-vae/README.md b/beta-vae/README.md new file mode 100644 index 0000000000..c1a462f7a9 --- /dev/null +++ b/beta-vae/README.md @@ -0,0 +1,20 @@ +# Basic beta-VAE Example + +This is an implementation of the paper [beta-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework](https://openreview.net/pdf?id=Sy2fzU9gl). + +We did experimentation on the FashionMNIST dataset, using a very simple convolution neural network architecture. + +```bash +pip install -r requirements.txt +python main.py +``` +The main.py script accepts the following arguments: + +```bash +optional arguments: + --batch-size input batch size for training (default: 128) + --epochs number of epochs to train (default: 10) + --no-cuda enables CUDA training + --seed random seed (default: 1) + --log-interval how many batches to wait before logging training status +``` \ No newline at end of file diff --git a/beta-vae/main.py b/beta-vae/main.py new file mode 100644 index 0000000000..5ca9953c61 --- /dev/null +++ b/beta-vae/main.py @@ -0,0 +1,203 @@ +from __future__ import print_function +import argparse +import torch +import torch.utils.data +from torch import nn, optim +from torch.nn import functional as F +from torchvision import datasets, transforms +from torchvision.utils import save_image + + +from typing import List, Callable, Union, Any, TypeVar, Tuple +Tensor = TypeVar('torch.tensor') + + + +parser = argparse.ArgumentParser(description='VAE MNIST Example') +parser.add_argument('--batch-size', type=int, default=128, metavar='N', + help='input batch size for training (default: 128)') +parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 10)') +parser.add_argument('--no-cuda', action='store_true', default=False, + help='enables CUDA training') +parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') +parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') +args = parser.parse_args() +args.cuda = not args.no_cuda and torch.cuda.is_available() + +torch.manual_seed(args.seed) + +device = torch.device("cuda" if args.cuda else "cpu") + +kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} + +train_loader = torch.utils.data.DataLoader( + datasets.FashionMNIST('../data', train=True, download=True, + transform=transforms.ToTensor()), + batch_size=args.batch_size, shuffle=True, **kwargs) +test_loader = torch.utils.data.DataLoader( + datasets.FashionMNIST('../data', train=False, transform=transforms.ToTensor()), + batch_size=args.batch_size, shuffle=True, **kwargs) + + +class BetaVAE(nn.Module): + def __init__(self, + in_channels: int, + latent_dim: int, + hidden_dims: List = None, + beta: int = 4): + super(BetaVAE, self).__init__() + self.latent_dim = latent_dim + self.beta = beta + + modules = [] + if hidden_dims is None: + hidden_dims = [16, 4] + + # Build Encoder + for h_dim in hidden_dims: + modules.append( + nn.Sequential( + nn.Conv2d(in_channels, out_channels=h_dim, + kernel_size= 3, padding = 1), + nn.ReLU(), + nn.MaxPool2d(2, 2)) + ) + in_channels = h_dim + + self.encoder = nn.Sequential(*modules) + self.fc_mu = nn.Linear(hidden_dims[-1]*7*7, latent_dim) + self.fc_var = nn.Linear(hidden_dims[-1]*7*7, latent_dim) + + # Build Decoder + modules = [] + + self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1]*7*7) + + hidden_dims.reverse() + + for i in range(len(hidden_dims) - 1): + modules.append( + nn.Sequential( + nn.ConvTranspose2d(hidden_dims[i], + hidden_dims[i + 1], + kernel_size=2, + stride = 2), + nn.ReLU()) + ) + + self.decoder = nn.Sequential(*modules) + + self.final_layer = nn.Sequential( + nn.ConvTranspose2d(hidden_dims[-1], + 1, + kernel_size=2, + stride=2), + nn.Sigmoid()) + + + def encode(self, input: Tensor) -> List[Tensor]: + """ + Encodes the input by passing through the encoder network + and returns the latent codes. + :param input: (Tensor) Input tensor to encoder [N x C x H x W] + :return: (Tensor) List of latent codes + """ + result = self.encoder(input) + result = torch.flatten(result, start_dim=1) + + # Split the result into mu and var components + # of the latent Gaussian distribution + mu = self.fc_mu(result) + log_var = self.fc_var(result) + + return [mu, log_var] + + def decode(self, z: Tensor) -> Tensor: + result = self.decoder_input(z) + result = result.view(-1, 4, 7, 7) + result = self.decoder(result) + result = self.final_layer(result) + return result + + def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: + """ + Will a single z be enough ti compute the expectation + for the loss?? + :param mu: (Tensor) Mean of the latent Gaussian + :param logvar: (Tensor) Standard deviation of the latent Gaussian + :return: + """ + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps * std + mu + + def forward(self, input: Tensor, **kwargs) -> Tensor: + mu, log_var = self.encode(input) + z = self.reparameterize(mu, log_var) + return self.decode(z), mu, log_var + + def loss_function(self, + *args, + **kwargs) -> dict: + recons = args[0] + input = args[1] + mu = args[2] + log_var = args[3] + criterion = nn.BCELoss(reduction='sum') + recons_loss =criterion(recons, input) + + kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) + + return recons_loss + self.beta * kld_loss + +model = BetaVAE(1, 100, beta=5).to(device) +optimizer = optim.Adam(model.parameters(), lr=1e-3) + + +def train(epoch): + model.train() + train_loss = 0 + for batch_idx, (data, _) in enumerate(train_loader): + data = data.to(device) + optimizer.zero_grad() + recon_batch, mu, logvar = model(data) + loss = model.loss_function(recon_batch, data, mu, logvar) + loss.backward() + train_loss += loss.item() + optimizer.step() + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), + loss.item() / len(data))) + + print('====> Epoch: {} Average loss: {:.4f}'.format( + epoch, train_loss / len(train_loader.dataset))) + + +def test(epoch): + model.eval() + test_loss = 0 + with torch.no_grad(): + for i, (data, _) in enumerate(test_loader): + data = data.to(device) + recon_batch, mu, logvar = model(data) + test_loss += model.loss_function(recon_batch, data, mu, logvar).item() + if i == 0: + n = min(data.size(0), 8) + # Allow to save the orginal image and the generated one as we go + comparison = torch.cat([data[:n], + recon_batch.view(args.batch_size, 1, 28, 28)[:n]]) + save_image(comparison.cpu(), + 'results/reconstruction_' + str(epoch) + '.png', nrow=n) + + test_loss /= len(test_loader.dataset) + print('====> Test set loss: {:.4f}'.format(test_loss)) + +if __name__ == "__main__": + for epoch in range(1, args.epochs + 1): + train(epoch) + test(epoch) \ No newline at end of file diff --git a/beta-vae/requirements.txt b/beta-vae/requirements.txt new file mode 100644 index 0000000000..73348074bf --- /dev/null +++ b/beta-vae/requirements.txt @@ -0,0 +1,4 @@ +torch +torchvision +tqdm +six diff --git a/beta-vae/results/reconstruction_1.png b/beta-vae/results/reconstruction_1.png new file mode 100644 index 0000000000..d54339c0a3 Binary files /dev/null and b/beta-vae/results/reconstruction_1.png differ diff --git a/beta-vae/results/reconstruction_10.png b/beta-vae/results/reconstruction_10.png new file mode 100644 index 0000000000..e69241a67b Binary files /dev/null and b/beta-vae/results/reconstruction_10.png differ diff --git a/beta-vae/results/reconstruction_2.png b/beta-vae/results/reconstruction_2.png new file mode 100644 index 0000000000..688f3af1ba Binary files /dev/null and b/beta-vae/results/reconstruction_2.png differ diff --git a/beta-vae/results/reconstruction_3.png b/beta-vae/results/reconstruction_3.png new file mode 100644 index 0000000000..0683be46c9 Binary files /dev/null and b/beta-vae/results/reconstruction_3.png differ diff --git a/beta-vae/results/reconstruction_4.png b/beta-vae/results/reconstruction_4.png new file mode 100644 index 0000000000..08d9c6a214 Binary files /dev/null and b/beta-vae/results/reconstruction_4.png differ diff --git a/beta-vae/results/reconstruction_5.png b/beta-vae/results/reconstruction_5.png new file mode 100644 index 0000000000..6508174dab Binary files /dev/null and b/beta-vae/results/reconstruction_5.png differ diff --git a/beta-vae/results/reconstruction_6.png b/beta-vae/results/reconstruction_6.png new file mode 100644 index 0000000000..7127133ddc Binary files /dev/null and b/beta-vae/results/reconstruction_6.png differ diff --git a/beta-vae/results/reconstruction_7.png b/beta-vae/results/reconstruction_7.png new file mode 100644 index 0000000000..c70e93b675 Binary files /dev/null and b/beta-vae/results/reconstruction_7.png differ diff --git a/beta-vae/results/reconstruction_8.png b/beta-vae/results/reconstruction_8.png new file mode 100644 index 0000000000..edf2c205b4 Binary files /dev/null and b/beta-vae/results/reconstruction_8.png differ diff --git a/beta-vae/results/reconstruction_9.png b/beta-vae/results/reconstruction_9.png new file mode 100644 index 0000000000..169b2f77bc Binary files /dev/null and b/beta-vae/results/reconstruction_9.png differ diff --git a/vae/.vscode/settings.json b/vae/.vscode/settings.json new file mode 100644 index 0000000000..fee5461855 --- /dev/null +++ b/vae/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.pythonPath": "/home/aims/anaconda3/envs/aims/bin/python" +} \ No newline at end of file