Improving GAN Training with Probability Ratio Clipping and Sample Reweighting

06/12/2020 ∙ by Yue Wu, et al. ∙ 15

Despite success on a wide range of problems related to vision, generative adversarial networks (GANs) can suffer from inferior performance due to unstable training, especially for text generation. we propose a new variational GAN training framework which enjoys superior training stability. Our approach is inspired by a connection of GANs and reinforcement learning under a variational perspective. The connection leads to (1) probability ratio clipping that regularizes generator training to prevent excessively large updates, and (2) a sample re-weighting mechanism that stabilizes discriminator training by downplaying bad-quality fake samples. We provide theoretical analysis on the convergence of our approach. By plugging the training approach in diverse state-of-the-art GAN architectures, we obtain significantly improved performance over a range of tasks, including text generation, text style transfer, and image generation.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 6

page 13

page 14

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Generative adversarial networks (GANs) (Goodfellow et al., 2014) have drawn great research interests and achieved remarkable success in image synthesis Radford et al. (2015); Brock et al. (2018), video generation Mathieu et al. (2015), and others. However, it is usually hard to train a GAN well, because the training process is commonly unstable, subject to disturbances and even collapses. To alleviate this issue, substantial efforts have been paid to improve the training stability from different perspectives, e.g., divergence minimization Nowozin et al. (2016); Nock et al. (2017), Wasserstein distance with Lipschitz continuity of the discriminator Arjovsky et al. (2017); Gulrajani et al. (2017); Wei et al. (2018)

, energy-based models

Zhao et al. (2016); Berthelot et al. (2017), etc.

In spite of the above progresses, the instability in training has not been well resolved Chu et al. (2020), since it is difficult to well balance the strength of the generator and the discriminator. What is worse, such an instability issue is exacerbated in text generation due to the sequential and discrete nature of text Fedus et al. (2018); Caccia et al. (2020); Nie et al. (2018). Specifically, the high sensitivity of text generation to noise and the underlying errors caused by sparse discriminator signals in the generated text can often result in destructive updates to both generator and discriminator, enlarging the instability in GANs.

In this work, we develop a novel variational GAN training framework to improve the training stability, which is broadly applicable to GANs of varied architectures for image and text generation. This training framework is derived from a variational perspective of GANs and the resulting connections to reinforcement learning (in particular, RL-as-inference) Abdolmaleki et al. (2018); Levine (2018); Schulman et al. (2017) and other rich literature Hu et al. (2018b); Grover et al. (2019); Burda et al. (2015). Specifically, our approach consists of two stabilization techniques, namely, probability ratio clipping and sample re-weighting, for stabilizing the generator and discriminator respectively. (1) Under the new variational perspective, the generator update is subject to a KL penalty on the change of the generator distribution. This KL penalty closely resembles that in the popular Trust-Region Policy Optimization (TRPO) Schulman et al. (2015) and its related Proximal Policy Optimization (PPO) Schulman et al. (2017). This connection motivates a simple surrogate objective with a clipped probability ratio between the new generator and the old one. The probability ratio clipping discourages excessively large generator updates, and has shown to be effective in the context of stabilizing policy optimization Schulman et al. (2017). Figure 1 (left) shows the intuition about the surrogate objective, where we can observe the objective decreases with an overly large generator change and thus imposes regularization on the updates.

Figure 1: Illustration of the proposed approach for stabilizing GAN training. Results are from the CIFAR-10 experiment in Sec. 4.1. Left:

The conventional and surrogate objectives for generator training, as we interpolate between the initial generator parameters

and the updated generator parameters which we compute after one iteration of training. The updated obtains maximal surrogate objective. The surrogate objective imposes a penalty for having too large of a generator update, since the curve starts decreasing after . In contrast, the conventional objective (for WGAN-GP) keeps increasing with larger generator updates. Middle and right:

Discriminator and generator losses w/ and w/o sample re-weighting. WGAN-GP with our re-weighting plugged in shows lower variance in both discriminator and generator losses throughout training (and achieves better final performance as shown in Sec. 

4.1).

(2) When updating the discriminator, the new perspective induces an importance sampling mechanism, which effectively re-weights fake samples by their discriminator scores. Since low-quality samples tend to receive smaller weights, the discriminator trained on the re-weighted samples is more likely to maintain stable performance, and in turn provide informative gradients for subsequent generator updates. Figure 1 (middle/right) demonstrates the effect of the re-weighting in reducing the variance of both discriminator and generator losses. Similar importance weighting methods have recently been used in other contexts, such as de-biasing generative models Grover et al. (2019) and sampling from energy-based models Deng et al. (2020). Our derivations can be seen as a variant for the new application of improving GANs.

We give theoretical analysis showing the generator under our training framework can converge to the real data distribution. Empirically, we conduct extensive experiments on a range of tasks, including text generation, text style transfer, and image generation. Our approach shows significant improvement over state-of-the-art methods, demonstrating its broad applicability and efficacy.

2 Related Work

Wasserstein distance, WGAN, and Lipschitz continuity. The GAN framework Goodfellow et al. (2014) features two components: a generator that synthesizes samples given some noise source , namely with , and a discriminator that distinguishes generator’s output and real data, which provides gradient feedback to improve the generator’s performance. WGAN Arjovsky et al. (2017) improves the training stability of GANs by minimizing the Wasserstein distance between the generation distribution (induced from ) and the real data distribution . Its training loss is formulated as:

(1)

where is the set of 1-Lipschitz functions;

acts as the discriminator and is usually implemented by a neural network

. The original resort to enforce the Lipschitz constraint is through weight clipping Arjovsky et al. (2017). WGAN-GP Gulrajani et al. (2017) later improves it by replacing it with a gradient penalty on the discriminator. CT-GAN Wei et al. (2018) further imposes the Lipschitz continuity constraint on the manifold of the real data . Our approach is orthogonal to these prior works and can serve as a drop-in replacement for the stabilize generator and discriminator in various kinds of GANs, such as WGAN-GP and CT-GAN.

Research on the Lipschitz continuity of GAN discriminators have resulted in the theory of “informative gradients” Zhou et al. (2019, 2018). Under certain mild conditions, a Lipschitz discriminator can provide informative gradient to the generator in a GAN framework: when and are disjoint, the gradient of optimal discriminator w.r.t each sample points to a sample , which guarantees that the generation distribution is moving towards . We extend the informative gradient theory to our new case and show convergence of our approach.

Reinforcement Learning as Inference. Casting RL as probabilistic inference has a long history of research Dayan and Hinton (1997); Deisenroth et al. (2013); Rawlik et al. (2013); Levine (2018); Abdolmaleki et al. (2018). For example, Abdolmaleki et al. (2018) introduced maximum a-posteriori policy optimization from a variational perspective. TRPO Schulman et al. (2015) is closely related to this line by using a KL divergence regularizer to stabilize standard RL objectives. PPO Schulman et al. (2017) further proposed a practical clipped surrogate objective that emulates the regularization. Our approach draws on the connections to the research, particularly the variational perspective and PPO, to improve GAN training.

Other related work. Importance re-weighting has been adopted in different problems, such as improving variational auto-encoders Burda et al. (2015), de-biasing generative models Grover et al. (2019), learning knowledge constraints Hu et al. (2018b), etc. We derive from the variational perspective which leads to re-weighting in the new context of discriminator stabilization.

3 Improving GAN Training

3.1 Motivations

Our approach is motivated by connecting GAN training with the well-established RL-as-inference methods Abdolmaleki et al. (2018); Levine (2018) under a variational perspective. The connections enable us to augment GAN training with existing powerful probabilistic inference tools as well as draw inspirations from the rich RL literature for stable training. In particular, the connection to the popular TRPO Schulman et al. (2015) and PPO Schulman et al. (2017) yields the probability ratio clipping in generator training that avoids destructive updates (Sec. 3.2

), and the application of importance sampling estimation gives rise to sample re-weighting for adaptive discriminator updates (Sec. 

3.3). The full training procedure is then summarized in Alg.1.

Specifically, as described in Sec. 2, the conventional formulation e.g., WGAN Arjovsky et al. (2017) for updating the generator maximizes the expected discriminator score: , where is the Lipschitz-continuous discriminator parameterized with . The objective straightforwardly relates to policy optimization in RL by seeing as a policy and as a reward function. Thus, inspired by the probabilistic inference formulations of policy optimization Abdolmaleki et al. (2018); Ding and Soricut (2017); Hu et al. (2018b), here we transform the conventional objective by introducing a non-parametric auxiliary distribution and defining a new variational objective:

(2)

where KL is the KL divergence. Intuitively, we are maximizing the expected discriminator score of the auxiliary (instead of generator ), and meanwhile encouraging the generator to stay close to .

As we shall see in more details shortly, the new formulation allows us to take advantage of off-the-shelf inference methods, which naturally leads to new components to improve the GAN training. In particular, maximizing the above objective is solved by the expectation maximization (EM) algorithm 

Neal and Hinton (1998) which alternatingly optimizes at E-step and optimizes at M-step. More specifically, at each iteration , given the current status of , the E-step that maximizes w.r.t has a closed-form solution:

(3)

where is the normalization term depending on the discriminator parameters . We elaborate on the M-step in the following, where we continue to develop the practical procedures for updating the generator and the discriminator, respectively.

3.2 Generator Training with Probability Ratio Clipping

The M-step optimizes w.r.t , which is equivalent to minimizing the KL divergence term in Eq.(2). However, since the generator in GANs is often an implicit distribution that does not permit evaluating likelihood, the above KL term (which involves evaluating the likelihood of samples from ) is not applicable. We adopt an approximation, which has also been used in the classical wake-sleep algorithm Hinton et al. (1995) and recent work Hu et al. (2018b), by minimizing the reverse KL divergence as below. With Eq.(3) plugged in, we have:

(4)

As proven in the appendix, approximating with the reverse KL does not change the optimization problem. The first term on the right-hand side of the equation recovers the conventional objective of updating the generator. Of particular interest is the second term, which is a new KL regularizer between the generator and its “old” state from the previous iteration. The regularizer discourages the generator from changing too much between updates, which is useful to stabilize the stochastic optimization procedure. The regularization closely resembles to that of TRPO/PPO, where a similar KL regularizer is imposed to prevent uncontrolled policy updates and make policy gradient robust to noises. Sec. 3.4 gives convergence analysis on the KL-regularized generator updates.

In practice, directly optimizing with the KL regularizer can be infeasible due to the same difficulty with the implicit distribution as above. Fortunately, PPO Schulman et al. (2017) has presented a simplified solution that emulates the regularized updates using a clipped surrogate objective, which is widely-used in RL. We adapt the solution to our context, leading to the following practical procedure of generator updates.

Probability Ratio Clipping. Let denote the probability ratio which measures the difference between the new and old distributions. We have . The clipped surrogate objective for updating the generator, as adapted from PPO, is:

(5)

where clips the probability ratio, so that moving outside of the interval is discouraged. Taking the minimum puts a ceiling on the increase of the objective. Thus the generator does not benefit by going far away from the old generator.

Finally, to estimate the probability ratio when is implicit, we use an efficient approximation similar to Che et al. (2017); Grover et al. (2019)

by introducing a binary classifier

trained to distinguish real and generated samples. Assuming an optimal  Goodfellow et al. (2014); Che et al. (2017) which has , we can approximate through:

(6)

Note that after plugging the rightmost expression into Eq.(5), gradient can propagate through to since . In practice, during the phase of generator training, we maintain by fine-tuning it for one iteration every time after is updated (Alg.1). Thus the maintenance of is cheap. We give more details of in the appendix.

3.3 Discriminator Training with Sample Re-weighting

We next discuss the training of the discriminator , where we augment the conventional training with an importance weighting mechanism for adaptive updates. Concretely, given the form of the auxiliary distribution solution in Eq.(3), we first draw from the recent energy-based modeling work Kim and Bengio (2016); Hu et al. (2018b); Deng et al. (2020); Che et al. (2020) and propose to maximize the data log-likelihood of w.r.t . By taking the gradient, we have:

(7)

We can observe that the resulting form resembles the conventional one (Sec. 2) as we are essentially maximizing on real data while minimizing on fake samples. An important difference is that here fake samples are drawn from the auxiliary distribution instead of the generator . This difference leads to the new sample re-weighting component as below. Note that, as in WGAN (Sec.2), we maintain to be from the class of -Lipschitz functions, which is necessary for the convergence analysis in Sec.3.4. In practice, we can use gradient penalty following Gulrajani et al. (2017); Wei et al. (2018).

Sample Re-weighting. We use the tool of importance sampling to estimate the expectation under in Eq.(7). Given the multiplicative form of in Eq.(3), similar to Abdolmaleki et al. (2018); Hu et al. (2018b); Deng et al. (2020), we use the generator as the proposal distribution. This leads to

(8)

That is, fake samples from the generator are weighted by the exponentiated discriminator score when used to update the discriminator. Intuitively, the mechanism assigns higher weights to samples that can fool the discriminator better, while low-quality samples are downplayed to avoid destructing the discriminator performance. It is worth mentioning that similar importance weighting scheme has been used in Che et al. (2017); Hu et al. (2018a) for generator training in GANs, and Burda et al. (2015) for improving variational auto-encoders. Our work instead results in a re-weighting scheme in the new context of discriminator training.

Alg.1 summarizes the proposed training procedure for the generator and discriminator.

1:  Initialize the generator , the discriminator , and the auxiliary classifier
2:  for  to  do
3:     for certain number of steps do
4:        Update the discriminator with sample re-weighting through Eqs.(7)-(8), and maintain to have upper-bounded Lipschitz constant through, e.g., gradient penalty Gulrajani et al. (2017).
5:     end for
6:     for certain number of steps do
7:        Finetune the real/fake classifier (for 1 step)
8:        Estimate probability ratio using through Eq.(6)
9:        Update the generator with probability ratio clipping through Eq.(5)
10:     end for
11:  end for
Algorithm 1 GAN Training with Probability Ratio Clipping and Sampling Re-weighting

3.4 Theoretical Analysis

In this section, we show that the generator distribution with our training approach can converge to the real data distribution with optimal discriminator. The analysis is based on the reverse KL updates for the generator (Eq.4), while the probability ratio clipping serves as a practical emulation for the updates. We begin by adapting Proposition 1 in Gulrajani et al. (2017) to our problem:

Proposition 3.1.

Let and be two distributions in , a compact metric space. Then, there is a -Lipschitz function which is the optimal solution of

Let be the optimal coupling between and , defined as the minimizer of: where

is the set of joint distributions

whose marginals are and , respectively. Then, if is differentiable, , and with , it holds that .

The proposition shows that the optimal provides informative gradient Zhou et al. (2018) from towards . We then generalize the conclusion to by considering correlation between and .

By the definition of with respect to in Equation (3), the support of and are the same; namely, given , we also have . Therefore, for all , is also a valid sample from , the in Proposition 3.1 provides informative gradient with respect to :

Therefore, assuming is the optimal discriminator to (7), optimizing Eq.(4) can provide informative gradient to the generator and lead to convergence to .

4 Experiments

We conduct extensive experiments on three unsupervised generation tasks, including image generation, text generation, and text style transfer. The three tasks apply GANs to model different data modalities, namely image, text, and neural hidden representations. Our approach consistently offers improvement over the state-of-the-arts on all three tasks. We present more experimental details in the appendix. Our code is included in the supplementary materials and will be released upon acceptance.

Method IS () FID () Real data 11.24.12 7.8 WGAN-GP (2017) 7.86.08 - CT-GAN (2018) 8.12.12 - SN-GANs (2018) 8.22.05 21.7.21 WGAN-ALP (2020) 8.34.06 12.96.35 SRNGAN (2020) 8.53 .04 19.83 AutoGAN (2019) 8.55.10 12.42 Ours (re-weighting only) 8.45.14 13.21.60 Ours (full) 8.69.13 10.70.10 Table 1:

CIFAR-10 results. Our method is run 3 times for average and standard deviation.

Figure 2: Generated samples by WGAN-GP (top-left), CT-GAN (bottom-left), and ours (right).

4.1 Image generation

We first use the popular CIFAR-10 benchmark for evaluation and in-depth analysis of our approach.

Setup. CIFAR-10 Krizhevsky and Hinton (2010) contains 50K images of sizes . Following the setup in CT-GAN Wei et al. (2018) and its public implementation222https://github.com/igul222/improved_wgan_training, we use a residual architecture to implement both generator and discriminator, and also impose a Lipschitz constraint on the discriminator. For each iteration, we update both generator and discriminator for 5 times. We use Inception Score (IS) Salimans et al. (2016) and Frechet Inception Distance (FID) Heusel et al. (2017)

as our evaluation metrics. Among them, IS evaluates the quality and diversity of generated images, and FID captures model issues, e.g., mode collapse 

Xu et al. (2018).

Results. Table 1 reports the results on CIFAR-10. For the four latest methods, SN-GANs Miyato et al. (2018) introduced spectral normalization to stabilize the discriminator training; WGAN-ALP Terjék (2020) developed an explicit Lipschitz penalty; WGAN-ALP Sanyal et al. (2020) introduced a weight-normalization scheme for generalization; and AutoGAN Gong et al. (2019) incorporated neural architecture search for the generator architecture. From Table 1, one can observe that our full approach (CT-GAN + discriminator sample re-weighting + generator probability ratio clipping) achieves the best performance, with both IS and FID significantly surpassing the baselines. These results accord with the visual results in Figure 2 where our generated samples show higher visual quality than those of the baselines. Moreover, comparison between CT-GAN and our approach with only discriminator re-weighting shows significant improvement. By further adding the probability ratio clipping to arrive our full approach, the performance in terms of both IS and FID is further improved with a large margin. The results demonstrate the effectiveness of the two components in our approach.

Figure 3: Left: Inception score on CIFAR-10 v.s. training iterations. The DCGAN Radford et al. (2015) architecture is used. Right: The gradient norms of discriminators on fake samples.

Figure 1 in Sec. 1 has shown the effects of the proposed approach in stabilizing the generator and discriminator training. Here we further analyze these two components. Figure 3 (left) shows the convergence curves of different GAN methods. For a fair comparison, all models use the same DCGAN architecture Radford et al. (2015), and both our approach and WGAN-GP Gulrajani et al. (2017) enforce the same discriminator Lipschitz constraint. From Figure 3 (left), one can observe that our full approach surpasses our approach with only sample re-weighting, and they both converge faster and achieve a higher IS score than WGAN-GP and DCGAN. Figure 3 (right) further looks into how the re-weighting on fake samples can affect the discriminator training. It shows that by injecting sample re-weighting into WGAN-GP, its gradients on fake samples become more stable and show much lower variance, which also partially explains the higher training stability of discriminator in Figure 1.

4.2 Text Generation

In this section, we evaluate our approach on text generation, a task that is known to be notoriously difficult for GANs due to the discrete and sequential nature of text.

Setup. We implement our approach based on the RelGAN Nie et al. (2018) architecture, a state-of-the-art GAN model for text generation. Specifically, we replace the generator and discriminator objectives in RelGAN with ours. We follow WGAN-GP Gulrajani et al. (2017) and impose discriminator Lipschitz constraint with gradient penalty. Same as Nie et al. (2018), we use Gumbel-softmax approximation Jang et al. (2017); Maddison et al. (2017)

on the discrete text to enable gradient backpropagation, and the generator is initialized with maximum likelihood (MLE) pre-training. Our implementation is based on the public PyTorch code of RelGAN

333https://github.com/williamSYSU/TextGAN-PyTorch. Same as previous studies Nie et al. (2018); Guo et al. (2018); Yu et al. (2017), we evaluate our approach on both synthetic and real text datasets.

Results on Synthetic Data. The synthetic data consists of 10K discrete sequences generated by an oracle-LSTM with fixed parameters Yu et al. (2017). This setup facilitates evaluation, as the quality of generated samples can be directly measured by the negative log-likelihood (NLL) of the oracle on the samples. We use synthetic data with sequence lengths 20 and 40, respectively. Table 2 reports the results. MLE is the baseline with maximum likelihood training, whose output model is used to initialize the generators of GANs. Besides the previous text generation GANs Yu et al. (2017); Guo et al. (2018); Nie et al. (2018), we also compare with WGAN-GP which uses the same neural architecture as RelGAN and ours. From Table 2, one can observe that our approach significantly outperforms all other approaches on both synthetic sets. Our improvement over RelGAN and WGAN-GP demonstrates that our proposed generator and discriminator objectives are more effective than the previous ones.

Length MLE SeqGAN Yu et al. (2017) LeakGAN Guo et al. (2018) RelGAN Nie et al. (2018) WGAN-GP (Gulrajani et al., 2017) Ours Real
20 9.038 8.736 7.038 6.680 6.89 5.67 5.750
40 10.411 10.310 7.191 6.765 6.78 6.14 4.071
Table 2: Oracle negative log-likelihood scores () on synthetic data.
Method BLEU-2 () BLEU-3 () BLEU-4 () BLEU-5 () NLL ()
MLE 0.768 0.473 0.240 0.126 2.382
SeqGAN Yu et al. (2017) 0.777 0.491 0.261 0.138 2.773
LeakGAN Guo et al. (2018) 0.826 0.645 0.437 0.272 2.356
RelGAN (100) Nie et al. (2018) 0.881 0.705 0.501 0.319 2.482
RelGAN (1000) Nie et al. (2018) 0.837 0.654 0.435 0.265 2.285
WGAN-GP Gulrajani et al. (2017) 0.872 0.636 0.379 0.220 2.209
Ours 0.905 0.692 0.470 0.322 2.265
Table 3: Results on EMNLP2017 WMT News. BLEU measures text quality and NLL evaluates sample diversity. We copied the results of previous text GAN models from Nie et al. (2018)

, where RelGAN (100) and RelGAN (1000) use different hyperparameters as reported in the paper.

Results on Real Data. We then evaluate our method on the EMNLP2017 WMT News, the largest real text data used for text GAN studies Guo et al. (2018); Nie et al. (2018). The dataset consists of 270K/10K training/test sentences with a maximum length of 51 and a vocabulary size of 5,255. To measure the generation quality, we use the popular BLEU- metric which measures -gram overlap between generated and real text (). To evaluate the diversity of generation, we use the negative log-likelihood of the generator on the real test set (NLLGuo et al. (2018); Nie et al. (2018). From the results in Table 3, one can see that our approach shows comparable performance with the previous best model RelGAN (100) in terms of text quality (BLEU), but has better sample diversity. Our model also achieves much higher BLEU scores than WGAN-GP (e.g., 0.322 v.s. 0.220 on BLEU-5), demonstrating its ability of generating higher-quality samples.

4.3 Text Style Transfer

We further apply our approach on the text style transfer task which is gaining increasing attention in NLP Hu et al. (2017); Shen et al. (2017). The task aims at rewriting a sentence to modify its style (e.g., sentiment) while preserving the content. Previous work applies GANs on neural hidden states to learn disentangled representations Shen et al. (2017); Tikhonov et al. (2019). The task thus can serve as a good benchmark task for GANs, as hidden state modeling provides a new modality that differs from image and text modeling as studied above.

Setup. We follow the same experimental setting and use the same model architecture in the latest work Tikhonov et al. (2019). In particular, Tikhonov et al. (2019)

extended the variational autoencoder based model 

Hu et al. (2017); Kingma and Welling (2013) by adding a latent code discriminator which eliminates stylistic information in the latent code. We replace their adversarial objectives with our proposed ones, and impose discriminator Lipschitz constraint with gradient penalty Gulrajani et al. (2017). Our implementation is based on the public code444https://github.com/VAShibaev/text_style_transfer released in Tikhonov et al. (2019). We test our approach on sentiment transfer, in which the sentiment (positive or negative) is treated as the style of the text. We use the standard Yelp review dataset555www.yelp.com/dataset, and the human written output text provided by Li et al. (2018) as the ground truth for evaluation.

Results. Following the previous work Tikhonov et al. (2019), we first report the BLEU score that measures the similarity of the generated samples against the human written text. Table 4 shows that our approach achieves best performance, improving the state-of-the-art result Tikhonov et al. (2019) from BLEU to .

The second widely used evaluation method is to measure (1) the style accuracy by applying a pre-trained style classifier on generated text, and (2) the content preservation by computing the BLEU score between the generated text and the original input text (BLEU-X). There is often a trade-off between the two metrics. Figure 4 displays the trade-off by different models. Our results locate on the top-right corner, indicating that our approach achieves the best overall style-content trade-off.

Method BLEU Zhang et al. (2018) 24.48 Tian et al. (2018) 24.90 Subramanian et al. (2018) 31.20 Tikhonov et al. (2019) 32.82 Ours 33.45.95 Table 4: BLEU scores between model generations and human-written text on the Yelp data. We run our method for 5 times and report the average and standard deviation. Figure 4: Trade-off between style accuracy and content preservation. The orange circles denote our results using varying values for an objective weight Tikhonov et al. (2019) which manages the trade-off.

5 Conclusion

We have presented a new training framework of GANs derived from a new variational perspective and draws on rich connections with RL-as-inference. This results in probably ratio clipping for generator updates to discourage overly large changes, and fake sample re-weighting for discriminator updates to stabilize training. Experiments show our approach achieves better results than previous best methods on image generation, text generation, and text style transfer. Our approach also shows more stable training. We are interested in exploring more connections between GANs and other learning paradigms to inspire more techniques for improved GAN training.

References

  • [1] A. Abdolmaleki, J. T. Springenberg, Y. Tassa, R. Munos, N. Heess, and M. Riedmiller (2018) Maximum a posteriori policy optimisation. In ICLR, Cited by: §1, §2, §3.1, §3.1, §3.3.
  • [2] M. Arjovsky, S. Chintala, and L. Bottou (2017) Wasserstein generative adversarial networks. In

    International Conference on Machine Learning

    ,
    pp. 214–223. Cited by: §1, §2, §3.1, §6.3.2, §6.3.3, §6.3.4.
  • [3] D. Berthelot, T. Schumm, and L. Metz (2017) Began: boundary equilibrium generative adversarial networks. arXiv preprint arXiv:1703.10717. Cited by: §1.
  • [4] A. Brock, J. Donahue, and K. Simonyan (2018) Large scale gan training for high fidelity natural image synthesis. arXiv preprint arXiv:1809.11096. Cited by: §1.
  • [5] Y. Burda, R. Grosse, and R. Salakhutdinov (2015) Importance weighted autoencoders. arXiv preprint arXiv:1509.00519. Cited by: §1, §2, §3.3.
  • [6] M. Caccia, L. Caccia, W. Fedus, H. Larochelle, J. Pineau, and L. Charlin (2020) Language gans falling short. In ICLR, Cited by: §1.
  • [7] T. Che, Y. Li, R. Zhang, R. D. Hjelm, W. Li, Y. Song, and Y. Bengio (2017) Maximum-likelihood augmented discrete generative adversarial networks. arXiv preprint arXiv:1702.07983. Cited by: §3.2, §3.3.
  • [8] T. Che, R. Zhang, J. Sohl-Dickstein, H. Larochelle, L. Paull, Y. Cao, and Y. Bengio (2020) Your GAN is secretly an energy-based model and you should use discriminator driven latent sampling. arXiv preprint arXiv:2003.06060. Cited by: §3.3.
  • [9] C. Chu, K. Minami, and K. Fukumizu (2020) Smoothness and stability in gans. In ICLR, Cited by: §1.
  • [10] P. Dayan and G. E. Hinton (1997) Using expectation-maximization for reinforcement learning. Neural Computation 9 (2), pp. 271–278. Cited by: §2.
  • [11] M. P. Deisenroth, G. Neumann, J. Peters, et al. (2013) A survey on policy search for robotics. Foundations and Trends® in Robotics 2 (1–2), pp. 1–142. Cited by: §2.
  • [12] Y. Deng, A. Bakhtin, M. Ott, A. Szlam, and M. Ranzato (2020) Residual energy-based models for text generation. In ICLR, Cited by: §1, §3.3, §3.3.
  • [13] N. Ding and R. Soricut (2017) Cold-start reinforcement learning with softmax policy gradient. In NeurIPS, Cited by: §3.1.
  • [14] W. Fedus, I. Goodfellow, and A. M. Dai (2018) MaskGAN: better text generation via filling in the_. In ICLR, Cited by: §1.
  • [15] X. Gong, S. Chang, Y. Jiang, and Z. Wang (2019) Autogan: neural architecture search for generative adversarial networks. In

    Proceedings of the IEEE International Conference on Computer Vision

    ,
    pp. 3224–3234. Cited by: §4.1, §4.
  • [16] I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, and Y. Bengio (2014) Generative adversarial nets. In Advances in neural information processing systems, pp. 2672–2680. Cited by: §1, §2, §3.2.
  • [17] A. Grover, J. Song, A. Kapoor, K. Tran, A. Agarwal, E. J. Horvitz, and S. Ermon (2019) Bias correction of learned generative models using likelihood-free importance weighting. In NeurIPS, Cited by: §1, §1, §2, §3.2.
  • [18] I. Gulrajani, F. Ahmed, M. Arjovsky, V. Dumoulin, and A. C. Courville (2017) Improved training of wasserstein gans. In Advances in neural information processing systems, pp. 5767–5777. Cited by: §1, §2, §3.3, §3.4, §4.1, §4.2, §4.3, Table 2, Table 3, §4, §6.3.2, 4.
  • [19] J. Guo, S. Lu, H. Cai, W. Zhang, Y. Yu, and J. Wang (2018) Long text generation via adversarial training with leaked information. In

    Thirty-Second AAAI Conference on Artificial Intelligence

    ,
    Cited by: §4.2, §4.2, §4.2, Table 2, Table 3.
  • [20] M. Heusel, H. Ramsauer, T. Unterthiner, B. Nessler, and S. Hochreiter (2017) Gans trained by a two time-scale update rule converge to a local nash equilibrium. In Advances in neural information processing systems, pp. 6626–6637. Cited by: §4.1.
  • [21] G. E. Hinton, P. Dayan, B. J. Frey, and R. M. Neal (1995) The" wake-sleep" algorithm for unsupervised neural networks. Science 268 (5214), pp. 1158–1161. Cited by: §3.2.
  • [22] Z. Hu, H. Shi, B. Tan, W. Wang, Z. Yang, T. Zhao, J. He, L. Qin, D. Wang, X. Ma, et al. (2019) Texar: a modularized, versatile, and extensible toolkit for text generation. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics: System Demonstrations, pp. 159–164. Cited by: §6.3.4.
  • [23] Z. Hu, Z. Yang, X. Liang, R. Salakhutdinov, and E. P. Xing (2017) Toward controlled generation of text. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 1587–1596. Cited by: §4.3, §4.3.
  • [24] Z. Hu, Z. Yang, R. Salakhutdinov, and E. P. Xing (2018) On unifying deep generative models. In ICLR, Cited by: §3.3.
  • [25] Z. Hu, Z. Yang, R. R. Salakhutdinov, L. Qin, X. Liang, H. Dong, and E. P. Xing (2018) Deep generative models with learnable knowledge constraints. In Advances in Neural Information Processing Systems, pp. 10501–10512. Cited by: §1, §2, §3.1, §3.2, §3.3, §3.3, §6.2.
  • [26] E. Jang, S. Gu, and B. Poole (2017) Categorical reparameterization with gumbel-softmax. In ICLR, Cited by: §4.2.
  • [27] T. Kim and Y. Bengio (2016) Deep directed generative models with energy-based probability estimation. arXiv preprint arXiv:1606.03439. Cited by: §3.3.
  • [28] D. P. Kingma and M. Welling (2013) Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114. Cited by: §4.3.
  • [29] A. Krizhevsky and G. Hinton (2010)

    Convolutional deep belief networks on cifar-10

    .
    Unpublished manuscript 40 (7), pp. 1–9. Cited by: §4.1.
  • [30] S. Levine (2018) Reinforcement learning and control as probabilistic inference: tutorial and review. arXiv preprint arXiv:1805.00909. Cited by: §1, §2, §3.1.
  • [31] J. Li, R. Jia, H. He, and P. Liang (2018) Delete, retrieve, generate: a simple approach to sentiment and style transfer. In Proceedings of NAACL-HLT, pp. 1865–1874. Cited by: §4.3.
  • [32] C. J. Maddison, A. Mnih, and Y. W. Teh (2017)

    The concrete distribution: a continuous relaxation of discrete random variables

    .
    Cited by: §4.2.
  • [33] M. Mathieu, C. Couprie, and Y. LeCun (2015) Deep multi-scale video prediction beyond mean square error. arXiv preprint arXiv:1511.05440. Cited by: §1.
  • [34] T. Miyato, T. Kataoka, M. Koyama, and Y. Yoshida (2018) Spectral normalization for generative adversarial networks. arXiv preprint arXiv:1802.05957. Cited by: §4.1, §4, §6.3.2.
  • [35] R. M. Neal and G. E. Hinton (1998) A view of the em algorithm that justifies incremental, sparse, and other variants. In Learning in graphical models, pp. 355–368. Cited by: §3.1.
  • [36] W. Nie, N. Narodytska, and A. Patel (2018) Relgan: relational generative adversarial networks for text generation. Cited by: §1, §4.2, §4.2, §4.2, Table 2, Table 3.
  • [37] R. Nock, Z. Cranko, A. K. Menon, L. Qu, and R. C. Williamson (2017) F-GANs in an information geometric nutshell. In Advances in Neural Information Processing Systems, pp. 456–464. Cited by: §1.
  • [38] S. Nowozin, B. Cseke, and R. Tomioka (2016) F-gan: training generative neural samplers using variational divergence minimization. In Advances in neural information processing systems, pp. 271–279. Cited by: §1.
  • [39] A. Odena, C. Olah, and J. Shlens (2017) Conditional image synthesis with auxiliary classifier gans. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 2642–2651. Cited by: §6.3.4.
  • [40] A. Radford, L. Metz, and S. Chintala (2015) Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434. Cited by: §1, Figure 3, §4.1, §6.3.2.
  • [41] K. Rawlik, M. Toussaint, and S. Vijayakumar (2013) On stochastic optimal control and reinforcement learning by approximate inference. In IJCAI, Cited by: §2.
  • [42] T. Salimans, I. Goodfellow, W. Zaremba, V. Cheung, A. Radford, and X. Chen (2016) Improved techniques for training gans. In Advances in neural information processing systems, pp. 2234–2242. Cited by: §4.1.
  • [43] A. Sanyal, P. H. Torr, and P. K. Dokania (2020) Stable rank normalization for improved generalization in neural networks and gans. In ICLR, Cited by: §4.1, §4.
  • [44] J. Schulman, S. Levine, P. Abbeel, M. Jordan, and P. Moritz (2015) Trust region policy optimization. In International conference on machine learning, pp. 1889–1897. Cited by: §1, §2, §3.1.
  • [45] J. Schulman, F. Wolski, P. Dhariwal, A. Radford, and O. Klimov (2017) Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347. Cited by: §1, §2, §3.1, §3.2.
  • [46] T. Shen, T. Lei, R. Barzilay, and T. Jaakkola (2017) Style transfer from non-parallel text by cross-alignment. In Advances in neural information processing systems, pp. 6830–6841. Cited by: §4.3.
  • [47] S. Subramanian, G. Lample, E. M. Smith, L. Denoyer, M. Ranzato, and Y. Boureau (2018) Multiple-attribute text style transfer. arXiv preprint arXiv:1811.00552. Cited by: §4.3.
  • [48] D. Terjék (2020) Adversarial lipschitz regularization. In ICLR, Cited by: §4.1, §4.
  • [49] Y. Tian, Z. Hu, and Z. Yu (2018) Structured content preservation for unsupervised text style transfer. arXiv preprint arXiv:1810.06526. Cited by: §4.3.
  • [50] A. Tikhonov, V. Shibaev, A. Nagaev, A. Nugmanova, and I. P. Yamshchikov (2019) Style transfer for texts: to err is human, but error margins matter. In EMNLP, Cited by: Figure 4, §4.3, §4.3, §4.3, §4.3, Figure 7, §6.3.4.
  • [51] X. Wei, B. Gong, Z. Liu, W. Lu, and L. Wang (2018) Improving the improved training of wasserstein GANs: a consistency term and its dual effect. arXiv preprint arXiv:1803.01541. Cited by: §1, §2, §3.3, §4.1, §4, §6.3.2, §6.3.2.
  • [52] Q. Xu, G. Huang, Y. Yuan, C. Guo, Y. Sun, F. Wu, and K. Weinberger (2018) An empirical study on evaluation metrics of generative adversarial networks. arXiv preprint arXiv:1806.07755. Cited by: §4.1.
  • [53] L. Yu, W. Zhang, J. Wang, and Y. Yu (2017) Seqgan: sequence generative adversarial nets with policy gradient. In Thirty-First AAAI Conference on Artificial Intelligence, Cited by: §4.2, §4.2, Table 2, Table 3.
  • [54] Z. Zhang, S. Ren, S. Liu, J. Wang, P. Chen, M. Li, M. Zhou, and E. Chen (2018) Style transfer as unsupervised machine translation. arXiv preprint arXiv:1808.07894. Cited by: §4.3.
  • [55] J. Zhao, M. Mathieu, and Y. LeCun (2016) Energy-based generative adversarial network. arXiv preprint arXiv:1609.03126. Cited by: §1.
  • [56] Z. Zhou, J. Liang, Y. Song, L. Yu, H. Wang, W. Zhang, Y. Yu, and Z. Zhang (2019) Lipschitz generative adversarial nets. arXiv preprint arXiv:1902.05687. Cited by: §2.
  • [57] Z. Zhou, Y. Song, L. Yu, H. Wang, J. Liang, W. Zhang, Z. Zhang, and Y. Yu (2018) Understanding the effectiveness of lipschitz-continuity in generative adversarial nets. arXiv preprint arXiv:1807.00751. Cited by: §2, §3.4, §6.2.

6 Appendix

6.1 Proof on the equivalence between Reverse KL Divergence and KL Divergence

We prove that optimizing are equivalent to optimizing . This provides guarantee for the approximation that leads to (4).

Claim: Under the assumption that Lipschitz, is bounded because the input is bounded. Let be the Lipschitz constant of , and let

(9)

We then show that differ by a constant. Since the function is lower and upper-bounded. There exists , such that for any bounded.

(10)

where ① plugs ; ② uses the fact ; ③ uses . The above claim completes the theoretical guarantee on the reverse-KL approximation in (4).

6.2 Proof on the necessity of Lipschitz constraint on the discriminator

Although [25] shows preliminary connections between PR and GAN, the proposed PR framework does not provide informative gradient to the generator when treated as a GAN loss. Following [57], we consider the training problem when the discriminator (i.e. here) is optimal: when discriminator is optimal, then the gradient of generator is which could be very small due to vanished . In this way, it is hard to push the generated data distribution towards the targeted real distribution . This problem also exists in LABEL:d_loss because

(11)

So if and are disjoint, we have

(12)

Note that for any , is not related to and thus its gradient also does not relate to . Similarly, for any , does not provide any information of . Therefore, the proposed loss in [25] cannot guarantee informative gradient [57] that pushes or towards to .

6.3 Experiments: More Details and Results

6.3.1 Binary classifier for probability ratio clipping

For the image generation and text generation, the binary classifier 6 has the same architecture as the discriminator except an additional softmax activation at the output layer. The binary classifier is trained with real and fake mini-batches alongside the generator, and requires no additional loops.

In addition in the task of image generation, we observe similar overall performance between training on raw inputs from the generator/dataset and training on input features from the first residual block of the discriminator (), thus further reducing the computational overhead of the binary classifier.

6.3.2 Image Generation on CIFAR-10

We translate the code666github.com/biuyq/CT-GAN provided by Wei et al. [51] into Pytorch to conduct our experiments. We use the same architecture: a residual architecture for both generator and discriminator, and enforcing Lipschitz constraint on the discriminator in the same way as CT-GAN [51]. During training, we interleave 5 generator iterations with 5 discriminator iterations. We optimize the generator and discriminators with Adam (Generator lr: , Discriminator lr: , betas: ). We set the clipping threshold

for the surrogate loss and we linearly anneal the learning rate with respect to the number of training epochs.

Discriminator sample re-weighting stabilizes DCGAN

We quantitatively evaluate the effect of discriminator re-weighted sampling by comparing DCGAN [40]

against DCGAN with discriminator re-weighting. Starting from the DCGAN architecture and hyper-parameters, we run 200 random configurations of learning rate, batch size, non-linearity (ReLU/LeakyReLU), and base filter count (32, 64). Results are summarized in Table

5. DCGANs trained with re-weighted sampling has significantly less collapse rate, and achieves better overall performance in terms of Inception Score. These results well demonstrate the effectiveness of the proposed discriminator re-weighted sampling mechanism.

Method Collapse rate Avg IS Best IS
DCGAN 52.4% 4.2 6.1
DCGAN + Re-weighting 30.2% 5.1 6.7
Table 5: Outcomes of 200 trials with random configurations. The performance of the models are measured through Inception score. We identify training collapse when the average discriminator loss over 2000 batches is below or above . DCGAN re-weighted with our loss has lower collapse rate and higher average performance.
Discriminator re-weighted samples

To provide an illustration of how discriminator weights can help the discriminator concentrate on the fake samples of better quality during the training phase, in Figure 5 we plot the fake samples of a trained ResNet model alongside their corresponding discriminator weights.

Figure 5: One batch of generated images together with their corresponding softmax discriminator weights. The more photo-realistic images (columns 2, 3, 5, 8) receive higher discriminator weights. In this batch, the generator will be influenced more by gradients from the better-quality samples above.
Clipped surrogate objective

One unique benefit of the clipped surrogate objective is that it allows our model to obtain an estimate of the effectiveness of the discriminator, which then enables us to follow a curriculum that takes more than one generator steps per critic steps. In practice, setting achieves good quality, which also allows us to take times more generator steps than prior works [2, 18, 51, 34] with the same number of discriminator iterations. Table 1 shows the improvement enabled by applying the surrogate objective.

Generated samples

Figure 6 shows more image samples by our model.


Figure 6: More samples from our generator on CIFAR-10

6.3.3 Text Generation

We build upon the Pytorch implementation777github.com/williamSYSU/TextGAN-PyTorch of RelGAN. We use the exact same model architecture as provided in the code, and enforce Lipschitz constraint on the discriminator in the same way as in WGAN-GP [2].

During training, we interleave 5 generator iterations with 5 discriminator iterations. We use Adam optimizer (generator lr: 1e-4, discriminator lr: 3e-4). We set the clipping threshold for the surrogate loss and we linearly anneal the learning rate with respect to the number of training epochs.

6.3.4 Text Style Transfer


Figure 7: Model architecture from [50], where the style discriminator () is a structured constraint the generator optimize against. A latent code discriminator ensure the independence between semantic part of the latent representation and the style of the text. Blue dashed arrows denote additional independence constraints of latent representation and controlled attribute, see [50] for the details.

We build upon the Texar-TensorFlow 

[22] style-transfer model by Tikhonov et al. [50]888https://github.com/VAShibaev/text_style_transfer. We use the exact same model architecture and hyper-parameters as provided in the code, and enforce Lipschitz constraint on the discriminator in the same way as WGAN-GP [2]. In addition, we replace the discriminator in Figure 7, by our loss with an auxiliary linear style classifier as in Odena et al. [39]

. We did not apply the surrogate loss to approximate the KL divergence, but relied on gradient clipping on the generator.