Generated Loss and Augmented Training of MNIST VAE

04/24/2019 ∙ by Jason Chou, et al. ∙ 20

The variational autoencoder (VAE) framework is a popular option for training unsupervised generative models, featuring ease of training and latent representation of data. The objective function of VAE does not guarantee to achieve the latter, however, and failure to do so leads to a frequent failure mode called posterior collapse. Even in successful cases, VAEs often result in low-precision reconstructions and generated samples. The introduction of the KL-divergence weight β can help steer the model clear of posterior collapse, but its tuning is often a trial-and-error process with no guiding metrics. Here we test the idea of using the total VAE loss of generated samples (generated loss) as the proxy metric for generation quality, the related hypothesis that VAE reconstruction from the mean latent vector tends to be a more typical example of its class than the original, and the idea of exploiting this property by augmenting training data with generated variants (augmented training). The results are mixed, but repeated encoding and decoding indeed result in qualitatively and quantitatively more typical examples from both convolutional and fully-connected MNIST VAEs, suggesting that it may be an inherent property of the VAE framework.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 5

page 6

page 12

page 15

page 16

page 19

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The variational autoencoder (VAE) framework (Kingma & Welling, 2013) has been a popular options for training deep generative models. It is a maximum likelihood model and establishes a lower bound (evidence lower bound, ELBO) on the likelihood of observed data by factoring out a latent representation. The VAE encoder is trained to encode training examples into posterior distributions in the latent space, and the VAE decoder is trained to reconstruct the training examples from the latent vectors sampled from these posterior distributions. To maximize ELBO, the model needs to minimize both mistakes in reconstruction (‘reconstruction loss’) and the difference between the posterior distributions and an assumed prior (‘latent loss’, measured by KL-divergence). If the latent loss dominates, posterior distributions of the latent vectors collapse to the assumed prior and the latent vector no longer carries any information about the training example (posterior collapse). Models of continuous data types in posterior collapse often hedge their bets by always outputting a weighted average of the training examples. In addition, models capable of autoregression fall back to it in posterior collapse.

A less severe but more pervasive issue of VAE is low-precision reconstructions and generated samples (Sajjadi et al., 2018). It has been theorized that this issue still comes down to the strength of the KL-divergence term. With strong KL-divergence term, posterior distributions of sampled latent vectors of two distinct training examples overlap and result in ambiguous and therefore low-precision generated samples. On the other hand, there may be holes in the latent space that is not covered by any posterior distributions when the KL-divergence term is weak, and such holes will result in samples unconstrained by training data (Rezende & Viola, 2018). In order to steer clear of the posterior collapse and optimize posterior distributions for generation, the KL-divergence weight

is often introduced as a hyperparameter

(Higgins et al., 2017) and sometimes used in conjunction with manual annealing over training steps (Bowman et al., 2015; Sønderby et al., 2016; Akuzawa et al., 2018). While tuning of has shown to be effective, hyperparameter sweep over it is a laborious process and auto-tuning of it w.r.t. reconstruction error constraints has been recently proposed (Rezende & Viola, 2018). It is perhaps worth pointing out a possible disconnect between the theoretical motivation and the use in practice: VAEs are designed to maximize the likelihood of training examples but often used as generative models. It is natural to wonder whether VAEs themselves find their generated sample to be a likely observation.

In this short paper we follow up Chou & Hathi (2019) and test the following ideas with a simple convolutional MNIST VAE: 1. Measuring the generated loss, the total VAE loss of generated samples as if they were training or testing data, as a guiding metric. 2. Repeated encoding and decoding using the mean latent vector result in more typical examples of their class than the original. 3. Exploiting 2. by augmenting training data with generated variants, as a variational method for training VAEs to achieve improved generation quality. The overall results are mixed, but 2. does hold up for both convolutional and fully-connected MNIST VAE. We will draw comparisons and contrast with the original result, which focuses on discrete data.

2 Background

2.1 -Vae

To recap, VAE is trained to maximize the evidence lower bound (ELBO) of the log-likelihood of training examples

(1)

where is the KL-divergence between two distributions and is the latent vector, whose prior distribution is most commonly assumed to be multivariate unit Gaussian. is given by the decoder, and is the posterior distribution of the latent vector given by the stochastic encoder, whose operation can be made differentiable through the reparameterization trick , if is assumed to be a diagonal-covariance Gaussian.

A common modification to the ELBO of VAE is to add a hyperparameter to the KL-divergence term and use the following objective function:

(2)

where controls the strength of the information bottleneck on the latent vector. For higher values of , we accept lossier reconstruction, in exchange of higher effective compression ratio. In both cases, the generator

samples from the probability distribution given by the decoder

where is a random latent vector in generation time:

(3)

In practice, explicit sampling is most frequently associated with discrete data like characters in a string. For continuous data types, decoder’s output is often interpreted as the mean and used directly.

2.2 Convolutional MNIST VAE setup

In intentional contrast with Chou & Hathi (2019), we apply its ideas to a simple case with off-the-shelf models. We use the original 60000:10000 training:testing split of the MNIST dataset (LeCun & Cortes, 2010) with only the minimum preprocessing of converting the integers to floats and scaling them down by a factor 255 to make sure that they fall in the range of [0, 1]. The architecture of the convolutional VAE is taken from the TensorFlow 2.0 convolutional VAE example (accessed 2019-04-16), with 2 covolutional layers + 1 fully connected layer for the encoder and 1 fully connected layer + 3 convolution layers for the decoder. Cross-entropy loss is used for the reconstruction loss since both the input and reconstruction give pixel values in the range of [0, 1], and the latent loss is the usual KL-divergence. The model is trained with the Adam optimizer (Kingma & Ba, 2014) with constant learning rate

. To make sure that epochs line up, we use batch size = 50 and run a testing step every 6 training steps. We also run a ‘generation’ step every 6 training steps that generates 50 samples and measure their total VAE loss as if they were testing data. For generation, we just sample from the 32-dim unit Gaussian prior

and feed the random latent vectors to the decoder. The training budget is fixed at 5 epochs = 6000 steps, and we report the average losses over the last epoch for each experiment. The value of is again reported as the relative weight of the average KL-divergence loss per latent dimension to the average cross-entropy loss per pixel value.

2.3 Fréchet Inception Distance and p-value

In order to quantify generation quality of the MNIST VAE, we use the widely adopted Fréchet Inception Distance (FID) (Heusel et al., 2017) computed with the TensorFlow-GAN library’s default graph (accessed 2019-04-17)

. With the simplifying assumption that logit activations of a trained classifier follow multivariate Gaussian distribution

when given training examples and when given generated samples, FID is the Fréchet distance (also known as Wasserstein-2 distance) between these two distributions:

(4)

With the same assumption, we can also quantify how ‘typical’ a generated sample is by p-value, which can be obtained by applying the multivariate version of the two-tailed t test to its logit activations

(Chou & Hathi, 2019):

where is the Mahalanobis distance squared,

is the cumulative distribution function (CDF) of chi-squared distribution with

degrees of freedom. Intuitively, p-value is the probability that the given is more likely than a sample drawn from the distribution and measures how close is to the mode of the distribution. As we will see later, however, our simplifying assumption breaks down when we try to quantify how typical a generated sample is by p-value due to a fundamental discrepancy: logit activations of a MNIST classifier follow a 10-modal distribution instead of an unimodal multivariate Gaussian distribution. Therefore, the distance from to the mean/mode of the distribution bears no relation to how typical the generated sample is as a handwritten digit. To remedy this, we make the weaker assumption that logit activations of a trained classifier follow multivariate Gaussian distribution when given training examples of a given class , and calculate the conditional p-value of a sample’s logit activations instead:

where , i.e. the label given by the classifier. To gather meaningful statistics, we always evaluate FID and aggregate properties of p-values over 10000 samples.

3 Convolutional -VAE result

With the setup described in Sec 2.2, here we report the losses of the baseline -VAE models over the full hyperparameter sweep of (Fig 1), and 10 generated samples for each model annotated with its value (Fig 2).

Figure 1: Training, testing, and generated loss over the value of , averaged over the last epoch. Values of FID were calculated over 10000 samples generated by the final model.
Figure 2: The value of and 10 generated samples for each model.

We again observe no overfitting, with training loss practically identical to the testing loss. Generated loss lags behind in the -VAE regime of the hyperparameter space, but not as severely as in Chou & Hathi (2019). In the case of image VAE, both the encoder and decoder are typically continuous functions that map the pixel values to the latent vector and back. Perhaps the continuity and the magnitudes of the derivatives of the functions limit how different generated loss can be from the training/testing loss. As expected, training/testing loss decreases monotonically as we lower the weight of the latent loss, but the value of generated loss starts rising again as the -VAE transitions towards a plain autoencoder. In the hyperparameter regime of , generated samples are increasingly unconstrained by the training examples, and the model is simply not trained to autoencode them. On the other end, as the value of increases past the point of posterior collapse, the model starts to hedge its bet and always outputs a weighted average of the training examples, so the values of training/testing loss and generated loss converge. In this sense, the value of generated loss does delimit different regimes of the hyperparameter space. However, it is not as sensitive to the model’s generation quality as FID: the value of where FID is at the minimum is a plausible optimal value in terms of subjective assessment of Fig 2, but the value of where generated loss is at the minimum is clearly too low. We notice that the value of generated loss is far more sensitive to the value of for a shallower MNIST VAE built with only fully-connected layers (Appendix C), so the sensitivity of generated loss may depend on the absence of built-in assumptions about the data (translational invariance in this case) in the model.

With generated loss measured, we then test the related hypothesis that VAE reconstruction from the mean latent vector tends to be a more typical example of its class. Intuitively, if this hypothesis is true, we expect reconstruction loss of generated samples (which are more likely to be atypical) to stay elevated and contribute to elevated generated loss. We again test this hypothesis by repeated encoding and decoding using the mean latent vector. Subjectively, both generated samples and training examples do seem to converge to ‘typical’, textbook examples of handwritten digits (Fig 3 and 4). To quantify how ‘typical’ generated samples and training examples are over repeated encoding and decoding, we turn to p-values as described in Sec 2.3. The unimodal, unconditional p-value results are puzzling (Appendix A), until one realizes that classifier logit activations actually follow a 10-modal distribution and computes the conditional p-values instead (Fig 5). The conditional p-values over repeated encoding and decoding follow the familiar pattern that both generated samples and training examples converge to the mode of the distribution and that training examples converge faster than the generated samples. Disappointingly, augmented training does not seem to be able to exploit this and improve the generation quality of the described convolutional VAE (Appendix B).

Figure 3: 10 samples generated by the model over repeated encoding and decoding using the mean latent vector, annotated with the number of repetitions.
Figure 4: 10 training examples over repeated encoding and decoding by the model using the mean latent vector, annotated with the number of repetitions.
Figure 5: Box plot of conditional p-values over repeated encoding and decoding. The ‘generated’ sequence at repetition = corresponds to the p-value distribution of 10000 generated samples, and the ‘reconstructed’ sequence at repetition = corresponds to that of 10000 randomly selected training examples using the model.

4 Conclusion and discussion

We applied the idea of generated loss and augmented training to MNIST VAEs and obtained mixed results. More specifically,

  1. The total VAE loss of generated samples (generated loss) lags behind the training/testing loss but does not seem to reflect the model’s generation quality, especially for convolutional models.

  2. For both convolutional and fully-connected MNIST VAEs (Appendix C), repeated encoding and decoding using the mean latent vector again lead to more typical examples of their class.

  3. Disappointingly, augmented training is not able to exploit 2. and improve the generation quality of the given MNIST VAEs.

In contrast to Chou & Hathi (2019), generated loss does not seem to reflect the model’s generation quality and augmented training does not seem to help. For continuous VAEs such as MNIST VAE, the encoder and decoder are both continuous functions carefully designed such that the derivatives are bounded. In particular, models like convolutional VAE already have built-in assumptions about the training data (translational invariance in this case). Both factors seem to allow convolutional VAE to generalize better and limit the difference between training/testing loss and generated loss. With less difference to exploit, augmented training is also no longer able to improve the model’s generation quality and perhaps more sophisticated approach is necessary. The observation that repeated encoding and decoding using the mean latent vector lead to more typical examples of the class still holds, however, both qualitatively and quantitatively. It may well be an inherent property of the VAE framework useful for clustering purpose a la mean-shift clustering, given that it is tied to the variational method of minimizing reconstruction loss to the given training example from the surrounding latent space.

References

Appendix A Unconditional p-values over repeated encoding and decoding

Figure 6: Box plot of (unconditional) p-values over repeated encoding and decoding. The ‘generated’ sequence at repetition = corresponds to the p-value distribution of 10000 generated samples, and the ‘reconstructed’ sequence at repetition = corresponds to that of 10000 randomly selected training examples using the model. In retrospect, maybe we should expect this result: Compared to the modes of the conditional distributions, logit activations of the training examples follow the true distribution and tend to be further away from the unconditional global mean , whereas logit activations of the samples generated by drawing from the assumed prior tend to be closer to the unconditional global mean . Over repeated encoding and decoding, both converge to the modes of the conditional distributions but from opposite directions.

Appendix B Augmented training

Augmented training is motivated by the persistent gap between generated loss and training/testing loss (Chou & Hathi, 2019). By augmenting the training data with generated variants, it makes sure that information from the training examples propagates beyond their respective posterior distributions. To recap, it features the following training scheme:

center After gen_start_step steps: Initialize augmented latent vectors with sampled latent vectors of the current training batch. Augment next training batch with variants generated from the augmented latent vectors. After a training step, each augmented latent vector is replaced with either: The sampled latent vector of an example from the current training batch, selected without replacement, with probability . The sampled latent vector of the variant generated from it with probability . Repeat from 2.

Augmented training extends the standard VAE training scheme. Right after training to minimize reconstruction loss from the sampled latent vector, we actually generate a reconstruction from it and augments the next training batch with the reconstruction. The next training step will encode this reconstruction into its own posterior distribution in the latent space and minimize the reconstruction loss from the augmented latent vector to this reconstruction, and so on. The similarity of the successive reconstructions will decay over repeated encoding / decoding due to the model’s capacity limit and the noise introduced by latent vector sampling, so we re-initialize it with probability such that the average lifetime is steps. We only start augmented training after gen_start_step steps to make sure that the model is already trained to generate reasonable reconstructions, and controls the number of augmented latent vectors we use. Formally, we train the model with reconstructions from the following sequence in addition to the training examples :

In terms of objective function, we have

(5)

Assuming that is equal to the training batch size, as is the case for our experiments. Disappointingly, augmented training does not seem to improve the generation quality of the described convolutional VAE, either in terms of FID or subjective evaluation (Fig 7 and 8, both taken from the best experiments so far). Conditional p-value statistics indicate that augmented training does not help the model generate more ‘typical’ examples of the class, either (Fig 9 and 10). One may notice that generated loss is not significantly lowered by augmented training, which may serve as an indication of the case that augmented training does not help.

Figure 7: Training, testing, and generated loss over the value of , averaged over the last epoch. Values of FID were calculated over 10000 samples generated by the final model (gen_start_step = 2400, ).
Figure 8: The value of and 10 generated samples for each model (gen_start_step = 2400, ).
Figure 9: Mean conditional p-values calculated over 10000 samples generated by the baseline convolutional models vs. the augmented training convolutional models (gen_start_step = 2400, ).
Figure 10: Median conditional p-values calculated over 10000 samples generated by the baseline convolutional models vs. the augmented training convolutional models (gen_start_step = 2400, ).

Appendix C Fully-connected MNIST VAE

The fully-connected MNIST VAE architecture is taken from the TensorFlow 2.0 Keras example (accessed 2019-04-16). Both the encoder and decoder consist of two fully-connected layers, with 64 intermediate dimensions and 32 latent dimensions. The first layer of the encoder is a ReLU layer, which leads to a split second layer that uses a linear network to generate the mean vector and another

linear network to generate the log-variance vector. The usual reparameterization trick is then used to generate the sampled latent vector from the diagonal-covariance Gaussian. The decoder takes the sampled latent vector and applies a

ReLU layer, followed by a sigmoid layer to generate floats in the range of [0, 1]. The rest of the setup remains the same as described in Sec 2.2.

With the setup described above, here we report the losses of the fully-connected -VAE models over the full hyperparameter sweep of (Fig 11), and 10 generated samples for each model annotated with its value (Fig 12).

Figure 11: Training, testing, and generated loss over the value of , averaged over the last epoch. Regimes of the hyperparameter space that approach plain autoencoder, -VAE, and posterior collapse are distinct and clearly visible.
Figure 12: The value of and 10 generated samples for each model.

We can see that generated loss decreases as the value of decreases over a much narrower range for the fully-connected VAE before it rises again as the model approaches a plain autoencoder in comparison to the deeper convolutional equivalent (Fig 1). Subjectively, however, model with still does not seem to be the best generative model (Fig 12). The fully-connected MNIST VAE overall is a worse generative model so it is not as obvious subjectively, but the observation that repeated encoding and decoding lead to more typical examples of the class still holds (Fig 1314, and 15). The results of applying augmented training to this fully-connected MNIST VAE are again negative (Fig 161718, and 19) in terms of quantitative measures.

Figure 13: 10 samples generated by the model over repeated encoding and decoding using the mean latent vector, annotated with the number of repetitions.
Figure 14: 10 training examples over repeated encoding and decoding by the model using the mean latent vector, annotated with the number of repetitions.
Figure 15: Box plot of conditional p-values over repeated encoding and decoding using the model.
Figure 16: Training, testing, and generated loss over the value of , averaged over the last epoch. Values of FID were calculated over 10000 samples generated by the final model (gen_start_step = 2400, ).
Figure 17: The value of and 10 generated samples for each model (gen_start_step = 2400, ).
Figure 18: Mean conditional p-values calculated over 10000 samples generated by the baseline fully-connected models vs. the augmented training fully-connected models (gen_start_step = 2400, ).
Figure 19: Median conditional p-values calculated over 10000 samples generated by the baseline fully-connected models vs. the augmented training fully-connected models (gen_start_step = 2400, ).