infoCatVAE
Github page for the preprint paper "InfoCatVAE: Representation Learning with Categorical Variational Autoencoders"
view repo
This paper describes InfoCatVAE, an extension of the variational autoencoder that enables unsupervised disentangled representation learning. InfoCatVAE uses multimodal distributions for the prior and the inference network and then maximizes the evidence lower bound objective (ELBO). We connect the new ELBO derived for our model with a natural soft clustering objective which explains the robustness of our approach. We then adapt the InfoGANs method to our setting in order to maximize the mutual information between the categorical code and the generated inputs and obtain an improved model.
READ FULL TEXT VIEW PDFGithub page for the preprint paper "InfoCatVAE: Representation Learning with Categorical Variational Autoencoders"
Neural networks are today a state of the art solution for many machine learning tasks. In particular, they show impressive results in specific tasks that require capturing complex features in data. Nevertheless, most of the neutral network success is associated with globally unavailable large labelled data-sets. To be more generally applicable, machine learning needs new unsupervised methods to leverage the largely available unlabelled training data. One role of unsupervised learning is capturing rich distributions of complex data. More specifically among unsupervised tasks, representation learning seek to expose semantically meaningful factors hidden in data. Recently generative neural network models have become a highly successful framework for this problem. In particular, variational autoencoders (VAEs) kingma2013auto and generative adversarial networks (GANs) goodfellow2014generative are major representation learning frameworks.
To be fully relevant for representation learning, unsupervised generative models have to encode the observable data in an informative space bengio2013representation . For example, VAEs are traditionally built with isotropic Gaussian latent distribution that theoretically should learn one specific and exclusive latent semantically meaningful factor of variation per dimension. This idea is part of the generic definition of the , assuming that any structured visible object has a factorial representation ridgeway2016survey . Recently, several extensions of VAEs with richer prior distributions have been proposed to obtain more powerful generative representation models tomczak2017vae ; dilokthanakul2016deep ; nalisnick2016approximate ; van2014factoring ; goyal2017nonparametric . As shown in hoffman2016elbo one main difficulty associated to these models is the choice of the prior.
We propose an extension of the standard VAE framework with multimodal distributions for the prior as well as for the inference network. We derive the evidence lower bound objective (ELBO) for this new model with categorical prior to define Categorical VAE (CatVAE). We show that thanks to these modifications, even with a simple fixed prior, CatVAE is able to find salient attributes within the data leading to readable representations. Moreover, CatVAE can be used as a conditional generative model. Finally, we present an improved extension of the model to enhance disentanglement and the quality of the generated samples. More generally, this work aims at showing that simple and principled modifications can have interesting disentangling power without the need of specific heuristics or highly fine-tuned neural networks architectures.
Variational autoencoding is today the core framework in which disentanglement is investigated zhao2017infovae ; higgins2016beta ; chen2016variational ; makhzani2015adversarial ; esmaeili2018hierarchical ; chen2018isolating . A major drawback known in VAE framework comes from the VAE objective function, the evidence lower bound (ELBO), that has a form inducing uninformative and unstable representation power.
Among the numerous studies that have been proposed about disentangled representation learning, a majority is based on a regularization along already existing objective functions for different generative models makhzani2015adversarial ; zhao2017infovae ; higgins2016beta ; chen2016infogan ; gao2018auto ; kim2018disentangling ; burgess2018understanding ; esmaeili2018hierarchical ; chen2018isolating . Other studies focus on the clustering aspect of disentangling latent representation by exploring the effect of mixture distribution for latent representation in variational autoencoders tomczak2017vae ; dilokthanakul2016deep ; nalisnick2016approximate ; van2014factoring . The mixture of distributions has the advantage of empowering the latent distribution to represent more complex variations in the data. Moreover, in term of disentangling, a mixture implicitly assumes that there is a hierarchical importance among meaningful factors of variations: each mode represents a major latent factor and local isotropic distributions hold minor disentangling power. Nevertheless, the mixture models hold several limits that depends on the chosen form of the prior and the form of the objective function.
Our work follows the footsteps of the Mixture of Gaussian based models presented in nalisnick2016approximate ; dilokthanakul2016deep . We propose not to use free prior parameterization via neural network transformation but to fix the prior parameters like in standard VAE framework (section 7) but we also modify the inference network and rederive the ELBO of our model. Finally we improve our model with explicit information maximization.
Consider a latent variable model with a data variable and a latent variable , . Given the data , we want to train the model by maximizing the marginal log-likelihood:
(1) |
where denotes the empirical distribution of : .
To avoid the (often) difficult computation of the integral in (1), the idea behind variational methods is to instead maximize a lower bound to the log-likelihood (called ELBO):
(2) |
Any choice of gives a valid lower bound. Variational autoencoders replace the variational posterior by an inference network that is trained together with to jointly maximize . The variational posterior is also called the encoder and the generative model , the decoder or generator.
It will be convenient to decompose the optimization done in the VAEs in two steps:
(3) | |||
(4) |
As already noted in the literature, the VAE objective is insufficient for representation learning. This is particularly true in the so-called high capacity regime, where is close to achievable. In this case, a strategy for the VAE is for the encoder to match the prior , while the decoder outputs the sample distribution without using the latent code . In such a case, the latent code is independent from the data, hence useless for representation theory.
This is clearly due to the fact that in (1), only the marginal appears and the only way VAEs enforce information in the latent code representation is by limiting the optimization in (3) to a constrained class of decoders . In higgins2016beta , to solve this problem, the authors modify the ELBO by putting more weight to the term (4) through the introduction of a new parameter . This approach is extended in dupont2018jointvae where this parameter is modified during the training (like an annealing procedure).
We propose to reuse the base block of semi-supervised method found in kingma2014semi to enforce information in the latent code representation. We call it CatVAE (for Categorical VAE). This architecture will be improved in the next section to obtain InfoCatVAE. CatVAE consists in the three following modifications to the VAE:
We modify the latent variable model as follows: where is a discrete latent variable. In other words, the prior of the VAEs is replaced by a prior . For instance, if the data are images from the MNIST dataset, then would encode the numerical identity of the digit (0-9) and then similarly as in standard VAEs, we would take a prior for a well-chosen .
We define the inference network by .
We modify the objective (4) as follows:
(5) |
A representation of CatVAE architecture is illustrated in figure 1. Note that our proposition is not a VAE with a prior given by a mixture of Gaussians but still CatVAE maximizes a variational lower bound of the marginal log-likelihood as shown by the following proposition:
We have .
See appendix A ∎
Also CatVAE still optimizes the (log-)marginal , it enforces information in the latent code thanks to the fixed prior which should be chosen appropriately, see Section 7
. The modifications proposed for CatVAE raise some difficulties due to the categorical variable
(it is well known that the reparametrization trick which is used for VAEs cannot be directly applied to discrete variables). We will explain how we overcome these difficulties in Section 6.In order to have more intuition about the term (5), we consider the case where
is the uniform distribution over
, and . In this case, the first term ishence maximizing it is the same as maximizing the entropy of the distributin , i.e. each category should be evenly represented. Then the second term becomes
For a given , if is close to one, i.e. if the category associated to
is with high probability the
’s category, then the variance
should be close to one and the mean close to , in order to maximize this term. In summary, the optimization step (5) can be interpreted as a soft clustering step in where is the dimension of the code and with clusters. Each data point is mapped to thanks to the function . The centers of the clusters are given by the means for, the classifier
is then a soft allocation of the data to the cluster and the functions are updated to get closer to the center of the cluster. The entropy term prevents the trivial solution where all points are mapped to one cluster and ensures that data points are evenly distributed among the different clusters. Note that step (5) tends to concentrate each cluster and this step will be mitigated by the reconstruction step (3). It is also important to note a main difference comparing to the other approaches in the literature described below. To the best of our knowledge, our CatVAE is the only architecture where all the distances between the code and all clusters are explicitly computed and used in the backpropagation algorithm. This is in contrast with standard approaches where typically first a classifier determines the cluster and then only the distance between the code and the selected cluster is computed. We believe that this specificity in our approach explains the robustness of CatVAE. Of course, our approach will be problematic if the number of clusters is very large.
Note that once the CatVAE has been trained, it can be used as a generative model: a data point of the category can be generated by first sampling a code with the distribution and then passing it through the encoder . The center of the cluster is mapped to .
We noted that in InfoGANs, the classifier is trained on generated inputs only and this already provide impressive results. In our CatVAE model, the classifier is only trained on real data. But since our CatVAE can be turned into a generative model, we can also use the generated inputs to improve our classifier. By analogy with InfoGAN, this can be easily done by modifying the step (5) as follows:
(6) | |||
(7) |
We call this model the InfoCatVAE. Indeed the same computation as in chen2016infogan can be carried out here and the additional term (7) is a variational lower bound of mutual information between the category and the generated inupt (where has been sampled with distribution given by the prior )
We find illustrative experiments of the advantage of using information maximization in Section 10. The additional part of InfoCatVAE over CatVAE is illustrated in figure 2.
In practice, the three terms in (7) are respectively multiplied by scalar factor , and higgins2016beta . These new parameters act as trade-off between reconstruction, latent distribution matching and preservation of salient information within latent space.
The reparametrization trick developed in kingma2013auto has no natural equivalence for multinomial sampling. A method called Gumbel max trick gumbel2012statistics is widely used in machine learning to overcome this problem jang2016categorical ; dupont2018jointvae . In InfoCatVAE learning, we propose to overpass this problem not by using Gumbel max trick but with an alternative two-step method that is naturally induced by our model.
First, as seen above, in the inference network, for all ’s and all data points . Each is then represented times in the latent space, and each representation is weighted by the probability . This first step enables back-propagation by still keeping the spirit of a uniform discrete sampling conditioned on the data.
Second, the information maximization step presented in section 5 optimizes the same bricks than inference learning, but with a free generative approach. This second step then propose a framework to optimize the InfoCatVAE network with a real categorical sampling, that does not block the back-propagation since the sampling is the initial layer.
Therefore, by construction, infoCatVAE naturally enables categorical sampling optimization.
As presented in Section 4, CatVAE requires the choice of the parameters of the prior distribution . Our intuition is that the prior should be fixed such that it fills the same objective than the isotropic structure of standard VAE. We remind that the isotropic Gaussian latent distribution should theoretically learn one specific and exclusive latent semantically meaningful factor of variation per dimension.
We choose the dimensions of the latent space such that . We consider that the data with categories should be encoded with a -modal distribution approximated with Gaussian whose mean parameters lives in and are all respectively orthogonal. De facto we propose:
(8) |
The main idea behind this choice of prior is that each major categories within the data should be mainly represented within a dimensional subspace of the latent space. This framework forces the model to learn quasi-independently each fundamental class structure. This way, the interpretability of the latent representation should be optimized.
Generative adversarial networks (GANs) train generative models through an objective function that implements a two-player zero sum game between a discriminator and a generator . That is
maps random vectors
to generated inputs and we assume to predict the probability of example being present in the dataset: . But GANs share with standard VAEs the absence of restrictions on the manner generator should use the noise. This way, there are no insurance that latent representation would be disentangled.InfoGANs propose to decompose the input noise vector into two parts: which is treated as source of incompressible noise and which will represent salient semantic features of the data distribution. InfoGANs also introduce a variational posterior . To highlight the similarities between our framework and InfoGANs, we use parameter for the generator and for the variational posterior and parameter for the discriminator
. With these notations, the minimax game with a variational regularization of mutual information and hyperparameter
solved by InfoGANS can be written as:(9) | |||
(10) |
Step (9) updates the generator as well as the classifier while step (10) updates the discriminator .
We refer to hu2017unifying for an in-depth comparison of VAEs and GANs. As explained above, VAEs can be seen as generative models and similarly GANs can produce embeddings from the data. Indeed, the discriminator is a neural network with one final fully connected layer to output the boolean parameter. A natural encoder is then given by the discriminator with the last fully connected layer removed. Note however that the discriminator is trained to detect generated sample from real data so that the features that will be kept by the discriminator are those helping into the discimination task. It is not clear a priori that those features will be the most readable one as expected in representation learning. In practice, for InfoGANs, and share all layers except the last one, so that the encoder described above for GANs will still work for InfoGANs and keep informations about the categories.
Note also that the classifier of the InfoGAN is trained only on generated inputs, whereas the classifier of the CatVAE is trained on real data but with ’noisy’ labels. We will present in Section 5 an extension of CatVAE building on this remark in order to improve the performance of the classifier.
Adversarial autoencoders (AAEs) build on standard (i.e. non variational) autoencoders with a deterministic encoding function denoted here mapping each input to a code and a generative process . The regularization in AAEs is done thanks to an adversarial network matching the prior we want to impose on the code with the aggregated posterior distribution . We denote by the discriminator of the adversarial network so that updates can now be written as:
(11) | |||
(12) |
Step (11) updates the encoder and the decoder in order to minimize the reconstruction loss (first term in (11)), as well as the encoder in order to confuse the discriminator (second term in (11)). Step (12) is the classical update of the discriminator.
As in the modification from GAN to InfoGAN, structure can be imposed onto the prior by using a code with a distribution and typically is a categorical distribution. Also the encoder now generates both a category and a continuous code . We see that the first component of plays the role of a classifier which can be used for unsupervised clustering.
Also the losses are not the same as ours, AAEs is very similar to our CatVAE. However, we note that AAEs inforce only a certain type of prior on the code . Indeed, the category
is one-hot encoded which corresponds in our CatVAE to orthogonals means
for the distributions of our CatVAE. Our CatVAE allows us to be more flexible for the choice of the priors.In this section, we aim at illustrating that InfoCatVAE enables readable representation and controlled generation. Therefore, we first illustrate our work with MNIST and FashionMNIST data with trivial multilayer perceptron architecture (see table
2 in appendix B)Figure 3 and figure 4 illustrate the accomplishment of the InfoCatVAE on readable representation task.
Discrete interpolation between the prior centroïds
when InfoCatVAE is trained with andrespectively on MNIST (left) and FashionMNIST (right). The left columns show the ten centroids represented in the observable space. Each line represent the reconstruction of nine latent values that pad the path between the ten centroids.
We reimplemented the Adversarial Autoencoder makhzani2015adversarial with the same multilayer perceptron encoder and decoder as presented in appendix B (table 2) to compare our sampling capacity with comparable state-of-the-art variational autoencoder sampling.
Model | MNIST LL score |
---|---|
Adversarial autoencoder | 95.5 |
CatVAE | 111.2 |
InfoCatVAE | 113.8 |
Log-likelihood of the 10K generated samples from different generative models trained on MNIST with 600 epochs. Higher values are better. The density function is estimated by the KernelDensity function of scikit-learn
pedregosa2011scikit , whose bandwidth parameter has been estimated via grid search with 5 folds cross-validation over the 60K training examples. Each model has the same encoder and decoder architecture than presented in appendix B.The table 1 illustrates the fact that CatVAE and InfoCatVAE are more adapted for sampling task than AAE when architecture parameters are reduced to their simplest form, despite the necessity for CatVAE and InfoCatVAE to learn disentangled representation.
Finally, like in InfoGAN, (7) can be approximated via Monte Carlo simulation. We generate 10K samples from 10K discrete labels sampled from multinomial distribution and we compute the cross-entropy labels and inferred classes of the samples. This framework enables easy optimization of the mutual information by error gradient back-propagation. In term of result, CatVAE we find a cross-entropy of 2.03 and for InfoCatVAE a cross-entropy of 1.62. Therefore, mutual information between generated samples and categories is improved with InfoCatVAE.
This section aim at showing that our choice of prior brings robustness in high capacity regime.
Our work proposes in addition to generative information maximization the fixing of a trivial multimodal prior distribution (see section 4 and 7). This framework stands between multimodal free prior learning framework nalisnick2016approximate ; dilokthanakul2016deep and unimodal fixed prior dupont2018jointvae . In particular, we have the intuition that for high-capacity regime tasks, relaxing the fixed prior makes the learning unstable.
An illustrative task is the representation learning of multivariate time series using recurrent encoder and recurrent decoder. The particularity of sequential data representation is the necessity to encode the temporal dimension in a non-temporal space. In the particular case of unsupervised representation learning, the complexity of the objective task associated to the gaps of recurrent modeling generally demands more exotic models to achieve acceptable representative and generative power boulanger2012modeling ; graves2013generating ; bowman2015generating . The rich multi-scaled structure of sequential data naturally induces the necessity to get a hierarchical representation. The readable hierarchical approach for sequential data has been treated recently chung2016hierarchical ; hsu2017unsupervised ; li2018deep ; liang2017recurrent . All these particularities of sequential data gives an example of complex task that can unveil the limits of models that work well simple static data representation.
For the illustrative experiment, we continue using MNIST data but by considering each image as a 28-dimensional time series of length 28. This way the trivial structure of shared by similar digits collapses. To do so, we simply take our InfoCatVAE model with fixed prior (as described in section 7) and recurrent encoder and decoder, and compare it to the same model with free priors as in dilokthanakul2016deep .
The figure 5 illustrates the kind of result that we obtain for multivariate sequential MNIST. It shows that the association of fixed multimodal prior framework and information generation maximization bring robustness and interpretability power to our model. The details of the implementation are given in appendix C.
In this paper, we have introduced the CatVAE, a multimodal variational autoencoder with fixed multimodal prior. We have shown that this model can learn disentangled representation of data and a highly credible conditional generation framework. Moreover, we show that contrary to more complex and flexible models, CatVAE overpasses complex task without the need of specifically fine-tuned architectures. Finally, we show that we can extend the CatVAE to an information optimized version, the InfoCatVAE. This framework both enhance generation and categorical information learning.
Despite encouraging results, we largely accept that fixing the prior can pose information transfer and representation problems. The free prior framework has multiple concepts and for specific tasks, a well controlled free prior with a fine-tuned architecture might be more powerful then InfoCatVAE. But for a robust generalization, our method seems more appropriate.
As a future work, we could prefix the prior parameters with bayesian hyper parametrization nalisnick2016approximate ; nalisnick2016deep . This way, we would not have to completely free the prior parameters through neural nets during the learning and therefore keep the stability of the InfoCatVAE, while still improving the structure of the latent space.
Workshop in Advances in Approximate Bayesian Inference, NIPS
, 2016.NIPS Workshop on Bayesian Deep Learning
, volume 2, 2016.Factoring variations in natural images with deep gaussian mixture models.
In Advances in Neural Information Processing Systems, pages 3518–3526, 2014.from Jensen’s inequality |
Then with we get:
(13) |
Finally:
(14) |
For all experiments, the Adam optimizer is used with a learning rate of 1e-4. We chose and
discriminator D / encoder Q | decoder G |
---|---|
Input flattened Gray image | Input |
FC. + Dropout(0.25) + ReLU |
FC. + Dropout(0.25) + ReLU |
D: FC. + Softmax | |
: FC. / : FC. | FC. + Dropout(0.25) + Sigmoid |
For all experiments, the Adam optimizer is used with a learning rate of 1e-4. We chose and
discriminator D / encoder Q | decoder G |
---|---|
Input | Input |
bidirectional GRU. | FC. |
D: FC. + Softmax | GRU. + ReLU |
: FC. / : FC. | FC. + Sigmoid |
The 28-dimensional MNIST sequences are considered as real-valued times series. Therefore, the loss function is the sum of the mean squared error of each time step.
Comments
There are no comments yet.