Disentangling Variational Autoencoders
Figure 1. A circular interpolation over the latent variable z\mathbf{z} with β=1\beta=1.
1. Introduction
A variational autoencoder (VAE) is a non-linear latent variable model with two exciting properties: assisted by a hidden variable z\mathbf{z}, it approximates the underlying structure of datapoints x\mathbf{x} and it creates a compact lower-dimensional latent representation of the data capable of generating new data points y\mathbf{y} that nevertheless resemble attributes found in x\mathbf{x}. Powered by (deep) neural networks, VAEs have gained SOTA prominence since their introduction by Kingma and Welling (2014) in the field of representation learning for image processing (Vahdat & Kautz, 2020), music generation (Roberts et al., 2019), and molecular design (Gómez-Bombarelli et al., 2018).
However, the latent variable z\mathbf{z} that a VAE produces is typically entangled and lacks two important factors for human parsing: independence and interpretability (Higgins et al, 2016). Disentanglement, on the other hand, implies finding low-dimensional representations of data where single latent directions are sensitive to changes in single generative factors while being relatively invariant to changes in others (Bengio et al., 2013). For a VAE, effectively learning disentangled latents would allow the manipulation of a digit's line-weight or scale, for example, by tweaking a single latent "knob".
In this project, we study VAE disentanglement in three stages using the MNIST dataset. First, we familiarize ourselves with the VAE underpinnings by replicating a portion of the results in Kingma & Welling (2014). Second, we explore β\beta -VAE (Higgins et al., 2016), a VAE extension that introduces a hyperparameter β\beta to hypothetically enforce a more efficient latent representation of the data. Lastly, we investigate to what extent conditioning the dataset x\mathbf{x} on class labels u\mathbf{u} facilitates disentanglement. We perform cross-validation to analyze the impact of β\beta on the ELBO and assess the disentanglement quality of our experiments through visual inspection heuristics. For reproducibility, a Pytorch implementation of our work is available at https://github.com/arpastrana/neu_vae.
2. Background
A VAE aims to maximize the log-likelihood of observed data x\mathbf{x}.
argmaxlogp(x)\arg \max \log p(\mathbf{x})
Parametrizing it with an unobserved variable z\mathbf{z}, this is equivalent to:
argmax∫logp(x∣z)p(z)dz\arg \max \int \log p(\mathbf{x} | \mathbf{z})p(\mathbf{z})d\mathbf{z}
As this equation is typically intractable, a variational approximation to the Evidence Lower Bound (ELBO) is maximized instead, resorting to an auxiliary recognition distribution qϕ(z)q_{\phi}(\mathbf{z}) on the basis of Jenssen's inequality (Hoffman & Johnson, 2016).
logp(x)≥ELBO(θ,ϕ)=Eqϕ(z∣x)[logpθ(x∣z)]−DKL(qϕ(z∣x)∣∣p(z))\log p(\mathbf{x}) \geq ELBO (\theta, \phi) = \mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})} [\log p_{\theta}(\mathbf{x} | \mathbf{z})] - D_{KL}(q_{\phi}(\mathbf{z} | \mathbf{x}) || p(\mathbf{z}))
Where ϕ,θ\phi, \theta are the parameters of a VAE's encoder and decoder respectively, computed by non-linear function approximators such as neural networks.
2.1 Autoencoding Variational Bayes - AEVB (Kingma & Welling, 2014)
Monte Carlo sampling is used to numerically integrate the expectation Eqϕ(z∣x)[logpθ(x∣z)]\mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})} [\log p_{\theta}(\mathbf{x} | \mathbf{z})]. In practice, a single sample per datapoint suffices, which leads to:
ELBO(θ,ϕ)≈logpθ(x∣z)−DKL(qϕ(z∣x)∣∣p(z))ELBO (\theta, \phi) \approx \log p_{\theta}(\mathbf{x} | \mathbf{z})- D_{KL}(q_{\phi}(\mathbf{z} | \mathbf{x}) || p(\mathbf{z}))
The first term of this approximation can be understood as the reconstruction error of the Bernoulli- or Gaussian-decoded data points, while the second corresponds to a regularization term which penalizes the Kullback-Leibler (KL) divergence between the encoding variational posterior qϕ(z∣x)q_{\phi}(\mathbf{z} | \mathbf{x}) and the choice of prior p(z)p(\mathbf{z}) over the unobserved latent variable.
To make the optimization of the ELBOELBO tractable and differentiable, it is assumed that Gaussian distributions parametrize both the posterior qϕ(z∣x)q_{\phi}(\mathbf{z} | \mathbf{x}) and prior p(z)p(\mathbf{z}). This permits the Kullback-Leibler divergence to be computed in closed form:
12∑j=1J(1+log((σj)2)−(μj)2−(σj)2)\frac{1}{2} \sum_{j=1}^{J} (1 + \log((\sigma_j)^2) - (\mu_j)^2-(\sigma_j)^2)
Where JJ is the dimensionality of z\mathbf{z}. Additionally, the Gaussian assumption enables gradients' back-propagation using the reparametrization trick, where the latent variable z∼qϕ(z∣x)=N(μ,σ2)\mathbf{z} \sim q_{\phi}(\mathbf{z} | \mathbf{x}) = \mathcal{N}(\mu, \sigma^2) is calculated with the affine transformation of an auxiliary noise variable ϵ∼N(0,1)\epsilon \sim \mathcal{N}(0, 1):
z=μ+σϵ\mathbf{z} = \mathbf{\mu} + \sigma \epsilon
2.2 β\beta - VAE (Higgins et al., 2016)
To focus on learning statistically independent latent factors, a hyperparameter β\beta was added to the initial VAE formulation by Kingma & Welling (2014) that increases the weight of the KL Divergence term in the ELBOELBO calculation.
ELBO(θ,ϕ)≈logpθ(x∣z)−βDKL(qϕ(z∣x)∣∣p(z))ELBO (\theta, \phi) \approx \log p_{\theta}(\mathbf{x} | \mathbf{z})- \beta D_{KL}(q_{\phi}(\mathbf{z} | \mathbf{x}) || p(\mathbf{z}))
Values of β>1\beta > 1 hypothetically put pressure on the VAE bottleneck to match the prior p(z)p(\mathbf{z}) and thus enforce learning a more efficient latent data representation. A β\beta-VAE with β=1\beta=1 corresponds to the original VAE formulation (Kingma & Welling, 2014). It should be noted, however, that disentanglement comes at the cost of a diminished reconstruction quality of x\mathbf{x}, and that too little or too high β\beta values may not necessarily lead to disentangled outcomes. Therefore, β\beta needs to be calibrated using visual inspection heuristics, or disentanglement metric produced, for instance, by mounting a simple linear classifier atop of a trained VAE in a supervised manner.
2.3 Conditional β\beta - VAE
Recent work in representation learning suggests that well-disentangled models cannot be identified without supervision (Locatello et al., 2019). A recent assumption that leads to identifiability is to build a conditionally factorized prior over the latent variables p(z∣u)p(\mathbf{z}|\mathbf{u}), which is possible by concurrently observing an auxiliary variable u\mathbf{u} that corresponds e.g. to the time index in a time series, previous data points, or class labels (Hyvärinen et al., 2019, Khemakhem et al., 2020). More formally, this implies working with a dataset with observation pairs D={(xi,ui),...,(xn,un)}\mathcal{D} = \{(\mathbf{x}_i, \mathbf{u}_i), ..., (\mathbf{x}_n, \mathbf{u}_n)\} instead of D={xi,...,xn}\mathcal{D} = \{\mathbf{x}_i, ..., \mathbf{x}_n\}. The total data log-likelihood in this case is expressed by:
argmax∫logp(x∣z)p(z∣u)dz\arg \max \int \log p(\mathbf{x} | \mathbf{z})p(\mathbf{z}| \mathbf{u})d\mathbf{z}
3. Experimental Setup
3.1 VAE Architecture
For consistency, we use the same neural architecture as in Kingma & Welling (2014) in all our experiments unless otherwise noted. The size of the mini-batches is fixed at 100, and training is capped at 200 epochs due to computational limitations. In our work, the optimization objective well saturates within this budget. An epoch corresponds to passing one full mini-batch.
3.2 Data
We use the MNIST dataset (LeCun et al., 2010) which consists of 60,000 images of hand-written digits from 0 to 9 in a 28x28 pixel resolution. We leverage automatically randomized data loaders to manage the train and validation sets that are utilized to tune the VAE parameters.
3.3 Gaussian Encoder
The encoder consists of an input layer with 784 units (corresponding to a flattened 28x28 image), followed by a dense multi-layer perceptron (MLP) with 500 units. Two additional MLPs, to calculate the mean and the variance of the latent variable z\mathbf{z} respectively, are attached next with a number of hidden units corresponding to the user-defined dimensionality of z\mathbf{z}.
3.4 Bernoulli Decoder
After applying the reparametrization trick, samples from z\mathbf{z} are fed-forward into the decoder's input MLP with the same number of units as in the encoder's last layer. An intermediate 500-unit MLP and an output 784-unit completed the decoder's setup. All the MLP layers in the VAE are activated using the hyperbolic tangent nonlinearity, except for the output layer in the decoder, where the sigmoid activation is applied.
3.5 Optimization
The optimization objective is set to minimize the negative value of the ELBOELBO. The reconstruction error is computed using the binary cross-entropy, and the KL divergence is calculated analytically since we assume Gaussians for the variational posterior and the latent prior. Adagrad with a fixed step size of 0.01 was used throughout all the experiments.
3.6 Disentanglement Evaluation
To visually examine the disentanglement quality of the learned representations, z\mathbf{z} is traversed in a similar fashion to Higgins et al (2016). Every dimension ziz_i is first zeroed and then traversed over five standard deviations around a unit Gaussian while keeping all the others zj≠iz_{j \neq i} fixed to the values obtained by running inference on an image from the dataset. The updated z\mathbf{z} is then passed to the decoder. The output images are plotted on a grid for inspection where the columns display the traversals of every latent dimension ziz_i.
4. Results
4.1 Autoencoding Variational Bayes
In order to gain a better understanding of how VAEs work, we started off by reproducing a portion of the experiments on the likelihood lower bound presented in Kingma & Welling (2014). We gradually increased the size of the latent variable z\mathbf{z}, experimenting with dimension steps of J∈{3,5,10,20,200}J \in \{3, 5, 10, 20, 200\}, and observed what their impact was on the ELBOELBO. Beyond 10 latent dimensions, only minimal improvements to the magnitude of the lower bound were observed.
Moreover, little evidence of overfitting was found between training and testing sets within the allocated VAE's training time, despite increasing the dimensionality of z\mathbf{z}. This is in accordance with Kingma & Welling, who attributed this phenomenon to the regularizing effect of the variational bound. However, larget latent variables did increase the portion that the KL divergence contributes to the ELBOELBO, ranging from about 6% when J=3J=3 to over 25% when J=200J=200. This can be explained as a larger penalty for simultaneously fitting more dimensions of z\mathbf{z} to a multivariate Gaussian prior.
4.2 β\beta-VAE
Next, we explored the sensitivity of the hyperparameter β\beta by running experiments with different values for β∈{1,3,5,10,20}\beta \in \{1, 3, 5, 10, 20\}. We fixed the number of latent dimensions J=10J=10 in all trials for β\beta-VAE. The ELBOELBO decreased by almost 70% when β=10\beta=10, which suggests that higher values of β\beta lead to lower ELBOELBO estimates. Similarly, the total reconstruction error deteriorates as the value of β\beta increases. However, this relationship is reversed when looking at the contribution of the KL divergence to the ELBOELBO.
In contrast to the original VAE, the image reconstructions are overall blurrier but preserve the digit features. Nevertheless, when β=20\beta=20 digits lose visual meaning: digits stop looking like digits. When β=10\beta=10, for example, some disentanglement patterns begin to emerge. Especially for digits 4,64, 6, and 88, we could observe that latent dimensions z4−5,z7−9z_{4-5}, z_{7-9} concentrate the majority of the changes in the latent traversal, neutralizing the impact of the others. However, the visual recognition of single generative factors corresponding to single feature dimensions is somewhat unclear, especially as the traversals rather expose inter-digit transformations: for example, digit 66 transitions to a 33 along z9z_9, instead of modifying any lower-level generative factor.
4.3 Conditional β\beta-VAE
We investigated the impact of using the digits labels on the quality of disentanglement of a β\beta-VAE. As per Hyvärinen et al. (2019), we concurrently observe an additional variable u\mathbf{u}, which corresponds to the ten class labels of the MNIST dataset. In practice, we embed these as 10-dimensional one-hot vectors and append them to the flattened 784-dimensional representation of a digit, and to the samples of the latent variable z\mathbf{z}. We adjust the VAE architecture to account for these changes: the number of layers in both the encoder and decoder's input layer is increased by ten (now 794794 J+10J + 10 units, respectively).
For these experiments, the number of dimensions of z\mathbf{z} was fixed to 1010. We tested different values for β∈{1,3,5,10}\beta \in \{1, 3, 5, 10 \}. Similar to the unconditioned version, higher β\beta values led to a lower ELBOELBO estimate. What the trend shows, however, is that for β∈{5,10}\beta \in \{5, 10 \} the ELBOELBO is higher for the conditioned version. Moreover, for β∈{1,3,5}\beta \in \{1, 3, 5\} the participation of the KL divergence in the ELBOELBO monotonically increased from 15% until over to 20%, but it interestingly saw a sharp decline when β=10\beta=10.
After training the β\beta-VAE on the labeled dataset, disentanglement was by far clearer compared to the previously unconditioned approach when β=10\beta=10. The most salient finding was that after running inference on images of digits 4,6,84, 6, 8, seven of the ten latent dimensions ziz_i were not changed by the traversals. By contrast, dimensions z6,z9,z10z_6, z_9, z_{10} concentrated all the information related to the digit's generative factors, which we associate to the latent directions that control the line-weight, lateral tilt, and scale of a digit.
4.4 VAEs Comparison
Finally, we have a closer look and compare the behavior of β\beta-VAE in both the conditioned and unconditioned datasets, using a control value of β=1\beta=1 and β=10\beta=10. In terms of the ELBOELBO, the conditioned dataset with β=1\beta=1 exhibits the highest value and unconditioned with β=10\beta=10, the lowest. This concurs with our previous experimental results that showed that higher values of β\beta led to lower ELBOELBOs.
The reconstruction error shows a parallel trend that aligns with the choice of β\beta: whether conditioned or not, the negative binary cross-entropy differs on average by 44 units when β=1\beta=1, while this gap increases to 88 units when β=10\beta=10. One of the most interesting findings is on the KL divergence to ELBOELBO ratio, where conditioning the dataset does make a difference. When β=10\beta=10, this proportion is almost cut from circa 20% down to slightly over 10%.
Since the u\mathbf{u}-conditioned dataset with β=10\beta=10 has produced the best disentanglement results so far, hypothesizing on whether minimizing the participation of the KL divergence in the ELBOELBO calculation while preserving a reasonable reconstruction error leads to good disentanglement is left to future work.
5. Closing Comments
In this project, we investigated how to disentangle the latent variable z\mathbf{z} that VAEs learn by analyzing the effect of the hyperparameter β\beta (Higgins et al., 2016) and of conditioning the data on class labels u\mathbf{u}. We found that using β=10\beta=10 and a conditioned dataset led to the clearest level of disentanglement through our experiments, revealing single latent directions that independently control the line-weight, tilt, and scale of the MNIST digits. Our finding supports the idea that good disentanglement is contingent upon some level of supervision, which we provided by utilizing the class labels of the MNIST dataset. However, our experiments also suggest that using a moderate value of β\beta to scale the KL Divergence term in the ELBOELBO was instrumental to obtain a disentangled latent.
In fact, disentangled representation learning is an active area of research with plenty of interesting challenges ahead. Several routes to extend our work are thus outlined. First, making use of a consistent and robust disentanglement metric is necessary since resorting to visual inspection is cumbersome. Utilizing a setup similar to the prediction-based measurements proposed by Higgins et al (2016) would be a natural next step. An alternative approach to disentanglement we would like to explore later is to directly perform Independent Component Analysis (ICA) on a VAEs learned z\mathbf{z} and to evaluate whether this post-process facilitates finding disentangled directions without having to use the hyperparameter β\beta or any class labels. Working with better-suited VAE architectures (e.g. using convolutional layers for image data instead of linear MLPs), may impact the quality of the learned latent representations as well. Lastly, we foresee working with richer non-Gaussian latent priors, which may ultimately better capture the hidden and disentangled generative structure of a dataset.
6. References
Bengio, Y., Courville, A., & Vincent, P. (2013). Representation Learning: A Review and New Perspectives. IEEE Transactions on Pattern Analysis and Machine Intelligence, 35(8), 1798–1828. https://doi.org/10.1109/TPAMI.2013.50
Gómez-Bombarelli, R., Wei, J. N., Duvenaud, D., Hernández-Lobato, J. M., Sánchez-Lengeling, B., Sheberla, D., Aguilera-Iparraguirre, J., Hirzel, T. D., Adams, R. P., & Aspuru-Guzik, A. (2018). Automatic Chemical Design Using a Data-Driven Continuous Representation of Molecules. ACS Central Science, 4(2), 268–276. https://doi.org/10.1021/acscentsci.7b00572
Higgins, I., Matthey, L., Pal, A., Burgess, C., Glorot, X., Botvinick, M., Mohamed, S., & Lerchner, A. (2016). beta-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework. https://openreview.net/forum?id=Sy2fzU9gl
Hoffman, M. D., & Johnson, M. J. (2016, December). Elbo surgery: yet another way to carve up the variational evidence lower bound. In Workshop in Advances in Approximate Bayesian Inference, NIPS (Vol. 1, p. 2). http://approximateinference.org/2016/accepted/HoffmanJohnson2016.pdf
Khemakhem, I., Kingma, D., Monti, R., & Hyvarinen, A. (2020). Variational Autoencoders and Nonlinear ICA: A Unifying Framework. In S. Chiappa & R. Calandra (Eds.), Proceedings of the Twenty Third International Conference on Artificial Intelligence and Statistics (Vol. 108, pp. 2207–2217). PMLR. http://proceedings.mlr.press/v108/khemakhem20a.html
Kingma, D. P., & Welling, M. (2014). Auto-Encoding Variational Bayes. ArXiv:1312.6114 [Cs, Stat]. http://arxiv.org/abs/1312.6114
LeCun, Y., Cortes, C., & Burges, C. (2010). MNIST handwritten digit database. AT&T Labs. http://yann.lecun.com/exdb/mnist/
Locatello, F., Bauer, S., Lucic, M., Raetsch, G., Gelly, S., Schölkopf, B., & Bachem, O. (2019). Challenging Common Assumptions in the Unsupervised Learning of Disentangled Representations. In K. Chaudhuri & R. Salakhutdinov (Eds.), Proceedings of the 36th International Conference on Machine Learning (Vol. 97, pp. 4114–4124). PMLR. http://proceedings.mlr.press/v97/locatello19a.html
Roberts, A., Engel, J., Raffel, C., Hawthorne, C., & Eck, D. (2019). A Hierarchical Latent Vector Model for Learning Long-Term Structure in Music. ArXiv:1803.05428 [Cs, Stat]. http://arxiv.org/abs/1803.05428
Vahdat, A., & Kautz, J. (2020). NVAE: A Deep Hierarchical Variational Autoencoder. ArXiv:2007.03898 [Cs, Stat]. http://arxiv.org/abs/2007.03898