The variational auto-encoder (VAE) Kingma and Welling (2013)
is a class of hierarchical Bayesian models based on transforming data from a simple latent prior into a complicated observation distribution using a neural network. The model density is represented by the marginal distribution., where denotes the model parameters. The data log-likelihood is in general intractable to compute, so VAEs maximize the so-called evidence lower bound (ELBO).
where denotes an approximate posterior that maps observations to a distribution over latent variables, and where denotes the mapping from latent codes to observations. While VAEs have been successfully applied in a number of settings Germain et al. (2015); Klys et al. (2018); Kingma et al. (2016); van den Berg et al. (2018), they suffer from a shortcoming known as posterior collapse Bowman et al. (2015); Chen et al. (2016); Razavi et al. (2019); van den Oord et al. (2016, 2017). Effectively, the penalty on the prior term in Eq. (1) causes the encoder to lose information in some dimensions of the latent code, leading the decoder to ignore those dimensions. The result is a model that generates good samples, but suffers from a poor latent representation. This is an issue for any downstream application that relies on .
Here we present a related but distinct framework called Mutual Information Machine (MIM). Our objectives are twofold: to promote symmetry in the encoding and decoding distributions (i.e., consistency), and to encourage high mutual information between and (i.e., good representation). By symmetry, we mean that the encoding and decoding distributions should represent two equivalent factorizations of the same underlying joint distribution. The framework can be seen as a symmetric analogue of VAEs, where we optimize a symmetric divergence instead of the asymmetric KL. In preliminary experiments, we find that MIM produces highly informative representations, with comparable sample quality to VAEs.
2 Mutual Information Machine
To begin, it is helpful to write down the VAE loss function as the KL divergence between two joint distributions over and , as in Pu et al. (2017) i.e., up to an additive constant the following holds.
This loss could be symmetrized by using the symmetric KL divergence, however this would result in a term that cannot be evaluated, as this distribution does not typically have an analytical form. In Pu et al. (2017), minimization of the symmetric KL loss is expressed as a stationary point of an adversarial min-max objective.
Here we propose to optimize a bound on a regularized version of the Jensen-Shannon divergence (JSD). Defining the mixture distribution , the JSD is defined as
A second principle of our formulation involves encouraging a representation with high mutual information. Mutual information is related to joint entropy by the identity . Since is tractable up to a constant under both the encoding and decoding distributions, we add the regularizer111We use and to notate entropy and mutual information under a distribution to , which can be shown to be equal to , the entropy of the mixture distribution .
This formulation allows us to derive a tractable (in terms of being amenable to optimization via reparameterization) bound on the objective. First we define parameterized joint distributions, , and . Where can be parameterized for added flexibility, or set to . is an observation model which could be e.g., a normalizing flow Dinh et al. (2017); Kingma and Dhariwal (2018) or an auto-regressive model Van den Oord et al. (2016). For a fair comparison with VAEs, we leverage the encoder and decoder distributions , to avoid adding additional parameters. Details are given in the supplementary material.
Defining , the bound can be expressed as,
where is the cross entropy between distributions and . While the first inequality no longer depends on and is therefore tractable, the second inequality, which follows from Jensen’s inequality, can be shown to additionally encourage consistency in (i.e., ). Further, when and , it can be shown that is equal to the symmetric KL loss plus the regularizer . See supplementary material for more details.
We have therefore derived a loss function that encourages both symmetry in the encoding and decoding distributions, as well as high mutual information in the learned representation. Further, this loss function can be directly minimized without requiring an adversarial reformulation. In the experiments, we will show preliminary results exploring the properties of MIM.
3 Experiment: 2D Mixture Model Data
VAE (odd columns) and MIM (even columns) with 2D inputs, a 2D latent space, and 5, 20 and 500 hidden units. Top: Black contours are level sets of data distribution, red points are reconstructed samples drawn from the decoder , where for data point from
. Bottom: The dashed black circle depicts one standard deviation of. Each green curve depicts a one standard deviation ellipse of the encoder posterior
. (a) For weak architectures MIM and VAE exhibit high posterior variance. (b,c) For more expressive architectures the VAE predictive variance remains high, an indication of posterior collapse. MIM generally produces lower predictive variance and lower reconstruction errors, consistent with high mutual information.
We begin with a dataset of 2D observations
drawn from a Gaussian mixture model, and a 2D latent space,. In 2D we can easily visualize the model and measure quantitative properties of interest (e.g., mutual information). (Complete experiment details given in supplementary material).
Figure 1 depicts results for the VAE (even columns) and MIM (odd columns) using single-layer encoder and decoder networks, with increasing numbers of hidden units (moving left to right) to control model expressiveness. The top row (for VAE and MIM respectively) depicts observation space. With each case we also report the mutual information and the root-mean-squared observation reconstruction error when sampling the predictive encoder/decoder distributions, with MIM showing a superior performance. See additional results in the supplementary material (Fig. 3).
The bottom row of Fig. 1 depicts the latent space behavior. For the weakest architecture, with only 5 hidden units, both MIM and VAE posteriors have large variances. When the number of hidden units increases, however, it is clear that while the VAE posterior variance remains very large in one dimension (i.e., a common sign of posterior collapse), the MIM encoder produces much tighter posteriors densities, which capture the global (i.e., aggregated) structure of the observations.
In addition, with a more expressive architecture, i.e., more hidden units, the MIM encoding variance is extremely small, and the reconstruction error approaches 0. In effect, the encoder and decoder learn an (approximately) invertible mapping using an unconstrained architecture (demonstrated here in 2D), when the dimensionality of the latent representation and the observations is the same.
The VAE, by comparison, is prone to posterior collapse, reflected in relatively low mutual information. In this regard, we note that several papers have described ways to mitigate posterior collapse in VAE learning, e.g., by lower bounding or annealing the KL divergence term in the VAE objective (e.g., (Alemi et al., 2017; Razavi et al., 2019)), or by limiting the expressiveness of the decoder (e.g., (Chen et al., 2016)). We posit that MIM does not suffer from this problem as a consequence of the objective design principles that encourage high mutual information between observations and the latent representation.
4 Experiment: MIM Representations with High Dimensional Image Data
Here we explore learning on higher dimensional image data, where we cannot accurately estimate mutual informationBelghazi et al. (2018). Instead, following Hjelm et al. (2019), we focus on an auxiliary classification task as a proxy for the quality of the learned representation and on qualitative visualization of it. We experiment with MNIST LeCun et al. (1998), and Fashion MNIST Xiao et al. (2017). In what follows we also explore multiple architectures of VAE models from Tomczak and Welling (2017), and the corresponding MIM models (see Algorithm. 1 in the supplementary material), where, again we use VAE as the baseline. See supplementary material for the details.
For the auxiliary transfer learning classification task we opted for K-NN classification, being a non-parametric method which represents the clustering in the latent representation without any additional training. We show quantitative results in Table1 for K-NN classification (). We also present the corresponding qualitative visual clustering results (i.e., projection to 2D using t-SNE van der Maaten and Hinton (2008)) in Fig. 2. Here, it is clear that MIM learning tends to cluster classes in the latent representation better than VAE, for an identical parameterization of a model.
- An information-theoretic analysis of deep latent-variable models. CoRR abs/1711.00464. External Links: Cited by: §3.
- MINE: Mutual information neural estimation. In ICML, Cited by: Appendix D, §4.
- Bidirectional helmholtz machines. CoRR abs/1506.03877. External Links: Cited by: §D.1.
- Generating sentences from a continuous space. CoRR abs/1511.06349. External Links: Cited by: §1.
- Variational lossy autoencoder. CoRR abs/1611.02731. External Links: Cited by: §1, §3.
- Density estimation using real nvp. International Conference on Learning Representations. Cited by: §2.
- Demystifying fixed k-nearest neighbor information estimators. CoRR abs/1604.03006. External Links: Cited by: Appendix D.
- MADE: Masked autoencoder for distribution estimation. In ICML, Cited by: §1.
- Learning deep representations by mutual information estimation and maximization. In International Conference on Learning Representations, Cited by: §4.
- Auto-Encoding Variational Bayes. In ICLR, Cited by: Appendix A, Appendix C, §1.
- Adam: A Method for Stochastic Optimization. arXiv e-prints, pp. arXiv:1412.6980. External Links: Cited by: §D.1.
- Improving variational inference with inverse autoregressive flow. In NIPS, Cited by: §E.2, §1.
- Glow: generative flow with invertible 1x1 convolutions. In Advances in Neural Information Processing Systems, pp. 10215–10224. Cited by: §2.
- Learning latent subspaces in variational autoencoders. In NIPS, Cited by: §1.
- Estimating mutual information. Phys. Rev. E 69, pp. 066138. External Links: Cited by: Appendix D.
- Gradient-based learning applied to document recognition. Proc. IEEE 86 (11), pp. 2278–2324. Cited by: §4.
- Parallel wavenet: fast high-fidelity speech synthesis. arXiv preprint arXiv:1711.10433. Cited by: Appendix A.
- Adversarial symmetric variational autoencoder. In Advances in Neural Information Processing Systems, pp. 4330–4339. Cited by: §2.
- Preventing posterior collapse with delta-vaes. CoRR abs/1901.03416. External Links: Cited by: §1, §3.
Stochastic backpropagation and approximate inference in deep generative Models. In ICML, Cited by: Appendix A, Appendix C.
- VAE with a vampprior. CoRR abs/1705.07120. External Links: Cited by: §D.2, §4.
- Sylvester normalizing flows for variational inference. arXiv:1803.05649. Cited by: §1.
- Conditional image generation with pixelcnn decoders. In Advances in neural information processing systems, pp. 4790–4798. Cited by: §2.
- . CoRR abs/1601.06759. External Links: Cited by: §1.
- Neural discrete representation learning. CoRR abs/1711.00937. External Links: Cited by: §1.
Visualizing high-dimensional data using t-sne.
Journal of Machine Learning Research9, pp. 2579–2605. Cited by: §4.
- Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. CoRR abs/1708.07747. External Links: Cited by: §4.
Appendix A Detail Derivation of MIM Learning
Here we provide a detailed derivation of the loss of MIM learning, as defined in (4). We would like to formulate a loss function which includes (3) that reflects our desire for model symmetry and high mutual information. This objective is difficult to optimize directly since we do not know how to evaluate in the general case (i.e., we do not have an exact closed-form expression for ). As a consequence, we introduce parameterized approximate priors, and , to derive tractable bounds on the penalized Jensen-Shannon divergence. This is similar in spirit to VAEs, which introduce a parameterized approximate posterior. These parameterized priors, together with the conditional encoder and decoder, and , comprise a new pair of joint distributions,
These new joint distributions allow us to formulate a new, tractable loss that bounds :
where denotes the cross-entropy between and , and
In what follows we refer to as the cross-entropy loss. It aims to match the model prior distributions to the anchors, while also minimizing . The main advantage of this formulation is that the cross-entropy loss can be trained by Monte Carlo sampling from the anchor distributions with the reparameterization trick [Kingma and Welling, 2013, Rezende et al., 2014].
At this stage it might seem odd to introduce a parametric prior for . Indeed, setting it directly is certainly an option. Nevertheless, in order to achieve consistency between and it can be advantageous to allow to vary. Essentially, we trade-off latent prior fidelity for increased model consistency.
One issue with is that, while it will try to enforce consistency between the model and the anchored distributions, i.e., and , it will not directly try to achieve model consistency: . To remedy this, we bound using Jensen’s inequality, i.e.,
Equation (7) gives us the loss function for the Mutual Information Machine (MIM). It is an average of cross entropy terms between the mixture distribution and the model encoding and decoding distributions respectively. To see that this encourages model consistency, it can be shown that is equivalent to plus a non-negative model consistency regularizer; i.e.,
The non-negativity of is a simple consequence of in (8).
In what follows we derive the form of the MIM consistency regularizer in Eq. (4), named . Recall that we define . We can show that is equivalent to plus a regularizer by taking their difference.
where is non-negative, and is zero only when the two joint model distributions, and , are identical under fair samples from the joint sample distribution . To prove that we now construct Equation (10) in terms of expectation over a joint distribution, which yields
where the inequality follows Jensen’s inequality, and equality holds only when (i.e., encoding and decoding distributions are consistent). In practice we find that encouraging model consistency also helps stabilize learning.
To understand the MIM objective in greater depth, we find it helpful to express as a sum of fundamental terms that provide some intuition for its expected behavior. In particular, as derived in the supplementary material:
The first term in (11), as discussed above, encourages high mutual information between observations and latent states. The second term shows that MIM directly encourages the model prior distributions to match the anchor distributions. Indeed, the KL term between the data anchor and the model prior is the maximum likelihood objective. The third term encourages consistency between the model distributions and the anchored distributions, in effect fitting the model decoder to samples drawn from the anchored encoder (cf. VAE), and, via symmetry, fitting the model encoder to samples drawn from the anchored decoder (both with reparameterization). In this view, MIM can be seen as simultaneously training and distilling a model distribution over the data into a latent variable model. The idea of distilling density models has been used in other domains, e.g., for parallelizing auto-regressive models [Oord et al., 2017].
In summary, the MIM loss can be viewed as an upper bound on the entropy of a particular mixture distribution :
Through the MIM loss and the introduction of the parameterized model distribution , we are pushing down on the entropy of the anchored mixture distribution , which is the sum of marginal entropies minus the mutual information. Minimizing the MIM bound yields consistency of the model encoder and decoder, and high mutual information of between observations and latent states.
Appendix B MIM in terms of Symmetric KL Divergence
As given in Equation (2), the VAE objective can be expressed as minimizing the KL divergence between the joint anchored encoding and anchored decoding distributions (i.e., which jointly defines the sample distribution ). Here we refer to and as anchors which are given externally and are not learned. Below we consider a model formulation using the symmetric KL divergence (SKL),
the second term of which is the VAE objective.
In what follows we explore the relation between SKL, JSD, and MIM. Recall that the JSD is written as,
Using the identity , we can express the JSD in terms of entropy and cross entropy.
Using Jensen’s inequality, we can bound from above,
If we add the regularizer and combine terms, we get
When the model priors and are equal to the fixed priors and , this regularized SKL and MIM are equivalent. In general, however, the MIM loss is not a bound on the regularized SKL.
In what follows, we derive the exact relationship between JSD and SKL.
which gives the exact relation between JSD and SKL.
Appendix C Parameterizing and for fair comparison with VAEs
In the MIM framework, there is flexibility in the choice of and . To facilitate a direct comparison with VAEs, we must be careful to keep the architectures consistent and not introduce additional model parameters. For simplicity and a fair comparison, we set and leave other considerations for future work. For , we consider two different approaches that leverage the decoder distribution . The first is to consider the marginal decoding distribution,
Which we approximate by drawing one sample from when we need to evaluate . This can suffer from high variance if the prior is far from the true posterior.
The other is to consider an importance sampling estimate, for which we use the encoder distribution ,
Where once again, we approximate using one sample, this time from the encoder distribution, and multiply by the importance weight . Samples of are drawn using the reparameterization trick Kingma and Welling , Rezende et al.  in order to allow for gradient-based training. We utilize (14) when sampling from the decoding distribution during the training of a MIM model, and (15) when sampling from the encoding distribution.
Appendix D Experimentation Details
Following Belghazi et al. , we estimate mutual information using the KSG mutual information estimator Kraskov et al. , Gao et al. , based on a K-NN neighborhoods with , and measure the quality of the representation with classification axuliary task.
The learning algorithm is described in Algorithm 1. In what follows we describe in details the experimental setup, architecture, and training procedure for the experiments that were presented in the paper.
d.1 2D Mixture Model Data
In all experiments we use Adam optimizer Kingma and Ba  with
, and mini-batch of size 128. We stopped training for all experiments when validation loss has not improved for 10 epochs.
Data are drawn from a Gaussian mixture model with five isotropic components with standard deviation 0.25, and the latent anchor,
, is an isotropic standard normal distribution. The encoder and decoder are conditional Gaussian distributions, where the means and variances of which are regressed from the input using two fully connected layers andtanhactivation function. Following Bornschein et al. , the parameterized data prior, , is defined to be the marginal of the decoding distribution, or explicitly , where the only model parameters are those of the encoder and decoder, and the encoding distribution is defined to be consistent with the decoding distribution . As such we can learn models with MIM and VAE objective that share the same architectures and parameterizations.
d.2 Representation Learning with MIM in High Dimensional Image Data
We experiment with convHVAE (L = 2) model from Tomczak and Welling , with Standard (S) prior which is the usual Normal distributions, and VampPrior (VP) prior which define the prior as a mixture model of the encoder conditioned on learnable pseudo-inputs , or explicitly . In all the experiments we used the same setup that was used in Tomczak and Welling , and with the same latent dimensionality . By doing so we aim to highlight the generality of MIM learning as being architecture independent, and to provide examples for the training procedure of existing VAE architectures with MIM learning.
Appendix E Additional Results
e.1 2D Mixture Model Data
|(a) MI||(b) NLL||(c) Recon. Error||(d) Classif. (5-NN)|
Test performance for MIM (blue) and VAE (red) for 2D GMM experiment, all as functions of the number of hidden units (on x-axis), based on 10 learned models in each case. From left to right, plots show mutual information, log marginal probability of test points, reconstruction error, and k-NN classification performance.
Here we quantify the complete experimental results that were presented in Fig. 1. We plot the mutual information, the average log marginal of test points under the model , the reconstruction error of test points, and 5-NN classification (predicting which of five GMM components the test points were drawn from).
e.2 Representation Learning with MIM in High Dimensional Image Data
|(a) VAE (VP)||(b) A-MIM (VP)|
|(a) VAE (VP)||(b) A-MIM (VP)|
Training times of MIM models are comparable to training times for VAEs with comparable architectures. The principal difference will be the time required for sampling from the decoder during training. For certain models, such as auto-regressive decoders Kingma et al. , this can be significant. In such cases (i.e., PixelHVAE here), we find that we can also learn the model by changing the sampling distribution to only include samples from the encoding distribution. By using asymmetric sampling , where we sample from the encoding distribution only (i.e., similar to VAE), training time is comparable to VAE. We name that model A-MIM.
Here we show qualitative results for the most expressive model, PixelHVAE (VP). Figures (4, 5), depict reconstruction, and sampling for Fashion-MNIST, and MNIST, correspondingly. The top three rows of each of the plots depicts data samples, VAE reconstruction, and A-MIM reconstruction, respectively. The bottom row depicts samples. The results demonstrate comparable samples and reconstruction for MIM and VAE.