Adversarially Learned Mixture Model

07/14/2018 ∙ by Andrew Jesson, et al. ∙ Imagia Cybernetics Inc. 0

The Adversarially Learned Mixture Model (AMM) is a generative model for unsupervised or semi-supervised data clustering. The AMM is the first adversarially optimized method to model the conditional dependence between inferred continuous and categorical latent variables. Experiments on the MNIST and SVHN datasets show that the AMM allows for semantic separation of complex data when little or no labeled data is available. The AMM achieves a state-of-the-art unsupervised clustering error rate of 2.86 dataset. A semi-supervised extension of the AMM yields competitive results on the SVHN dataset.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 6

page 7

page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Semi-supervised or unsupervised representation learning enables the utilization of all available data when tackling problems where there are little or no labeled examples. This is a common scenario in many applications of machine learning, such as medical image analysis, where it is reinforced by the expense of obtaining expert labeled examples. Moreover, machine-learned representations are more likely to be used for subsequent tasks if they are interpretable and meaningful. Deep generative modelling is a suitable approach to this problem, as derived models have been shown capable of learning from both labeled and unlabeled examples, embedding data according to desired latent variable distributions, and producing realistic data examples generated from samples of those latent variables.

The Generative Adversarial Network (GAN) has recently emerged as a powerful framework for modeling complex data distributions without having to approximate intractable likelihoods. In the formulation by Goodfellow et al. (2014), a GAN consists of two networks: a generator that is trained to yield unique samples from the data distribution, and a discriminator that is trained to distinguish between generated and true data samples.

Dumoulin et al. (2016) and Donahue et al. (2017) have proposed the ALI and BiGAN models that add an inference process, i.e., the ability to map data samples to points in the latent space, to the GAN framework. A second generator for inference, or encoder, is added to the original GAN generator and the discriminator is adapted for the two-dimensional space of data inputs and latent representations. A variant of the resulting model is also introduced by Dumoulin et al. (2016)

for conditional data generation, but still assumes that the class of the data is always observed, as inference of categorical variables is not included.

Adversarial approaches for the inference of both continuous and categorical variables are actively researched. Chen et al. (2016) introduce a hybrid adversarial method that is capable of modelling both continuous and categorical latent variables for unsupervised clustering and feature disentanglement. Another hybrid adversarial method is introduced by Makhzani et al. (2016) where adversarial objectives on continuous and categorical latent variables are optimized for unlabeled examples and categorical cross entropy on categorical variables is optimized for labeled examples. Li et al. (2017) and Deng et al. (2017) point toward fully adversarial semi-supervised classification using inferred categorical variables by introducing a “three player” adversarial game, but stop short by adding auxiliary “collaborative” objectives. In each of these methods, it is assumed that categorical and continuous latent variables are independently distributed. This independence assumption results in discontinuities in the latent space between categories, which removes the notion of inter-categorical proximity.

Another notable family of generative models, Variational Autoencoders (VAEs), maximize the posterior distribution of latent representations given the data instead of using an adversarial approach. As VAEs integrate inference, semi-supervised classification can be performed by conditioning the continuous latent variable of the VAE on the class label

(Kingma et al., 2014; Dilokthanakul et al., 2016; Maaløe et al., 2017). However, the quality of VAE results depend on the expressiveness of the inference distribution and every time the assumptions about the inference or data distributions are changed a new objective function needs to be derived. In this way, variational optimization is not as versatile as adversarial training.

We present the Adversarially Learned Mixture Model (AMM). The AMM is, to our knowledge, the first generative model inferring both continuous and categorical latent variables to perform either unsupervised or semi-supervised clustering of data using a single adversarial objective. This is enabled, in part, by explicitly modelling the dependence between continuous and categorical latent variables, which eliminates discontinuities between categories in the latent space. Semi-supervised clustering and classification is enabled by a simplified formulation of the “three player game”, presented by Li et al. (2017). In this paper we show that the AMM achieves state of the art unsupervised clustering error rate on the MNIST dataset (LeCun & Cortes, 2010), and that it achieves competitive results for semi-supervised classification on the SVHN dataset (Netzer et al., 2011).

2 Method

2.1 Preliminaries

The ALI and BiGAN models are trained by matching two joint distributions of images

and their latent code . The two distributions to be matched are the inference distribution and the synthesis distribution , where,

(1)
(2)

Samples of are drawn from the training data and samples of are drawn from a prior distribution, usually . Samples from and

are drawn from neural networks that are optimized during training.

Dumoulin et al. (2016) show that sampling from is possible by employing the reparametrization trick (Kingma & Welling, 2013), i.e. computing

(3)

where

is element wise vector multiplication.

A conditional variant of ALI has also been explored by Dumoulin et al. (2016) where an observed class-conditional categorical variable has been introduced. The joint factorization of each distribution to be matched are:

(4)
(5)

Samples of are drawn from the data. Samples of are drawn from a continuous prior on , and samples of are drawn from a categorical prior on , both of which are marginally independent. Samples from and are drawn from neural networks that are optimized during training.

In the following sections we present graphical models for and that build off of conditional ALI. Where conditional ALI requires the full observation of categorical variables, the models we present will account for both unobserved and partially observed categorical variables. We finally show how they can be optimized using a single adversarial objective.

2.2 Adversarially Learned Mixture Model

The AMM is an adversarial generative model for deep unsupervised clustering of data. Figure 1 presents an overview of the model.

Figure 1: Overview of the unsupervised (AMM) and semi-supervised (SAMM) model with the first option (Equation (6)) for the inference distribution. AMM consists of two generators, encoder and decoder , and a discriminator . SAMM includes an additional generator for labeled data, .

Like conditional ALI, a categorical variable is introduced to model the labels. However, the unsupervised setting now requires a different factorization of the inference distribution in order to enable inference of the categorical variable , namely:

(6)

or

(7)

Samples of are drawn from the training data, and samples from , or , are generated by neural networks. We follow Kendall & Gal (2017) and sample from by computing

(8)
(9)

Then, we can sample from by computing

(10)

A similar sampling strategy can be used to sample from in (7).

The factorization of the synthesis distribution also differs from conditional ALI:

(11)

The product can be conveniently given by a mixture model. Samples from are drawn from a multinomial prior, and samples from are drawn from a continuous prior, for example, . Samples from can alternatively be generated by a neural network by again employing the reparameterization trick. Namely,

(12)

This approach effectively learns the parameters of .

2.2.1 Adversarial Value Function

We follow Dumoulin et al. (2016) and define the value function that describes the unsupervised game between the discriminator and the generator as:

(13)

There are four generators in total: two for the encoder and , which map the data samples to the latent space; and two for the decoder and , which map samples from the prior to the input space. can either be a learned function, or be specified by a known prior. See Algorithm 1 for a detailed description of the optimization procedure.

Initialize AMM parameters
while not done do
      Sample from data and priors
     
     
      Sample from conditionals
     
     
      Compute discriminator predictions
     
      Compute discriminator losses
      Compute x generator losses
      Compute y and z generator loss
      Update discriminator parameters
      Update generator parameters
     
     
     
Algorithm 1 AMM training procedure using distributions (6) and (11).

2.3 Semi-Supervised Adversarially Learned Mixture Model

The Semi-Supervised Adversarially Learned Mixture Model (SAMM) is an adversarial generative model for supervised or semi-supervised clustering and classification of data. The objective for training SAMM involves two adversarial games to match pairs of joint distributions. The supervised game matches inference distribution (4) to synthesis distribution (11) and is described by the following value function:

(14)

The unsupervised game matches either of the inference distributions, (6) or (7) to the synthesis distribution (11). In the case using distribution (6), the unsupervised game is described by (13

). The generator for semi-supervised learning has three components: encoders

and map the labeled and unlabeled data samples, respectively, to the latent space, and a decoder maps samples of and to the input space, where can either be a learned function or be specified by a prior. The encoder for labeled data again consists of two generators (Figure 1). A detailed description of the training algorithm is given in algorithm 2 of the appendix. In practice, optimization of each of the generators and the discriminator can be done simultaneously for both the unsupervised and semi-supervised updates.

3 Related Works

Unsupervised clustering using hybrid adversarial approaches are proposed by both Makhzani et al. (2016) (AAE) and Chen et al. (2016) (InfoGAN). For AAE, the synthesis generator is optimized by minimizing the per-example L2 loss between between input data and their reconstructions }, while the inference generator is optimized using both the L2 objective and an adversarial objective. For InfoGAN, the inference generator is optimized by maximizing the per-example Mutual Information (MI) between samples of categorical latent variables and continuous latent variables and their “reconstructions” }, while the synthesis generator is optimized using both the MI objective and an adversarial objective.

On the other end of the generative spectrum, Dilokthanakul et al. (2016) and Jiang et al. (2016) offer non-adversarial, VAE-based approaches for unsupervised clustering. Like in the AMM, the combination of priors for the latent variables and

is modeled as a Gaussian mixture model, where

corresponds to the mixture components.

Multiple adversarial methodologies have been proposed for supervised or semi-supervised learning (Springenberg, 2015; Salimans et al., 2016; Miyato et al., 2017), but they suffer from the same limitation as the original GAN: they do not provide inference. Gan et al. (2017), Li et al. (2017) and Deng et al. (2017) introduce a third player to the adversarial game. Although this extra player allows to infer categorical variables, these approaches are not fully adversarial as auxiliary “collaborative” terms are added to the objective function. Moreover, categorical and continuous latent variables are modeled independently.

The adversarial and hybrid-adversarial approaches thus far discussed all model and as being conditionally independent from each other. This may be an ideal prior structure for inference, for example, in learning disentangled representations of sampled from a limited domain (Chen et al., 2016). However, the independence assumption cannot account for the notion of proximity between categories because is identically distributed for each category in . Therefore, the distance between categories is equal and indeterminate. AMM and SAMM are presented as adversarial approaches to model conditional dependencies between and , but they do not preclude the independence assumption. The proposed methods can model and as conditionally independent with inference distribution

(15)

and synthesis distribution

(16)

however, analysis of this graphical model is left for future work.

4 Evaluation

AMM and SAMM are evaluated using two image datasets: MNIST (LeCun & Cortes, 2010) and SVHN (Netzer et al., 2011). The provided training and testing splits are used for MNIST experiments with 5000 randomly selected examples left out of the training set for validation. The same training, testing, and validation splits as Dumoulin et al. (2016) are used for SVHN. Preprocessing is limited to scaling image intensities on the range . Detailed architectures for each experiment are shown in figure 6 of the appendix. We optimize all networks using Adam (Kingma & Ba, 2014) with and

. All kernel weights are initialized using a Gaussian distribution with standard deviation 0.02, all biases are initialized to 0.0.

4.1 Gradient Penalty

The gradient penalty introduced by Gulrajani et al. (2017) is added to the discriminator loss to help stabilize training of AMM and SAMM models. This penalty keeps the gradients of the discriminator with respect to the inputs , , and on the same order of magnitude. The penalty applied to the discriminator loss is

(17)

where points are drawn at random on straight lines between real or prior samples and synthesized or inferred samples . The gradient penalty for Jensen-Shannon GAN introduced by Roth et al. (2017) has also been explored, but did not produce better results. The regularization term is set to , and for MNIST and SVHN experiments, respectively.

4.2 Mnist

In this section, the AMM is evaluated on the task of unsupervised clustering of hand-drawn digits using the MNIST dataset. To model , a 10-component, 64 dimensional mixture of Gaussians is used. A multinomial prior is used for

with uniform probability for each class. The means of the component distributions are learned using the reparameterization trick via (

12

), and the variance for each distribution is fixed to unit value. Table

1 reports the test-set clustering error-rate mean and variance over 5 trials. The AMM achieves percent error rate, which is an improvement over the state-of-the-art. Figure 2 shows visualizations of results from 1 of the 5 trials.

(a)
(b) Cluster matrix
(c) Reconstruction
(d) Interpolation
(e) t-SNE projection
Figure 2: Unsupervised clustering of MNIST data with 10 mixture components. (a) Comparing test image membership and randomly generated digits for each mixture component. (b) Cluster matrix: rows correspond to true test labels, and columns correspond to component membership. (c) Reconstructions of input images: original data on the left of each pair. (d) Interpolation between examples: original data samples are shown in the first and last columns with linearly interpolated generations between. (e) t-SNE projection of testing samples, color-coded for the MNIST class labels (0 to 9).
Model MNIST
CatGAN (Springenberg, 2015)
VaDE (Jiang et al., 2016)
InfoGAN (Chen et al., 2016)
AAE (Makhzani et al., 2016)
AMM
Table 1: Test set clustering error rate and standard deviation for MNIST data.

4.3 Svhn

4.3.1 Unsupervised Clustering

In this section, unsupervised clustering is revisited. The SVHN dataset is used to investigate how the introduction of confounding attributes, such as color and contrast, affects the semantic separation of digits. To model a 32 dimensional mixture of 18 spherical, unit variance, Gaussians is used. A multinomial prior is used for with uniform probability for each class The means of each distribution are regularly spaced at intervals of 6 units from -6 to 6 along the first two dimensions and from -3 to 3 along the third dimension. The trailing 29 dimensions are set to 0 for each mean.

Figure 2(a) shows random samples drawn from each component distribution generated by . We can see four distinct groupings based on the global features of SVHN examples. The top row and last three columns of the bottom row show images with dark backgrounds with light numbers. The middle row and first three columns of the last row show images with light backgrounds and dark numbers. Looking closer at the top two rows we see a nearly symmetric clustering based on number. For example, in the first column we see clusters corresponding to zero, in the second column we see clusters corresponding to one, and in all of the main groupings we see clusters with numbers two and seven together. The clusters that combine twos and sevens are reflected by the orange and green groupings in figure 2(b), which is a t-SNE projection of testing samples drawn from onto a 2D manifold. We show in 2(c) that AMM learns a smooth latent manifold as we interpolate between examples from SVHN.

(a) Randomly generated images
(b) t-SNE projection
(c) Interpolation
Figure 3: Unsupervised clustering of SVHN data with 18 mixture components. (a) Randomly generated images for each mixture component. The color-boxes delineate four groups of clusters (Rows 1, 2, 3 (Left) and 3 (Right)) with shared global characteristics. (b) t-SNE projection of testing samples, color-coded for the SVHN class label (0 to 9). (c) Interpolation between examples: original data samples (Columns i) and v)), associated reconstructions (Columns ii) and iv)), linearly interpolated reconstructions (Columns iii)).

4.3.2 Semi-supervised clustering and classification

It is evident from the last experiment that the confounders introduced by the SVHN dataset made unsupervised semantic clustering more difficult. In this section we show how SAMM can be used to guide clustering along predefined categories using only a small amount of labeled data. To this end we limit the samples drawn from to a random selection of 1000 examples from the training set. To model we use a 64 dimensional mixture of 10 spherical Gaussians, each with unit variance. In placing the means of each distribution, we take advantage of our prior knowledge of the task. For example, from figure 1(e), we can see that nines are closer to fours than they are to zeros, and reflect these assumptions in designing . There is considerable class imbalance in the SVHN dataset so a multinomial prior is used for with each class probability set to the frequency observed in the training data. The placement of each mean within the continuous latent manifold is shown in table 3 of the appendix. We also run this experiment allowing the ’s to be learned using equation (12).

Table 2 reports the test-set error-rate mean and variance over 10 trials. SAMM achieves percent error rate with the fixed means, and when the means are learned, which is an improvement over the ALI baseline. Figure 4 shows visualizations of results from 1 of the 10 trials. Finally, given that we have defined

we can use Bayes’ theorem to derive

and get a classifier given an image embedding

:

(18)

Figures 3(e) and 3(f) compare the confusion matrices for predictions given by and those given by from . The similarity between each is further evidence that the inference network has learned to embed data according to the desired distribution.

(a) Test images
(b) Random
(c) Interpolation
(d) t-SNE projection
(e)
(f)
Figure 4: Semi-supervised clustering and classification of SVHN data with 10 mixture components. (a) Test image predictions: each row corresponds to the predicted class. (b) Randomly generated images for each mixture component. (c) Interpolation between examples: original data samples in first and last columns. (d)

t-SNE projection of testing samples, color-coded for the SVHN class label (0 to 9). Confusion matrix for predictions given an image embedding

(e) and given the generator (e).
Model SVHN
()
AAE (Makhzani et al., 2016)
ImprovedGAN (Salimans et al., 2016)
ALI (Dumoulin et al., 2016)
TripleGAN (Li et al., 2017)
SGAN (Deng et al., 2017)
SAMM
SAMM Learned
Table 2: Semi-supervised test set missclassification rate and standard deviation for SVHN data.

5 Conclusion

The AMM is presented as a generative model for unsupervised or semi-supervised data clustering. It is the first adversarially optimized method to model the conditional dependence between categorical and continuous latent variables. The AMM achieves state of the art unsupervised clustering results and competitive semi-supervised classification results on benchmark datasets.

References

Appendix A SAMM Algorithm

Algorithm 2 outlines the SAMM training procedure.

Initialize SAMM parameters
while not done do
      Sample from unlabeled data and priors
     
     
      Sample from conditionals
     
     
      Sample from labeled data and priors
     
     
      Sample from conditionals
     
      Compute predictions for unlabeled data
     
      Compute predictions for labeled data
     
      Compute discriminator losses
     
      Compute inference losses
     
      Compute generator losses
     
      Update discriminator parameters
      Update inference parameters
      Update inference parameters
      Update synthesis parameters
      Update synthesis parameters
Algorithm 2 SAMM training procedure using distributions (4), (6), and (11).

Appendix B Experiment Information

b.1 Model Architectures

Figures 5 and 6 detail the model architectures for the SVHN and MNIST experiments, respectively.

(a) SVHN:
(b) SVHN:
(c) SVHN:
(d) SVHN:
Figure 5: Model architecture for SVHN
(a) MNIST:
(b) MNIST:
(c) MNIST:
(d) MNIST:
Figure 6: Model architecture for MNIST

b.2 Mean Placement

The placement of each mean for the fixed mean semi-supervised SVHN experiment is shown in table 3

Mean
-3 3 -3 -3 0
-3 -3 3 3 0
-3 3 3 -3 0
3 -3 -3 -3 0
-3 -3 3 -3 0
3 -3 3 -3 0
3 3 3 -3 0
-3 3 3 3 0
3 3 -3 -3 0
-3 -3 -3 -3 0
Table 3: SVHN Semi-Supervised: Placement of means for