Self-Attention Generative Adversarial Networks Implementation in PyTorch
In this paper, we propose the Self-Attention Generative Adversarial Network (SAGAN) which allows attention-driven, long-range dependency modeling for image generation tasks. Traditional convolutional GANs generate high-resolution details as a function of only spatially local points in lower-resolution feature maps. In SAGAN, details can be generated using cues from all feature locations. Moreover, the discriminator can check that highly detailed features in distant portions of the image are consistent with each other. Furthermore, recent work has shown that generator conditioning affects GAN performance. Leveraging this insight, we apply spectral normalization to the GAN generator and find that this improves training dynamics. The proposed SAGAN achieves the state-of-the-art results, boosting the best published Inception score from 36.8 to 52.52 and reducing Frechet Inception distance from 27.62 to 18.65 on the challenging ImageNet dataset. Visualization of the attention layers shows that the generator leverages neighborhoods that correspond to object shapes rather than local regions of fixed shape.READ FULL TEXT VIEW PDF
In this paper, we propose an Attentional Generative Adversarial Network
Attention-based models, exemplified by the Transformer, can effectively ...
Traditional convolution-based generative adversarial networks synthesize...
In this paper, we propose a novel method to efficiently train a Generati...
We introduce a new local sparse attention layer that preserves
Generating images according to natural language descriptions is a challe...
Training generative adversarial networks requires balancing of delicate
Self-Attention Generative Adversarial Networks Implementation in PyTorch
Image synthesis is an important problem in computer vision. There has been remarkable progress in this direction with the emergence of Generative Adversarial Networks (GANs)goodfellow2014generative . GANs based on deep convolutional networks Radford15 ; KarrasALL18 ; Han17stackgan2 have been especially successful. However, by carefully examining the generated samples from these models, we can observe that convolutional GANs Odena2016 ; Miyato18a ; Miyato18b have much more difficulty modeling some image classes than others when trained on multi-class datasets (e.g., ImageNet ILSVRC15 ). For example, while the state-of-the-art ImageNet GAN model Miyato18b excels at synthesizing image classes with few structural constraints (e.g. ocean, sky and landscape classes, which are distinguished more by texture than by geometry), it fails to capture geometric or structural patterns that occur consistently in some classes (for example, dogs are often drawn with realistic fur texture but without clearly defined separate feet). One possible explanation for this is that previous models rely heavily on convolution to model the dependencies across different image regions. Since the convolution operator has a local receptive field, long range dependencies can only be processed after passing through several convolutional layers. This could prevent learning about long-term dependencies for a variety of reasons: a small model may not be able to represent them, optimization algorithms may have trouble discovering parameter values that carefully coordinate multiple layers to capture these dependencies, and these parameterizations may be statistically brittle and prone to failure when applied to previously unseen inputs. Increasing the size of the convolution kernels can increase the representational capacity of the network but doing so also loses the computational and statistical efficiency obtained by using local convolutional structure. Self-attention Cheng16 ; ParikhT0U16 ; Ashish17
, on the other hand, exhibits a better balance between ability to model long-range dependencies and computational and statistical efficiency. The self-attention module calculates response at a position as a weighted sum of the features at all positions, where the weights – or attention vectors – are calculated with only a small computational cost.
In this work, we propose Self-Attention Generative Adversarial Networks (SAGANs), which introduce a self-attention mechanism into convolutional GANs. The self-attention module is complementary to convolutions and helps with modeling long range, multi-level dependencies across image regions. Armed with self-attention, the generator can draw images in which fine details at every location are carefully coordinated with fine details in distant portions of the image. Moreover, the discriminator can also more accurately enforce complicated geometric constraints on the global image structure.
In addition to self-attention, we also incorporate recent insights relating network conditioning to GAN performance. Odena18 showed that well-conditioned generators tend to perform better. We propose enforcing good conditioning of GAN generators using the spectral normalization technique that has previously been applied only to the discriminator Miyato18a .
We have conducted extensive experiments on the ImageNet dataset to validate the effectiveness of the proposed self-attention mechanism and stabilization techniques. SAGAN significantly outperforms the state of the art in image synthesis by boosting the best reported Inception score from 36.8 to 52.52 and reducing Fréchet Inception distance from 27.62 to 18.65. Visualization of the attention layers shows that the generator leverages neighborhoods that correspond to object shapes rather than local regions of fixed shape.
Generative Adversarial Networks.
GANs have achieved great success in various image generation tasks, including image-to-image translationpix2pix2017 ; cyclegan2017 ; Taigmaniclr17 ; LiuT16
, image super-resolutionChristian2016 ; Casper2016 and text-to-image synthesis reed2016generative ; reed2016learning ; Han16 . Despite this success, the training of GANs is known to be unstable and sensitive to the choices of hyper-parameters. Several works have attempted to stabilize the GAN training dynamics and improve the sample diversity by designing new network architectures Radford15 ; Han16 ; KarrasALL18 , modifying the learning objectives and dynamics Martin17WGAN ; Salimans18 ; metz2017unrolled ; CheLJBL16 ; Zhao2016 , adding regularization methods GulrajaniAADC17 ; Miyato18a
and introducing heuristic tricksSalimans2016 ; Odena2016 . Recently, Miyato et al. Miyato18a proposed limiting the spectral norm of the weight matrices in the discriminator in order to constrain the Lipschitz constant of the discriminator function. Combined with the projection-based discriminator Miyato18b , the spectrally normalized model greatly improves class-conditional image generation on ImageNet.
Attention Models. Recently, attention mechanisms have become an integral part of models that must capture global dependencies Dzmitry14 ; XuBKCCSZB15 ; YangHGDS16 ; Gregor15DRAW . In particular, self-attention Cheng16 ; ParikhT0U16 , also called intra-attention, calculates the response at a position in a sequence by attending to all positions within the same sequence. Vaswani et al. Ashish17
demonstrated that machine translation models could achieve state-of-the-art results by solely using a self-attention model. Parmaret al. Parmar18
proposed an Image Transformer model to add self-attention into an autoregressive model for image generation. Wanget al. Wang18 formalized self-attention as a non-local operation to model the spatial-temporal dependencies in video sequences. In spite of this progress, self-attention has not yet been explored in the context of GANs. (AttnGAN Xu18 uses attention over word embeddings within an input sequence, but not self-attention over internal model states). SAGAN learns to efficiently find global, long-range dependencies within internal representations of images.
Most GAN-based models Radford15 ; Salimans2016 ; KarrasALL18 for image generation are built using convolutional layers. Convolution processes the information in a local neighborhood, thus using convolutional layers alone is computationally inefficient for modeling long-range dependencies in images. In this section, we adapt the non-local model of Wang18 to introduce self-attention to the GAN framework, enabling both the generator and the discriminator to efficiently model relationships between widely separated spatial regions.
The image features from the previous hidden layer are first transformed into two feature spaces to calculate the attention, where
and indicates the extent to which the model attends to the location when synthesizing the region. Then the output of the attention layer is , where,
In the above formulation, , , are the learned weight matrices, which are implemented as 11 convolutions. We use in all our experiments.
In addition, we further multiply the output of the attention layer by a scale parameter and add back the input feature map. Therefore, the final output is given by,
where is initialized as 0. This allows the network to first rely on the cues in the local neighborhood – since this is easier – and then gradually learn to assign more weight to the non-local evidence. The intuition for why we do this is straightforward: we want to learn the easy task first and then progressively increase the complexity of the task. In SAGAN, the proposed attention module has been applied to both generator and discriminator, which are trained in an alternating fashion by minimizing the hinge version of the adversarial loss lim2017 ; Tran2017 ; Miyato18a ,
We also investigate two techniques to stabilize the training of GANs on challenging datasets. First, we use spectral normalization Miyato18a in the generator as well as in the discriminator. Second, we confirm that the two-timescale update rule (TTUR) HeuselRUNH17 is effective, and we advocate using it specifically to address slow learning in regularized discriminators.
Miyato et al. Miyato18a originally proposed stabilizing the training of GANs by applying spectral normalization to the discriminator network. Doing so constrains the Lipschitz constant of the discriminator by restricting the spectral norm of each layer. Compared to other normalization techniques, spectral normalization does not require extra hyper-parameter tuning (setting the spectral norm of all weight layers to consistently performs well in practice). Moreover, the computational cost is also relatively small.
We argue that the generator can also benefit from spectral normalization, based on recent evidence that the conditioning of the generator is an important causal factor in GAN performance Odena18 . Spectral normalization in the generator can prevent the escalation of parameter magnitudes and avoid unusual gradients. We find empirically that spectral normalization of both generator and discriminator makes it possible to use fewer discriminator updates per generator update, thus significantly reducing the computational cost of training. The approach also shows more stable training behavior.
In previous work, regularization of the discriminator Miyato18a ; GulrajaniAADC17 often slows down the GAN learning process. In practice, methods using regularized discriminators typically require multiple (e.g., 5) discriminator update steps per generator update step during training. Independently, Heusel et al. HeuselRUNH17 have advocated using separate learning rates (TTUR) for the generator and the discriminator. We propose using TTUR specifically to compensate for the problem of slow learning in a regularized discriminator, making it possible to use fewer generator steps per discriminator step. Using this approach, we were able to produce better results given the same wall-clock time.
To evaluate the proposed methods, we conducted extensive experiments on the LSVRC2012 (ImageNet) dataset ILSVRC15 . First, in Section 5.1, we present experiments designed to evaluate the effectiveness of the two proposed techniques for stabilizing GAN training. Next, the proposed self-attention mechanism is investigated in Section 5.2. Finally, SAGAN is compared with state-of-the-art methods Odena2016 ; Miyato18b on image generation in Section 5.3.
Evaluation metrics. We choose the Inception score (IS) Salimans2016 and the Fréchet Inception distance (FID) HeuselRUNH17 for quantitative evaluation. The Inception score Salimans2016 computes the KL divergence between the conditional class distribution and the marginal class distribution. Higher Inception score indicates better image quality. We include the Inception score because it is widely used and thus makes it possible to compare our results to previous work. However, it is important to understand that Inception score has serious limitations—it is intended primarily to ensure that the model generates samples that can be confidently recognized as belonging to a specific class, and that the model generates samples from many classes, not necessarily to assess realism of details or intra-class diversity. FID is a more principled and comprehensive metric, and has been shown to be more consistent with human evaluation in assessing the realism and variation of the generated samples HeuselRUNH17 . FID calculates the Wasserstein-2 distance between the generated images and the real images in the feature space of an Inception-v3 network. Lower FID values mean closer distances between synthetic and real data distributions. In all our experiments, 50k samples are randomly generated for each model to compute the Inception score and FID.
|[3pt] Baseline: SN on D (10k, FID=181.84)||[3pt] SN on / (10k, FID=93.52)||[3pt] SN on / (160k, FID=33.39)||[3pt] SN on / (260k, FID=72.41)|
|[3pt] SN on /+TTUR (10k, FID=99.04)||[3pt] SN on /+TTUR (160k, FID=40.96)||[3pt] SN on /+TTUR (260k, FID=34.62)||[3pt] SN on /+TTUR (1M, FID=22.96)|
Network structures and implementation details. All the SAGAN models we train are designed to generate 128128 images. By default, spectral normalization Miyato18a is used for the layers in both generator and discriminator. Similar to Miyato18b
, SAGAN uses conditional batch normalization in the generator and projection in the discriminator. For all models, we use the Adam optimizerKingmaB14 with and for training. By default, the learning rate for the discriminator is 0.0004 and the learning rate for the generator is 0.0001.
In this section, experiments are conducted to evaluate the effectiveness of the proposed stabilization techniques, i.e., applying spectral normalization (SN) to the generator and utilizing imbalanced learning rates (TTUR). In Figure 3, our models “SN on /” and “SN on /+TTUR” are compared with a baseline model, which is implemented based on the state-of-the-art image generation method Miyato18a . In this baseline model, SN is only utilized in the discriminator. When we train it with 1:1 balanced updates for the discriminator () and the generator (), the training becomes very unstable, as shown in the leftmost sub-figures of Figure 3. It exhibits mode collapse very early in training. For example, the top-left sub-figure of Figure 4 illustrates some images randomly generated by the baseline model at the 10k-th iteration. Although in the the original paper Miyato18a this unstable training behavior is greatly mitigated by using 5:1 imbalanced updates for and , the ability to be stably trained with 1:1 balanced updates is desirable for improving the convergence speed of the model. Thus, using our proposed techniques means that the model can produce better results given the same wall-clock time. Given this, there is no need to search for a suitable update ratio for the generator and discriminator. As shown in the middle sub-figures of Figure 3, adding SN to both the generator and the discriminator greatly stabilized our model “SN on /”, even when it was trained with 1:1 balanced updates. However, the quality of samples does not improve monotonically during training. For example, the image quality as measured by FID and IS is starting to drop at the 260k-th iteration. Example images randomly generated by this model at different iterations can be found in Figure 4. When we also apply the imbalanced learning rates to train the discriminator and the generator, the quality of images generated by our model “SN on /+TTUR” improves monotonically during the whole training process. As shown in Figure 3 and Figure 4, we do not observe any significant decrease in sample quality or in the FID or the Inception score during one million training iterations. Thus, both quantitative results and qualitative results demonstrate the effectiveness of the proposed stabilization techniques for GAN training. They also demonstrate that the effect of the two techniques is at least partly additive. In the rest of experiments, all models use spectral normalization for both the generator and discriminator and use the imbalanced learning rates to train the generator and the discriminator with 1:1 updates.
To explore the effect of the proposed self-attention mechanism, we build several SAGAN models by adding the self-attention mechanism to different stages of the generator and discriminator. As shown in Table 1, the SAGAN models with the self-attention mechanism at the middle-to-high level feature maps (e.g., and ) achieve better performance than the models with the self-attention mechanism at the low level feature maps (e.g., and ). For example, the FID of the model “SAGAN, ” is improved from 22.98 to 18.28 by “SAGAN, ”. The reason could be that the network receives more evidence with larger feature maps and enjoys more freedom to choose the conditions. The attention mechanism gives more power to both generator and discriminator to directly model the long-range dependencies in the feature maps. Thus, it is complementary to the convolutions, whose advantage lies in modeling local dependencies. In addition, the comparison of our SAGAN and the baseline model without attention (2nd column of Table 1) demonstrate the effectiveness of the proposed self-attention mechanism.
Compared with residual blocks with the same number of parameters, the self-attention blocks also achieve better results. For example, the training is not stable when we replace the self-attention block with the residual block in 88 feature maps, which leads to a significant decrease in performance (e.g., FID increases from 22.98 to 42.13). Even for the cases when the training goes smoothly, replacing the self-attention block with the residual block still leads to worse results in terms of FID and Inception score. (e.g., FID 18.28 vs 27.33 in feature map 32 32). This comparison demonstrates that the performance improvement given by using SAGAN is not simply due to an increase in model depth and capacity.
To better understand what has been learned during the generation process, we visualize the attention weights of the generator in SAGAN for different images. Some sample images with attention are shown in Figure 5 and Figure 1. See the caption of Figure 5 for descriptions of some of the properties of learned attention maps.
SAGAN is also compared with state-of-the-art GAN models Odena2016 ; Miyato18b for class conditional image generation on ImageNet. As shown in Table 2, our proposed SAGAN achieves the best Inception score and FID. SAGAN significantly improves the best published Inception score from 36.8 to 52.52. The lower FID (18.65) achieved by SAGAN also indicates that SAGAN can better approximate the original image distribution by using the self-attention module to model the global dependencies between image regions. Figure 6 shows some sample images generated by SAGAN.
In this paper, we proposed Self-Attention Generative Adversarial Networks (SAGANs), which incorporate a self-attention mechanism into the GAN framework. The self-attention module is effective in modeling long-range dependencies. In addition, we show that spectral normalization applied to the generator stabilizes GAN training and that TTUR speeds up training of regularized discriminators. SAGAN achieves the state-of-the-art performance on class-conditional image generation on ImageNet.
We thank Surya Bhupatiraju for feedback on drafts of this article. We also thank David Berthelot and Tom B. Brown for help with implementation details. Finally, we thank Jakob Uszkoreit, Tao Xu, and Ashish Vaswani for helpful discussions.
DRAW: A recurrent neural network for image generation.In ICML, 2015.
Image-to-image translation with conditional adversarial networks.In CVPR, 2017.
Conditional image synthesis with auxiliary classifier gans.In ICLR, 2017.