Variational Autoencoder (VAE)
Created on July 10|Last edited on July 17
Comment
Theory
Here, we summarize the basics of the variational autoencoder (VAE).
Let's assume we want to model a complicated distribution . To do that we can resort to the latent variable model, that allows us to utilize simple distributions and such that:
In a machine learning setting we are given a dataset and we want to maximize the log-likelihood with respect to the parameters :
However, to optimize this objective we need to compute the integral above, which is intractable, and also its gradient...
The idea is that instead of maximizing directly, we try to find a lower bound that it's easier to optimise.
Therefore, we find the expression of the evidence lower bound (ELBO) to be equal to
where is the so called variational distribution.
Now we have to make some choices:
- we choose a simple prior for , that is .
- for we can decide a parametric distribution, such that the parameters are produced by a neural network (the decoder of our VAE)
- for example, for a binary input , we can choose our decoder to output the parameter of the Bernoulli distribution
- in other cases we can choose to be a Normal with zero-mean and diagonal covariance
Additionally, we assume that there are no dependencies between the latent variables that correspond to the different observations , for . (mean field assumption)
Therefore we have A nice thing about this choice is that we can compute the in closed form because we have two multivariate Normal distributions.
The parameters for each sample, are also learned with a neural network (the encoder) that takes the observation as input and outputs the parameters of the Normal distribution.
Furthermore, we can simplify .
To optimise with respect to the parameters of the encoder and decoder neural networks we have to resort to the reparametrization trick. In particular, to optimise the ELBO we draw samples as
- sample
- .
Coding the VAE
Here, we consider a simple example with the MNIST dataset. We will implement step by step the class VAE.
Encoder
We implement an encoder that outputs the parameters and , for the given datapoint .

Encoder structure.
def encoder(self, x):# Obtain the parameters of q(z) for a batch of data points.# Args:# x: Batch of data points, shape [batch_size, obs_dim]# Returns:# mu: Means of q(z), shape [batch_size, latent_dim]# logsigma: Log-sigmas of q(z), shape [batch_size, latent_dim]h_relu = torch.relu(self.linear1(x))mu = self.linear21(h_relu)logsigma = self.linear22(h_relu)return mu, logsigma
Sampling with reparametrization
The next step is to implement the sampling with reparametrization, to obtain the latent variable .
def sample_with_reparam(self, mu, logsigma):# Draw sample from q(z) with reparametrization.# We draw a single sample z_i for each data point x_i.# Args:# mu: Means of q(z) for the batch, shape [batch_size, latent_dim]# logsigma: Log-sigmas of q(z) for the batch, shape [batch_size, latent_dim]# Returns:# z: Latent variables samples from q(z), shape [batch_size, latent_dim]batch_size, latent_dim = mu.shapeeps = torch.normal(0, 1, size=(batch_size, latent_dim)).to(device)sigma = torch.exp(logsigma)z = sigma * eps + mureturn z
Decoder
The decoder takes the samples and produces the parameters of the data likelihood .
Our data is binary, so we use Bernoulli likelihood:
.
The parameters must be in the interval (0, 1), therefore, we use Sigmoid activation function in the last layer of the decoder.
The decoder has the following structure:

Decoder structure.
def decoder(self, z):# Convert sampled latent variables z into observations x.# Args:# z: Sampled latent variables, shape [batch_size, latent_dim]# Returns:# theta: Parameters of the conditional likelihood, shape [batch_size, obs_dim]h_relu = torch.relu(self.linear3(z))theta = torch.sigmoid(self.linear4(h_relu))return theta
KL divergence
To compute the ELBO, we will need to compute the KL divergence , where is the standard multivariate normal distribution (zero mean, identity covariance).
The KL divergence can be computed in closed form.
def kl_divergence(self, mu, logsigma):# Compute KL divergence KL(q_i(z)||p(z)) for each q_i in the batch.# Args:# mu: Means of the q_i distributions, shape [batch_size, latent_dim]# logsigma: Logarithm of standard deviations of the q_i distributions, shape [batch_size, latent_dim]# Returns:# kl: KL divergence for each of the q_i distributions, shape [batch_size]sigma = torch.exp(logsigma)pre_kl = sigma**2 + mu**2 - 2*logsigma - 1kl = 0.5 * torch.sum(pre_kl, dim=1)return kl
ELBO
Finally, we can compute the ELBO using all the methods that we implemented above.
The ELBO for a single sample reads as:
,
where and are the parameters of the encoder and the decoder neural networks, respectively.
def elbo(self, x):# Estimate the ELBO for the mini-batch of data.# Args:# x: Mini-batch of the observations, shape [batch_size, obs_dim]# Returns:# elbo_mc: MC estimate of ELBO for each sample in the mini-batch, shape [batch_size]mu, logsigma = self.encoder(x)z = self.sample_with_reparam(mu, logsigma)theta = self.decoder(z)kl = self.kl_divergence(mu, logsigma)log_px_ifz = torch.sum(x*torch.log(theta) + (1 -x)*torch.log(1 - theta), dim=1)elbo_mc = log_px_ifz - klreturn elbo_mc
Generating new data
We can then implement a method for generating new data points, by sampling from the prior and then utilizing the decoder neural network.
def sample(self, num_samples):# Generate samples from the model.# Args:# num_samples: Number of samples to generate.# Returns:# x: Samples generated by the model, shape [num_samples, obs_dim]zp = torch.normal(0, 1, size=(num_samples, self.latent_dim)).to(device)theta = self.decoder(zp)x = torch.bernoulli(theta)return x
Now we can train our model using the negative ELBO as loss function, that is
loss = -vae.elbo(x).mean(-1)
Results
After training for five epochs we observe the following results.
- Sampling new data points using the sample method:
Run set
1
- Visualizing the embeddings by taking the Means at the encoder output and running the t-SNE algorithm.
Run: olive-snowflake-1
1
Nicely, we can observe that the encoder learned to assign similar Means to the images that belong to the same class. Therefore, these images are close to each other in the latent space.
Add a comment