fastgan
Improving the Speed and Quality of GAN by Adversarial Training
view repo
Generative adversarial networks (GAN) have shown remarkable results in image generation tasks. High fidelity class-conditional GAN methods often rely on stabilization techniques by constraining the global Lipschitz continuity. Such regularization leads to less expressive models and slower convergence speed; other techniques, such as the large batch training, require unconventional computing power and are not widely accessible. In this paper, we develop an efficient algorithm, namely FastGAN (Free AdverSarial Training), to improve the speed and quality of GAN training based on the adversarial training technique. We benchmark our method on CIFAR10, a subset of ImageNet, and the full ImageNet datasets. We choose strong baselines such as SNGAN and SAGAN; the results demonstrate that our training algorithm can achieve better generation quality (in terms of the Inception score and Frechet Inception distance) with less overall training time. Most notably, our training algorithm brings ImageNet training to the broader public by requiring 2-4 GPUs.
READ FULL TEXT VIEW PDFImproving the Speed and Quality of GAN by Adversarial Training
Generating high-quality samples from complex distribution is one of the fundamental problems in machine learning. One of the approaches is by generative adversarial networks (GAN)
[5]. Inside the GAN, there is a pair of generator and discriminator trained in some adversarial manner. GAN is recognized as an efficient method to sample from high-dimensional, complex distribution given limited examples, though it is not easy to converge stably. Major failure modes include gradient vanishing and mode collapsing.A long line of research is conducted to combat instability during training, e.g. [14, 1, 6, 17, 27, 2]. Notably, WGAN [1] and WGAN-GP [6] propose to replace the KL-divergence metric with Wasserstein distance, which relieve the gradient vanishing problem; SNGAN [17] takes a step further by replacing the gradient penalty term with spectral normalization layers, improving both the training stability and sample quality over the prior art. Since then, many successful architectures, including SAGAN [27] and BigGAN [2], utilize the spectral normalization layer to stabilize training. The core idea behind the WGAN or SNGAN is to constrain the smoothness of discriminator. Examples include weight clipping, gradient norm regularization, or normalizing the weight matrix by the largest singular value. However, all of them incur non-negligible computational overhead (higher-order gradient or power iteration).
Apart from spectral normalization, a bundle of other techniques are proposed to aid the training further; among them, essential techniques are two time-scale update rule (TTUR) [7], self attention [27] and large batch training [2]. However, both of them improve the Inception score(IS) [22] and Frechet Inception Distance(FID) [7] at the cost of slowing down the per iteration time or requiring careful tuning of two learning rates. Therefore, it would be useful to seek some alternative ways to reach the same goal but with less overhead. In particular, we answer the following question:
Can we train a plain ResNet generator and discriminator without architectural innovations, spectral normalization, TTUR or big batch size?
In this paper, we give an affirmative answer to the question above. We will show that all we need is adding another adversarial training loop to the discriminator. Our contributions can be summarized as follows
[leftmargin=*,noitemsep]
We show a simple training algorithm that accelerates the training speed and improves the image quality all at the same time.
We show that with the training algorithm, we could match or beat the strong baselines (SNGAN and SAGAN) with plain neural networks (the raw ResNet without spectral normalization or self-attention).
Our algorithm is widely applicable: experimental results show that the algorithm works on from tiny data (CIFAR10) to large data (1000-class ImageNet) using or less GPUs.
Throughout this paper, we use to denote the generator network and to denote the discriminator network (and the parameters therein). is the data in dimensional image space;
is Gaussian random variable;
is categorical variable indicating the class id. Subscripts
and indicates real and fake data. For brevity, we denote the Euclidean 2-norm as .is the loss function of generator or discriminator depending on the subscripts.
Generative Adversarial Network (GAN) can be seen as a min-max game between two players: the generator () and discriminator (). At training time, the generator transforms Gaussian noise to synthetic images to fool the discriminator, while the discriminator learns to distinguish the fake data from the real. The loss function is formulated as
(1) |
where is the distribution of real data, and
is sampled from the standard normal distribution. Our desired solution of (
1) is and ; however, in practice we can rarely get this solution. Early findings show that training and with multi-layer CNNs on GAN loss (1) does not always generate sensible images even on simple CIFAR10 data. Researchers tackle this problem by having better theoretical understandings and intensive empirical experiments. Theoretical works mainly focus on the local convergence to the Nash equilibrium [16, 15, 25]by analyzing the eigenvalues of the Jacobian matrix derived from gradient descent ascent (GDA). Intuitively, if the eigenvalues
(maybe complex) lie inside the unit sphere , i.e. , and the ratio is small, then the convergence to local equilibrium will be fast.SNGAN | SAGAN | BigGAN | FastGAN (ours) | |
---|---|---|---|---|
ResNet | ✓ | ✓ | ✓ | ✓ |
Wider | ✗ | ✓ | ✓ | ✗ |
Shared embedding & skip-z | ✗ | ✗ | ✓ | ✗ |
Self attention | ✗ | ✓ | ✓ | ✗ |
TTUR | ✗ | ✓ | ✓ | ✗ |
Spectral normalization | ✓ | ✓ | ✓ | ✗ |
Orthogonal regularization | ✗ | ✗ | ✓ | ✗ |
Adversarial training | ✗ | ✗ | ✗ | ✓ |
Model size | 72.0M | 81.5M | 158.3M | 72.0M |
Empirical findings mainly focus on regularization techniques, evolving from weight clipping [1], gradient norm regularization [6], spectral normalization [17], regularized optimal transport [23] and Lipschitz regularization [29], among others. Our method could also be categorized as a regularizer [20], in which we train the discriminator to be robust to adversarial perturbation on either real or fake images.
A recent correction to the traditional gradient descent ascent (GDA) solver for (1) is proposed [25]. Inside FR algorithm, the authors add a second-order term that drags the iterates back to the “ridge” of the loss surface , specifically,
(2) | ||||
where we denote as the GAN loss in (1), and . With the new update rule, the authors show less rotational behavior in training and an improved convergence to the local minimax. The update rule (2
) involves Hessian vector products and Hessian inverse. Comparing with the simple GDA updates, this is computationally expensive even by solving a linear system. Interestingly, under certain assumptions, we can see our adversarial training scheme could as an approximation of (
2) with no overhead. As a result, our FastGAN converges to the local minimax with fewer iterations (and wall-clock time), even though we dropped the bag of tricks (Table 1) commonly used in GAN training.Adversarial training [4] is initially proposed to enhance the robustness of the neural network. The central idea is to find the adversarial examples that maximize the loss and feed them as synthetic training data. Instances of which include the FGSM [4], PGD [13], and free adversarial training [24]. All of them are approximately solving the following problem
(3) |
where are the training data; is the model parameterized by ; is the loss function; is the perturbation with norm constraint . It is nontrivial to solve (3) efficiently, previous methods solve it with alternative gradient descent-ascent: on each training batch , they launch projected gradient ascent for a few steps and then run one step gradient descent [13]. In this paper, we study how adversarial training helps GAN convergence. To the best of our knowledge, only [28] and [12] correlate to our idea. However, [28] is about unconditional GAN, and none of them can scale to ImageNet. In our experiments, we will include [12] as a strong baseline in smaller datasets such as CIFAR10 and subset of ImageNet.
Our new GAN objective is a three-level min-max-min problem:
(4) |
Note that this objective function is similar to the one proposed in RobGAN [12]. However, RobGAN cannot scale to large datasets due to the bottleneck introduced by adversarial training (see our experimental results), and no satisfactory explanations are provided to justify why adversarial training improves GAN. We made two improvements: First of all, we employ the recent free adversarial training technique [24], which simultaneously updates the adversaries and the model weights. Second, in the label-conditioned GAN setting, we improved the loss function used in RobGAN (both are based on cross-entropy rather than hinge loss). In the following parts, we first show the connection to follow-the-ridge update rules (2) and then elaborate more on the algorithmic details as well as a better loss function.
Recently [25] showed that the Follow-the-Ridge (FR) optimizer improves GAN training, but FR relies on Hessian-vector products and cannot scale to large datasets. We show that solving (4) by one-step adversarial training can be regarded as an efficient simplification of FR, which partially explains why the proposed algorithm can stabilize and speed up GAN training. To show this, we first simplify the GAN training problem as follows:
[leftmargin=*,noitemsep]
Inside the minimax game, we replace the generator network by its output . This dramatically simplifies the notations as we no longer need to calculate the gradient through G.
The standard GAN loss on fake images [5] is written as . However, we often modify it to to mitigate the gradient vanishing problem.
We first consider the original GAN loss:
(5) |
With FR algorithm, the update rule for (parameterized by ) can be written as^{1}^{1}1In this equation, we simplify the FR update ( is dropped). The details are shown in the appendix.
(6) |
where . The last term is regarded as a correction for “off-the-ridge move” by -update. If we decompose it further, we will get
(7) |
Since the first-order term in (7) is already accounted for in (6) (both are highlighted in blue), the second-order term in (7) plays the key role for fast convergence in FR algorithm. However, the algorithm involves Hessian-vector products in each iteration and not very efficient for large data such as ImageNet. Next, we show that our adversarial training could be regarded as a Hessian-free way to perform almost the same functionality. Recall in the adversarial training of GAN, the loss (Eq. (4)) becomes (fixing for brevity)
(8) | ||||
here we assume the inner minimizer only conducts one gradient descent update (similar to the algorithm proposed in the next section), and use first-order Taylor expansion to approximate the inner minimization problem. So the gradient descent/ascent updates will be
. | (9) |
Here we define as the gradient correction term introduced by adversarial training, i.e.
(10) |
Comparing (10) with (7), we can see both of them are a linear combination of second-order term and first-order term, except that the FR is calculated on fake images while our adversarial training can be done on or . The two becomes the same up to a scalar if 1-Lipschitz constraint is enforced on the discriminator , in which case . The constraint is commonly seen in previous works, including WGAN, WGAN-GP, SNGAN, SAGAN, BigGAN, etc. However, we do not put this constraint explicitly in our FastGAN.
Our algorithm is described in Algorithm 1, the main difference from classic GAN training is the extra for-loop inside in updating the discriminator. In each iteration, we do MAX_ADV_STEP steps of adversarial training to the discriminator network. Inside the adversarial training loop, the gradients over input images (denoted as ) and over the discriminator weights (denoted as
) can be obtained with just one backpropagation. We train the discriminator with fake data
immediately after one adversarial training step. A handy trick is applied by reusing the same fake images (generated at the beginning of each D_step) multiple times – we found no performance degradation and faster wall-clock time by avoiding some unnecessary propagations through .Contrary to the common beliefs [7, 27] that it would be beneficial (for stability, performance and convergence rate, etc.) to have different learning rates for generator and discriminator, our empirical results show that once the discriminator undergoes robust training, it is no longer necessary to tune learning rates for two networks.
Projection-based loss function [18] is dominating current state-of-the-art GANs (e.g., [26, 2]) considering the stability in training; nevertheless, it does not imply that traditional cross-entropy loss is an inferior choice. In parallel to other works [3], we believe ACGAN loss can be as good as projection loss after slight modifications. First of all, consider the objective of discriminator where and are the likelihoods on real and fake minibatch, respectively. For instance, in ACGAN we have
(11) |
We remark that in the class-conditioning case, the discriminator contains two output branches: one is for binary classification of real or fake, the other is for multi-class classification of different class labels. The log-likelihood should be interpreted under the joint distribution of the two. However, as pointed out in
[3, 12], the loss (11) encourages a degenerate solution featuring a mode collapse behavior. The solution of TwinGAN [3]is to incorporate another classifier (namely “twin-classifier”) to help generator
promoting its diversity; while RobGAN [12] removes the classification branch on fake data:(12) |
Overall, our FastGAN uses a similar loss function as RobGAN (12
), except we changed the adversarial part from probability to hinge loss
[11, 17], which reduces the gradient vanishing and instability problem. As to the class-conditional branch, FastGAN inherits the auxiliary classifier from ACGAN as it is more suitable for adversarial training. However, as reported in prior works [19, 18], training a GAN with auxiliary classification loss has no good intra-class diversity. To tackle this problem, we added a KL term to . Therefore, and become:(13) | ||||
where is a coefficient, is a uniform categorical distribution among all -classes. Our loss (13) is in sharp difference to ACGAN loss in (11): in ACGAN, the discriminator gets rewarded by assigning high probability to ; while our FastGAN is encouraged to assign a uniform probability to all labels in order to enhance the intra-class diversity of the generated samples. Additionally, we found that it is worth to add another factor to to balance between image diversity and fidelity, so the generator loss in FastGAN is defined as
(14) |
Overall, in the improved objectives of FastGAN, is trained to minimize while is trained to maximize .
In this section, we test the performance of FastGAN on a variety of datasets. For the baselines, we choose SNGAN [17] with projection discriminator, RobGAN [12], and SAGAN [27]. Although better GAN models could be found, such as the BigGAN [2] and the LOGAN [26], we do not include them because they require large batch training (batch size 1k) on much larger networks (see Table 1 for model sizes).
Datasets. We test on following datasets: CIFAR10, CIFAR100 [10], ImageNet-143 [17, 12], and the full ImageNet [21]. Notablly the ImageNet-143 dataset is an -class subset of ImageNet [21], first seen in SNGAN [17]. We use both 64x64 and 128x128 resolutions in our expriments.
Choice of architecture. As our focus is not on architectural innovations, for a fair comparison, we did the layer-by-layer copy of the ResNet backbone from SNGAN (spectral normalization layers are removed). So our FastGAN, SNGAN, and RobGAN are directly comparable, whereas SAGAN is bigger in model size. For experiments on CIFAR, we follow the architecture in WGAN-GP [6], which is also used in SNGAN [17, 18].
Optimizer. We use Adam [9] with learning rate and momentum , (CIFAR/ImageNet-143) or (ImageNet) for both and . We use exponential decaying learning rate scheduler: where is the iteration number.
Other hyperparameters are attached in appendix.
We report the experimental results in Table 2. We remind that SAGAN does not contain official results on CIFAR, so we exclude it form this experiment. The metrics are measured after all GANs stop improving, which took seconds. As we can see, our FastGAN is better than SNGAN and RobGAN at CIFAR dataset in terms of both IS score and FID score. Furthermore, to compare the convergence speed, we exhibit the learning curves of all results in Figure 1. From this figure, we can observe a consistent acceleration effect from FastGAN.
CIFAR10 | CIFAR100 | |||||
IS | FID | Time | IS | FID | Time | |
Real data | – | – | ||||
SNGAN | ||||||
RobGAN | ||||||
FastGAN | ||||||
+ Revert to RobGAN loss | ||||||
+ Disable adv. training | ||||||
+ Constant lr. | ||||||
+ Disable KL-term (13) |
Next, we study which parts of our FastGAN attribute to performance improvement. To this end, we disable some essential modifications of FastGAN (last four rows in Table 2). We also try different in the experiments of CIFAR100 in Figure 2 and show it effectively controls the tradeoff between diversity and fidelity.
ImageNet-143 is a dataset first appeared in SNGAN. As a test suite, the distribution complexity stands between CIFAR and full ImageNet as it has 143 classes and 64x64 or 128x128 pixels. We perform experiments on this dataset in a similar way as in CIFAR. The experimental results are shown in Table 3, in which we can see FastGAN overtakes SNGAN and RobGAN in both IS and FID metrics. As to the convergence rate, as shown in Figure 3, FastGAN often requires times less training time compared with SNGAN and RobGAN to achieve the same scores.
64x64 pixels | 128x128 pixels | |||||
IS | FID | Time | IS | FID | Time | |
Real data | – | – | ||||
SNGAN | ||||||
RobGAN | ||||||
FastGAN | ||||||
+ Revert to RobGAN loss | ||||||
+ Disable KL-term (13) |
IS | FID | Time | IS | FID | Time | ||
Real data | – | ||||||
Trained with batch size | Trained with batch size | ||||||
SNGAN | SAGAN | ||||||
FastGAN | FastGAN |
In this experiment, we set SNGAN and SAGAN as the baselines. RobGAN is not included because no ImageNet experiment is seen in the original paper, nor can we scale RobGAN to ImageNet with the official implementation. The images are scaled and cropped to 128x128 pixels. A notable fact is that SNGAN is trained with batch size , while SAGAN is trained with batch size . To make a fair comparison, we train our FastGAN with both batch sizes and compare them with the corresponding official results. From Table 4, we can generally find the FastGAN better at both metrics, only the FID score is slightly worse than SAGAN. Considering that our model contains no self-attention block, the FastGAN is smaller and faster than SAGAN in training.
In this work, we propose the FastGAN, which incorporates free adversarial training strategies to reduce the overall training time of GAN. Furthermore, we further modify the loss function to improve generation quality. We test FastGAN from small to large scale datasets and compare it with strong prior works. Our FastGAN demonstrates better generation quality with faster convergence speed in most cases.
We justify the stability of our simplified FR update rule in Eq. (6), i.e.
(15) | ||||
Compared with the original update rule in [25], our update rule is essentially setting . Similar to [25], we analyze the spectral norm of Jacobian update. Before that, we consider the Jacobian at local Nash Equilibrium (note that [25] assumes local minimax). From [8], we have following properties under local Nash Equilibrium
Assuming f is differentiable, any local Nash equilibrium satisfies and , where is the minimax objective function.
Assuming f is twice-differentiable, any local Nash equilibrium satisfies and .
Based on the properties above, it becomes straightforward to have following deductions
(16) |
where . As in [25] we have following similar transformation
(17) |
with a small enough , we can see the Jacobian matrix always have specral radius given the positive definitive , and .
Throughout all experiments, we set the steps of updating per iteration and the steps of free adversarial training in our Algorithm 1. Other hyperparameters are searched manually, we list them in Table 5.
We tried 4 datasets in our experiments: CIFAR10, CIFAR100, ImageNet-143, and the full ImageNet. CIFAR10 and CIFAR100 both contain 50,000 3232 resolution RGB images in 10 and 100 classes respectively. ImageNet-143 is a subset of ImageNet which contains 180,373 RGB images in 143 classes. The experiments were conducted on 6464 and 128128 resolutions. The full ImageNet contains 1,281,167 RGB images in 1000 classes. Each class contains approximately 1300 samples. Our experiments ran on 128128 resolution.
Throughout all experiments, we scale all the values of images to . In addition, during the training, we perform random horizontal flip and random cropping and resizing, scaling from 0.8 to 1.0, on the training images.
We choose Inception Score (IS) [22] and Frechet Inception distance (FID) [7] to assess the quality of fake images.
IS utilizes the pre-trained Inception-v3 network to measure the KL divergence between the conditional class distribution and marginal class distribution:
IS | (18) |
where are samples for testing, the conditional class distribution is given by the pre-trained Inception-v3 networks, and the marginal class distribution is calculated by using . In this work, we use 50k samples to measure the IS for all experiments including the IS score of real training data.
FID measures the 2-Wasserstein distance between two groups of samples drawn from distribution and in the feature space of Inception-v3 network:
FID | (19) |
where , are the means and covariance matrices in activations. Here is the real distribution and is the generated distribution. We use the whole dataset to calculate . For the generated samples, we randomly sample 5k data in CIFAR experiments and 50k in other experiments.
We ran all experiments with no more than four Nvidia 1080 Ti GPUs, depending on the datasets. We also repeat measurements multiple times but it does not show much variations. We report the details in Table 6.
Method | Speed | Total iter. | Total time | |
(Sec./ iter.) | () | (seconds) | ||
CIFAR10 | FastGAN | 95.51.6 | 240 | |
(1 GPU) | SNGAN | 245.12.4 | 100 | |
RobGAN | 385.13.1 | 65 | ||
CIFAR100 | FastGAN | 96.40.9 | 240 | |
(1 GPU) | SNGAN | 255.71.2 | 100 | |
RobGAN | 397.11.1 | 65 | ||
ImageNet-143 64px | FastGAN | 301.31.4 | 120 | |
(1 GPU) | SNGAN | 741.22.1 | 300 | |
RobGAN | 1496.74.1 | 120 | ||
ImageNet-143 128px | FastGAN | 595.32.4 | 120 | |
(2 GPUs) | SNGAN | 1310.25.9 | 450 | |
RobGAN | 2993.25.5 | 60 | ||
ImageNet BS=64 | FastGAN | 1113.25.1 | 1200 | |
(2 GPUs) | SNGAN | 3506.24.9 | 850 | |
ImageNet BS=256 | FastGAN | 1699.64.5 | 650 | |
(4 GPUs) | SAGAN | 1793.36.5 | 1000 |
For RobGAN on CIFAR, we set the number of updates per iteration to 5, the adversarial training steps are 3, the PGD bound is 0.006, and PGD step size is 0.002.
We try different values of on CIFAR datasets, the results are shown in results in Figure 4. Notice that due to the stochasticity in training process, the curves are not strictly monotone on both sides (a “V”-shape). But we could see our final choice of maximizes both CIFAR10 and CIFAR100 performances.
Despite that our FastGAN reaches a better IS and FID than other models in shorter time, we would still like to see how it performs after fully convergence. To this end, we double the training iterations for FastGAN on ImageNet-143 experiments and present the results in Table 7. It turns out that FastGAN can further improve the IS and FID scores. Although the training time is doubled compared with the original runs, the total cost is still less than SNGAN and RobGAN.
64x64 pixels | 128x128 pixels | |||||
IS | FID | Time | IS | FID | Time | |
Real data | – | – | ||||
SNGAN | ||||||
RobGAN | ||||||
FastGAN | ||||||
FastGAN(cont.) |
(IS:8.87, FID:17,27) |
(IS:6:26, FID:43.69) |
(IS:7.42, FID:24.02) |
(IS:40.41, FID:14.48) |
(IS:45.94, FID:25.38) |
(IS:42.54, FID:17.51) |
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition
, pp. 11234–11243. Cited by: §2.3, §3.3, §3, §4, §4.Towards deep learning models resistant to adversarial attacks
. arXiv preprint arXiv:1706.06083. Cited by: §2.3.