DeepAI
Log In Sign Up

Beta-VAE Reproducibility: Challenges and Extensions

β-VAE is a follow-up technique to variational autoencoders that proposes special weighting of the KL divergence term in the VAE loss to obtain disentangled representations. Unsupervised learning is known to be brittle even on toy datasets and a meaningful, mathematically precise definition of disentanglement remains difficult to find. Here we investigate the original β-VAE paper and add evidence to the results previously obtained indicating its lack of reproducibility. We also further expand the experimentation of the models and include further more complex datasets in the analysis. We also implement an FID scoring metric for the β-VAE model and conclude a qualitative analysis of the results obtained. We end with a brief discussion on possible future investigations that can be conducted to add more robustness to the claims.

READ FULL TEXT VIEW PDF

page 3

page 7

page 10

12/11/2019

Variational Learning with Disentanglement-PyTorch

Unsupervised learning of disentangled representations is an open problem...
11/26/2019

A Preliminary Study of Disentanglement With Insights on the Inadequacy of Metrics

Disentangled encoding is an important step towards a better representati...
02/05/2019

Relevance Factor VAE: Learning and Identifying Disentangled Factors

We propose a novel VAE-based deep auto-encoder model that can learn dise...
05/19/2020

Unsupervised anomaly localization using VAE and beta-VAE

Variational Auto-Encoders (VAEs) have shown great potential in the unsup...
02/20/2021

GroupifyVAE: from Group-based Definition to VAE-based Unsupervised Representation Disentanglement

The key idea of the state-of-the-art VAE-based unsupervised representati...
11/15/2019

Gated Variational AutoEncoders: Incorporating Weak Supervision to Encourage Disentanglement

Variational AutoEncoders (VAEs) provide a means to generate representati...
12/23/2019

The Usual Suspects? Reassessing Blame for VAE Posterior Collapse

In narrow asymptotic settings Gaussian VAE models of continuous data hav...

1 Introduction

Variational autoencoders (DBLP:journals/corr/KingmaW13) are a class of unsupervised representation learning models with a principled probabilistic interpretation that extends normal autoencoders first described by hinton2006reducing. -VAE is a follow-up technique that proposes special weighting of the KL divergence term in the VAE loss to obtain disentangled representations. However, unsupervised learning is notoriously brittle even on toy datasets and a meaningful, mathematically precise definition of disentanglement remains difficult to find.

It is thus not obvious to what extent -VAEs can robustly obtain disentangled representations in different settings. The main contributions of our reproducibility report can be summarised as follows:

  1. We add to the evidence provided by follow-up work (pmlr-v80-kim18b; locatello2020sober) that the almost perfect performance presented by higgins2016beta is very difficult to reproduce.

  2. We demonstrate that does not continue to yield the best quantitative disentanglement results for very complex datasets.

  3. We show that high disentanglement metric scores do not imply a qualitative disentanglement.

  4. We quantitatively assess how lower values give better reconstructions of the original images.

1.1 VAE framework

VAEs are a special class of deep generative models optimised via variational inference, which allows one to approximate intractable distributions in Bayesian inference by solving an optimisation problem. Assume we have a directed latent variable model

(1)

and we observed a dataset

. Many standard techniques such as Expectation-Maximization do not scale to the large-scale deep learning setting because they require computing the

, for which the normalization constant is not available. We can avoid the need for precise normalization constants by using a variational approximation and instead optimising the evidence lower bound (ELBO) as in Equation (2), where is the marginal likelihood or model evidence:

ELBO (2)
(3)
(4)
(5)

However, obtaining gradients of the ELBO with respect to the variational parameters is difficult, because we cannot safely exchange derivatives and integrals in this case. Instead, VAEs crucially rely on using the reparameterization trick for computing Monte Carlo estimates of the gradient, typically using a single minibatch. The reparameterization trick is crucial, since it produces an estimator with much lower variance than more general-purpose Monte Carlo estimators, such as the score function estimator. However, reparameterization requires working with continuous distributions and makes VAEs difficult to apply in the discrete setting, albeit not impossible

(DBLP:conf/iclr/MaddisonMT17; NIPS2017_7a98af17). Standard VAEs typically use an isotropic Gaussian prior for the KL divergence, which enables computing the KL divergence analytically.

1.2 -VAE improvements

higgins2016beta propose to increase the weighting of the KL divergence in Equation (2). This should in turn enforce a greater similarity between the posterior and the prior , which leads to greater disentanglement. bengio2013representation define disentangled representations as the property that latent variables are sensitive to one of the ground truth generative factors, but invariant to others. Because the VAE prior is typically chosen to be a Gaussian with diagonal covariance matrix, its dimensions are independent, which can be seen as disentangled.

While the

-weighted KL loss can be seen merely as a heuristic addition to the normal autoencoder, follow-up work exploiting the information-theoretic nature of KL divergence has led to further improvements in the algorithm

(NEURIPS2018_1ee3dfcd; DBLP:journals/corr/abs-1804-03599)

. Regardless, large-scale reproduction studies show the importance of random seeds and hyperparameter settings to be at least comparable to model choice

(locatello2020sober).

2 Methodology

2.1 Datasets

higgins2016beta use a number of standard image datasets to evaluate both image generation properties of -VAEs and specially designed datasets for evaluating disentanglement of ground truth generative factors.

2.1.1 Disentanglement evaluation - 2Dshapes, 3Dshapes, MPI3DToy

To assess the disentanglement quantitatively, we use three synthetic datasets that come with ground truth generative factors, with samples of each shown in Figure 1. 2Dshapes was originally created by higgins2016beta, consisting of 737,280 2D shapes that are generated from the Cartesian product of five ground truth independent latent factors.

We hypothesise that the 2Dshapes dataset is too easy to solve, as evidenced by the fairly high scores of PCA and ICA, and further check the robustness of -VAEs on harder, more recent datasets; specifically, RGB datasets 3Dshapes (3Dshapes18) and MPI3DToy (gondal2019transfer). 3Dshapes consists of 480,000 images generated from six data generative factors. MPI3DToy is the most complex dataset containing 1,036,800 images from seven data generative factors captured from real-world robotics experiments. Due to computational constraints, we only use the toy version of this dataset which contains renders of the real scenes, rather than the scenes themselves.

2.1.2 Image generation - CelebA, Chairs, CIFAR10, and CIFAR100

CelebA and Chairs are the two datasets utilised by higgins2016beta for qualitative evaluation. Due to hardware constraints, we instead qualitatively inspect models trained on CIFAR10 and CIFAR100, as they are standard benchmark datasets for evaluating generative models (shmelkov2018good). These datasets are non-synthetic, and samples are shown in Figure 2.

Figure 1: Samples from the datasets 2Dshapes, 3Dshapes and MPI3DToy
Figure 2: Samples from the CIFAR10 and CIFAR100 datasets

2.2 Metrics

To evaluate the disentanglement of a representation, we adopt the approach presented in higgins2016beta. This can be used to evaluate disentanglement directly without having to rely on qualitative inspection.

higgins2016beta

suggest that there exists a trade-off between generated image quality and level of disentanglement. This is at odds with the notion that disentangled representations should lead to superior performance in downstream tasks. To quantitatively investigate the extent to which higher

harms generative model quality while enhancing disentanglement, we adopt Fréchet Inception Distance as a state-of-the-art metric for evaluating reconstruction quality of generative models (heusel2017gans). higgins2016beta only evaluate image quality qualitatively by inspection.

2.2.1 Fréchet Inception Distance

Fréchet Inception Distance (FID) is a metric used to assess the reconstruction quality of images produced by generative models, usually a generative adversarial network (GAN). Rather than comparing generated and real images on a pixel-by-pixel basis, the FID compares the distribution of the activations of the final layer of a pretrained InceptionV3 model

(szegedy2016rethinking). This layer corresponds to high-level features of objects (such as airplanes) and thus captures the human notion of similarity in images (heusel2017gans). FID is based on the Wasserstein metric and can be computed as:

(6)

where

and

are Gaussians fit to the 2048-dimensional activations of the last InceptionV3 pooling layer for real and generated samples, respectively (FrechetI11). As we used a pretrained model, images need to be scaled to the correct size and have colour channels added if they were black and white.

We use FID to examine the trade-off between disentanglement and reconstruction quality since the best performing models for image generation (i.e. GANs) use FID for evaluation (dai2019diagnosing; lucic2017gans). A lower FID score corresponds to real and generated samples being more similar. WGAN-GP (NIPS2017_892c3b1c), a GAN of similar age as the -VAE, is noted to achieve an FID of 29.3 on CIFAR10 (heusel2017gans). Early iterations of VAEs are generally understood to have poor image generation capabilities compared to contemporary GANs. (shmelkov2018good; makhzani2015adversarial). However, recent VAE-based models (parmar2020dual) achieve a CIFAR10 FID of 17.9, and are competitive with GAN-based approaches.

Limitations when using FID as a metric are indicated by shmelkov2018good. FID cannot separate image quality from image diversity. For example, a poor FID score can be due to the reconstructed images either being unrealistic (low image quality) or too similar to each other (low diversity), with no way to analyze the cause.

2.2.2 Disentanglement Metric

The disentanglement metric refers to the framework proposed by higgins2016beta that is meant to quantify the level of disentanglement in deep generative models by measuring the independence and interpretability of their latent representation. The idea is that for a disentangled representation, images generated from fixing one factor of variation and randomly sampling all others should result in a relatively lower variance in the latents corresponding to

. The lower this variance is, the easier it will be to predict the corresponding data generative factor. Therefore, we can measure the disentanglement by reporting the accuracy of a classifier identifying the corresponding data generative factor given the latent representation.

An assumption on the dataset is that its elements are generated from a true world simulator

(7)

where with are conditionally independent factors and are conditionally dependent factors. In particular, these ground truth factors need to be known for computing the metric. In all of the datasets that we consider in this report for computing the metric, the data is generated by independent factors alone. The full procedure is outlined in algorithm 1.

Result: Classification accuracy
Uniformly sample one of the given data generative factors ;
for  do
       for  do
             Sample a pair such that they agree on their th value ;
             Simulate images , ;
             Infer the expectation values of the latent Gaussians , ;
             Compute elementwise differences and absolute values ;
            
       end for
      
end for
Train a classifier using the inputs to predict ;
Report classification accuracy on a test set.
Algorithm 1 The Disentanglement Metric (higgins2016beta)

Note that rather than using the expectation values of the latent Gaussian for one simulated image, a pair of images is sampled, and the absolute difference of their latent representations is computed to reduce the variance of the classifier’s inputs and to lower the conditional dependence on the input images . The classifier is taken to be linear to ensure the interpretability of the inferred latents and not learn any nonlinear disentanglement itself.

One drawback of this metric is its dependence on hyperparameters, such as the choice of the classifier, the optimiser, and most significantly the sample size . This is already noted by pmlr-v80-kim18b

. We investigate this further by reporting the accuracy scores of a linear and nonlinear MLP, a logistic regression classifier and a random forest classifier from the

SKlearn library. Furthermore, we compare the scores for different values of the sample size in generating training data for the metric.

In addition, pmlr-v80-kim18b show that there is a mode in which the classifier reports 100% accuracy while only factors are disentangled. Issues like this could account for the discrepancy between quantitative scores and qualitative observations as described in Section 3.1.3.

2.3 Models

higgins2016beta originally use a VAE where both the encoder and decoder are MLPs for the 2Dshapes dataset and a Convolutional VAE for their remaining experiments. Follow-up work (DBLP:journals/corr/abs-1804-03599; locatello2020sober; pmlr-v80-kim18b) instead uses a Convolutional VAE across all experiments, including 2Dshapes.

We report results for both the MLP and Convolutional VAEs on 2Dshapes and only use the Convolutional VAE for other experiments as we found it to give better results. As optimisers, we use Adagrad (lr = 1e-2) for the MLP and Adam (lr = 5e-4) for the Convolutional VAE as done in the original papers. We also use PCA and ICA as non-deep baselines, following higgins2016beta. All models are trained with multiple seeds. Full details on the model, hyperparameters, and training protocols are available in Appendix A.

3 Results

In this section, we present the results using the datasets described in Section 2.1 and the metrics from Section 2.2 to evaluate -VAEs in terms of their ability to learn a disentangled representation and to reconstruct images.

3.1 Disentanglement

First, we present the scores of the disentanglement metric on the dataset used by higgins2016beta and compare them to baseline methods PCA and ICA. Next, we investigate the behaviour of -VAE on the more complex three-dimensional datasets 3Dshapes and MPI3DToy. Finally, we draw a connection between quantitative and qualitative evaluation of disentanglement to see if these two notions coincide.

3.1.1 Disentanglement on 2Dshapes

Table 1 presents the disentanglement scores achieved on the 2Dshapes dataset. It is important to note that among the total five ground truth factors, higgins2016beta disregard ‘shape’ when sampling data generative factors for the disentanglement metric. This is presumably because -VAE struggles to learn a disentangled representation for this data generative factor, as evidenced by the traversals in Figure 7 from higgins2016beta (reprinted in Figure 4) all having a very similar shape.

5 factors 4 factors
Model Mean Median Mean Median higgins2016beta
-VAE 65.13 4.98% 64.79% 81.89 2.33% 81.82% (–)
-VAE 64.31 6.56% 63.79% 82.39 5.79% 81.37% 61.58 0.5%
-VAE 74.53 15.06% 76.62% 89.09 15.82% 95.15% 99.23 0.1%
-VAE 76.77 11.44% 79.67% 86.86 15.5% 88.81% (–)
PCA 69.79 4.22% 71.53% 88.95 3.89% 89.91% 84.9 0.4%
ICA 68.99 1.68% 69.38% 83.94 2.16% 84.12% 42.03 10.6%
Table 1: Disentanglement metric scores on 2Dshapes for -VAE, , PCA and ICA.

In Table 1 we compare the metric scores using all five data generative factors to the scores obtained by disregarding the ‘shape’ data generative factor. Furthermore, we reprint the results presented in higgins2016beta for comparison. We see that the scores using all five data generative factors are lower across all models confirming the conjecture mentioned above. While our results show the superiority of -VAE over regular VAE for , we do not obtain a difference as big as that in higgins2016beta. A potential reason for this may be that higgins2016beta discard the worst performing 50% of their training runs due to training instabilities. We also observe this instability especially for

, as suggested by the high standard deviations. In the rest of this report, we report median by default as a robust measure of performance even in case some of the training runs diverge. Furthermore, we note that ICA has a much stronger performance in our results. We found that fine-tuning ICA parameters was crucial for improving its scores.

3.1.2 Beyond 2D datasets

Next, we investigate the disentanglement scores on the more complex datasets 3Dshapes and MPI3DToy. Because images in these datasets have three colour channels, a convolutional architecture for the encoder and decoder is more suitable. Therefore, we trained -VAE using the architecture proposed in burgess2018understanding, which is also used by follow-up work on disentangling in VAEs, on all of the three datasets with ground truth variation factors. For 2Dshapes in particular, it leads to significantly improved disentanglement scores, which often reach the almost 100 accuracy reported by higgins2016beta when using only four latents on 2Dshapes and around 90 for all five. These Convolutional VAE experiments are illustrated in Figure 3, including the five latent 2Dshapes.

Figure 3: Final median disentanglement scores using the Conv. VAE architecture from burgess2018understanding for different values of , PCA and ICA on the datasets 2Dshapes, 3Dshapes and MPI3DToy.

-VAE manages to reach very high scores on 3Dshapes as well. However, PCA and ICA experience a significant drop in the disentanglement metric. This may be due to the fact that we flattened the 3-channel images to be in a suitable shape for these methods. On the most complex dataset MPI3DToy we see that the best scores are actually reached by followed by .

Compared to MPI3DToy, the 3Dshapes dataset has much higher contrasts and the shapes are more regular, as illustrated in Figure 1. Furthermore, it is generated from six factors of variation compared to the seven of MPI3DToy. We hypothesise that the reason why the lowest performs the best is that since MPI3DToy is the hardest dataset and it has the highest relative magnitude of reconstruction loss to the KL loss. This implicitly creates a different weighting of the two loss terms, suggesting that needs to be scaled according to the difficulty of the dataset.

3.1.3 Quantitative vs qualitative evaluation of disentanglement

The Higgins disentanglement metric tends to assign very high scores that do not correlate well with human judgement of the level of disentanglement. In Figure 4, we first encode the latents of a true data sample from 2Dshapes and then vary the latent dimensions individually. The true generative factors are position X, position Y, scale, rotation and shape. All models achieve almost perfect disentanglement quantitatively, but upon visual inspection, generally only two out of five latents are learnt well and even those might be entangled.

Figure 4: Example latent traversals on 2Dshapes from our MLP VAE (left, higgins2016beta), our Convolutional VAE (middle, DBLP:journals/corr/abs-1804-03599) and the original MLP VAE (right, reprinted from higgins2016beta). Latent traversals are done across columns. While the disentanglement scores are very high, qualitative evaluation suggests that the level of disentanglement is quite low.

3.1.4 Latent space visualization

Figure 5 shows posterior latents on the 2Dshapes dataset embedded by the first two components of PCA. In order to generate the data, we first seeded the VAE with a picture that had the median values for each of the five ground truth generative factors. Then we conducted a traversal across each dimension in the ground truth latent space (rather than the model posterior latents as in Figure 4).

When the parameter is set low, the embeddings of the ground truth traversal are sparse and appear to have much more regular structure than for higher

s. This is likely because the KL divergence with a Normal prior penalizes embeddings that are far away from the mean since Normal distributions have thin tails. As a result, the model posterior latents become much more concentrated near the mean and do not extend very far into the latent space.

spinner_towards_2018 explain how low makes the embeddings more similar to those of a standard autoencoder. This suggests that the posterior latent distribution of standard autoencoders (or VAEs with low for that matter) is strongly concentrated on the training samples. Sampling new images is difficult because decoding a randomly sampled latent code in the autoencoder case is unlikely to generate meaningful images because we are almost certainly not going to hit the part of the latent space that the model learnt to decode during training. However, the KL divergence regularization for causes the model to utilise the whole latent space rather than a small part of it. This in turn makes the model work better for generating new data. In Section 4.2, we formalise the intuition of controlling the posterior latent variance when assuming Gaussian likelihood over the decoded pixels.

Figure 5: Dimensionality reduction of the latent space to the first two components of PCA. Each colour corresponds to variation of one of latent factor. Higher results in denser embeddings of the latent factors.

3.1.5 Sensitivity analysis of the disentanglement metric

While the metric introduced by higgins2016beta has a number of parameters, no guidance is provided on what values to choose. Follow-up work pmlr-v80-kim18b notices a particular importance of the parameter which represents a sample size for generating the training samples. In Figure 6, we train the linear classifier for and find that higher values improve the average disentanglement score significantly for PCA and ICA in particular. Moreover, it causes the -VAEs to have increasingly similar performance which makes it harder to use the metric for model selection.

As seen in Table 2, we also note that the choice of classifier does not have a significant effect on the disentanglement metric score. This contradicts the assertion of higgins2016beta, who claim that using a non-linear classifier itself may disentangle the results.

Figure 6: Final disentanglement scores for various parameters of the disentanglement metric evaluated with the MLP VAE on 2Dshapes.

Linear (PyTorch)

MLP Random Forest Linear (SKlearn)
2Dshapes 77.62 13.04% 78.8 12.73% 80.16 12.29% 80.73 12.03%
3Dshapes 91.98 6.82% 92.61 6.82% 94.45 6.52% 94.89 5.67%
MPI3DToy 33.21 14.38% 34.78 15.72% 36.25 16.91% 35.86 16.82%
Table 2: Disentanglement scores using different classifiers aggregated across values on 2Dshapes

3.2 Reconstruction

-VAE -VAE -VAE -VAE WGAN-GP
CIFAR10 197.8 171.1 221.1 266.5 29.3
CIFAR100 132.1 163.5 204.6 252.9 N/A
Table 3: FID (lower is better) for different values on CIFAR10 and CIFAR100 compared with WGAN-GP (NIPS2017_892c3b1c).

Table  3 shows the computed FID values for various values of the parameter and Figures 10, 10 show the corresponding reconstructions. The already very high values of FID achieved by the standard VAE with steeply increase as we make the larger. Remarkably, CIFAR10 achieves the best FID for , but CIFAR100 does so for . rybkin2020simple observe similar behaviour across datasets where small but nonzero tends to be the best, suggesting that cannot be seen as just a regularization hyperparameter given that it can improve the model performance even on training data. Interestingly, CIFAR100 consistently achieves lower FID scores than CIFAR10, even though CIFAR100 is a more complex dataset.

We have already shown in Figure 3 that for MPI3DToy, which is a dataset containing ground truth variation factors with complexity the closest to CIFAR10, the best values of disentanglement were also produced for . We found the reconstructions for MPI3DToy to also be significantly lower quality than for 2Dshapes or 3Dshapes. This suggests that the model first needs to be able to produce good reconstructions, and then it can specialise in producing disentangled representations. There is likely to be a range of values for which both the reconstruction and disentanglement improve simultaneously, and only later there starts to be a trade-off between those two notions of representation quality. Based on this, we hypothesise that in order to scale disentanglement beyond toy datasets such as 2Dshapes, it is necessary to use models which are simultaneously very strong at image generation.

In Figure 8 and Figure 8, we show FID and reconstruction loss for training on CIFAR100. We found that even though VAEs tend to be evaluated by the negative log-likelihood, there is almost no qualitative difference between the behaviour of FID and NLL in our models (dai2019diagnosing; lucic2017gans)

. The reconstructions for the CIFAR datasets can be seen in Figure 

10 and Figure 10. It is clear that lower values of enable reconstructions that have a far higher quality, albeit ones that are still inferior to ones produced by state-of-the-art GANs, as measured by FID (heusel2017gans; NIPS2017_892c3b1c). Finally, we note that these qualitative inspections coincide with the quantitative evaluation, unlike the disentanglement metric.

Figure 7: FID during training of CIFAR100 (for varying )
Figure 8: Reconstruction loss during CIFAR100 training (for varying )
Figure 9: Reconstructions for CIFAR100 with
Figure 10: Reconstructions for CIFAR100 with

4 Discussion

4.1 Broad reproducibility

Neither us nor follow-up work have been able to reproduce the original results of higgins2016beta. The work by pmlr-v80-kim18b; locatello2020sober was unable to significantly exceed 80 accuracy on 2Dshapes, regardless of the parameter setting. This contrasts with the 99 accuracy in the original paper. locatello2020sober extensively show that even other metrics proposed in the literature exhibit high variance and inconsistency across datasets for the same model. In our results, it was crucial to both ignore the ‘shape’ generative factor in 2Dshapes and to only consider median performance. Those things were only mentioned in the Appendix of the original paper but are vital for achieving close to the same performance.

4.2 VAE with MSE loss

-VAE has a particular interpretation when assuming Gaussian likelihood over the decoded pixels. Using MSE loss is effectively equivalent to the normal VAE formulation but with a calibrated Normal prior. In the Gaussian likelihood case, we normally have and the log-likelihood of the decoder is then

(8)

where is the prediction of the decoder and D is the dimensionality of the dataset. rybkin2020simple show that if we instead assume , the log likelihood would be

(9)

and the full VAE objective is

(10)

This is now very similar to the -VAE objective since plays the same weighting role as , and if is assumed to be constant, the term disappears during optimisation. This shows the equivalence of -VAE to the standard VAE in this case.

Some of the experiments by higgins2016beta use MSE loss and show results superior to the standard VAE. However, such examples cannot be used to show superiority of -VAE as a model class since it becomes equivalent to the standard VAE with a different prior.

4.3 Implementation details

We have initially adopted an online implementation of various VAE models (yanndubs2019), to which we added the disentanglement metric, more datasets, FID score, and more model options. In order to reproduce the results better, we later tried re-implementing the core model and training code while crosschecking it via other sources. The implementation of the disentanglement metric was initially constructed from scratch; we later tried another freely available implementation of the disentanglement metric (noauthor_google-researchdisentanglement_lib_2021), but achieved the same results. To ensure our implementation of FID was accurate, we based our code off the standard implementation by Seitzer2020FID.

Our final code used to produce the results in this report is mostly custom-made and is available at https://github.com/Mandelbrot99/BetaVAE.

5 Conclusion

In this report, we studied the performance of -VAE in terms of disentanglement and reconstruction across a variety of datasets. First, we observed that the results originally reported by higgins2016beta are difficult to reproduce. They rely heavily on discarding of the worst 50 performing random seeds as well as simplifying the learning task, despite both of those aspects only being mentioned in the Appendix of the paper. The newly proposed disentanglement metric fails to fully capture human-interpretable disentanglement, as evidenced by qualitative evaluation. We show that even the hyperparameters of the disentanglement metric itself can be used to artificially boost the scores.

On more complex datasets, we noted that the intuition suggested by higgins2016beta of yielding better disentanglement scores breaks down. Rather, gives the best results on MPI3DToy. We hypothesise that this can be attributed to different relative magnitudes of the reconstruction loss and the KL divergence on more complex datasets. Finally, we confirmed that FID is strongly correlated with the reconstruction loss and that both notions coincide with a visual inspection of the reconstructed images. However, the best FID scores are not necessarily achieved for the lowest value which is somewhat counter-intuitive, similar to how the best disentanglement can be achieved for .

Our work could be further improved by running the experiments for many more random seeds than we did, but we were constrained by limited hardware. It is known that the variance of performance across random seeds and hyperparameters is generally even greater than across model choices (locatello2020sober). This weakens the robustness of our claims, given that we were not able to run very large-scale experiments. We would also be very interested in seeing more results on the disentanglement-reconstruction trade-off described in Section 3.2, which would again require very extensive experiments with modern VAE-based models. For example, we would like to use the real MPI3D instead of the toy version to see if our results generalise to a truly real-world setting.

References

Appendix

A Model and Hyperparameters Details

A summary of the model architectures we used is shown in Table 4. The MLP model for 2Dshapes follows higgins2016beta but the Convolutional VAE is adopted from burgess2018understanding because other follow-up work typically uses the same architecture.

The settings of the optimiser likewise generally use the same parameters across all experiments except for 3Dshapes, where we found it necessary to decrease the learning rate to 1e-4. The 2Dshapes experiments use 256 batch size while the remaining datasets use 64. We also reduced the learning rate by a factor of 5 for last 25epochs of training on all experiments.

All 2Dshapes experiments were ran using four seeds 123, 427, 235, 921. MPI3DToy and 3Dshapes used 123, 427, 235. CIFAR10 and CIFAR100 experiments used 4723, 1263.

B PCA and ICA implementation

Because all common implementations of ICA and PCA are not able to train from minibatches and instead use expensive linear algebra routines with the full dataset, we had to significantly limit the training data size in order to execute the training. For PCA, we use 25000 samples for BW datasets and 3500 for RGB datasets while for ICA, we use 2500 for BW and 1000 for RGB datasets, respectively. Additionally, RGB images are flattened since standard PCA/ICA implementations are not able to support images with multiple channels.

 

Dataset Optimiser Architecture

 

2Dshapes (MLP) Adagrad
1e-2
256
Input
Encoder
Latents
Decoder
4096 (flattened 64x64x1).

FC 1200, 1200. ReLU activation.


10
FC 1200, 1200, 1200, 4096. Tanh activation. Bernoulli.
2Dshapes (Conv VAE) Adam
5e-4
256
Input
Encoder
Latents
Decoder
4096 (flattened 64x64x1).

Conv 32x4x4 (stride 2) 3x, FC 256 2x. ReLU activation.


10
Deconv reverse of encoder. ReLU activation. Bernoulli.
3Dshapes (Conv VAE) Adam
1e-4
64
Input
Encoder
Latents
Decoder
12288 (flattened 64x64x3).
Conv 32x4x4 (stride 2) 3x, FC 256 2x. ReLU activation.
10
Deconv reverse of encoder. ReLU activation. Bernoulli.
MPI3dToy (Conv VAE) Adam
5e-4
64
Input
Encoder
Latents
Decoder
12288 (flattened 64x64x3).
Conv 32x4x4 (stride 2) 3x, FC 256 2x. ReLU activation.
10
Deconv reverse of encoder. ReLU activation. Bernoulli.
CIFAR10 (Conv VAE) Adam
5e-4
64
Input
Encoder
Latents
Decoder
12288 (flattened 64x64x3).
Conv 32x4x4 (stride 2) 3x, FC 256 2x. ReLU activation.
128
Deconv reverse of encoder. ReLU activation. Bernoulli.
CIFAR100 (Conv VAE) Adam
5e-4
64
Input
Encoder
Latents
Decoder
12288 (flattened 64x64x3).
Conv 32x4x4 (stride 2) 3x, FC 256 2x. ReLU activation.
128
Deconv reverse of encoder. ReLU activation. Bernoulli.
Table 4: Model Architecture Details