GAN-EM: GAN based EM learning framework

12/02/2018 ∙ by Wentian Zhao, et al. ∙ 0

Expectation maximization (EM) algorithm is to find maximum likelihood solution for models having latent variables. A typical example is Gaussian Mixture Model (GMM) which requires Gaussian assumption, however, natural images are highly non-Gaussian so that GMM cannot be applied to perform clustering task on pixel space. To overcome such limitation, we propose a GAN based EM learning framework that can maximize the likelihood of images and estimate the latent variables with only the constraint of L-Lipschitz continuity. We call this model GAN-EM, which is a framework for image clustering, semi-supervised classification and dimensionality reduction. In M-step, we design a novel loss function for discriminator of GAN to perform maximum likelihood estimation (MLE) on data with soft class label assignments. Specifically, a conditional generator captures data distribution for K classes, and a discriminator tells whether a sample is real or fake for each class. Since our model is unsupervised, the class label of real data is regarded as latent variable, which is estimated by an additional network (E-net) in E-step. The proposed GAN-EM achieves state-of-the-art clustering and semi-supervised classification results on MNIST, SVHN and CelebA, as well as comparable quality of generated images to other recently developed generative models.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 12

page 14

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Expectation maximization (EM) [5]

is a traditional learning framework, which has various applications in unsupervised learning. A typical example is Gaussian mixture model (GMM), where data distribution is estimated by maximum likelihood estimate (MLE) under the Gaussian assumption in M-step, and soft class labels are assigned using Bayes rule in E-step. Although GMM has the nice property that likelihood increases monotonically, many previous studies 

[4] [26] [25] have shown that natural image intensities exhibit highly non-Gaussian behaviors so that GMM cannot be applied to image clustering on pixel space directly. The motivation of this work is to find an alternative way to achieve EM mechanism without Gaussian assumption.

Generative adversarial network (GAN) [8] has been proved to be powerful on learning data distribution. We propose to apply it in M-step to maximize the likelihood of data with soft class label assignments passed from E-step. It is easy for GAN to perform MLE as in [7, 20], but to incorporate the soft class assignments into the GAN model in the meantime is rather difficult. To address this problem, we design a weighted binary cross entropy loss function for discriminator where the weights are the soft label assignments. In Sec. 4

, we prove that such design enables GAN to optimize the Q function of EM algorithm. Since neural networks are not reversible in most cases, we could not use Bayes rule to compute the expectation analytically in E-step like that for GMM. To solve this, we use generated samples to train another network, named E-net, then predict the soft class labels for real samples in E-step.

To evaluate our model, we perform the clustering task based on MNIST and achieve lowest error rate with all 3 different numbers of clusters: 10, 20 and 30, which are common settings in previous works. We also test the semi-supervised classification performance on MNIST and SVHN with partially labeled data, both results being rather competitive compared to recently proposed generative models. Especially, on SVHN dataset, GAN-EM outperforms all other models. Apart from the two commonly used datasets, we test our model on an additional dataset, CelebA, under both unsupervised and semi-supervised settings, which is a more challenging task because attributes of human faces are rather abstract. It turns out that our model still achieves the best results.

We make the following contributions: (1) We are the first to achieve general EM process using GAN by introducing a novel GAN based EM learning framework (GAN-EM) that is able to perform clustering, semi-supervised classification and dimensionality reduction; (2) We conduct thoughtful experiments and show that our GAN-EM achieves state-of-the-art clustering results on MNIST and CelebA datasets, and semi-supervised classification results on SVHN and CelebA. (3) We relax the Gaussian assumption of GMM by applying L-Lipchitz continuity on the generator of GAN.

2 Related Work

Image Clustering:

Image classification has been well developed due to the advances of Convolutional Neural Network (CNN) 

[13]

in recent years. However, the excellent classification performance relies on large amounts of image labels. Deep models are far from satisfactory in scenarios where the annotations are insufficient. Therefore, image clustering is an essential problem in computer vision studies.

[27] proposed Deep Embedded Clustering (DEC) which learns feature representations and cluster assignments using deep neural networks. [21] is another study on deep clustering, which aims to cluster data into multiple categories by implicitly finding a subspace to fit each class.

Deep EM: A successful combination of neural networks and EM is the neural expectation maximization (N-EM) [9]. N-EM trains the parameters of EM using a neural network, which derives a differentiable clustering model, and is used for unsupervised segmentation, where N-EM can cluster constituent objects. Banijamali et al. [2]

use generative mixture of networks (GMN) to simulate the GMM. They first use K-means to obtain prior knowledge of the dataset, and then treat each network as a cluster. Variational deep embedding (VaDE) 

[11]

combines GMM with variational autoencoder (VAE), which keeps the Gaussian assumption. In M-step, VaDE maximizes the lower bound on the log-likelihood given by Jensen inequality. In E-step, a neural network is used to model the mapping from data to class assignment.

GAN Clustering: Generative adversarial networks evolute through the years. At the outset, the vanilla GAN [8] could not perform clustering or semi-supervised classification. Springenberg [23]

proposed categorical generative adversarial networks (CatGAN), which can perform unsupervised learning and semi-supervised learning. They try to make the discriminator be certain about the classification of real samples, and be uncertain about that of generated samples, and apply the opposite for the generator. Moreover, their model is based on the assumption that all classes are uniformly distributed, while we relax such assumption in our model. InfoGAN 

[24], adversarial autoencoder [17], feature matching GAN [22] and pixelGAN autoencoder [16] are all GAN variants that can do clustering tasks in unsupervised manner. Our proposed GAN-EM is quite different from the previous GAN variants, in which we fit GAN to the EM framework which has been proved that the likelihood of data increases monotonically. A concurrent work to ours is [28]. Similar to GMM, they fit GANs into the GMM (GANMM). In GANMM, hard label assignment strategy limits the model to K-means, which is an extreme case of EM for mixture model [3]. We have three differences from their work. First, we use soft label assignment, rather than the hard assignment in GANMM. To the best of our knowledge, our work is the first to achieve general EM process using GAN. Second, we use only one GAN, rather than GANs where is the number of clusters. The drawback of using multiple GANs will be discussed in Sec. 4.3. Third, we deal with prior distribution assumption by making the generator L-Lipschitz continuous. Experimental results show that our GAN-EM outperforms GANMM by a big margin.

3 Gan-Em

The overall architecture of our model is shown in Fig. 1. We first clarify some denotations. is the parameters for GAN, which includes for generator and for discriminator, and is the parameters for multinomial prior distribution, which is composed of . stands for all the parameters for the EM process. We denote the number of clusters by , and the number of training samples by .

Figure 1: GAN-EM architecture. : generator, : discriminator, : E-net, : random noise, : specified class, : real images, and : generated images. takes and as input and generates . The inverse function of is approximated by , which is trained with input and label . takes in both and

and outputs the probability of an input to be real for each class.

is tasked to make soft class assignments to all .

3.1 M-step

The goal of M-step is to update the parameters that maximizes the lower bound of log-likelihood given the soft class assignment provided by E-step (where is a matrix). consists of and . Updating is simple, since we can compute the analytical solutions for each : , where is the number of samples for the -th cluster. Details are in Sec. 4.

To update , we extend GAN to a multi-class version to learn the mixture models . The vanilla GAN [8] is modified in several ways as follows.

Generator:  Similar to conditional GAN [18], apart from the random noise , class label information is also added to the input of the generator. With as input, we specify the generator to generate the -th class images. This makes our generator act as generators.

Discriminator:  Different from the vanilla GAN that uses only one output unit, we set units in the output layer.

Loss function:  Here, we only give the form of the loss functions, the derivation of which will be discussed in Sec. 4.1 in detail. and are loss functions of the generator and the discriminator respectively. We have:

(1)
(2)

where is random noise, is the specified class of images to be generated, U and

are uniform distribution and Gaussian distribution respectively,

is to sample images from the real data distribution. stands for generated images,

for sigmoid activation function, and

for the -th class label assignments, which comes from the previous E-step. Here, denotes the -th output unit’s value of the discriminator. Notice that is a modified form for the loss of the generator so that GAN can perform likelihood maximization [7]. The first term of means that all output units are expected to give a low probability to fake images. Conversely, the second term guides the discriminator to output a high probability for all real ones, while the loss for each output unit is weighted by the soft label assignment passed from E-step.

3.2 E-step

In the unsupervised manner, class label for real data cannot be observed and is regarded as latent variable. Thus the goal of E-step is to estimate such latent variable, which is normally obtained using Bayes rule given the parameters learned from the previous M-step. Therefore, we have:

(3)

where represents the distribution of data in class given by the generator network, and the denominator is the normalization term. However, Eq. 3 is hard to calculate since neural network is not invertible.

To circumvent such problem, we introduce another neural network, called E-net, to fit the distribution expressed on the left hand side of Eq. 3. To train the E-net, we first generate samples from the generator, where the number of generated samples for cluster is subject to because is proportional to according to Eq. 3 (remind that ). After the E-net is well trained, it approximates the parameters and act as an inverse generator. Similar approach is also used in BiGAN [6], where they also prove that almost everywhere. However, the goal of BiGAN is feature learning, which is different from ours.

Specifically, as shown in Fig. 1 we take the output of the generator, i.e., , as the input of the E-net, and take the corresponding class as the output label. Therefore, we can learn the approximate distribution of the left hand side of Eq. 3, and thus obtain soft class assignments:

(4)

then feed them back to M-step. The E-net takes the following loss:

(5)

where stands for cross-entropy function,

is a one-hot vector that encodes the class information

, and is the output of E-net. The trained E-net is then responsible for giving the soft class assignment for real images.

3.3 EM algorithm

The bridge between M-step and E-step is the generated fake images with their conditional class labels and the output produced by E-net. So far, the whole training loop has been built up. Then we can train M-step and E-step alternatively to achieve the EM mechanism without Gaussian assumption. We start the training with M-step, where is initialized with uniform distribution. The pseudo code is in Algorithm 1.

1:Initialization:
2:for  do
3:     update : M-step
4:          
5:     update :
6:          for  do
7:               train G:
8:               train D:
9:          end for
10:     for  do E-step
11:         Sample a batch of , and obtain
12:         train E-net: (: weights of E-net111Because E-net aims to learn an inverse function of generator, is not independent with . Thus, we could not say that is the parameter of EM process. In other words, is not part of .)
13:     end for
14:     update label assignment:
15:end for
16:Predict:
Algorithm 1 GAN-EM

4 Theoretical Analysis

This section mainly focuses on the theoretical analysis of GAN in M-step. We first show how our model works with GANs by deriving the objective functions. After that, we simplify the model by using only one GAN. Finally, we show how to deal with prior distribution assumption.

4.1 Background

In M-step, we aim to maximize the Q function [3] expressed as:

(6)

Furthermore, we can write Eq. 6 as the sum of two terms and :

(7)
(8)

Remind that is explained in the previous section. These two terms are independent with each other, so they can be optimized separately.

We first optimize Eq. 8. Since is irrelevant to , we can ignore the expectation because it only introduces a constant factor. With the constraint , the optimal solution for is , where is the sample number of the -th cluster ( is not an integer necessarily because the class label assignment is in a soft version), and is the sample number of the whole dataset. Derivation can be seen in Appendix A.2.

Then we consider Eq. 7. We are expecting to employ GANs to optimize . Each term in the summation in is independent with each other because we are currently considering separate GANs. Therefore, we can optimize each term in , rewrite it as

(9)

to fit the single GAN model, and sum them up in the end. Here, is the parameters of the -th GAN, stands for the distribution of all real data, and stands for the fake data distribution for each cluster .

For convenience, we introduce a new distribution

(10)

where is the normalization factor, . Substitute into Eq. 9 and we have:

(11)

We can drop the constant to obtain:

(12)

which is equivalent to Eq. 11 in terms of optimization.

4.2 Objectives

In this subsection, we aim to show how our design for M-step (i.e. Eq. 1 and Eq. 2) is capable of maximizing Eq. 12 which takes exactly the form of likelihood. This is feasible due to the following two facts [7]:

  1. MLE is equivalent to minimizing the KL-divergence between the real distribution and generated distribution;

  2. When discriminator is optimal, we can modify the loss function for generator so that the optimization goal of GAN is to minimize KL divergence.

Maximizing Eq. 12 is equivalent to minimizing according to fact 1. Then we show that how GANs can be tasked to minimize such KL-divergence by introducing minor changes.

According to fact 2, only when discriminator is optimized can we modify GAN to minimize KL divergence. Therefore, we consider the optimum of discriminators. With Eq. 10, the loss function of the -th discriminator given any generator (denoted by fake data) is:

(13)

where is the -th discriminator, similar to vanilla GAN [8]. We show that when the discriminators reach optimum, is equivalent to the sum of JS divergence. The following corollary is derived from Propositional 1 and Theorem 1 in [8].

Corollary 1.

Equation 13 is equivalent to the JS divergence between real distribution and generated distribution for each cluster when discriminators are optimal, i.e.

(14)

and the optimal for each cluster is:

(15)
Proof.

See Appendix A.1. ∎

If we use the same loss function for generator as the vanilla GAN, the JS divergence in Eq. 14 will be minimized. However, we aim to make GAN to minimize the KL divergence, , for each cluster so as to achieve the goal of maximizing Eq. 12. In Corollary. 1, we already have the optimal discriminator given fixed generator. Therefore, according to fact 2, we need to modify the loss function for the generator as:

(16)

where is the sigmoid activation function in the last layer of the discriminator, and is the output of -th generator.

Now we have derived the objectives of single generator and discriminator, and we need to ensemble them up as a whole model. Since we are currently using GANs, we only need to sum up Eq.16 for the loss of generators:

(17)

and sum up Eq. 13 for the loss of discriminators:

(18)

Here, Eq. 18 is equivalent to Eq. 2 since is generated by generator . The derivation from Eq. 17 to Eq. 1 will be introduced in Sec. 4.3.

4.3 Single GAN v.s. multiple GANs

We have shown that GANs can be tasked to perform MLE in M-step of EM. In fact, using such many GANs is intractable since the complexity of the model grows along with cluster numbers proportionally. Moreover, data is separated per cluster and distributed to different GANs, which could not make the most use of data for individual GAN efficiently.

4.3.1 Single generator

For the generator part, we employ a conditional variable  [18] to make a single generator act as generators. Then the final loss function for generator is exactly Eq. 1.

4.3.2 Single discriminator

In our work, instead of applying discriminators, we use a single discriminator with output units. Each output has individual weights in the last fully connected layer. The preceding layers of the discriminator is shared since the convolutional layers play a role in extracting features that are often in common among different clusters of data.

To this end, we denote the last fully connected layer of the discriminator by function , and all other layers by function , then we have:

(19)

where stands for the features learned by . The objective still holds the form of Eq. 18, but the meaning of has changed from the -th discriminator to the -th output unit of that stands for the probability of belonging to -th cluster. Till now, we have finished deriving the loss functions for our proposed model.

In practice, to speed up the learning, we add an extra output unit for the discriminator, which only distinguishes all fake data from all real data.

4.4 Prior distribution assumption

The prior distribution assumption is necessary because without one we will have for the MLE in M-step and then the parameters of the model will not be updated anymore. GMM has a strong assumption on prior distribution, i.e. Gaussian distribution, while our GAN-EM has a weaker assumption. Details are discussed as follows.

In our work, we make a conjecture that the data distribution generated by a well-trained GAN obeys intra-cluster continuity and inter-cluster discontinuity. Such assumed property of GAN can be seen as the prior distribution assumption of GAN-EM, which is exactly the goal of clustering. Also, compared to GMM, GAN is more powerful in terms of modeling distribution. Therefore, by applying GAN to the EM process, we weaken the Gaussian assumption of GMM. Next, we discuss the intuition of the conjecture.

For intra-cluster continuity, we make use of the uniform continuity of generator. In fact, we use CNN for generator where each layer can be treated as a uniformly continuous function. Composition of these functions, i.e. the generator, is then also uniformly continuous. Therefore, the output space of generator is continuous since the input Gaussian noise is continuous.

For inter-cluster discontinuity, the discriminator helps prevent generator from disobeying this rule. For convenience, we call fake data lying in the gap between clusters “gap data”. Suppose that the generator disobeys inter-cluster discontinuity, e.g. the generator treats two clusters as one. Then it must generate gap data to maintain the intra-cluster continuity. In this case, a well trained discriminator penalizes the generator harder because gap data are farther away from real data, so that the generator will eventually obey inter-cluster discontinuity.

However, in practice, we may encounter a situation where the generator generates very sparse gap data which also satisfies intra-cluster continuity, but the penalization given by discriminator is too small due to the sparsity. Consequently, the clustering would be led to a wrong direction. To solve the problem, we can make the generator L-Lipschitz continuous, which is much stronger than uniformly continuous. We use weight clipping to enforce Lipschitz condition similar to what WGAN [1] does.

5 Experiments

We perform unsupervised clustering on MNIST [14] and CelebA [15] datasets, and semi-supervised classification on MNIST, SVHN [19] and CelebA datasets. We also evaluate the capability of dimensionality reduction by adding an additional hidden layer to the E-net. The results show that our model achieve state-of-the-art results on various tasks. Meanwhile, the quality of generated images are also comparable to many other generative models. The training details and network structures are illustrated in Appendix B.

5.1 Implementation details

We apply RMSprop optimizer to all 3 networks G, D and E with learning rate 0.0002 (decay rate: 0.98). The random noise of generator is in uniform distribution. In each M-step, there are 5 epoches with a minibatch size of 64 for both the generated batch and the real samples batch. We use a same update frequency for generator and discriminator. For E-step, we generate samples using well trained generator with batch size of 256, then we apply 1000 iterations to update E-net.

5.2 Unsupervised clustering

GAN-EM achieves state-of-the-art results on MNIST clustering task with 10, 20 and 30 clusters. We evaluate the error rate based on the following metric which has been used in most other clustering models in Tab. 1:

(20)

where is for the predicted label of the -th cluster, for ground truth label, and for all one-to-one mapping from ground truth labels to predicted labels. We compare our method with most popular clustering models that use generative models like GANs or autoencoders. The experimental results are shown in column 1 of Tab. 1.

MNIST
(Unsupervised)
MNIST
(100 labels)
MNIST
(1000 labels)
SVHN
(1000 labels)
K-means [16] 46.51 () - - -
GMM [16] 32.61(0.06) () - - -
DEC [27] 15.7 () - - -
VAE [12] - 3.33(0.14) 2.40(0.02) 36.02(0.10)
AAE [17] 4.10( 1.13) () 1.90(0.10) 1.60(0.08) 17.70(0.30)
CatGAN [23] 4.27 () 1.91(0.10) 1.73(0.28) -
InfoGAN [24] 5.00 (not specified) - - -
Improved GAN [22] - 0.93(0.06) - 8.11(1.30)
VaDE [11] 5.54 () - - -
PixelGAN [16] 5.27(1.81) () 1.08(0.15) - 6.96(0.55)
GANMM [28] 35.70(0.45) () - - -
GAN-EM 4.20( 0.51) () 1.09( 0.18) 1.03( 0.15) 6.05( 0.26)
4.04( 0.42) () - -
3.97( 0.37) () - -
Table 1: Experiment results of different models on MNIST and SVHN.
(a) MNIST (unsupervised)
(b) SVHN (1000 labels)
Figure 2: Clustering and semi-supervised classification results by GAN-EM.
unsupervised
semi-supervised
(100 labels)
VAE - 45.38
AAE 42.88 31.03
CatGAN 44.57 34.78
VaDE 43.64 -
PixelGAN 44.27 32.54
GANMM 49.32 -
GAN-EM 42.09 28.82
Table 2: GAN-EM unsupervised clustering and semi-supervised classification on all 40 CelebA attributes

Since different models have different experiment setups, there is no uniform standard for the clustering numbers under the unsupervised setting. MNIST dataset has 10 different digits, so naturally we should set the number of clusters . However, some models such as CatGAN, AAE, and PixelGAN use or to achieve better performance, since the models might be confused by different handwriting styles of digits. In other words, the more clusters we use, the better performance we can expect. To make fair comparisons, we conduct experiments with respectively. Also, all models in Tab. 1

take input on the 784-dimension raw pixel space. Note that both K-means and GMM have high error rates (46.51 and 32.61) on raw pixel space, since MNIST is highly non-Gaussian, which is an approximate Bernoulli distribution with high peaks at 0 and 1. The huge margin achieved by GAN-EM demonstrates that the relaxation of Gaussian assumption is effective on clustering problem.

Our proposed GAN-EM has the lowest error rate 3.97 with . Moreover, When , GAN-EM still has better results than other models. With 10 clusters, GAN-EM is only outperformed by AAE, but AAE achieves the error rate of 4.10 using 30 clusters. VaDE also has a low error rate under the setting of , yet still higher than that of our model under the same setting. GANMM has a rather high error rate222When the feature space is reduced to 10 dimensions using SAE, GANMM achieves an error rate of 10.92 () with  [28]., while GAN-EM achieves the state-of-the-art clustering results on MNIST, which shows that the clustering capability of EM is much superior to that of K-means.

Then we test our model on CelebA dataset using the same strategy as stated above. CelebA is a large-scale human face dataset that labels faces by 40 binary attributes. Totally unsupervised clustering on such tasks is rather challenging because these face attributes are so abstract that it is difficult for CNN to figure out what features it should extract to cluster the samples. Tab. 2 (column 1) illustrates the average error rate of different models on all 40 CelebA attributes. We also list detailed results of GAN-EM on all the 40 attributes in Appendix A.3. We achieve the best overall result for unsupervised clustering, and we demonstrate two representative attributes on which we achieve lowest error rates, i.e. hat (29.41) and glasses (29.89), in Figs. 2(a) and 2(b), where the two rows of images are generated by the generator given two different conditional labels respectively. The details of strategies for selecting samples is illustrated in supplementary material Appendix B.3.

(a) CelebA hat
(b) CelebA glasses
Figure 3: Unsupervised feature learning on CelebA: (a) generated images after clustering learning on ’hat’ attribute; (b) generated images after clustering learning on ’glasses’ attribute.

5.3 Semi-supervised classification

It is easy to extend our model to semi-supervised classification tasks where only a small part of samples’ labels are known, while the remainders are unknown. We use almost the same training strategies as clustering task except that we add the supervision to the E-net in every E-step. The method is that at the end of the E-net training using generated fake samples, we train it by labeled real samples. Then the loss function takes the form , where is the one-hot vector that encodes the class label . Once the E-net has an error rate below on the labeled data, we stop the training, where is a number that is close to zero (e.g. 5% or 10%) and can be tuned in the training process. The reason why is greater than zero is to avoid over-fitting.

We evaluate the performance of semi-supervised GAN-EM on MNIST, SVHN and CelebA datasets. As shown in Tab. 1 (column 2, 3 and 4) and Tab. 2 (column 2), our GAN-EM achieves rather competitive results on semi-supervised learning on all three datasets (state-of-the-art on SVHN and CelebA). The images generated by GAN-EM on SVHN are shown in Fig. 1(b). On MNIST, when we use 100 ground truth labels for the semi-supervised classification, the error rate is 1.09, which is only 0.16 higher than the top-ranking result by improved GAN, and when 1000 ground truth labels are used, GAN-EM achieves the lowest error rate 1.03. On SVHN dataset, 1000 labels are applied as other models do, and we achieve state-of-the-art result with an error rate of 6.05. For CelebA, the number of ground truth labels is set to 100, and our model outperforms all other models with respect to average error rate on all 40 attributes.

(a) Supervised 1000D
(b) Unsupervised 2D
(c) Unsupervised 100D
(d) Unsupervised 1000D
Figure 4: Representation of unsupervised dimensionality reduction on MNIST. Each color denotes one class of digit.

5.4 Dimensionality reduction

We can easily modify our GAN-EM model to perform dimensionality reduction by inserting a new layer with hidden units to the E-net ( is the number of dimension that we want to transform to). Layer lays right before the output layer. Then, we can use exactly the same training strategy as the unsupervised clustering task. Once the training converges, we consider the E-net as a feature extractor by removing the output layer. Then we feed the real samples to the E-net and take the output on layer as the extracted features after dimensionality reduction. Three different dimensions, i.e. 1000, 100 and 2, are selected for test on the MNIST dataset. We also apply t-SNE [10] technique to project the dimensionality reduced data to 2D for visualization purpose.

Fig. 3(a) shows the supervised feature learning result. Figs. 3(b)3(c) and 3(d) are unsupervised data dimensionality reduction results with three different dimensions. We can see that our model can deal with all cases very well. Most different digits have large gap with each other and the same digits are clustered together compactly. In particular, the 100 dimension result (Fig. 3(c)) has almost equivalent performance with the supervised feature learning (Fig. 3(a)). It is worth mentioning that in the 2 dimension reduction case, the second last hidden layer of the E-net has only 2 hidden units, but the clustering error rate is as low as 11.8 with 10 clusters that demonstrates the robustness of GAN-EM model.

6 Conclusion

In this paper, we propose a novel GAN-EM learning framework that embeds GAN into EM algorithm to do clustering, semi-supervised classification and dimensionality reduction. We achieve state-of-the-art results on MNIST, SVHN and CelebA datasets. Furthermore, the fidelity of the generated images are comparable to many other generative models. Although all our experiments are performed based on vanilla GAN, GAN-EM framework can also be embedded by many other GAN variants and better results are expected.

References

  • [1] Martín Arjovsky, Soumith Chintala, and Léon Bottou. Wasserstein GAN. CoRR, abs/1701.07875, 2017.
  • [2] Ershad Banijamali, Ali Ghodsi, and Pascal Poupart. Generative mixture of networks. In 2017 International Joint Conference on Neural Networks, IJCNN 2017, Anchorage, AK, USA, May 14-19, 2017, pages 3753–3760, 2017.
  • [3] Christopher M. Bishop. Pattern Recognition and Machine Learning (Information Science and Statistics). Springer-Verlag, Berlin, Heidelberg, 2006.
  • [4] Issam Dagher and Rabih Nachar. Face recognition using IPCA-ICA algorithm. IEEE Trans. Pattern Anal. Mach. Intell., 28(6):996–1000, 2006.
  • [5] A. P. Dempster, N. M. Laird, and D. B. Rubin. Maximum likelihood from incomplete data via the em algorithm. Journal of the Royal Statistical Society. Series B (Methodological), 39(1):1–38, 1977.
  • [6] Jeff Donahue, Philipp Krähenbühl, and Trevor Darrell. Adversarial feature learning. CoRR, abs/1605.09782, 2016.
  • [7] Ian J. Goodfellow. NIPS 2016 tutorial: Generative adversarial networks. CoRR, abs/1701.00160, 2017.
  • [8] Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron C. Courville, and Yoshua Bengio. Generative adversarial nets. In Advances in Neural Information Processing Systems 27: Annual Conference on Neural Information Processing Systems 2014, December 8-13 2014, Montreal, Quebec, Canada, pages 2672–2680, 2014.
  • [9] Klaus Greff, Sjoerd van Steenkiste, and Jürgen Schmidhuber. Neural expectation maximization. In Advances in Neural Information Processing Systems 30: Annual Conference on Neural Information Processing Systems 2017, 4-9 December 2017, Long Beach, CA, USA, pages 6694–6704, 2017.
  • [10] G. E. Hinton.

    Visualizing high-dimensional data using t-sne.

    Journal of Machine Learning Research, 9(2):2579–2605, 2008.
  • [11] Zhuxi Jiang, Yin Zheng, Huachun Tan, Bangsheng Tang, and Hanning Zhou. Variational deep embedding: An unsupervised and generative approach to clustering. In

    Proceedings of the Twenty-Sixth International Joint Conference on Artificial Intelligence, IJCAI, Melbourne, Australia

    , pages 1965–1972, 2017.
  • [12] Diederik P. Kingma and Max Welling. Auto-encoding variational bayes. CoRR, abs/1312.6114, 2013.
  • [13] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E. Hinton. Imagenet classification with deep convolutional neural networks. In Advances in Neural Information Processing Systems 25: 26th Annual Conference on Neural Information Processing Systems 2012. Proceedings of a meeting held December 3-6, 2012, Lake Tahoe, Nevada, United States., pages 1106–1114, 2012.
  • [14] Yann LeCun, Bernhard E. Boser, John S. Denker, Donnie Henderson, Richard E. Howard, Wayne E. Hubbard, and Lawrence D. Jackel. Backpropagation applied to handwritten zip code recognition. Neural Computation, 1(4):541–551, 1989.
  • [15] Ziwei Liu, Ping Luo, Xiaogang Wang, and Xiaoou Tang. Deep learning face attributes in the wild. In Proceedings of International Conference on Computer Vision (ICCV), 2015.
  • [16] Alireza Makhzani and Brendan J. Frey. Pixelgan autoencoders. In Advances in Neural Information Processing Systems 30: Annual Conference on Neural Information Processing Systems 2017, 4-9 December 2017, Long Beach, CA, USA, pages 1972–1982, 2017.
  • [17] Alireza Makhzani, Jonathon Shlens, Navdeep Jaitly, and Ian J. Goodfellow. Adversarial autoencoders. CoRR, abs/1511.05644, 2015.
  • [18] Mehdi Mirza and Simon Osindero. Conditional generative adversarial nets. CoRR, abs/1411.1784, 2014.
  • [19] Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco, Bo Wu, and Andrew Y. Ng. Reading digits in natural images with unsupervised feature learning. NIPS Workshop on Deep Learning and Unsupervised Feature Learning, 2011.
  • [20] Sebastian Nowozin, Botond Cseke, and Ryota Tomioka. f-gan: Training generative neural samplers using variational divergence minimization. In Advances in Neural Information Processing Systems 29: Annual Conference on Neural Information Processing Systems 2016, December 5-10, 2016, Barcelona, Spain, pages 271–279, 2016.
  • [21] Xi Peng, Shijie Xiao, Jiashi Feng, Wei-Yun Yau, and Zhang Yi. Deep subspace clustering with sparsity prior. In IJCAI, pages 1925–1931. IJCAI/AAAI Press, 2016.
  • [22] Tim Salimans, Ian J. Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, and Xi Chen. Improved techniques for training gans. In Advances in Neural Information Processing Systems 29: Annual Conference on Neural Information Processing Systems 2016, December 5-10, 2016, Barcelona, Spain, pages 2226–2234, 2016.
  • [23] Jost Tobias Springenberg. Unsupervised and semi-supervised learning with categorical generative adversarial networks. CoRR, abs/1511.06390, 2015.
  • [24] Adrian Spurr, Emre Aksan, and Otmar Hilliges. Guiding infogan with semi-supervision. In Machine Learning and Knowledge Discovery in Databases - European Conference, ECML PKDD 2017, Skopje, Macedonia, September 18-22, 2017, Proceedings, Part I, pages 119–134, 2017.
  • [25] A. Srivastava, A. B. Lee, Eero P. Simoncelli, and S.-C. Zhu. On advances in statistical modeling of natural images. Journal of Mathematical Imaging and Vision, 18(1):17–33, 2003.
  • [26] Martin J. Wainwright and Eero P. Simoncelli. Scale mixtures of gaussians and the statistics of natural images. In NIPS, pages 855–861. The MIT Press, 1999.
  • [27] Junyuan Xie, Ross B. Girshick, and Ali Farhadi.

    Unsupervised deep embedding for clustering analysis.

    In Proceedings of the 33nd International Conference on Machine Learning, ICML 2016, New York City, NY, USA, June 19-24, 2016, pages 478–487, 2016.
  • [28] Yu Yang and Wen-Ji Zhou. Mixture of gans for clustering. In Proceedings of the Twenty-Seventh International Joint Conference on Artificial Intelligence, IJCAI, Stockholm, Sweden, 2018.

Appendix A Missing details

a.1 Proof for Corollary 1

Proof.

For every cluster , we have:

Let , we have:

Then we substitute this optimal into the summation in Eq. 13:

Since we already have the optimal solution for all clusters as given in Eq. 15, we can say that for each cluster, we minimize the JS divergence between real data and fake data. When summed up, the optimal remains since each term is independent with others. ∎

a.2 Derivation for

This result can be easily obtained by solving the optimization problem:

We can derive that

a.3 Results on all 40 CelebA attributes

Here in Tab. 3 we show the experiment results of GAN-EM on all 40 CelebA attributes with both unsupervised and semi-supervised settings. We bold top three results for unsupervised clustering and semi-supervised classification respectively.

Attributes Unsupervised
Semi-supervised
(100 labels)
Attributes Unsupervised
Semi-supervised
(100 labels)
Goatee 0.4103 0.2218 Bald 0.4855 0.4001
Narrow_Eyes 0.3308 0.1757 Arched_Eyebrows 0.4447 0.2993
Bangs 0.3592 0.1506 Wavy_Hair 0.3376 0.2657
Gray_Hair 0.4538 0.1849 Mouth_Slightly_Open 0.4272 0.1895
Big_Lips 0.4804 0.1767 Young 0.4557 0.3216
Heavy_Makeup 0.2366 0.2218 No_Beard 0.4869 0.4315
Attractive 0.4749 0.3736 Pointy_Nose 0.3426 0.1847
Bags_Under_Eyes 0.4468 0.3545 Bushy_Eyebrows 0.4447 0.3592
High_Cheekbones 0.4455 0.2637 Double_Chin 0.4546 0.3294
Oval_Face 0.4858 0.2583 gender 0.3683 0.1446
Rosy_Cheeks 0.3518 0.2891 hat 0.2941 0.1598
Sideburns 0.4726 0.2507 glass 0.2983 0.1662
Mustache 0.4539 0.2947 Male 0.4763 0.2746
Brown_Hair 0.4463 0.4362 Receding_Hairline 0.4847 0.4138
Pale_Skin 0.4736 0.4457 Wearing_Necklace 0.4758 0.2755
Chubby 0.4478 0.3242 Wearing_Necktie 0.4805 0.4157
Big_Nose 0.4695 0.4337 Wearing_Lipstick 0.4141 0.1601
Blurry 0.3045 0.2188 Straight_Hair 0.4687 0.3102
Black_Hair 0.4555 0.3949 Wearing_Earrings 0.4393 0.3686
Blond_Hair 0.3145 0.3057 5_o_Clock_Shadow 0.4435 0.2822
Table 3: Error rates of GAN-EM on all 40 CelebA attributes

Appendix B Implementation details

b.1 Mnist

MNIST dataset has 60,000 training samples, among which we assign 10,000 samples as validation data. The images are pixels of which the values are normalized to [0, 1]. The input of the generator is 72 dimension (62 dimension random noise and 10 dimension conditional class label)The generator has one fully connected layer with 6272 hidden units () followed by 2 transpose convolutional layers. The first transpose convolutional layer has 64 filters of which the size is

. The convolutional operation is stride 2 with 1 zero padding. The second convoluational layer uses the same structure with the first one except that the number of filters is 1. Then the output of the generator is in size

which is the same as the real image size. The discriminator also has 3 convolutional layers with 64, 128, 256 filters respectively. All of these 3 layers uses kernel size 4, stride 2 and zero padding 1. Then 2 fully connected layers are used with 1024 and 11 hidden units respectively. E-net shares the same structure with the discriminator except that the number of units in last layer is 10. Tab. 4 describes the network structure in detail.

Generator Discriminator E-net
Input 72 dim vector Input Input
FC () Conv, f:64 s:2 d:1 Conv, f:64 s:2 d:1
BN lRelu(0.2) BN lRelu(0.2) BN lRelu(0.2)
Deconv, f:64 s:2 d:1 Conv, f:128 s:2 d:1 Conv, f:128 s:2 d:1
BN lRelu(0.2) BN lRelu(0.2) BN lRelu(0.2)
Deconv, f:1 s:2 d:1 Conv, f:256 s:2 d:1 Conv, f:256 s:2 d:1
BN lRelu(0.2) BN lRelu(0.2)
FC(1024), BN lRelu(0.2) FC(1024), BN lRelu(0.2)
FC(11) FC(10)

BN: batch normalization, FC: fully connected layer. f: filter number. s: stride size. d: padding size

Table 4: GAN structure for MNIST.

We apply RMSprop optimizer to all these 3 networks with learning rate 0.0002 (decay rate: 0.98). The random noise of generator is in uniform distribution. In each M-step, there are 5 epoches with a minibatch size of 70 for both the generated batch and the real samples batch. We use a same update frequency for generator and discriminator. For E-step, we generate samples using well trained generator with batch size of 200, then we apply 1000 iterations to update E-net.

For semi-supervised classification, we add the supervision to both E-net. At the end of E-step, we train the E-net by labeled data until the prediction accuracy given by E-net is 100%. Then we feed the real data to E-net to obtain which is the same as unsupervised clustering.

b.2 Svhn

Similar to MNIST, SVHN is also a digits recognition datasets with about 53,000 training images ( pixels). Since the images of SVHN are all from the real world and many of them are blurry, the recognition problem is much more difficult than MNIST which is preprocessed to grey-scale images. We normalize the images to [-1, 1]. We use almost the same network structure with MNIST with only a slight difference. The details of GAN are in Tab. 5.

Generator Discriminator E-net
Input 72 dim vector Input Input
FC () Conv, f:64 s:2 d:1 Conv, f:64 s:2 d:1
BN lRelu(0.2) BN lRelu(0.2) BN lRelu(0.2)
Deconv, f:64 s:2 d:1 Conv, f:128 s:2 d:1 Conv, f:128 s:2 d:1
BN lRelu(0.2) BN lRelu(0.2) BN lRelu(0.2)
Deconv, f:32 s:2 d:1 Conv, f:256 s:2 d:1 Conv, f:256 s:2 d:1
BN lRelu(0.2) BN lRelu(0.2) BN lRelu(0.2)
Deconv, f:3 s:2 d:1 FC(1024), BN lRelu(0.2) FC(1024), BN lRelu(0.2)
FC(11) FC(10)
BN: batch normalization, FC: fully connected layer. f: filter number. s: stride size. d: padding size
Table 5: GAN structure for SVHN.

For semi supervision, we use the same strategies with MNIST which is discussed above.

b.3 CelebA

CelebA dataset is a large-scale human face dataset with more than 200k images. Each image is annotated by 40 binary attributes. For example, for gender attribute, we first select all 12k male images, then we select the same amount of female images. The selection of female images is based on such principle: all other attributes except for gender should maintain as much purity as possible. Since we are unable to guarantee all the other attributes are 100% pure, we regard those impure attributes as noise. The selected images are cropped to pixels and are normalized to [-1, 1]. We apply the same strategy for the other attributes. The network details are in Tab. 6.

Generator Discriminator E-net
Input 300 dim vector Input Input
Deconv, f:1024 s:1 d:0 Conv, f:128 s:2 d:1 Conv, f:128 s:2 d:1
BN lRelu(0.2) BN lRelu(0.2) BN lRelu(0.2)
Deconv, f:512 s:1 d:0 Conv, f:256 s:2 d:1 Conv, f:256 s:2 d:1
BN lRelu(0.2) BN lRelu(0.2) BN lRelu(0.2)
Deconv, f:256 s:1 d:0 Conv, f:512 s:2 d:1 Conv, f:512 s:2 d:1
BN lRelu(0.2) BN lRelu(0.2) BN lRelu(0.2)
Deconv, f:128 s:1 d:0 Conv, f:1024 s:2 d:1 Conv, f:1024 s:2 d:1
BN lRelu(0.2) BN lRelu(0.2) BN lRelu(0.2)
Deconv, f:3 s:1 d:0 Conv, f:3 s:1 d:0 Conv, f:3 s:1 d:0
BN lRelu(0.2) BN lRelu(0.2)
FC(3) FC(2)
BN: batch normalization, FC: fully connected layer. f: filter number. s: stride size. d: padding size
Table 6: GAN structure for CelebA.