Decomposed Adversarial Learned Inference

04/21/2020 ∙ by Alexander Hanbo Li, et al. ∙ University at Buffalo 3

Effective inference for a generative adversarial model remains an important and challenging problem. We propose a novel approach, Decomposed Adversarial Learned Inference (DALI), which explicitly matches prior and conditional distributions in both data and code spaces, and puts a direct constraint on the dependency structure of the generative model. We derive an equivalent form of the prior and conditional matching objective that can be optimized efficiently without any parametric assumption on the data. We validate the effectiveness of DALI on the MNIST, CIFAR-10, and CelebA datasets by conducting quantitative and qualitative evaluations. Results demonstrate that DALI significantly improves both reconstruction and generation as compared to other adversarial inference models.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 5

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Deep directed generative models like variational autoencoder (VAE)

[kingma2013auto, rezende2014stochastic] and generative adversarial network (GAN) [goodfellow2014generative]

have been proved to be powerful for modeling complex high-dimensional distributions. While both VAE and GAN can learn to generate realistic images, their underlying mechanisms are fundamentally different. VAE maps the data into low-dimensional codes using an encoder, and then reconstructs the original data by a decoder. This allows it to perform both generation and inference. GAN, on the other hand, trains a generator and a discriminator adversarially. The generator learns to fool the discriminator by mapping low-dimensional noise vectors to the data space; at the same time, the discriminator evolves to detect the generated fake samples from the true ones. These two methods have complementary strengths and weaknesses. VAE can learn a bidirectional mapping between data and code spaces, but relies on over-simplified parametric assumptions on the complex data distribution, thereby causing it to generate blurry images

[donahue2016adversarial, goodfellow2014generative, larsen2015autoencoding]. GAN generates more realistic samples than VAE [radford2015unsupervised, larsen2015autoencoding] because the adversarial regime allows it to learn more complex distributions. However, note that GAN only learns a unidirectional mapping for data generation, and does not allow inferring the latent codes from given samples. This is limiting because the ability of inference is very crucial for several downstream applications, such as classification, clustering, similarity search, and interpretation. Furthermore, GAN also suffers from the mode collapse problem [che2016mode, salimans2016improved] – many modes of the data distribution are not represented in the generated samples.

Therefore, one may wonder on whether we can develop a generative model that enjoys the strengths of both GAN and VAE without their inherent weaknesses. Such model should be able to generate high-quality samples as good as GAN, have an inference mechanism as effective as VAE, and also avoid the mode collapse issue. Many recent efforts have been devoted to combining VAE with adversarial discriminator(s) [brock2016neural, che2016mode, larsen2015autoencoding, makhzani2015adversarial, mescheder2017adversarial]. However, VAE-GAN hybrids tend to manifest a compromise of the strengths and weaknesses of both the approaches. The main reason is that all of them retain the VAE structure, which requires an explicit metric to measure the data reconstruction and assumes over-simplified parametric data distributions. To overcome such limitations, adversarially learned inference (ALI) [donahue2016adversarial, dumoulin2016adversarially]

was recently proposed, wherein the discriminator is trained on the joint distribution of data and latent codes. In this way, under a perfect discriminator, one can match joint distributions of the decoder and encoder, thereby, performing inference by sampling from the encoder’s conditional that also matches the decoder’s posterior. In practice, however, the equilibrium of the jointly adversarial game is hard to attain as the dependency structure between data and codes is not explicitly specified. The reconstructions of ALI are thus not always faithful

[dumoulin2016adversarially, li2017alice] implying that its inference is not always effective.

To overcome the aforementioned issues, in this paper, we propose a novel approach, decomposed adversarial learned inference (DALI), that integrates efficient inference to GAN and overcomes the limitations of prior approaches. The approach keeps the structure simple, involving only one generator, one encoder, and one discriminator. Furthermore, DALI’s objective is directly derived from our goal of matching both prior and conditional distributions of the generator and encoder, instead of a heuristic combination with

norm regularization. Compared to regular GANs, DALI has the ability to conduct inference, and also does not suffer from the mode collapse problem. Moreover, DALI also abandons the unrealistic parametric assumption on the conditional data distribution, and does not require any reconstruction in the data space. This is fundamentally different from VAE or VAE-GAN hybrids in which the norm is used to measure the data reconstruction. The usage of simple data-fitting metrics on the complex data distribution leads to worse generation performance. Different from ALI, DALI decomposes the hard problem of matching the joint distributions into two sub-tasks – explicitly matching the priors on the latent codes and the conditionals on the data. As a consequence of more restrictive constraint, it achieves better generation and more faithful reconstruction than ALI. Note that GAN variations with inference mechanism usually worse generation performance as compared to regular GANs [rosca2018distribution]. To the best of our knowledge, as demonstrated in the experiments, DALI is the first framework that further improves the generating performance compared with GANs with the same architecture, while providing consistent inference on even complicated distributions.

2 Background

We consider the generative model , where a latent variable is first generated from the prior distribution , and then the data is sampled from the conditional distribution . The parameter stands for the ground truth parameter of the underlying distribution. The prior is always assumed to be a simple parametric distribution (e.g. ), but the generative conditional is much more complicated and not known to us. Moreover, the posterior distribution is intractable but stands for an important inference procedure: given a data , it allows us to infer its latent variable .

3 Methodology

In our method, we will model the generating process by a neural network called

generator, and the inference process by another neural network called encoder. Consider the following two distributions of the generator and encoder, and their corresponding sampling procedures:

  • the generator distribution: ; , .

  • the encoder distribution: ; , .

The generator’s conditional approximates the generating distribution . The encoder’s conditional approximates the posterior distribution , which is what we need for inference. The marginal distribution stands for the empirical data distribution, and the other marginal is taken to be , which is always a known distribution like standard Gaussian.

3.1 Decomposition of the Joint Distribution

The ultimate goal is to match the joint distributions, and . If this is achieved, we are guaranteed that all marginals match and all conditionals match as well. In particular, the conditional matches the posterior . We propose to decompose this goal into two sub-tasks – matching the priors and , and matching the conditionals and . There are two advantages. Firstly, we explicitly define the dependency structure . Secondly, the explicit constraints on both priors and conditionals are stronger than merely one constraint on the joint distributions.

More formally, we decompose the problem of minimizing into matching both the prior and conditional distributions, that is, to minimize

(1)

Note that (1) is not identical to ALI’s objective, but their minimums are attained at the same point. By the properties of KL-divergence, when the minimum of (1) is attained, we have and for all and , and hence .

3.2 Objective Function

The objective (1) cannot be directly optimized because both and are impossible to sample from, as the flow in the encoder is from to . However, we prove that the intractable (1) can be rephrased as the combination of a KL-divergence term and a reconstruction term, both containing only distributions that can either be sampled from or directly evaluated.

Firstly, by definition of KL-divergence, for any fixed ,

(2)

Then by Bayes’ theorem, we have

. Plugging this identity into (3.2) and doing some algebra, we get

(3)

Next for the second term of (1), we also write out the definition

Then we have

(4)

where is a constant because the prior is a fixed parametric distribution. For example, when , we have . Therefore, minimizing the objective (1) is now transformed to minimizing the new objective

(5)

Intuitively, term measures the difference between the generated and real samples, and term measures the reconstruction of the latent codes. We summarize the above procedure as a proposition.

Proposition 1.

The final objective function (5) is minimized when and for all , . And hence, at the minimum, the joint distributions .

3.3 Relation to the Variational Autoencoder

The VAE [kingma2013auto] method, using our notations in this paper, actually depends on the following identity:

(6)

Then because of the non-negativity of KL-divergence, we have

and hence maximizing the log-likelihood of the observations can be transferred to maximizing the evidence lower bound (ELBO). But taking a closer look at (6) and comparing it to (3), we notice that (6) is a decomposition of the KL-divergence between two conditionals, and . Therefore, we can follow the same approach after (3) and get the following identity:

(7)

We denote and . Since the marginal stands for the empirical data distribution, the right hand side of (3.3) is the empirical expectation of the negative ELBO, which is what VAE tries to minimize. We then conclude from (3.3) that VAE performs marginal distribution matching in the data space and conditional distribution matching in the latent space. This distribution matching of VAE is also observed by rosca2018distributionrosca2018distribution.

However, the marginal distributions in the data space are very complex, and the direction in the conditional distributions in the latent space is actually opposite to the generating process . Hence, in order to match these distributions, VAE’s objective has a reconstruction term on , and a regularization term on latent . But to evaluate both terms, we need to make parametric assumptions on both conditionals and . The assumption on can be loosed using GANs [makhzani2015adversarial], but the assumption on is critical and limits the performance of VAE-GAN hybrids.

Our model, DALI, instead performs marginal distribution matching in the latent space and conditional distribution matching in the data space. From (5), since the term I will be replaced with an adversarial game (see Section 3.4), the only assumption we need to make is on term II, that is, on the conditional . And our model is very flexible in its dependence on . This assumption is much weaker than that on and does not lead to the problems of VAE or VAE-GANs (e.g. blurriness).

3.4 DALI framework

The KL-divergence part (I) can be replaced by an adversarial game using the -divergence theory [nowozin2016f]. The reconstruction term (II) is a log-likelihood and can be simply evaluated if we assume a parametric . Therefore, our framework only requires exactly one generator G, one discriminator D, and one encoder E. We will now discuss how to play the adversarial game and measure the reconstruction in details.

Adversarial game

Because we do not want to make any parametric assumption on the distribution , an adversarial game will be played to distinguish from . By the theory of -GAN [nowozin2016f], we construct an adversarial game with the value function to be

(8)

Under the perfect discriminator, finding the optimal generator of (8

) is then equivalent to minimizing the KL-divergence. The activation function for the discriminator in (

8

) is just the identity mapping instead of the sigmoid function in the original GAN. But just like in the original GAN, the generator of (

8) also suffers from the gradient vanishing problem [goodfellow2016nips]. Therefore, in our experiments, we maximize (8) for the discriminator , but minimize for the generator . We call the algorithm using this value function DALI-.

As shown in fedus2017many,lucic2018gansfedus2017many,lucic2018gans, the equilibrium of the adversarial is hard to attain in practice, and we are not using the theoretical value function to train because of the gradient vanishing problem. Therefore, we also try using WGAN and GAN for the adversarial game in our experiments, and find out GAN provides consistently better and more stable results.

Reconstruction

Because of the simplicity of the distribution of , we make a reasonable parametric assumption on so that the log-likelihood can be explicitly calculated. In this paper we will assume , and define

(9)

where is the dimension of the latent variable . In this case, the encoder network only needs to output two vectors, and , that is, . Then we can compute the approximate negative posterior log-likelihood by plugging into (9).

Final Framework

To summarize, our final optimization problem is

(10)

Here, is a hyper-parameter that needs to be set so that two parts of (10) are in the same scale. We will discuss the selection of in detail in the experiment section.

3.5 Training and Inference Procedures

The training procedure is summarized in Algorithm 1. Given random , we first generate samples using the generator. Then the discriminator is updated to distinguish between generated and real samples. The encoder outputs the parameters for the distribution , from which we calculate the log-likelihood in (II). Then the generator and encoder are updated together to minimize the reconstruction error (i.e. maximize the expected log-likelihood), while the generator has an extra goal that is to fool the discriminator. For any data , its inferred latent code is set to be the conditional mean . Then the reconstruction of is . Besides the reconstruction, we can also generate more samples which are close to in the sense that they have similar latent codes. This can be done by first sampling ’s from the posterior , and then map them to the data space using the generator.

repeat
      Draw samples from the prior
      Generate samples using the generator network
     

Calculate mean and variance of

      Compute discriminator predictions
     
      Compute discriminator loss
      Compute generator loss
      Compute encoder loss
      Compute reconstruction loss
      Gradient update on discriminator network
      Gradient update on generator and encoder networks
until convergence
Algorithm 1 The DALI training procedure.

4 Experiments

We evaluate our proposed method, DALI, for both reconstruction and generation tasks, on the data sets MNIST [lecun1998gradient], CIFAR-10 [krizhevsky2009learning] and CelebA [liu2015deep]. To show the effectiveness of DALI on mode collapse reduction, we also conduct the same 2D Gaussian mixture experiment as in dumoulin2016adversariallydumoulin2016adversarially. The architectures of our discriminator and generator are based on DCGAN [radford2015unsupervised] and slightly simpler, which can be easily replaced by more advanced state-of-the-art GANs, and we use a deterministic generator throughout the experiments. Our encoder network consists of convolutional layers followed by two separated fully connected networks, which are used to predict the mean and variance of the posterior , respectively. The Adam optimizer [kingma2014adam] is used and the learning rate decay strategy suggested by kingma2014adamkingma2014adam is applied. Since there are summands in (9), we simply set to be in our experiments to calculate the average distance on each dimension of

. We also observe that the discriminator shares a similar task with the encoder: both of them need to extract higher level features from raw images. Therefore, in order to reduce the number of parameters and to stabilize the training procedure, our encoder takes the intermediate hidden representation learned by the discriminator as its own input. It is worth noting that the encoder does not update the common feature extracting layers. We use the PyTorch 1.1 to implement our model.

4.1 Quantitative Results on Real Datasets

In this section, we use quantitative measures (MSE, Inception Score (IS), Frechet Inception Distance (FID)) to compare the inference and generation performance of DALI, GAN, ALI and ALICE. And for fair comparison, GAN is implemented to have the identical generator and discriminator with DALI. We also include a reduced version of DALI, named DALI-, in which the conditional distribution of the encoder is assumed to be a Gaussian with identity covariance matrix. To evaluate the performance of inference, we measure it through reconstructing test images and calculating the mean squared error (MSE), which has been adopted in li2017aliceli2017alice. As for generation, we calculate the inception score [salimans2016improved] on

randomly generated images. The inception scores on MNIST are evaluated by the pre-trained classifier from li2017aliceli2017alice, and the inception scores on CIFAR-10 is based on the ImageNet. The quantitative results are summarized in Table

1.

Method MNIST CIFAR-10
MSE Inception Score MSE Inception Score FID
DALI 0.026 0.018 9.483 0.020 0.019 0.009 6.450 0.085 28.4
DALI- 0.028 0.018 9.331 0.021 0.037 0.017 6.324 0.056 29.2
GAN - 9.464 0.020 - 6.287 0.061 37.1
ALI 0.480 0.100 8.749 0.090 0.672 0.113 5.930 0.044 58.9
ALICE 0.080 0.007 9.279 0.070 0.416 0.202 6.015 0.028 -
Table 1: MSE (lower is better) and Inception scores (higher is better) on MNIST and CIFAR-10. ALI and ALICE results are from the experiments in li2017aliceli2017alice.

Inference

From Table 1, DALI achieves the best reconstruction results on both data sets. On MNIST, DALI significantly decreases the MSE by 68% and 95% compared with ALICE and ALI respectively. On the more complicated CIFAR-10 data set, DALI decreases the MSE by 95% and 97%. In order to alleviate the non-identifiable issue of ALI, ALICE adds the conditional entropy constraint by explicitly regularizing the norms between the reconstructed and real images. However, as the data distribution becomes more complicated like in CIFAR-10, the norms become inadequate to measure the reconstruction. Consequently, ALICE’s reconstruction error on CIFAR-10 increases significantly compared with that on MNIST. In contrast, the reconstruction performance of DALI is consistent on both data sets. The reason is that our model explicitly specifies the dependency structure of the generative model, and matches both prior and conditional distributions without using the simple data-fitting metrics in the data space. This can be further justified by the performance of DALI- which follows the same structure. Compared with DALI-, DALI further decreases the MSE significantly by a relative 49% on CIFAR-10, which shows that the inferred conditional variance is crucial for achieving the faithful reconstructions on complicated data sets.

Figure 1:

Reconstruction comparison between our proposed model DALI (first row) and ALI (BiGAN) (second row) on MNIST, CIFAR-10 and CelebA datasets. In each subfigure, the odd columns represent original samples from the test set and the even columns are their reconstructions.

Generation

DALI outperforms all the baseline models including GAN on inception score. This suggests that DALI can bring further improvement on generation performance instead of deteriorating it like the other baselines. The reason that both ALI and ALICE perform worse than GAN on generation is that the task of matching two complicated joint distributions, and , is more difficult than the task of the regular GAN, which is to match only the marginals, and . The proposed model DALI explicitly defines the dependency structure between and , which is more effective compared with one step joint distribution matching. Comparison between DALI and DALI- shows that the learned variance is also critical for better generation performance. We also want to highlight that DALI’s generation performance can be further improved by replacing the adversarial network with more advanced state-of-the-art GANs.

4.2 Visualization of the Reconstructions

In Figure 1, we compare reconstruction of DALI with the results reported in ALI[dumoulin2016adversarially] (BiGAN [donahue2016adversarial]). From the first column of Figure 1, we observe that ALI provides a certain level of reconstructions. However, it fails to capture the precise style of the original digits. In contrast, DALI can achieve very sharp and faithful reconstructions. On CIFAR-10, ALI’s reconstructions are less faithful and oftentimes make mistakes in capturing exact object placement, color, style, and object identity. Our model produces better reconstructions in all these aspects. For the reconstructions on CelebA, DALI reproduces the similar style, color and face placement, and even achieves a high level of face identity. As stated in dumoulin2016adversariallydumoulin2016adversarially, they believe ALI’s unfaithful reconstructions is caused by underfitting. This also leads us to believe that our adversarial regime (marginal and conditional distribution matching) is more efficient for inference compared to joint distribution matching regimes.

4.3 Mode Collapse Reduction

GAN ALI Unrolled GAN VAEGAN VEEGAN SN-GAN DALI DALI-
Modes (Max 25) 3.3 15.84 23.6 21.4 24.6 25 25 25
% High Quality Samples 0.5 1.6 16 34.1 40 67.8 81.1 66.4
Table 2: Degree of mode collapse, measured by modes captured (higher is better) and % high quality samples (higher is better) on 2D grid data. The baseline results of GAN, ALI and Unrolled GAN are reported in srivastava2017veegansrivastava2017veegan.

To show the effectiveness of our model on mode collapse reduction, we perform the same synthetic experiment as in dumoulin2016adversariallydumoulin2016adversarially. The data is a 2D Gaussian mixture of 25 components laid out on a grid. To quantify the degree of mode collapse, we use the two metrics used in srivastava2017veegansrivastava2017veegan: the number of modes captured and the percentage of high quality samples

. A generated sample is counted as high quality if it is within three standard deviations of the nearest mode. Then the number of modes captured is the number of mixture components whose mean is nearest to at least one high quality sample. We compare the proposed method DALI and DALI-

to ALI, Unrolled GAN [metz2016unrolled], VAEGAN [larsen2015autoencoding], VEEGAN [srivastava2017veegan] and SN-GAN [miyato2018spectral]. As shown in Table 2, the proposed model DALI provides the best performance on both measures consistently. More specifically, DALI can capture 25 modes every time and generate more than 80% of high-quality samples. This suggests that the proposed model DALI significantly alleviates the mode collapse issue of the GAN framework and hence further improves the generation performance.

5 Related Work

The most straightforward way to learn an inference mechanism is to learn the inverse mapping of GAN’s generator post-hoc [zhu2016generative]. However, since its training process is the same as GAN, it still suffers from mode collapse problem. InfoGAN [chen2016infogan] minimizes the mutual information between a subset of the latent code and the generated samples, and hence can only do partial inference on . AGE [ulyanov2017takes] encourages encoder and generator to be reciprocal by simultaneously minimizing an reconstruction error in the data space and an error in the code space. This is closely related to the cycle-consistency criterion [zhu2017unpaired, kim2017learning, yi2017dualgan, li2017alice]. Although the pairwise reconstruction errors help reduce mode collapse, the data reconstruction is still measured by or norm, which brings the same problem of VAE and VAE-GAN hybrids. It is worth noting that the main difference between our method and VAE is not about which divergence we use, but rather about upon which space we calculate the divergence. In VAE, they calculate the divergence on -space, but in DALI, we calculate the divergence on . Putting (reverse) KL-divergence on allows us to play the adversarial game on the more complicated distribution of , but leave the parametric reconstruction to simpler .

Different from the heuristic combination of VAE and GANs, mescheder2017adversarialmescheder2017adversarial theoretically derived an adversarial game to replace the KL-divergence term in the variational lower bound (also called ELBO), and gives the new method, adversarial variational Bayes (AVB), much more flexibility in its dependence on latent . However, the reconstruction term on still exists and so is the parametric assumption on the conditional data distribution, leading to the blurriness in their reconstructed and generated samples.

ALI [dumoulin2016adversarially, donahue2016adversarial] is an elegant approach to bring inference mechanism into adversarial learning without assuming parametric distribution on the data. Different from our work, it directly plays an adversarial game to match the joint distributions of the decoder and encoder. But in practice, ALI’s reconstructions are not necessarily faithful because the dependency structures within the two joint distributions are not specified [li2017alice]. ALICE [li2017alice] tries to solve this problem by regularizing ALI using an extra conditional entropy constraint on the data. The conditional entropy is either explicitly measured by norm, or implicitly learned by adversarial training. However, when the data distribution becomes complicated (e.g. CIFAR-10), the metric may lead to blurry reconstructions and the adversarial training is hard to achieve [li2017alice]. Compared with ALI and ALICE, our method is proven to minimize the KL-divergence between both priors and conditionals of generator and encoder, and can provide consistent effective inference even on complicated distribution (see Section 4.1).

srivastava2017veegansrivastava2017veegan proposed VEEGAN to tackle the mode collapse issue of GANs by adding an implicit variatinoal learning on the latent . To our best knowledge, this is by far the only approach that is also reconstructing . Different from VAEs, VEEGAN autoencodes the latent variable or noise . By doing so, it enforces the generator not to collapse the mappings of to a single mode, because otherwise, the encoder will not be able to recover all the noise . The details of their model can be summarized as ALI regularized by an extra reconstruction of latent . Therefore, VEEGAN is similar to ALICE in the sense that they are both adversarial games on the joint distribution with an extra regularization on either data or latent reconstruction. Our model DALI instead only plays the adversarial game on the marginal data distribution, and reconstructs the latent by maximizing its log-likelihood under the latent posterior distribution.

6 Conclusion and Future Work

We proposed a novel framework, DALI, which matches both prior and conditional distributions between the generator and the encoder. Adversarial inference is incorporated into this framework and there is no parametric assumption on the conditional data distribution. We show in the experiments that the proposed method not only allows efficient inference but also improves the image generation.

The assumption on can be further released using an autoregressive . However, the same technique cannot be easily applied to or . Therefore, we believe the reconstruction direction is more expressive than the opposite .

References