, the dominant architectures for a variety of natural language processing (NLP) tasks, have been attracting an ever-increasing research interest in the computer vision community since the success of the Vision Transformer (ViT)[vit]. Built on top of self-attention mechanisms, transformers are capable of capturing long-range dependencies among pixels/patches from input images effectively, which is arguably one of the main reasons that they outperform standard CNNs in vision tasks spanning from image classification [touvron2021training_deit, t2t, chen2021crossvit, li2021localvit, wu2021cvt, liu2021swin, han2021transformer_tnt] to object detection [liu2021swin, wang2021pyramid, carion2020end_detr, chu2021twins], action recognition [fan2021multiscale, liu2021videoswin, zhang2021vidtr] and so forth.
Recent studies on vision transformers [vit, touvron2021training_deit, t2t, chen2021crossvit] typically adopt the Transformer [transformer] architecture from NLP with minimal surgery. Taking a sequence of sliced image patches analogous to tokens/words as inputs, the transformer backbone consists of stacked building blocks with two sublayers, i.e. a self-attention layer and a feed-forward network. To ensure that the model can attend to information from different representation subspaces jointly, multi-head attention is used in each block instead of a single attention function [transformer]. While these self-attention-based vision transformers have outperformed CNNs on a multitude of benchmarks like ImageNet [deng2009imagenet], the competitive performance does not come for free—the computational cost of the stacked attention blocks with multiple heads is large, which further grows quadratically with the number of patches.
But are all patches needed to be attended to throughout the network for correctly classifying images? Do we need all the self-attention blocks with multiple heads tolook for where to attend to and model the underlying dependencies for all different images? After all, large variations exist in images such as object shape, object size, occlusion and background complexity. Intuitively, more patches and self-attention blocks are required for complex images containing cluttered background or occluded objects, which require sufficient contextual information and understanding of the whole image so as to infer their ground-truth classes (e.g. the barber shop in Figure 1), while only a small number of informative patches and attention heads/blocks are enough to classify easy images correctly.
With this in mind, we seek to develop an adaptive computation framework that learns which patches to use and which self-attention heads/blocks to activate on a per-input basis. By doing so, the computational cost of vision transformers can be saved through discarding redundant input patches and backbone network layers for easy samples, and only using full model with all patches for hard and complex samples. This is an orthogonal and complementary direction to recent approaches on efficient vision transformers that focus on designing static network architectures [liu2021swin, chen2021crossvit, t2t, graham2021levit].
To this end, we introduce Adaptive Vision Transformer (AdaViT), an end-to-end framework that adaptively determines the usage of patches, heads and layers of vision transformers conditioned on input images for efficient image classification. Our framework learns to derive instance-specific inference strategies on: 1) which patches to keep; 2) which self-attention heads to activate; and 3) which transformer blocks to skip for each image, to improve the inference efficiency with a minimal drop of classification accuracy. In particular, we insert a light-weight multi-head subnetwork (i.e. a decision network) to each transformer block of the backbone network, which learns to predict binary decisions on the usage of patch embeddings, self-attention heads and blocks throughout the network. Since binary decisions are non-differentiable, we resort to Gumbel-Softmax [maddison2016concrete_gumbel] during training to make the whole framework end-to-end trainable. The decision network is jointly optimized with the transformer backbone with a usage loss that measures the computational cost of the produced usage policies and a normal cross-entropy loss, which incentivizes the network to produce policies that reduce the computational cost while maintaining classification accuracy. The overall target
computational cost can be controlled by hyperparametercorresponding to the percentage of computational cost of the full model with all patches as input during training, making the framework flexible to suit the need of different computational budgets.
We conduct extensive experiments on ImageNet [deng2009imagenet] to validate the effectiveness of AdaViT and show that our method is able to improve the inference efficiency of vision transformers by more than with only drop of classification accuracy, achieving good trade-offs between efficiency and accuracy when compared with other standard vision transformers and CNNs. In addition, we conduct quantitative and qualitative analyses on the learned usage policies, providing more intuitions and insights on the redundancy in vision transformers. We further show visualizations and demonstrate that AdaViT learns to use more computation for relatively hard samples with complex scenes, and less for easy object-centric samples.
2 Related Work
Vision Transformers. Inspired by its great success in NLP tasks, many recent studies have explored adapting the Transformer [transformer] architecture to multiple computer vision tasks [vit, liu2021swin, fan2021multiscale, wang2021end_vistr, ranftl2021vision, xie2021segformer, el2021training, he2021transreid, pan20213d_pointformer, mao2021voxel]. Following ViT [vit], a variety of vision transformer variants have been proposed to improve the recognition performance as well as training and inference efficiency. DeiT [touvron2021training_deit] incorporates distillation strategies to improve training efficiency of vision transformers, outperforming standard CNNs without pretraining on large-scale dataset like JFT [sun2017revisiting_jft]. Other approaches like T2T-ViT [t2t], Swin Transformer [liu2021swin], PVT [wang2021pyramid] and CrossViT [chen2021crossvit] seek to improve the network architecture of vision transformers. Efforts have also been made to introduce the advantages of 2D CNNs to transformers through using convolutional layers [li2021localvit, xiao2021early], hierarchical network structures [liu2021swin, liu2021videoswin, wang2021pyramid], multi-scale feature aggregation [fan2021multiscale, chen2021crossvit] and so on. While obtaining superior performance, the computational cost of vision transformers is still intensive and scales up quickly as the numbers of patches, self-attention heads and transformer blocks increase.
Efficient Networks. Extensive studies have been conducted to improve the efficiency of CNNs for vision tasks through designing effective light-weight network architectures like MobileNets [mobilenets, mobilenetv2, howard2019searching_mobilenetv3], EfficientNets [tan2019efficientnet] and ShuffleNets [zhang2018shufflenet, ma2018shufflenetv2]. To match the inference efficiency of standard CNNs, recent work has also explored developing efficient vision transformer architectures. T2T-ViT [t2t] proposes to use a deep-narrow structure and a token-to-token module, achieving better accuracy and less computational cost than ViT [vit]. LeViT [graham2021levit] and Swin Transformer [liu2021swin] develop multi-stage network architectures with down-sampling and obtain better inference efficiency. These methods, however, use a fixed network architecture for all input samples regardless of the redundancy in patches and network architecture for easy samples. Our work is orthogonal to this direction and focuses on learning input-specific strategies that adaptively allocate computational resources for saved computation and a minimal drop in accuracy at the same time.
Adaptive Computation. Adaptive computation methods exploit the large variations within network inputs as well as the redundancy in network architectures to improve efficiency with instance-specific inference strategies. In particular, existing methods for CNNs have explored altering input resolution [autofocus, arnet, huanggaoresolution, whenandwhere], skipping network layers [andreasadaptive, skipnet, blockdrop, figurnov2017spatially] and channels [channelgated, runtime], early exiting with a multi-classifier structure [icmladaptive, multiscale_densenet, huanggaoimproved], to name a few. A few attempts have also been made recently to accelerate vision transformers with adaptive inference policies exploiting the redundancy in patches, i.e. producing policies on what patch size [wang2021not] and which patches [pan2021iared2, rao2021dynamicvit] to use conditioned on input image. In contrast, we exploit the redundancy in the attention mechanism of vision transformer and propose to improve efficiency by adaptively choosing which self-attention heads, transformer blocks and patch embeddings to keep/drop conditioned on the input samples.
We propose AdaViT, an end-to-end adaptive computation framework to reduce the computational cost of vision transformers. Given an input image, AdaViT learns to adaptively derive policies on which patches, self-attention heads and transformer blocks to use or activate in the transformer backbone conditioned on the input image, encouraging using less computation while maintaining the classification accuracy. An overview of our method is shown in Figure 2. In this section, we first give a brief introduction of vision transformers in Sec. 3.1. We then present our proposed method in Sec. 3.2 and elaborate the optimization function of the framework in Sec. 3.3.
Vision transformers [vit, touvron2021training_deit, t2t] for image classification take a sequence of sliced patches from image as input, and model their long-range dependencies with stacked multi-head self-attention layers and feed-forward networks***In this section we consider the architecture of ViT [vit], and extend it to other variants of vision transformers is straightforward.. Formally, for an input image , it is first split into a sequence of fixed-size 2D patches where is the number of patches (e.g. ). These raw patches are then mapped into -dimensional patch embeddings with a linear layer. A learnable embedding termed class token is appended to the sequence of patch embeddings, which serves as the representation of image. Positional embeddings are also optionally added to patch embeddings to augment them with positional information. To summarize, the input to the first transformer block is:
where and respectively.
Similar to Transformers [transformer] in NLP, the backbone network of vision transformers consist of blocks, each of which consists of a multi-head self-attention layer (MSA) and a feed-forward network (FFN). In particular, a single-head attention is computed as below:
where are—in a broad sense—query, key and value matrices respectively, and is a scaling factor. For vision transformers, are projected from the same input, i.e. patch embeddings. For more effective attention on different representation subspaces, multi-head self-attention concatenates the output from several single-head attentions and projects it with another parameter matrix:
where are the parameter matrices in the -th attention head of the -th transformer block, and denotes the input at the -th block. The output from MSA is then fed into FFN, a two-layer MLP, and produce the output of the transformer block
. Residual connections are also applied on both MSA and FFN as follows:
The final prediction is produced by a linear layer taking the class token from last transformer block () as inputs.
3.2 Adaptive Vision Transformer
While large vision transformer models have achieved superior image classification performance, the computational cost grows quickly as we increase the numbers of patches, attention heads and transformer blocks to obtain higher accuracies. In addition, a computationally expensive one-size-fit-all network is often an overkill for many easy samples. To remedy this, AdaViT learns to adaptively choose 1) which patch embeddings to use; 2) which self-attention heads in MSA to activate; and 3) which transformer block to skip—on a per-input basis—to improve the inference efficiency of vision transformers. We achieve this by inserting a light-weight decision network before each of the transformer blocks, and it is trained to produce the three sets of usage policies for this block.
Decision Network. The decision network at -th block consists of three linear layers with parameters to produce computation usage policies for patch selection, attention head selection and transformer block selection respectively. Formally, given the input to -th block , the usage policy matrices for this block is computed as follows:
where and denote the numbers of patches and self-attention heads in a transformer block, and . Each entry of , and is further passed to a sigmoid
function, indicating the probability of keeping the corresponding patch, attention head and transformer block respectively. The-th decision network shares the output from previous transformer blocks, making the framework more efficient than using a standalone decision network.
As the decisions are binary, the action of keeping / discarding can be selected by simply applying a threshold on the entries during inference. However, deriving the optimal thresholds for different samples is challenging. To this end, we define random variables, , to make decisions by sampling from , and . For example, the -th patch embedding in -th block is kept when , and dropped when . We relax the sampling process with Gumbel-Softmax trick [maddison2016concrete_gumbel] to make it differentiable during training, which will be further elaborated in Sec. 3.3.
Patch Selection. For the input to each transformer block, we aim at keeping only the most informative patch embeddings and discard the rest to speedup inference. More formally, for -th block, the patches are removed from the input to this block if the corresponding entries in equal to :
The class token is always kept since it is used as representation of the whole image.
Head Selection. Multi-head self attention enables the model to attend to different subspaces of the representation jointly [transformer] and is adopted in most, if not all, vision transformer variants [vit, touvron2021training_deit, t2t, chen2021crossvit, liu2021swin]. Such a multi-head design is crucial to model the underlying long-range dependencies in images especially those with complex scenes and cluttered background, but fewer attention heads could arguably suffice to look for where to attend to in easy images. With this in mind, we explore dropping attention heads adaptively conditioned on input image for faster inference. Similar to patch selection, the decision of activating or deactivating certain attention head is determined by the corresponding entry in . The “deactivation” of an attention head can be instantiated in different ways. In our framework, we explore two methods for head selection, namely partial deactivation and full deactivation.
For partial deactivation, the softmax output in attention as in Eqn. 2 is replaced with predefined ones like an identity matrix , such that the cost of computing attention map is saved. The attention in -th head of -th block is then computed as:
For full deactivation, the entire head is removed from the multi-head self attention layer, and the embedding size of the output from MSA is reduced correspondingly:
In practice, full deactivation saves more computation compared with partial deactivation when same percentage of heads are deactivated , yet is likely to incur more classification errors as the embedding size is manipulated on-the-fly.
Block Selection. In addition to patch selection and head selection, a transformer block can also be favourably skipped entirely when it is redundant, by virtue of the residual connections throughout the network. To increase the flexibility of layer skipping, we increase the dimension of block usage policy matrix from to , enabling the two sublayers (MSA and FFN) in each transformer block to be controlled individually. Eqn. 5 then becomes:
In summary, given the input of each transformer block, the decision network produces the usage policies for this block, and then the input is forwarded through the block with the decisions applied. Finally, the classification prediction from the last layer and the decisions for all blocks , , are obtained.
|Method||Top-1 Acc (%)||FLOPs (G)||Image Size||# Patch||# Head||# Block|
|ResNet-50* [resnet, t2t]||79.1||4.1||224224||-||-||-|
|ResNet-101* [resnet, t2t]||79.9||7.9||224224||-||-||-|
3.3 Objective Function
Since our goal is to reduce the overall computational cost of vision transformers with a minimal drop in accuracy, the objective function of AdaViT is designed to incentivize correct classification and less computation at the same time. In particular, a usage loss and a cross-entropy loss are used to jointly optimize the framework. Given an input image with a label , the final prediction is produced by the transformer with parameters , and the cross-entropy loss is computed as follows:
While the binary decisions on whether to keep/discard a patch/head/block can be readily obtained through applying a threshold during inference, determining the optimal thresholds is challenging. In addition, such an operation is not differentiable during training and thus makes the optimization of decision network challenging. A common solution is to resort to reinforcement learning and optimize the network with policy gradient methods[policygradient]
, yet it can be slow to converge due to the large variance that scales with the dimension of discrete variables[policygradient, maddison2016concrete_gumbel]. To this end, we use the Gumbel-Softmax trick [maddison2016concrete_gumbel] to relax the sampling process and make it differentiable. Formally, the decision at -th entry of is derived in the following way:
where is the total number of categories ( for binary decision in our case), and is the Gumbel distribution in which is sampled from
, an i.i.d uniform distribution. Temperatureis used to control the smoothness of .
To encourage reducing the overall computational cost, we devise the usage loss as follows:
denote the sizes of flattened probability vectors from the decision network for patch/head/block selection,i.e. the total numbers of patches, heads and blocks of the entire transformer respectively. The hyperparameters indicate target computation budgets in terms of the percentage of patches/heads/blocks to keep.
Finally, the two loss functions are combined and minimized in an end-to-end manner as in Eqn.14.
4.1 Experimental Setup
Dataset and evaluation metrics.
Dataset and evaluation metrics.We conduct experiments on ImageNet [deng2009imagenet] with 1.2M images for training and 50K images for validation, and report the Top-1 classification accuracy. To evaluate model efficiency, we report the number of giga floating-point operations (GFLOPs) per image.
Implementation details. We use T2T-ViT [t2t] as the transformer backbone due to its superior performance on ImageNet with a moderate computational cost. The backbone consists of blocks and heads in each MSA layer, and the number of tokens . The decision network is attached to each transformer block starting from -nd block. For head selection, we use the full deactivation method if not mentioned otherwise. We initialize the transformer backbone of AdaViT with the pretrained weights released in the official implementation of [t2t]. We will release the code.
We use 8 GPUs with a batch size 512 for training. The model is trained with a learning rate , a weight decay and a cosine learning rate schedule for epochs following [t2t]. AdamW [loshchilov2017decoupled_adamw] is used as the optimizer. For all the experiments, we set the input size to . Temperature in Gumbel-Softmax is set to . The choices of vary flexibly for different desired trade-offs between classification accuracy and computational cost.
4.2 Main Results
We first evaluate the overall performance of AdaViT in terms of classification accuracy and efficiency, and report the results in Table 1. Besides standard CNN and transformer architectures such as ResNets [resnet], ViT [vit], DeiT [touvron2021training_deit], T2T-ViT [t2t] and so on, we also compare our method with the following baseline methods:
Upperbound: The original pretrained vision transformer model, with all patch embeddings kept as input and all self-attention heads and transformer blocks activated. This serves as an “upperbound” of our method regarding classification accuracy.
Random: Given the usage policies produced by AdaVit, we generate random policies on patch selection, head selection and block selection that use similar computational cost and apply them to the pretrained models to validate the effectiveness of learned policies.
Random+: The pretrained models are further finetuned with the random policies applied, in order to adapt to the varied input distribution and network architecture incurred by the random policies.
As shown in Table 1, AdaViT is able to obtain good efficiency improvement with only a small drop on classification accuracy. Specifically, AdaViT obtains Top-1 accuracy requiring GFLOPs per image during inference, achieving more than efficiency than the original T2T-ViT model with only drop of accuracy. Compared with standard ResNets [resnet] and vision transformers that use a similar backbone architecture of ours [vit, touvron2021training_deit, t2t, chen2021crossvit], AdaViT obtains better classification performance with less computational cost, achieving a good efficiency/accuracy trade-off as further shown in Figure 3. It is also worth pointing out that compared with vision transformer variants [liu2021swin, wang2021pyramid] which resort to advanced design choices like multi-scale feature pyramid and hierarchical downsampling, our method still obtains comparable or better accuracy under similar computational cost.
When using a similar computation budget, AdaViT outperforms random and random+ baselines by clear margins. Specifically, Ada-ViT with T2T-ViT as the backbone network obtains and higher accuracy than random and random+ respectively at a similar cost of GFLOPs per image, demonstrating that the usage policies learned by AdaViT can effectively maintain classification accuracy and reduce computational cost at the same time.
AdaViT with different computational budgets. AdaViT is designed to accommodate the need of different computational budgets flexibly by varying the hyperparameters and as discussed in Section 3.2. As demonstrated in Figure 4(a), AdaViT is able to cover a wide range of tradeoffs between efficiency and accuracy, and outperforms Random+ baselines by a large margin.
|Method||Top-1 Acc||% Head||GFLOPs|
4.3 Ablation Study
Effectiveness of learned usage policies. Here we validate that each of the three sets of learned usage policies is able to effectively maintain the classification accuracy while reducing the computational cost of vision transformers. For this purpose, we replace the learned usage policies with randomly generated policies that cost similar computational resources and report the results in Table 2. As shown in Table 2, changing any set of learned policies to a random one results in a drop of accuracy by a clear margin. Compared with random patch/head/block selection, AdaViT obtains higher accuracy under similar computational budget. This confirms the effectiveness of each learned usage policy.
Ablation of individual components. Having demonstrated the effectiveness of the jointly learned usage policies for patch, head and block selection, we now evaluate the performance when only one of the three selection methods is used. It is arguable that part of the performance gap in Table 2 results from the change of input/feature distribution when random policies are applied, and thus we compare each component with its further finetuned Random+ counterparts. For faster training and evaluation, we train these models for epochs. As shown in Figure 4(b-d), our method with only patch/head/block selection is also able to cover a wide range of accuracy/efficiency tradeoffs and outperforms Random+ baselines by a clear margin, confirming the effectiveness of each component.
Partial vs. Full deactivation for head selection. As discussed in Sec. 3.2, we propose two methods to deactivate a head in the multi-head self-attention layer, namely partial deactivation and full deactivation. We now analyze their effectiveness on improving the efficiency of vision transformers. As demonstrated in Table 3, when deactivating the same percentage (i.e. ) of self-attention heads within the backbone, partial deactivation is able to obtain much higher accuracy than full deactivation ( vs. ), but also incurs higher computational cost ( vs. GFLOPs). This is intuitive since partial deactivation only skips the computation of attention maps before Softmax, while full deactivation removes the entire head and its output to the FFN. As the number of heads increases, full deactivation obtains better accuracy gradually. In practice these different head selection methods provide more flexible options to suit different computational budgets.
Computational saving throughout the network. AdaViT exploits the redundancy of computation to improve the efficiency of vision transformers. To better understand such redundancy, we collect the usage policies on patch/head/block selection predicted by our method on the validation set and show the distribution of computational cost (i.e. percentage of patches/heads/blocks kept) throughout the backbone network. As shown in Figure 5, AdaViT tends to allocate more computation in earlier stages of the network. In particular, for patch selection, the average number of kept patches in each transformer block gradually decrease until the final output layer. This is intuitive since the patches keep aggregating information from all other patches in the stacked self-attention layers, and a few informative patches near the output layer would suffice to represent the whole input image for correct classification. As visualized in Figure 7, the number of selected patches gradually decreases with a focus on the discriminative part of the images.
For head selection and block selection, the patterns are a bit different from token selection, where relatively more computation is kept in the last few blocks. We hypothesize that the last few layers in the backbone are more responsible for the final prediction and thus are kept more often.
Learned usage policies for different classes. We further analyze the distribution of learned usage policies for different classes. In Figure 8, we show the box plot of several classes that are allocated the most/least computational resources. As can be seen, our method learns to allocate more computation for difficult classes with complex scenes such as “shoe shop”, “barber shop”, “toyshop” but uses less computation for relatively easy and object-centric classes like “parachute” and “kite”.
Qualitative Results. Images allocated with the least and the most computation by our method are shown in Figure 6. It can be seen that object-centric images with simple background (like the parachute and the tennis ball) tend to use less computation, while hard samples with clutter background (e.g. the drum and the toy shop) are allocated more.
Limitation. One potential limitation is that there is still a small drop of accuracy when comparing our method with the Upperbound baseline, which we believe would be further addressed in future work.
In this paper we presented AdaViT, an adaptive computation framework that learns which patches, self-attention heads and blocks to keep throughout the transformer backbone on a per-input basis for an improved efficiency for image recognition. To achieve this, a light-weight decision network is attached to each transformer block and optimized with the backbone jointly in an end-to-end manner. Extensive experiments demonstrated that our method obtains more than improvement on efficiency with only a small drop of accuracy compared with state-of-the-art vision transformers, and covers a wide range of efficiency/accuracy trade-offs. We further analyzed the learned usage policies quantitatively and qualitatively, providing more insights on the redundancy in vision transformers.
Appendix A Qualitative Results
We further provide more qualitative results in addition to those in the main text. Images that are allocated the least/most computational resources by our method are shown in Figure 9, demonstrating that our method learns to use less computation on easy object-centric images and more computation on hard complex images with cluttered background. Figure 11 shows more visualization of the learned usage policies for patch selection, demonstrating the pattern that our method allocates less and less computation gradually throughout the backbone network, which indicates that more redundancy in computation resides in the later stages of the vision transformer backbone.
Appendix B Compatibility to Other Backbones
Our method is by design model-agnostic and thus can be applied to different vision transformer backbones. To verify this, we use DeiT-small [touvron2021training_deit] as the backbone of AdaViT and show the results in Figure 10. AdaViT achieves better efficiency/accuracy tradeoff when compared with standard variants of DeiT, and consistently outperforms its Random+ baseline by large margins, as demonstrated in Figure 10(a) and 10(b) respectively.
We further show the visualization of patch selection usage policies with DeiT-small as the backbone as well in Figure 11. A similar trend of keeping more computation at earlier layers and gradually allocating less computation throughout the network is also observed.