GANsBeyondDivergenceMin
Code for replication of the paper "GANs beyond divergence minimization"
view repo
Generative adversarial networks (GANs) can be interpreted as an adversarial game between two players, a discriminator D and a generator G, in which D learns to classify real from fake data and G learns to generate realistic data by "fooling" D into thinking that fake data is actually real data. Currently, a dominating view is that G actually learns by minimizing a divergence given that the general objective function is a divergence when D is optimal. However, this view has been challenged due to inconsistencies between theory and practice. In this paper, we discuss of the properties associated with most loss functions for G (e.g., saturating/non-saturating f-GAN, LSGAN, WGAN, etc.). We show that these loss functions are not divergences and do not have the same equilibrium as expected of divergences. This suggests that G does not need to minimize the same objective function as D maximize, nor maximize the objective of D after swapping real data with fake data (non-saturating GAN) but can instead use a wide range of possible loss functions to learn to generate realistic data. We define GANs through two separate and independent D maximization and G minimization steps. We generalize the generator step to four new classes of loss functions, most of which are actual divergences (while traditional G loss functions are not). We test a wide variety of loss functions from these four classes on a synthetic dataset and on CIFAR-10. We observe that most loss functions converge well and provide comparable data generation quality to non-saturating GAN, LSGAN, and WGAN-GP generator loss functions, whether we use divergences or non-divergences. These results suggest that GANs do not conform well to the divergence minimization theory and form a much broader range of models than previously assumed.
READ FULL TEXT VIEW PDFCode for replication of the paper "GANs beyond divergence minimization"
Generative adversarial networks (GANs) form a class of generative models that is most famously known for generating state-of-the-art photo-realistic images (Zhang et al., 2017) (Karras et al., 2017). Note that we refer to the original version of GAN (Goodfellow et al., 2014) as “standard GAN” and to all variants of generative adversarial networks that work in a similar fashion as “GANs”.
GANs consist in training two neural networks, a discriminator
and a generator , that work in competition so that can learn to generate fake data that appears to be genuine. is trained to differentiate real from fake data, which is done by classifying real from fake data ( and ) or, more generally, by maximizing the expectation of and , where is generally monotone increasing and is generally monotone decreasing. takes as input a random number from, generally a multivariate normal distributed centered at
with variance 1, and output a randomly generated fake data.
GANs are generally interpreted from two differing point-of-views: (1) adversarial game and (2) divergence minimization. In the former, is trained to maximize the same objective function as but swapping real data with fake data, thus, intuitively, tries to fool into thinking fake data is real data. In the latter, is trained to minimize the same objective function that was previously maximized by . Given that the loss of is generally an approximation of a divergence (if is optimal, it is exactly equal to the divergence), is assumed to be minimizing a divergence.
We start by presenting these two differing views in detail. Then, we explain why minimizing the loss of cannot generally be interpreted as minimizing a divergence. We present four general forms of loss functions that can be used to train the generator, some of which are shown to be divergence. Through experiments, we show that most of these loss functions converge well. Finally, we discuss the implications of these results.
The main contributions of this paper are showing that:
In most GANs, the loss of is not a divergence even when is optimal
does not directly minimize the divergence assumed by the objective function of
The loss of does not need to match the objective of and can be
matching the mean or individual discriminator output of the fake data to the real labels
matching the mean or individual discriminator output of the fake data to the classification threshold
mean matching the discriminator output of the fake data to the discriminator output of the real data
Using actual divergences for the loss functions of does not provide any benefit
In the first interpretation, GANs are understood as an adversarial game (Goodfellow et al. (2014); Kodali et al. (2017); Heusel et al. (2017)) between and in which tries to classify which data is real or fake while tries to fools into thinking the fake data it generates is actually real data. To fool , the generator maximize the objective function of after swapping real data with fake data. Goodfellow et al. (2014) showed that maximizing the objective of after swapping data (non-saturating GAN) works much better in practice than directly minimizing the objective of (saturating GAN). Non-saturating GAN can be represented mathematically as the following two steps:
(1) | ||||
where is the distribution of the real data on domain and is the domain of . Note that we generally denote the distribution of the fake data formed by as .
This can be generalized in the following matter:
(2) | ||||
where and are scalar-to-scalar functions chosen so that is a discriminator that predicts the likelihood of the data being real; generally is monotone increasing and is monotone decreasing.
In the second interpretation, GANs are understood as divergence minimization (Nowozin et al. (2016); Arjovsky et al. (2017), Mroueh and Sercu (2017); Li et al. (2017); Bellemare et al. (2017); Mroueh et al. (2017)
). Divergences are a weak form of distance between two probability distributions with the two following properties: non-negative and equal to zero if and only if the distributions are equal. Some well-known divergences are the Kullback–Leibler distance (KL), the Jensen-Shannon distance (JSD), and the Wasserstein distance. To prevent confusion between the discriminator and the divergence, we denote divergences between two distributions
and as . Some divergences like JSD and Wasserstein have symmetry so , but otherwise this is not the case.From this perspective, one tries to find the generator (with parameters ) that minimizes a divergence between real and fake data:
However, commonly used divergences, such as
-divergences (a class of divergences for which KL and JSD are special cases), are difficult to minimize given that they require knowing the probability density functions of the real and fake data,
and respectively. In practice, we do not know the probability distributions of the real data or the fake data.Traditionally, one approximates as using the empirical distribution, i.e., a discrete distribution where each data sample has probability . In this case, it can be shown (Goodfellow, 2016)
that minimizing the KL-divergence is equivalent to maximizing the log-likelihood (one of the most popular approaches in machine learning and statistics):
In the divergence point-of-view, GANs try to minimize an equivalent “dual” parametrization of a divergence that does not require knowing the probability density functions. Common divergences can be represented with respect to rather than and .
For saturating GAN, it can be shown (Goodfellow et al., 2014) that JSD is equal to an affine function of the minimum cross-entropy:
(3) |
Nowozin et al. (2016) generalized this concept to the class of -divergences and showed that:
(4) |
where is the convex conjugate of and is the function that defines the -divergence used (e.g., leads to KL and leads to saturating GAN).
Thus, on a very general level, GANs can be formulated in the following matter:
(5) |
Note that, in practice, we let and be neural networks and we optimize for their respective parameters and .
Most GANs fit into this category, some examples are: saturating GAN (Goodfellow et al., 2014) which minimize the JSD, saturating -GAN (Nowozin et al., 2016) which minimize -divergences, and Wasserstein GAN (WGAN) (Arjovsky et al., 2017) which minimize the Wasserstein distance.
There are a few issues regarding this interpretation of GANs. Firstly, standard GAN and -GANs actually converge better when solving the optimization problem of equation (2) rather than equation (5) (Nowozin et al., 2016). Secondly, non-saturating GANs are able to learn the distribution of the real data even when directly minimizing the JSD fails (Fedus et al., 2017) because the gradient of the divergence is actually constant or infinite (Arjovsky et al., 2017).
With the exception of standard GAN and -GANs, GANs that follow the divergence minimization interpretation are generally based on integral probability metrics (IPMs) (Müller, 1997):
where is a class of functions chosen to prevent the supremum from being infinite. See Mroueh et al. (2017) for a summary of the various IPMs used in the literature. Importantly, IPM-based GANs can still be understood as following equation (2) considering that is the same as . Therefore, the fact these GANs work without swapping real data with fake data doesn’t necessarily disprove the possibility of the adversarial game interpretation being actually correct.
In the following section, we show what loss function actually minimizes in GANs and we present four new types of loss functions, most of which are divergences, that can minimize instead of the non-saturating/saturating loss function generally assumed (equation 2 or 5).
If we concentrate entirely on the generator step, for the vast majority of GANs, the stochastic gradient descent (SGD) step can be formulated as:
(6) |
where is a scalar-to-scalar function (generally monotone decreasing). Given that this is the only part of the equation that is used by SGD, this is effectively the loss function minimized rather than the divergence envisioned.
For saturating GAN, we have that . The loss function is not lower bounded since we have that as . This is problematic because there is no minimum and if for a single real sample, the expectation equals to and the infinimum is reached. On the other hand, for non-saturating GAN, the loss is bounded as we have that . Therefore, minimizing this loss, we have that as .
These observations can be generalized to a broader class of GANs with -GANs (equation 4). For the saturating loss, , where is a convex function (Nowozin et al., 2016). Given the convexity of , the maximum of (or equivalently the minimum of ) is reached when or . For most common divergences, we have that is monotone increasing; out of all the -divergences presented by Nowozin et al. (2016) (KL, reverse KL, Pearson/Neyman , Squared Hellinger, Jeffrey, JSD), only Pearson is not monotone increasing. Thus, in all common -divergences, except Pearson , we have that the optimum is reached when . The non-saturating loss is , thus it is minimized also when . The loss function is lower bounded (has a minimum) if is upper bounded for the non-saturating loss function and if is upper bounded for the saturating loss function.
In the commonly assumed parametrization of LSGAN (Mao et al., 2017), we have that . Therefore, minimizing this loss, we have that as , just as saturating GAN.
WGAN-GP (Gulrajani et al., 2017) is a very popular variant of WGAN which impose a constraint so that the magnitude of the gradient of
for any point interpolated between real and fake data must be close to 1 (i.e.,
, where and ). In WGAN-GP, we have that . Although this loss is not lower bounded, the gradient penalty forces the gradient of to be around 1, thus the step taken by SGD cannot be too large. This suggests that the gradient penalty may be useful because it limits how much can change in one minimization step.By definition, a zero divergence should arise when an optimal is not able to discriminate real from fake data, i.e., when for all , where is the classification threshold ( in standard GAN, in -GAN) or when given in IPM-based GANs. However, as mentioned above, to reach the optimum, generally attempt to make or , thus pushing as far as possible away from . This means that after training , the discriminator cannot be optimal anymore. The divergence can even become negative, which is impossible by the definition of a divergence. This is something that we observe in practice; the objective function of the discriminator becomes positive (or large) after training and negative (or very small) after training , thus cycling between an approximated divergence and a non-divergence.
This means that in most currently used GANs (-GAN and all GANs present in the large-scale study by Lucic et al. (2017) ^{1}^{1}1With the exception of BEGAN (Berthelot et al., 2017) which is very unusual and cannot really be considered a GAN given that D is a auto-encoder rather than a discriminator function.), the loss function of is not a divergence. This suggests that is not directly minimizing the divergence, but only indirectly by updating its weights so that or which result in generating more realistic data.
One way to reconcile these observations with divergence minimization is to interpret GANs as effectively minimizing a divergence without the assumption of an optimal
(overshooting the goal) and re-estimating the divergence by training
to optimality in the next steps. This only applies to GANs such as WGAN/WGAN-GP which train for a large number of iterations before training . In this point-of-view, GANs act comparably to projected gradient descent, i.e., we take a step into the gradient direction and then project back into the feasible set (with the constraint that we impose, an optimal ). However, the difference here is that, with constraint optimization, we can make sure that the constraint is respected entirely, but with GANs, we can never truly train to be optimal. Also with constraint optimization, it is possible for the loss to still respect the constraint after minimizing it without constraint, but with GANs, it is impossible for to be optimal after training for even a single step.Training to optimality before training , to try to approximately minimize the divergence, does not necessarily lead to better results (Fedus et al., 2017). The current state-of-the-art in generation of human faces (Karras et al., 2017) used WGAN-GP with only one discriminator update. Given that training multiple times before significantly increase the training time and does not necessarily lead to generated samples of better quality, trying to render GAN training analogous to divergence minimization may be unnecessary and overly constraining research/practice to a small subset of all possible GANs.
As shown above, the generator generally take a step so that reach for or . For the loss of to be a divergence, one could think that we should have instead reach for or . We generalize this idea to four general forms of loss functions for :
Discriminator Matching (DM):
(7) |
Label Matching (LM):
(8) |
Expectation Discriminator Matching (EDM):
(9) |
Expectation Label Matching (ELM):
(10) |
where is a discriminator trained in any way to differentiate real from fake data, is a distance function (ideally a metric) and . See Algorithm 1 for how to train GANs using these loss functions. If , we are striving for equilibrium (). If , we are overshooting, as usually done in GANs. Note that, although we presented as the label for real data, many GANs do not have labels (e.g., -GANs and IPM-based GANs). When there is no label, assuming that is trained using equation 2 or 5, one can let be defined as (generally just ) and be defined as (generally just ). If , one can use LM or ELM with . Also note that certain GANs do not have a (e.g., , for any constant , lead to a Wassertein distance of 0), so these approaches cannot use LM or ELM with .
It can be shown that if is a positive-definitive function ( and ), and is optimal, all of these loss functions are divergences (See Appendix for proof and details). DM and LM are more general as they only require optimality at equilibrium (), while EDM and ELM need the usual assumption of optimality and that and have the same support.
Note that although these loss functions are divergences under the conditions mentioned above, the assumption of an optimal is still problematic given that the discriminator will always lose optimality after minimizing any loss function of . This is true for all GANs and this is because modifying cannot change . The only exception is when , if the loss function of is a divergence, the loss will already be zero, thus it will not change . This is neat theoretical property which traditional loss functions don’t have because they are not divergences as they push toward imbalance rather than equilibrium.
We trained GANs on a synthetic dataset (infinite swiss roll dataset (Marsland, 2015)) and CIFAR-10 (Krizhevsky, 2009)
. All experiments were ran in PyTorch
(Paszke et al., 2017) using the Adam optimizer (Kingma and Ba, 2014)with hyperparameters
and .As a first experiment, we trained GANs on the simple swiss roll infinite dataset in from scikit-learn (Pedregosa et al., 2011) (see Appendix for more details). We used three different objective functions for , cross entropy (as in standard GAN), least squares (as in LSGAN) and WGAN-GP objective function. As previously observed, the two-sided penalty for WGAN-GP works poorly in the swiss roll dataset (Viehmann, ), thus we used a one-sided penalty, a variant that the authors of WGAN-GP found to lead to similar results in their own experiments (Gulrajani et al., 2017).
We experimented with a wide range of loss functions for , including the saturating loss, non-saturating loss, LSGAN, WGAN, the absolute value of the log difference, the squared value of the log difference, the absolute difference, the squared difference, and the pseudo-Huber loss (Barron, 2017). For most loss functions, we tried all variants of DM, EDM, LM and ELM.
Both neural networks consisted of three linear layers followed by leaky ReLU activation functions
(Maas et al., )and one final linear layer. The discriminator was also followed by a sigmoid function when using the cross-entropy loss. Learning rates of
were used for both and . We trained the models for 5000 cycles (one cycle = all iterations and one iteration) with a batch size of 256. We used 10 discriminator updates per cycle with penalty of 10 for the WGAN-GP models.We reported the average root mean squared difference between real/fake samples and their nearest fake/real sample (NNRMSE) using 1000 real and fake samples. This can be defined mathematically as: , where finds the nearest neighbor of from the set , is the set of 1k real samples, is the set of 1k fake samples. This is a simple measure that give us a very good indication of how well the generator converge to the true data distribution as it penalize both under-coverage and over-coverage (See Figure 1 from Appendix for more details). We report the median NNRMSE from five runs with seed 1, 2, 3, 4, 5 respectively.
As a second experiment, we trained GANs on the CIFAR-10 dataset (Krizhevsky, 2009). We used the same objective functions for and as in the synthetic experiment. For these experiments, we used the original two-sided penalty for WGAN-GP (Gulrajani et al., 2017). The neural networks were following the DCGAN architecture (Radford et al., 2015)
using batch normalization
(Ioffe and Szegedy, 2015).Learning rates of were used for training both and with a batch size of 32. We used 5 discriminator updates per cycle with penalty of 10 for the models using the WGAN-GP objective function. To compare models, we reported the Inception score (IS) (Salimans et al., 2016) (Barratt and Sharma, 2018) (larger is better) and the Fréchet Inception Distance (FID) (Heusel et al., 2017) (FID, )
(smaller is better). Note that most researchers calculate the IS and FID using TensorFlow
(Abadi et al., 2015)implementations, therefore the values we report may be sightly different. Given our limited computing power (a single GPU), we only trained the models once for 25 epochs using seed 1. Although 25 epochs was not enough to reach optimality, it was enough to detect non-convergence. Our goal with these experiments was not to show that we could achieve the state-of-the-art but simply to compare different loss functions on equal ground and show that most of them work just as well as standard loss functions.
GAN | ||||||
[Non-saturating] | 57.66 | |||||
[Saturating] | 90.46 | |||||
59.33 | 63.97 | 61.97 | 61.15 | 61.75 | 59.67 | |
64.50 | 61.34 | 62.06 | 59.33 | 63.40 | 56.00 | |
405.84 | 62.85 | 64.50 | 60.84 | 62.90 | 66.69 | |
63.17 | 64.66 | 65.82 | 65.62 | 65.08 | 62.38 | |
64.77 | 64.70 | 63.93 | 63.04 | 65.97 | 65.82 |
LSGAN | ||||||
[LSGAN] | 61.64 | |||||
63.79 | 65.80 | 56.57 | 58.06 | 63.14 | 61.59 | |
60.66 | 62.78 | 58.23 | 63.81 | 68.41 | 61.64 | |
58.83 | 58.14 | 59.27 | 59.76 | 62.49 | 63.42 |
WGAN-GP (two-sided penalty) | ||||||
[WGAN-GP] | 55.18 | |||||
498.30 | 55.76 | |||||
446.26 | 300.95 | |||||
136.25 | 54.86 |
FIDs of the CIFAR-10 experiments are shown in Table 1. See Appendix for results of the swiss roll experiments (Table 2) and IS of the CIFAR experiments (Table 3). Overall, most loss functions converged well, with the exception of DM and LM with in the swiss roll dataset. Importantly, no loss function performed much better than other loss functions in a wide range of scenarios, thus there was no overall best.
A priori we expected that the loss functions that are divergences without requiring the assumption of same support (DM and LM with ) would work best. However, these divergences performed badly in the swiss roll dataset, while all loss functions performed equally well on CIFAR-10. This provide further evidence that the generator does not improve by minimizing a divergence but simply by trying to increase . We suspect that striving for may sometime have more difficulty converging than in the swiss roll dataset because it is not taking a strong enough step to prevent from dominating.
In summary, most GANs do not directly minimize a divergence and trying to make GANs analogous to divergence minimization does not confer any benefit. Instead of training using the saturating or non-saturating loss, one can instead train using a wide range of possible loss functions. What we have shown is just a very small set of all the possible loss functions that one could use and we did not attempt to determine if some of these loss functions could lead to better state-of-the-art results in data generation. This paper brings a greater level of customization to GANs which we hope will lead to more diversity in GANs research (enlarging the GAN zoo) and new ways to improve data generation quality.
In this paper, we focused solely on the generator step; however, the discriminator step is as important, if not more. Issues or limitations of the discriminator will affect how well any loss function of will perform given that and its gradient are fundamental to the gradient of the loss of . For example, in standard GAN, there are perfect discriminators ( such that and for all and ) for which is exactly zero under certain theoretical conditions (Arjovsky and Bottou, 2017)
. Thus, by the chain rule, any loss function of
will also be zero when is one of those perfect discriminators. In practice, we can never obtain a perfect discriminator and close to perfect data separation becomes less likely over time as the support of and get closer to one another, but this still shows a major issue with standard GAN that cannot be resolved simply by changing the loss of . Thus, understanding what makes a discriminator "good" remains paramount. We encourage research in this direction rather than solely focusing on finding a "good" divergence that has informative gradients (generally an IPM) since is not minimizing this divergence directly.Our results also suggest that feature matching (FM) (Salimans et al., 2016) may be more than a trick, but instead, a specific case of GAN, as FM can be seen as a special case of applying an EDM to the intermediate layers.
IEEE Int. Conf. Comput. Vision (ICCV)
, pages 5907–5915, 2017.Mmd gan: Towards deeper understanding of moment matching network.
In Advances in Neural Information Processing Systems, pages 2200–2210, 2017.A function is positive definite if it respects the following two conditions:
If and are probability distributions, is called a divergence.
A discriminator , where , is said to be optimal at equilibrium on the distributions and (with domain ) if there exists a such that
A discriminator , where , is said to be optimal on the distributions and (with domain ) if
(1) is optimal at equality
(2) when , where is a positive definite function.
(2) when , where is a positive definite function.
This is just a way to formalize the notion of what is an optimal discriminator without resorting to any objective function for . These definitions are very general; however, they do not apply to some GANs (e.g., WGAN, since any Lipschitz will lead to a Wasserstein distance of 0 when ; GAN-GP, since for all when would mean that , but we enforce the constraint that ). Note that with standard GAN and LSGAN (with default parameters), it has been shown [Goodfellow et al., 2014] [Mao et al., 2017] that the optimal discriminator is . Thus, in both GAN and LSGAN, is optimal by def 7.3.
Let be positive-definite, and distributions on the domain and , where . If is optimal at equilibrium, we have that is a divergence and is a divergence if . If is optimal and and have the same support, i.e., supp() = supp() = supp(), we have that is a divergence and is a divergence if .
is positive-definite DM, LM, EDM, ELM are always .
otherwise wouldn’t be optimal | ||||
since is optimal | ||||
Thus, DM is a divergence.
Thus, LM is only a divergence when .
Thus EDM is a divergence when the two distributions have the same support.
Thus, just as LM, we need for ELM to possibly be a divergence.
(Follow same arguments as proof for EDM) | |||
Thus, ELM is a divergence when the two distributions have the same support and . ∎
GAN | ||||||
[Non-saturating] | .057 | |||||
[Saturating] | .052 | |||||
1.401 | .282 | .054 | .055 | .062 | .052 | |
1.451 | 1.556 | .058 | 1.511 | .052 | .054 | |
1.488 | 1.565 | .050 | .053 | .060 | .050 | |
1.708 | 1.518 | .056 | .059 | .056 | .051 | |
1.536 | 1.693 | .050 | .055 | .052 | .050 |
LSGAN | ||||||
[LSGAN] | .052 | |||||
.163 | .164 | .053 | .053 | .057 | .052 | |
.089 | .147 | .052 | .055 | .055 | .054 | |
.088 | .153 | .052 | .069 | .053 | .050 |
WGAN-GP (one-sided penalty) | ||||||
[WGAN-GP] | .062 | |||||
.291 | .062 | |||||
.297 | .070 | |||||
.332 | .073 |
GAN | ||||||
[Non-saturating] | 3.36 | |||||
[Saturating] | 2.09 | |||||
Comments
There are no comments yet.