Semi-supervised or unsupervised representation learning enables the utilization of all available data when tackling problems where there are little or no labeled examples. This is a common scenario in many applications of machine learning, such as medical image analysis, where it is reinforced by the expense of obtaining expert labeled examples. Moreover, machine-learned representations are more likely to be used for subsequent tasks if they are interpretable and meaningful. Deep generative modelling is a suitable approach to this problem, as derived models have been shown capable of learning from both labeled and unlabeled examples, embedding data according to desired latent variable distributions, and producing realistic data examples generated from samples of those latent variables.
The Generative Adversarial Network (GAN) has recently emerged as a powerful framework for modeling complex data distributions without having to approximate intractable likelihoods. In the formulation by Goodfellow et al. (2014), a GAN consists of two networks: a generator that is trained to yield unique samples from the data distribution, and a discriminator that is trained to distinguish between generated and true data samples.
Dumoulin et al. (2016) and Donahue et al. (2017) have proposed the ALI and BiGAN models that add an inference process, i.e., the ability to map data samples to points in the latent space, to the GAN framework. A second generator for inference, or encoder, is added to the original GAN generator and the discriminator is adapted for the two-dimensional space of data inputs and latent representations. A variant of the resulting model is also introduced by Dumoulin et al. (2016)
for conditional data generation, but still assumes that the class of the data is always observed, as inference of categorical variables is not included.
Adversarial approaches for the inference of both continuous and categorical variables are actively researched. Chen et al. (2016) introduce a hybrid adversarial method that is capable of modelling both continuous and categorical latent variables for unsupervised clustering and feature disentanglement. Another hybrid adversarial method is introduced by Makhzani et al. (2016) where adversarial objectives on continuous and categorical latent variables are optimized for unlabeled examples and categorical cross entropy on categorical variables is optimized for labeled examples. Li et al. (2017) and Deng et al. (2017) point toward fully adversarial semi-supervised classification using inferred categorical variables by introducing a “three player” adversarial game, but stop short by adding auxiliary “collaborative” objectives. In each of these methods, it is assumed that categorical and continuous latent variables are independently distributed. This independence assumption results in discontinuities in the latent space between categories, which removes the notion of inter-categorical proximity.
Another notable family of generative models, Variational Autoencoders (VAEs), maximize the posterior distribution of latent representations given the data instead of using an adversarial approach. As VAEs integrate inference, semi-supervised classification can be performed by conditioning the continuous latent variable of the VAE on the class label(Kingma et al., 2014; Dilokthanakul et al., 2016; Maaløe et al., 2017). However, the quality of VAE results depend on the expressiveness of the inference distribution and every time the assumptions about the inference or data distributions are changed a new objective function needs to be derived. In this way, variational optimization is not as versatile as adversarial training.
We present the Adversarially Learned Mixture Model (AMM). The AMM is, to our knowledge, the first generative model inferring both continuous and categorical latent variables to perform either unsupervised or semi-supervised clustering of data using a single adversarial objective. This is enabled, in part, by explicitly modelling the dependence between continuous and categorical latent variables, which eliminates discontinuities between categories in the latent space. Semi-supervised clustering and classification is enabled by a simplified formulation of the “three player game”, presented by Li et al. (2017). In this paper we show that the AMM achieves state of the art unsupervised clustering error rate on the MNIST dataset (LeCun & Cortes, 2010), and that it achieves competitive results for semi-supervised classification on the SVHN dataset (Netzer et al., 2011).
The ALI and BiGAN models are trained by matching two joint distributions of imagesand their latent code . The two distributions to be matched are the inference distribution and the synthesis distribution , where,
Samples of are drawn from the training data and samples of are drawn from a prior distribution, usually . Samples from and
are drawn from neural networks that are optimized during training.Dumoulin et al. (2016) show that sampling from is possible by employing the reparametrization trick (Kingma & Welling, 2013), i.e. computing
is element wise vector multiplication.
A conditional variant of ALI has also been explored by Dumoulin et al. (2016) where an observed class-conditional categorical variable has been introduced. The joint factorization of each distribution to be matched are:
Samples of are drawn from the data. Samples of are drawn from a continuous prior on , and samples of are drawn from a categorical prior on , both of which are marginally independent. Samples from and are drawn from neural networks that are optimized during training.
In the following sections we present graphical models for and that build off of conditional ALI. Where conditional ALI requires the full observation of categorical variables, the models we present will account for both unobserved and partially observed categorical variables. We finally show how they can be optimized using a single adversarial objective.
2.2 Adversarially Learned Mixture Model
The AMM is an adversarial generative model for deep unsupervised clustering of data. Figure 1 presents an overview of the model.
Like conditional ALI, a categorical variable is introduced to model the labels. However, the unsupervised setting now requires a different factorization of the inference distribution in order to enable inference of the categorical variable , namely:
Samples of are drawn from the training data, and samples from , or , are generated by neural networks. We follow Kendall & Gal (2017) and sample from by computing
Then, we can sample from by computing
A similar sampling strategy can be used to sample from in (7).
The factorization of the synthesis distribution also differs from conditional ALI:
The product can be conveniently given by a mixture model. Samples from are drawn from a multinomial prior, and samples from are drawn from a continuous prior, for example, . Samples from can alternatively be generated by a neural network by again employing the reparameterization trick. Namely,
This approach effectively learns the parameters of .
2.2.1 Adversarial Value Function
We follow Dumoulin et al. (2016) and define the value function that describes the unsupervised game between the discriminator and the generator as:
There are four generators in total: two for the encoder and , which map the data samples to the latent space; and two for the decoder and , which map samples from the prior to the input space. can either be a learned function, or be specified by a known prior. See Algorithm 1 for a detailed description of the optimization procedure.
2.3 Semi-Supervised Adversarially Learned Mixture Model
The Semi-Supervised Adversarially Learned Mixture Model (SAMM) is an adversarial generative model for supervised or semi-supervised clustering and classification of data. The objective for training SAMM involves two adversarial games to match pairs of joint distributions. The supervised game matches inference distribution (4) to synthesis distribution (11) and is described by the following value function:
). The generator for semi-supervised learning has three components: encodersand map the labeled and unlabeled data samples, respectively, to the latent space, and a decoder maps samples of and to the input space, where can either be a learned function or be specified by a prior. The encoder for labeled data again consists of two generators (Figure 1). A detailed description of the training algorithm is given in algorithm 2 of the appendix. In practice, optimization of each of the generators and the discriminator can be done simultaneously for both the unsupervised and semi-supervised updates.
3 Related Works
Unsupervised clustering using hybrid adversarial approaches are proposed by both Makhzani et al. (2016) (AAE) and Chen et al. (2016) (InfoGAN). For AAE, the synthesis generator is optimized by minimizing the per-example L2 loss between between input data and their reconstructions }, while the inference generator is optimized using both the L2 objective and an adversarial objective. For InfoGAN, the inference generator is optimized by maximizing the per-example Mutual Information (MI) between samples of categorical latent variables and continuous latent variables and their “reconstructions” }, while the synthesis generator is optimized using both the MI objective and an adversarial objective.
On the other end of the generative spectrum, Dilokthanakul et al. (2016) and Jiang et al. (2016) offer non-adversarial, VAE-based approaches for unsupervised clustering. Like in the AMM, the combination of priors for the latent variables and
is modeled as a Gaussian mixture model, wherecorresponds to the mixture components.
Multiple adversarial methodologies have been proposed for supervised or semi-supervised learning (Springenberg, 2015; Salimans et al., 2016; Miyato et al., 2017), but they suffer from the same limitation as the original GAN: they do not provide inference. Gan et al. (2017), Li et al. (2017) and Deng et al. (2017) introduce a third player to the adversarial game. Although this extra player allows to infer categorical variables, these approaches are not fully adversarial as auxiliary “collaborative” terms are added to the objective function. Moreover, categorical and continuous latent variables are modeled independently.
The adversarial and hybrid-adversarial approaches thus far discussed all model and as being conditionally independent from each other. This may be an ideal prior structure for inference, for example, in learning disentangled representations of sampled from a limited domain (Chen et al., 2016). However, the independence assumption cannot account for the notion of proximity between categories because is identically distributed for each category in . Therefore, the distance between categories is equal and indeterminate. AMM and SAMM are presented as adversarial approaches to model conditional dependencies between and , but they do not preclude the independence assumption. The proposed methods can model and as conditionally independent with inference distribution
and synthesis distribution
however, analysis of this graphical model is left for future work.
AMM and SAMM are evaluated using two image datasets: MNIST (LeCun & Cortes, 2010) and SVHN (Netzer et al., 2011). The provided training and testing splits are used for MNIST experiments with 5000 randomly selected examples left out of the training set for validation. The same training, testing, and validation splits as Dumoulin et al. (2016) are used for SVHN. Preprocessing is limited to scaling image intensities on the range . Detailed architectures for each experiment are shown in figure 6 of the appendix. We optimize all networks using Adam (Kingma & Ba, 2014) with and
4.1 Gradient Penalty
The gradient penalty introduced by Gulrajani et al. (2017) is added to the discriminator loss to help stabilize training of AMM and SAMM models. This penalty keeps the gradients of the discriminator with respect to the inputs , , and on the same order of magnitude. The penalty applied to the discriminator loss is
where points are drawn at random on straight lines between real or prior samples and synthesized or inferred samples . The gradient penalty for Jensen-Shannon GAN introduced by Roth et al. (2017) has also been explored, but did not produce better results. The regularization term is set to , and for MNIST and SVHN experiments, respectively.
In this section, the AMM is evaluated on the task of unsupervised clustering of hand-drawn digits using the MNIST dataset. To model , a 10-component, 64 dimensional mixture of Gaussians is used. A multinomial prior is used for
with uniform probability for each class. The means of the component distributions are learned using the reparameterization trick via (12
), and the variance for each distribution is fixed to unit value. Table1 reports the test-set clustering error-rate mean and variance over 5 trials. The AMM achieves percent error rate, which is an improvement over the state-of-the-art. Figure 2 shows visualizations of results from 1 of the 5 trials.
4.3.1 Unsupervised Clustering
In this section, unsupervised clustering is revisited. The SVHN dataset is used to investigate how the introduction of confounding attributes, such as color and contrast, affects the semantic separation of digits. To model a 32 dimensional mixture of 18 spherical, unit variance, Gaussians is used. A multinomial prior is used for with uniform probability for each class The means of each distribution are regularly spaced at intervals of 6 units from -6 to 6 along the first two dimensions and from -3 to 3 along the third dimension. The trailing 29 dimensions are set to 0 for each mean.
Figure 2(a) shows random samples drawn from each component distribution generated by . We can see four distinct groupings based on the global features of SVHN examples. The top row and last three columns of the bottom row show images with dark backgrounds with light numbers. The middle row and first three columns of the last row show images with light backgrounds and dark numbers. Looking closer at the top two rows we see a nearly symmetric clustering based on number. For example, in the first column we see clusters corresponding to zero, in the second column we see clusters corresponding to one, and in all of the main groupings we see clusters with numbers two and seven together. The clusters that combine twos and sevens are reflected by the orange and green groupings in figure 2(b), which is a t-SNE projection of testing samples drawn from onto a 2D manifold. We show in 2(c) that AMM learns a smooth latent manifold as we interpolate between examples from SVHN.
4.3.2 Semi-supervised clustering and classification
It is evident from the last experiment that the confounders introduced by the SVHN dataset made unsupervised semantic clustering more difficult. In this section we show how SAMM can be used to guide clustering along predefined categories using only a small amount of labeled data. To this end we limit the samples drawn from to a random selection of 1000 examples from the training set. To model we use a 64 dimensional mixture of 10 spherical Gaussians, each with unit variance. In placing the means of each distribution, we take advantage of our prior knowledge of the task. For example, from figure 1(e), we can see that nines are closer to fours than they are to zeros, and reflect these assumptions in designing . There is considerable class imbalance in the SVHN dataset so a multinomial prior is used for with each class probability set to the frequency observed in the training data. The placement of each mean within the continuous latent manifold is shown in table 3 of the appendix. We also run this experiment allowing the ’s to be learned using equation (12).
Table 2 reports the test-set error-rate mean and variance over 10 trials. SAMM achieves percent error rate with the fixed means, and when the means are learned, which is an improvement over the ALI baseline. Figure 4 shows visualizations of results from 1 of the 10 trials. Finally, given that we have defined
we can use Bayes’ theorem to derive
and get a classifier given an image embedding:
Figures 3(e) and 3(f) compare the confusion matrices for predictions given by and those given by from . The similarity between each is further evidence that the inference network has learned to embed data according to the desired distribution.
t-SNE projection of testing samples, color-coded for the SVHN class label (0 to 9). Confusion matrix for predictions given an image embedding(e) and given the generator (e).
The AMM is presented as a generative model for unsupervised or semi-supervised data clustering. It is the first adversarially optimized method to model the conditional dependence between categorical and continuous latent variables. The AMM achieves state of the art unsupervised clustering results and competitive semi-supervised classification results on benchmark datasets.
- Chen et al. (2016) Xi Chen, Xi Chen, Yan Duan, Rein Houthooft, John Schulman, Ilya Sutskever, and Pieter Abbeel. InfoGAN: Interpretable representation learning by information maximizing generative adversarial nets. In Advances in Neural Information Processing Systems 29, pp. 2172–2180. 2016.
- Deng et al. (2017) Zhijie Deng, Hao Zhang, Xiaodan Liang, Luona Yang, Shizhen Xu, Jun Zhu, and Eric P. Xing. Structured generative adversarial networks. In Advances in Neural Information Processing Systems 30, pp. 3902–3912. 2017.
- Dilokthanakul et al. (2016) Nat Dilokthanakul, Pedro A. M. Mediano, Marta Garnelo, Matthew C. H. Lee, Hugh Salimbeni, Kai Arulkumaran, and Murray Shanahan. Deep unsupervised clustering with gaussian mixture variational autoencoders. arXiv preprint arXiv:1611.02648, 2016.
- Donahue et al. (2017) Jeff Donahue, Philipp Krähenbühl, and Trevor Darrell. Adversarial feature learning. In International Conference on Learning Representations, 2017.
- Dumoulin et al. (2016) Vincent Dumoulin, Ishmael Belghazi, Ben Poole, Alex Lamb, Martin Arjovsky, Olivier Mastropietro, and Aaron Courville. Adversarially learned inference. In International Conference on Learning Representations, 2016.
- Gan et al. (2017) Zhe Gan, Liqun Chen, Weiyao Wang, Yuchen Pu, Yizhe Zhang, and Lawrence Carin. Triangle generative adversarial networks. In Advances in Neural Information Processing Systems 30, pp. 5251–5260, 2017.
- Goodfellow et al. (2014) Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In Advances in Neural Information Processing Systems 27, pp. 2672–2680. 2014.
- Gulrajani et al. (2017) Ishaan Gulrajani, Faruk Ahmed, Martín Arjovsky, Vincent Dumoulin, and Aaron C. Courville. Improved training of wasserstein gans. arXiv preprint arXiv:1704.00028, 2017.
- Jiang et al. (2016) Zhuxi Jiang, Yin Zheng, et al. Variational deep embedding: An unsupervised and generative approach to clustering. arXiv preprint arXiv:1611.05148, 2016.
- Kendall & Gal (2017) Alex Kendall and Yarin Gal. In Advances in Neural Information Processing Systems 30, pp. 5580–5590. 2017.
- Kingma & Welling (2013) Diederik Kingma and Max Welling. Auto-encoding variational Bayes. In International Conference on Learning Representations, 2013.
- Kingma & Ba (2014) Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
- Kingma et al. (2014) Diederik P Kingma, Shakir Mohamed, Danilo Jimenez Rezende, and Max Welling. Semi-supervised learning with deep generative models. In Advances in Neural Information Processing Systems 27, pp. 3581–3589. 2014.
- LeCun & Cortes (2010) Yann LeCun and Corinna Cortes. MNIST handwritten digit database. 2010. URL http://yann.lecun.com/exdb/mnist/.
- Li et al. (2017) Chongxuan Li, Kun Xu, Jun Zhu, and Bo Zhang. Triple generative adversarial nets. arXiv preprint arXiv:1703.02291, 2017.
- Maaløe et al. (2017) Lars Maaløe, Marco Fraccaro, and Ole Winther. Semi-supervised generation with cluster-aware generative models. arXiv preprint arXiv:1704.00637, 2017.
- Makhzani et al. (2016) Alireza Makhzani, Jonathon Shlens, Navdeep Jaitly, and Ian Goodfellow. Adversarial autoencoders. In International Conference on Learning Representations, 2016.
- Miyato et al. (2017) Takeru Miyato, Shin-ichi Maeda, Masanori Koyama, and Shin Ishii. Virtual adversarial training: a regularization method for supervised and semi-supervised learning. arXiv preprint arXiv:1704.03976, 2017.
- Netzer et al. (2011) Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco, Bo Wu, and Andrew Y. Ng. Reading digits in natural images with unsupervised feature learning. In NIPS Workshop on Deep Learning and Unsupervised Feature Learning, 2011.
- Roth et al. (2017) Kevin Roth, Aurelien Lucchi, Sebastian Nowozin, and Thomas Hofmann. Stabilizing training of generative adversarial networks through regularization. In Advances in Neural Information Processing Systems 30, pp. 2015–2025. 2017.
- Salimans et al. (2016) Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, Xi Chen, and Xi Chen. Improved techniques for training gans. In Advances in Neural Information Processing Systems 29, pp. 2234–2242. 2016.
- Springenberg (2015) Jost Tobias Springenberg. Unsupervised and semi-supervised learning with categorical generative adversarial networks. arXiv preprint arXiv:1511.06390, 2015.
Appendix A SAMM Algorithm
Algorithm 2 outlines the SAMM training procedure.
Appendix B Experiment Information
b.1 Model Architectures
b.2 Mean Placement
The placement of each mean for the fixed mean semi-supervised SVHN experiment is shown in table 3