metropolis-hastings-gans
None
view repo
We introduce the Metropolis-Hastings generative adversarial network (MH-GAN), which combines aspects of Markov chain Monte Carlo and GANs. The MH-GAN draws samples from the distribution implicitly defined by a GAN's discriminator-generator pair, as opposed to sampling in a standard GAN which draws samples from the distribution defined by the generator. It uses the discriminator from GAN training to build a wrapper around the generator for improved sampling. With a perfect discriminator, this wrapped generator samples from the true distribution on the data exactly even when the generator is imperfect. We demonstrate the benefits of the improved generator on multiple benchmark datasets, including CIFAR-10 and CelebA, using DCGAN and WGAN.
READ FULL TEXT VIEW PDFNone
Metropolis-Hastings GAN in Tensorflow for enhanced generator sampling
Experiments on Metropolis-Hastings Generative Adversarial Networks, including my own implementation
None
Traditionally, density estimation is done with a model that can compute the data likelihood. Generative adversarial networks (GANs)
(Goodfellow et al., 2014)present a radically new way to do density estimation: They implicitly represent the density of the data via a classifier that distinguishes real from generated data.
GANs iterate between updating a discriminator and a generator , where generates new (synthetic) samples of data, and attempts to distinguish samples of from the real data.
In the typical setup, is thrown away at the end of training, and only is kept for generating new synthetic data points.
In this work we propose the Metropolis-Hastings GAN (MH-GAN), a GAN that constructs a new generator that “wraps” using the information contained in .
This principle is illustrated in Figure 1.^{1}^{1}1The
code for this project is available at:
https://github.com/uber-research/metropolis-hastings-gans
The MH-GAN uses Markov chain Monte Carlo (MCMC) methods to sample from the distribution implicitly defined by the discriminator learned for the generator . This is built upon the notion that the discriminator classifies between the generator and a data distribution:
(1) |
where is the (intractable) density of samples from the generator , and is the data density implied by the discriminator with respect to . If GAN training reaches its global optimum then this discriminator distribution is equal to the data distribution and the generator distribution () (Goodfellow et al., 2014). Furthermore, if the discriminator is optimal for a fixed imperfect generator then the implied distribution still equals the data distribution ().
We use an MCMC independence sampler (Tierney, 1994) to sample from by taking multiple samples from . Amazingly, using the algorithm we present, one can show that given a perfect discriminator and a decent (but imperfect) generator , one can obtain exact samples from the true data distribution . Standard MCMC implementations require (unnormalized) densities for the target and the proposal , which are both unavailable for GANs. However, the Metropolis-Hastings (MH) algorithm requires only the ratio
(2) |
which we can obtain using only evaluation of .
A few other works combine GANs and MCMC in some way. Song et al. (2017) use a GAN-like procedure to train a RealNVP (Dinh et al., 2016) MCMC proposal for sampling an externally provided target . Whereas Song et al. (2017) use GANs to accelerate MCMC, we use MCMC to enhance the samples from a GAN. Similar to Song et al. (2017), Kempinska and Shawe-Taylor (2017) improve proposals in particle filters rather than MCMC. Song et al. (2017) was recently generalized by Neklyudov et al. (2018).
The GAN approach to density estimation is complementary to the earlier density ratio estimation (DRE) approach (Sugiyama et al., 2012). In DRE the generator is fixed, and the density is found by combining Bayes’ rule and the learned classifier . In GANs, the key is learning well; while in DRE, the key is learning well. The MH-GAN has flavors of both in that it uses both and to build .
A very similar concurrent work from Azadi et al. (2018) proposes discriminator rejection sampling (DRS) for GANs, which performs rejection sampling of outputs of
by using the probabilities given by
. This approach is conceptually simpler than our approach at first but suffers from two major shortcomings in practice. First, it is necessary to find an upper-bound on over all possible samples in order to obtain a valid proposal distribution for rejection sampling. Because this is not possible, one must instead rely on estimating this bound by drawing many pilot samples. Secondly, even if one were to find a good bound, the acceptance rate would become very low due to the high-dimensionality of the sampling space. This leads Azadi et al. (2018) to use an extra heuristic to shift the logit scores, making the model sample from a distribution different from even when is perfect. We use MCMC instead, which was invented precisely as a replacement for rejection sampling in higher dimensions. We further improve the robustness of MCMC via use of a calibrator on the discriminator to get more accurate probabilities for computing acceptance probabilities.In this section, we briefly review the notation and equations with MCMC and GANs.
MCMC methods attempt to draw a chain of samples that marginally come from a target distribution . We refer to the initial distribution as and the proposal for the independence sampler as . The proposal is accepted with probability
(3) |
If is accepted, , otherwise . Note that when estimating the distribution , one must include the duplicates that are a result of rejections in .
Each chain samples and then does MH iterations to get as the output of the chain; in this case it is also the output of . Therefore, each MCMC chain results in a single sample , and independent chains are used for multiple samples from .
The detailed balance condition implies that if exactly then exactly as well. Additionally, even if is not exactly distributed according to , the Kullback-Leibler (KL) divergence between the implied density it is drawn from and will always decrease as increases (Murray and Salakhutdinov, 2008; Cover and Thomas, 2012).^{2}^{2}2The curious reader may wonder why we do not simply stop the chain after the first accept. In general, allowing chain length to be conditioned in any way on the state of the chain (including which samples were accepted or rejected) has the potential to introduce bias to the samples (Cowles et al., 1999).
GANs implicitly model the data via a synthetic data generator :
(4) |
This implies a (intractable) distribution on the data . We refer to the unknown true distribution on the data as . The discriminator is a soft classifier predicting if a data point is real as opposed to being sampled from . If converges optimally for a fixed , then , and if both and converge then (Goodfellow et al., 2014). GAN training forms a game between and . In practice is often better at estimating the density ratio than G is at generating high-fidelity samples (Shibuya, 2017). This motivates wrapping an imperfect to obtain an improved by using the density ratio information encapsulated in .
In this section we show how to sample from the distribution implied by the discriminator . We apply (2) and (3) for a target of and proposal :
(5) |
The ratio is computed entirely from the discriminator scores . If is perfect, , so the sampler will marginally sample from . A toy one-dimensional example with just such a perfect discriminator is shown in Figure 2.
The probabilities for must not merely provide a good AUC score, but must also be well calibrated. Put in other terms, if one were to warp the probabilities of the perfect discriminator in (1) it may still suffice for standard GAN training, but it will not work in the MCMC procedure defined in (5), as it will result in erroneous density ratio values.
To calibrate , we use a held out calibration set (10% of training data) and either logistic, isotonic, or beta (Kull et al., 2017) regression to warp the output of . Furthermore, for WGAN, calibration is required as it does not learn a typical GAN probabilistic discriminator.
We also detect miscalibration of using the statistic of Dawid (1997) on held out samples and real/fake labels . If is well calibrated, i.e., is indistinguishable from a ,
(6) |
This means that for large values of , such as when , we reject the hypothesis that is well-calibrated.
We also avoid the burn-in issues that usually plague MCMC methods. Recall that via the detailed balance property (Gilks et al., 1996, Ch. 1), if the marginal distribution of a Markov chain state at time step matches the target (), then the marginal at time step will also follow (). In most MCMC applications it is not possible to get an initial sample from the target distribution (). However, in MH-GANs, we are sampling from the data distribution, so we may use a trick: Initialize the chain at a sample of real data (the correct distribution) to avoid burn-in. If no generated sample is accepted by the end of the chain, restart sampling from a synthetic sample to ensure the initial real sample is never output. To make these restarts rare, we set large (often 640).
The assumption of a perfect may be weakened for two reasons: (A) Because we recalibrate the discriminator, actual probabilities can be incorrect as long as the decision boundary between real and fake is correct. (B) Because the discriminator is only ever evaluated at samples from or the initial real sample , only needs to be accurate on the manifold of samples from the generator and the real data .
We first show an illustrative synthetic mixture model example followed by real data with images.
We consider the grid of two-dimensional Gaussians used in Azadi et al. (2018), which has become a popular toy example in the GAN literature (Dumoulin et al., 2016). The means are arranged on the grid
and use a standard deviation of
.Following Azadi et al. (2018)
, we use four fully connected layers with ReLU activations for both the generator and discriminator. The final output layer of the discriminator is a sigmoid, and no nonlinearity is applied to the final generator layer. The latent
and observed vectors have dimension two. All hidden layers have size 100. We used training points and generated points in test. The training data was standardized before training.In Figure 3, we show the original data along with samples generated by the GAN. We also show samples enhanced via the MH-GAN (with calibration) and with DRS. The standard GAN creates spurious links along the grid lines between modes. It also missed some modes along the bottom row. DRS is able to “clean up” some of the spurious links but not fill in the missing modes. The MH-GAN recovers these under-estimated modes and further “clean up” the spurious links.
These results are made more quantitative in Figure 4, where we follow some of the metrics for the example from Azadi et al. (2018). We consider the standard deviations within each mode in Figure 3(a) and the rate of “high quality” samples in Figure 3(b). A sample is assigned to a mode if its distance is within four standard deviations () of its mean. Samples within four standard deviations of any mixture component are considered “high quality”. The within standard deviation plot (Figure 3(a)) shows a slight improvement for MH-GAN, and the high quality sample rate (Figure 3(b)) approaches 100% faster for the MH-GAN than the GAN or DRS.
To test the spread of the distribution, we inspect the categorical distribution of the closest mode. Far away (non-high quality) samples are assigned to a 26th unassigned category. This categorical distribution should be uniform over the 25 real modes for a perfect generator. To assess generator quality, we look at the Jensen-Shannon divergence (JSD) between the sample mode distribution and a uniform distribution. This is a much more stringent test of appropriate spread of probability mass than checking if a single sample is produced near a mode (as was done in
Azadi et al. (2018)).. MH-GAN denotes using the raw discriminator scores and “MH-GAN (cal)” for the calibrated scores. The error bars on MH-GAN performance (in gray) are computed using a t-test on the variation per batch across 80 splits of the Inception score. In the center we show the Inception score vs. MCMC iteration
for the GAN at epoch 15. On the right, we show the scores at epoch 13 where there is some overlap between the scores of fake and real images. When there is overlap, the MH-GAN corrects the distribution to have scores looking similar to the real data. DRS fails to fully shift the distribution because 1) it does not use calibration and 2) its “ shift” setup violates the validity of rejection sampling.In Figure 3(c), we see that the MH-GAN improves the JSD over DRS by on average, meaning it achieves a much more balanced spread across modes. DRS fails to make gains after epoch 30. Using the principled approach of the MH-GAN along with calibrated probabilities ensures a correct spread of probability mass.
For real data experiments we considered the CelebA (Liu et al., 2015) and CIFAR-10 (Torralba et al., 2008) data sets modeled using the DCGAN (Radford et al., 2015) and WGAN (Arjovsky et al., 2017; Gulrajani et al., 2017). To evaluate the generator , we plot Inception scores (Salimans et al., 2016) per epoch in Figure 4(a) after MCMC iterations. The actual performance boost realized by MH-GAN oscillates from one epoch to the next, perhaps due to fluctuations in the density ratio estimation performance per epoch. Accordingly the statistical significance of the boost from to with calibration varies from no significant change to boost with . Figure 4(b) shows Inception score per MCMC iteration: most gains are made in the first iterations, but smaller gains continue to .
DCGAN | WGAN | |||||||
---|---|---|---|---|---|---|---|---|
CIFAR-10 | p | CelebA | p | CIFAR-10 | p | CelebA | p | |
GAN | 2.8789 | – | 2.3317 | – | 3.0734 | – | 2.7876 | – |
DRS | 2.977(77) | 0.0131 | 2.511(50) | <0.0001 | ||||
DRS (cal) | 3.073(80) | <0.0001 | 2.869(67) | <0.0001 | 3.137(64) | 0.0497 | 2.861(66) | 0.0277 |
MH-GAN | 3.113(69) | <0.0001 | 2.682(50) | <0.0001 | ||||
MH-GAN (cal) | 3.379(66) | <0.0001 | 3.106(64) | <0.0001 | 3.305(83) | <0.0001 | 2.889(89) | 0.0266 |
In Table 1, we summarize performance (Inception score) across all experiments, running MCMC with iterations in all cases. Behavior is qualitatively similar to that in Figure 4(a). While DRS improves on a direct GAN, MH-GAN improves Inception score more in every case. Calibration helps in every case. There was not a substantial difference between different calibration methods, but we found a slight advantage for isotonic regression. Results are computed at epoch 60, and as in Figure 4(a), error bars and p-values are computed using a paired t-test across Inception score batches. All results are significantly better than the baseline GAN at .
In Figure 4(c), we visualize what does to the distribution on discriminator scores. MCMC shifts the distribution of the fakes to match the distribution on true images.
Figure 6 shows the results per epoch for both CIFAR-10 and CelebA. It shows that the raw discriminator is highly miscalibrated, but can be fixed with any of the standard calibration methods. The statistic for the raw discriminator (DCGAN on CIFAR-10) varies from to in the first 60 epochs; even after Bonferroni correction at , we expect with 95% confidence for a calibrated classifier. The calibrated discriminator varies from to , which shows it is almost perfectly calibrated. Accordingly, it is unsurprising that the calibrated discriminator significantly boosts performance in the MH-GAN.
We have shown how to incorporate the knowledge in the discriminator into an improved generator . Our method is based on the premise that is better at density ratio estimation than is at sampling data, which is inherently a harder task. The principled MCMC setup selects among samples from to correct biases in . This is the only method in the literature which has the property that given a perfect one can recover such that . We have shown the raw discriminators in GANs and DRS are poorly calibrated. To our knowledge, this is the first work to evaluate the discriminator in this way and to rigorously show the poor calibration of the discriminator. The MH-GAN has great potential for extension.
Proceedings of the International Conference on Machine Learning
, volume 70, pages 214–223, 2017.Bayesian Deep Learning (NIPS Workshop)
, 2017.Proceedings of the International Conference on Artificial Intelligence and Statistics
, volume 54, pages 623–631, 2017.Proceedings of the IEEE International Conference on Computer Vision
, pages 3730–3738, 2015.80 million tiny images: A large data set for nonparametric object and scene recognition.
IEEE transactions on pattern analysis and machine intelligence, 30(11):1958–1970, 2008.
Comments
There are no comments yet.