Convolutional neural networks (CNNs) lecun_89
are dominating computer vision today, thanks to their powerful capability of convolution (weight-sharing and local-connectivity) and pooling (translation equivariance). Recently, however, Transformer architecturesvaswani2017transformer have started to rival CNNs in image chen2020generative; dosovitskiy2021vit; touvron2020deit and video bertasius2021timesformer; arnab2021vivit recognition tasks. In particular, Vision Transformers (ViTs) dosovitskiy2021vit, which interprets an image as a sequence of tokens (analogous to words
in natural language), has been shown by Dosovitskiy to achieve comparable classification accuracy with smaller computational budgets (, fewer FLOPs) on the ImageNet benchmark. Different from local-connectivity in CNNs, ViTs rely on globally-contextualized representation where each patch is attended toall patches of the same image. ViTs, along with their variants touvron2020deit; tolstikhin2021mlp, though still in their infancy, have demonstrated promising advantages in modeling non-local contextual dependencies Ranftl2021dpt; strudel2021segmenter as well as excellent efficiency and scalability. Since its recent inception, ViTs have been used in various other tasks such as object detection beal2020toward, video recognition bertasius2021timesformer; arnab2021vivit, multitask pretraining chen2020pre, .
In this paper, we are interested in examining whether the task of image generation can be achieved by Vision Transformers without using convolution or pooling, and more specifically, whether ViTs can be used to train generative adversarial networks (GANs) with competitive quality to well-studied CNN-based GANs. To this end, we train GANs with Vanilla-ViT (as in Fig. 2) following the design of the original ViT (dosovitskiy2021vit)
. The challenge is that GAN training becomes highly unstable when coupled with ViTs, and that adversarial training is frequently hindered by high-variance gradients (or spiking gradients) in the later stage of discriminator training. Furthermore, conventional regularization methods such as gradient penaltyWGAN-GP; r1_penalty, spectral normalization miyato2018spectral cannot resolve the instability issue even though they are proved to be effective for CNN-based GAN models (shown in Fig. 4). As unstable training is uncommon in the CNN-based GANs training with appropriate regularization, this presents a unique challenge to the design of ViT-based GANs.
Hence, in this paper, we propose several necessary modifications to stabilize the training dynamics and facilitate convergence of ViT-based GANs. In the discriminator, we revisit the Lipschitz property of self-attention kim2021lipschitz, and further design an improved spectral normalization that enforces Lipschitz continuity. Different from the conventional spectral normalization, which fails to resolve the instability issue, these techniques exhibit high efficacy in stabilizing the training dynamics of the ViT-based discriminators. We conduct ablation studies (Fig. 4 and Table 2(b)) to validate the necessity of the proposed techniques and their central role in achieving stable and superior image generation performance. For the ViT-based generator, we study a variety of architecture designs and discover two key modifications to the layer normalization and output mapping layers. Experiments show that the modified ViT-based generator can better facilitate the adversarial training with both ViT-based and CNN-based discriminators.
We perform experiments on three standard image synthesis benchmarks. The results show that our model, named ViTGAN, outperforms the previous Transformer-based GAN model (jiang2021transgan) by a large margin, and achieves comparable performance to leading CNN-based GANs such as StyleGAN2 Karras2019stylegan2, even without using convolution or pooling. To the best of our knowledge, the proposed ViTGAN model is among the first approaches that leverage Vision Transformers in GANs, and more importantly, the first to demonstrate such Transformer’s comparable performance over the state-of-the-art convolutional architecture (Karras2019stylegan2; biggan) on standard image generation benchmarks including CIFAR, CelebA, and LSUN bedroom datasets.
2 Related Work
Generative Adversarial Networks
Generative adversarial networks (GANs) goodfellow2014generative model the target distribution using adversarial learning. It is typically formulated as a min-max optimization problem minimizing some distance between the real and generated data distributions, , through various -divergences nowozin2016f
or integral probability metrics (IPMs)muller1997integral; song2019bridging such as the Wasserstein distance WGAN.
The GAN models are notorious for unstable training dynamics. As a result, numerous efforts have been proposed aiming at stabilizing training, thereby ensuring convergence. Common approaches include spectral normalization miyato2018spectral, gradient penalty WGAN-GP; r1_penalty; kodali2017convergence, consistency regularization CRGAN; ICRGAN, and data augmentation zhao2020diffaugment; Karras2020ada; Zhao2020aug; Tran2021OnDA
. These techniques are all designed inside convolutional neural networks (CNN) and have been only verified in the convolutional GAN models. However, we find that these methods are insufficient for stabilizing the training of Transformer-based GANs. A similar finding was reported inchen2021selfvit on a different task of pretraining. This may stem from ViTs’ exceeding capability dosovitskiy2021vit and capturing of a different type of inductive bias than CNN via self-attention kim2021lipschitz; dong2021attention; cordonnier2020multihead. This paper introduces novel techniques to overcome the unstable adversarial training of Vision Transformers.
Vision Transformer (ViT) dosovitskiy2021vit is a convolution-free Transformer that performs image classification over a sequence of image patches. ViT demonstrates the superiority of the Transformer architecture over the classical CNNs by taking advantage of pretraining on large-scale datasets. Afterward, DeiT touvron2020deit improves ViTs’ sample efficiency by knowledge distillation as well as regularization tricks. MLP-Mixer tolstikhin2021mlp further drops self-attention and replaces it by an MLP to mix the per-location feature. In parallel, ViT has been extended to various computer vision tasks such as object detection beal2020toward, action recognition in video bertasius2021timesformer; arnab2021vivit, multitask pretraining chen2020pre. Our work is among the first to exploit Vision Transformers in the GAN model for image generation.
Generative Transformer in Vision
Motivated by the success of GPT-3gpt3, a few pilot works study image generation using Transformer by autoregressive learning chen2020generative; esser2020taming or cross-modal learning between image and text ramesh2021zero. These methods are different from ours as they model image generation as a autoregressive sequence learning problem. On the contrary, our work trains Vision Transformers in the generative adversarial training paradigm. The closest work to ours is TransGAN jiang2021transgan, presenting a pure transformer based GAN model. While proposing multi-task co-training and localized initialization for better training, TransGAN neglects key techniques for training stability and underperforms the leading convolutional GAN models by a considerable margin. By virtue of our design, this paper is the first to demonstrate Transformer-based GANs are able to achieve competitive performance compared to state-of-the-art CNN-based GAN models.
3 Preliminaries: Vision Transformers (ViTs)
Vision Transformer (dosovitskiy2021vit) is a pure transformer architecture for image classification that operates upon a sequence of image patches. The 2D image is flattened into a sequence of image patches, following the raster scan, denoted by , where is the effective sequence length and is the dimension of each image patch.
Following BERT (devlin2019bert), a learnable classification embedding is prepended to the image sequence along with the added 1D positional embeddings to formulate the patch embedding . The architecture of ViT closely follows the Transformer architecture vaswani2017transformer.
Equation 2 applies multi-headed self-attention (MSA). Given learnable matrices corresponding to query, key, and value representations, a single self-attention head (indexed with ) is computed by:
where , , and . Multi-headed self-attention aggregates information from self-attention heads by means of concatenation and linear projection, as follows:
Fig. 1 illustrates the architecture of the proposed ViTGAN with a ViT discriminator and a ViT-based generator. We find that directly using ViT as the discriminator makes the training volatile. We introduce techniques to both generator and discriminator to stabilize the training dynamics and facilitate the convergence: (1) regularization on ViT discriminator and (2) new architecture for generator.
4.1 Regularizing ViT-based discriminator
Enforcing Lipschitzness of Transformer Discriminator
Lipschitz continuity plays a critical role in GAN discriminators. It was first brought to attention as a condition to approximate the Wasserstein distance in WGAN WGAN, and later was confirmed in other GAN settings FedusRLDMG18; miyato2018spectral; SAGAN beyond the Wasserstein loss. In particular, ZhouLSYW00Z19 proves that Lipschitz discriminator guarantees the existence of the optimal discriminative function as well as the existence of a unique Nash equilibrium. A very recent work kim2021lipschitz, however, shows that Lipschitz constant of standard dot product self-attention (, Equation 5) layer can be unbounded, rendering Lipschitz continuity violated in ViTs. To enforce Lipschitzness of our ViT discriminator, we adopt L2 attention proposed in kim2021lipschitz. As shown in Equation 7, we replace the dot product similarity with Euclidean distance and also tie the weights for the projection matrices for query and key in self-attention:
, , and are the projection matrices for query, key, and value, respectively. computes vectorized distances between two sets of points. is the feature dimension for each head. This modification improves the stability of Transformers when used for GAN discriminators.
Improved Spectral Normalization.
To further strengthen the Lipschitz continuity, we also apply spectral normalization (SN) (miyato2018spectral)
in the discriminator training. The standard SN uses power iterations to estimate spectral norm of the projection matrix for each layer in the neural network. Then it divides the weight matrix with the estimated spectral norm, so Lipschitz constant of the resulting projection matrix equals 1. We find Transformer blocks are sensitive to the scale of Lipschitz constant, and the training exhibits very slow progress when the SN is used (Table2(b)). Similarly, we find R1 gradient penalty cripples GAN training when ViT-based discriminators are used (Figure 4). (dong2021attention) suggests that the small Lipschitz constant of MLP block may cause the output of Transformer collapse to a rank-1 matrix. To resolve this, we propose to increase the spectral norm of the projection matrix.
We find that multiplying the normalized weight matrix of each layer with the spectral norm at initialization is sufficient to solve this problem. Concretely, we use the following update rule for our spectral normalization, where computes the standard spectral norm of weight matrices:
Overlapping Image Patches.
ViT discriminators are prone to overfitting due to their exceeding learning capacity. Our discriminator and generator use the same image representation that partitions an image as a sequence of non-overlapping patches according to a predefined grid . These arbitrary grid partitions, if not carefully tuned, may encourage the discriminator to memorize local cues and stop providing meaningful loss for the generator. We use a simple trick to mitigate this issue by allowing some overlap between image patches. For each border edge of the patch, we extend it by pixels, making the effective patch size .
This results in a sequence with the same length but less sensitivity to the predefined grids. It may also give the Transformer a better sense of which ones are neighboring patches to the current patch, hence giving a better sense of locality.
4.2 Generator Design
Designing a generator based on the ViT architecture is a nontrivial task. A challenge is converting ViT from predicting a set of class labels to generating pixels over a spatial region. Before introducing our model, we discuss two plausible baseline models, as shown in Fig. 2 and 2. Both models swap ViT’s input and output to generate pixels from embeddings, specifically from the latent vector derived from a Gaussian noise vector by an MLP, , (called mapping network karras2019style in Fig. 2). The two baseline generators differ in their input sequences. Fig. 2 takes as input a sequence of positional embeddings and adds the intermediate latent vector to every positional embedding. Alternatively, Fig. 2 prepends the sequence with the latent vector. This design is inspired by inverting ViT where is used to replace the classification embedding in Equation 4.
To generate pixel values, a linear projection is learned in both models to map a -dimensional output embedding to an image patch of shape . The sequence of image patches are finally reshaped to form an whole image .
These baseline transformers perform poorly compared to the CNN-based generator. We propose a novel generator following the design principle of ViT. Our ViTGAN Generator, shown in Fig. 2 (c), consists of two components (1) a transformer block and (2) an output mapping layer.
The proposed generator incorporates two modifications to facilitate the training.
Instead of sending the noise vector as the input to ViT, we use to modulate the layernorm operation in Equation 10. This is known as self-modulation (chen2018selfmod) since the modulation depends on no external information. The self-modulated layernorm (SLN) in Equation 10 is computed by:
where and track the mean and the variance of the summed inputs within the layer, and and compute adaptive normalization parameters controlled by the latent vector derived from . is the element-wise dot product.
Implicit Neural Representation for Patch Generation.
We use an implicit neural representation (Park_2019_DeepSDF; occupancy_net; fourfeat2020; sitzmann2020siren) to learn a continuous mapping from a patch embedding to patch pixel values . When coupled with Fourier features (fourfeat2020)
or sinusoidal activation functions(sitzmann2020siren), implicit representations can constrain the space of generated samples to the space of smooth-varying natural signals. Concretely, similarly to cips, where is a Fourier encoding of spatial locations and is a 2-layer MLP. For details, please refer to B. We find implicit representation to be particularly helpful for training GANs with ViT-based generators, Table 2(a).
It is noteworthy that the generator and discriminator can have different image grids and thus different sequence lengths. We find that it is often sufficient to increase the sequence length or feature dimension of only the discriminator when scaling our model to higher resolution images.
5.1 Experiment Setup
The CIFAR-10 dataset cifar10 is a standard benchmark for image generation, containing 50K training images and 10K test images. Inception score (IS) improved_gans and Fréchet Inception Distance (FID) FID are computed over the 50K images. The LSUN bedroom dataset yu15lsun is a large-scale image generation benchmark, consisting of 3 million training images and 300 images for validation. On this dataset, FID is computed against the training set due to the small validation set. The CelebA dataset CelebA comprises 162,770 unlabeled face images and 19,962 test images. By default, we generate 3232 images on the CIFAR dataset and 6464 images on the other two datasets.
For 3232 resolution, we use a 4-block ViT-based discriminator and a 4-block ViT-based generator. For 6464 resolution, we increase the number of blocks to 6. Following ViT-Small dosovitskiy2021vit, the input/output feature dimension is 384 for all Transformer blocks, and the MLP hidden dimension is 1,536. Unlike dosovitskiy2021vit, we choose the number of attention heads to be 6. We find increasing the number of heads does not improve GAN training. For 3232 resolution, we use patch size 44, yielding a sequence length of patches. For 6464 resolution, we simply increase the patch size to 88, keeping the same sequence length as in 3232 resolution.
Translation, Color, Cutout, Scaling data augmentations zhao2020diffaugment; Karras2020ada are applied with probability . All baseline transformer-based GAN models, including ours, use balanced consistency regularization (bCR) with . Other than bCR, we do not employ regularization methods typically used for training ViTs touvron2020deit such as Dropout, weight decay, or Stochastic Depth. We found that LeCam regularization tseng2021regularizing, similar to bCR, improves the performance. But for clearer ablation, we do not include the LeCam regularization. We train our models with Adam with , , and a learning rate of following the practice of Karras2019stylegan2. In addition, we employ non-saturating logistic loss goodfellow2014generative, exponential moving average of generator weights karras2018progressive, and equalized learning rate karras2018progressive. We use a mini-batch size of .
Both ViTGAN and StyleGAN2 are based on Tensorflow 2 implementationstylegan_tf2. We train our models on Google Cloud TPU v2-32 and v3-8.
5.2 Main Results
|BigGAN biggan + DiffAug zhao2020diffaugment||✓||✓||8.59||9.25||-||-||-||-|
Table 1 shows the main results on three standard benchmarks for image synthesis. Our method is compared with the following baseline architectures. TransGAN jiang2021transgan is the only existing convolution-free GAN that is entirely built on the Transformer architecture. Its best variant TransGAN-XL is compared. Vanilla-ViT is a ViT-based GAN that employs the generator illustrated in Fig. 2 and a vanilla ViT discriminator without any techniques discussed in Section 4.1. For fair comparison, R1 penalty and bCR ICRGAN + DiffAug zhao2020diffaugment were used for this baseline. The architecture with the generator illustrated in Fig. 2 is separately compared in Table 2(a). In addition, BigGAN biggan and StyleGAN2 Karras2019stylegan2 are also included as state-of-the-art CNN-based GAN models.
Our ViTGAN model outperforms other Transformer-based GAN models by a large margin. This results from the improved stable GAN training on the Transformer architecture, as shown in Fig. 4. It achieves comparable performance to the state-of-the-art CNN-based models. This result provides an empirical evidence that Transformer architectures may rival convolutional networks in generative adversarial training. Note that in Table 1 to focus on comparing the architectures, we use a generic version of StyleGAN2. More comprehensive comparisons with StyleGAN2 with data augmentation ( zhao2020diffaugment; Karras2020ada) are included in Appendix A.1.
As shown in Fig. 3, the image fidelity of the best Transformer baseline (Middle Row) has been notably improved by the proposed ViTGAN model (Last Row). Even compared with StyleGAN2, ViTGAN generates images with comparable quality and diversity. Notice there appears to be a perceivable difference between the images generated by Transformers and CNNs, in the background of the CelebA images. Both the quantitative results and qualitative comparison substantiate the efficacy of the proposed ViTGAN as a competitive Transformer-based GAN model.
5.3 Ablation Studies
We conduct ablation experiments on the CIFAR dataset to study the contributions of the key techniques and verify the design choices in our model.
Compatibility with CNN-based GAN
In Table 2, we mix and match the generator and discriminator of our ViTGAN and the leading CNN-based GAN: StyleGAN2. With the StyleGAN2 generator, our ViTGAN discriminator outperforms the vanilla ViT discriminator. Besides, our ViTGAN generator still works together with the StyleGAN2 discriminator. The results show the proposed techniques are compatible with both Transformer-based and CNN-based generators and discriminators.
Table 2(a) shows GAN performances under three different generator architectures, as shown in Figure 2. Fig. 2 underperforms other architectures. We find that Fig. 2 works well but lags behind Fig. 2 due to its instability. Regarding mapping between patch embedding and pixel, it seems consistently better to use implicit neural representation (denoted as NeurRep in Table 2(a)) than linear mapping. This observation substantiates our claim that implicit neural representation benefits for training GANs with ViT-based generators.
Table 2(b) validates the necessity of the techniques discussed in Section 4.1. First, we compare GAN performances under different regularization methods. Training GANs with ViT discriminator under R1 penalty r1_penalty is highly unstable, as shown in Figure 4, sometimes resulting in complete training failure (indicated as IS=NaN in Row 1 of Table 2(b)). Spectral normalization (SN) is better than R1 penalty. But SN still exhibits high-variance gradients and therefore suffers from low quality scores. Our +ISN regularization improves the stability significantly (Fig. 4) and achieves the best IS and FID scores as a consequence. On the other hand, the overlapping patch is a simple trick that yields further improvement over the +ISN method. However, the overlapping patch by itself does not work well (see a comparison between Row 3 and 9). The above results validate the essential role of these techniques in achieving the final performance of the ViTGAN model.
We have introduced ViTGAN, leveraging Vision Transformers (ViTs) in GANs, and proposed essential techniques to ensuring its training stability and improving its convergence. Our experiments on standard benchmarks (CIFAR-10, CelebA, and LSUN bedroom) demonstrate that the presented model achieves comparable performance to state-of-the-art CNN-based GANs. Regarding the limitation, ViTGAN is a new generic GAN model built on vanilla ViT architecture. It still cannot beat the best available CNN-based GAN model with sophisticated techniques developed over years. This could be improved by incorporating advanced training techniques (, jeong2021contrad; unet-gan) into the ViTGAN framework. We hope that ViTGAN can facilitate future research in this area and could be extended to other image isola2017image; dcvae and video mocogan; tgan2 synthesis tasks.
This work was supported in part by Google Cloud Platform (GCP) Credit Award. We would also like to acknowledge Cloud TPU support from Google’s TensorFlow Research Cloud (TFRC) program.
Appendix A More Quantitative Results
a.1 Effects of Data Augmentation
Table 4 presents the comparison of the Convolution-based GAN architectures (BigGAN and StyleGAN2) and our Transformer-based architecture (ViTGAN). This table complements the results in Table 1 of the main paper by a closer examination of the network architecture performance with and without using data augmentation. The differentiable data augmentation (DiffAug) zhao2020diffaugment is used in this study.
As shown, data augmentation plays a more critical role in ViTGAN. This is not unexpected because discriminators built on Transformer architectures are more capable of over-fitting or memorizing the data. DiffAug increases the diversity of the training data, thereby mitigating the overfitting issue in adversarial training. Nevertheless, with DiffAug, ViTGAN performs comparably to the leading-performing CNN-based GAN models: BigGAN and StyleGAN2.
In addition, Table 4 includes the model performance without using the balanced consistency regularization (bCR) ICRGAN.
|ViTGAN w/o. bCR||DiffAug||N||8.84||9.02|
Appendix B Implementation Notes
We use a simple trick to mitigate over-fitting of the ViT-based discriminator by allowing some overlap between image patches. For each border edge of the patch, we extend it by pixels, making the effective patch size , where . Although this operation has a connection to a convolution operation with kernel
and stride, we do not regard it as a convolution operator in our model because of the fact that we do not use convolution in our implementation. Note that the extraction of (non-overlapping) patches in the Vanilla ViT dosovitskiy2021vit also has a connection to a convolution operation with kernel and stride .
Each positional embedding of ViT networks is a linear projection of patch position followed by a sine activation function. The patch positions are normalized to lie between and .
Implicit Neural Representation for Patch Generation
Each positional embedding is a linear projection of pixel coordinate followed by a sine activation function (hence the name Fourier encoding). The pixel coordinates for pixels are normalized to lie between and . The 2-layer MLP takes positional embedding as its input, and it is conditioned on patch embedding via weight modulation as in Karras2019stylegan2; cips.