CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

03/27/2021 ∙ by Chun-Fu Chen, et al. ∙ 0

The recently developed vision transformer (ViT) has achieved promising results on image classification compared to convolutional neural networks. Inspired by this, in this paper, we study how to learn multi-scale feature representations in transformer models for image classification. To this end, we propose a dual-branch transformer to combine image patches (i.e., tokens in a transformer) of different sizes to produce stronger image features. Our approach processes small-patch and large-patch tokens with two separate branches of different computational complexity and these tokens are then fused purely by attention multiple times to complement each other. Furthermore, to reduce computation, we develop a simple yet effective token fusion module based on cross attention, which uses a single token for each branch as a query to exchange information with other branches. Our proposed cross-attention only requires linear time for both computational and memory complexity instead of quadratic time otherwise. Extensive experiments demonstrate that the proposed approach performs better than or on par with several concurrent works on vision transformer, in addition to efficient CNN models. For example, on the ImageNet1K dataset, with some architectural changes, our approach outperforms the recent DeiT by a large margin of 2%

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 12

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Figure 1: Improvement of our proposed approach over DeiT [37] and ViT [12]. The circle size is proportional to the model size. All models are trained on ImageNet1K from scratch. The results of ViT are referenced from [47].

The novel transformer architecture [38] has led to a big leap forward in capabilities for sequence-to-sequence modeling in NLP tasks [10]. The great success of transformers in NLP has sparked particular interest from the vision community in understanding whether transformers can be a strong competitor against the dominant Convolutional Neural Network based architectures (CNNs) in vision tasks such as ResNet [16] and EfficientNet [36]. Previous research efforts on transformers in vision have, until very recently, been largely focused on combining CNNs with self-attention [2, 50, 32, 34]. While these hybrid approaches achieve promising performance, they have limited scalability in computation compared to purely attention-based transformers. Vision Transformer (ViT) [12], which uses a sequence of embedded image patches as input to a standard transformer, is the first kind of convolution-free transformers that demonstrate comparable performance to CNN models. However, ViT requires very large datasets such as ImageNet21k [9] and JFT300M [35] for training. DeiT [37] subsequently shows that data augmentation and model regularization can enable training of high-performance ViT models with fewer data. Since then, ViT has instantly inspired several attempts to improve its efficiency and effectiveness from different aspects [37, 47, 15, 40, 20].

Along the same line of research on building stronger vision transformers, in this work, we study how to learn multi-scale feature representations in transformer models for image recognition. Multi-scale feature representations have proven beneficial for many vision tasks [5, 4, 23, 22, 26, 25, 7], but such potential benefit for vision transformers remains to be validated. Motivated by the effectiveness of multi-branch CNN architectures such as Big-Little Net [5] and Octave convolutions [6], we propose a dual-branch transformer to combine image patches (i.e. tokens in a transformer) of different sizes to produce stronger visual features for image classification. Our approach processes small and large patch tokens with two separate branches of different computational complexities and these tokens are fused together multiple times to complement each other. Our main focus of this work is to develop feature fusion methods that are appropriate for vision transformers, which has not been addressed to the best of our knowledge. We do so by an efficient cross-attention module, in which each transformer branch creates a non-patch token as an agent to exchange information with the other branch by attention. This allows for linear-time generation of the attention map in fusion instead of quadratic time otherwise. With some proper architectural adjustments in computational loads of each branch, our proposed approach outperforms DeiT [37] by a large margin of 2% with a small to moderate increase in FLOPs and model parameters (See Figure 1).

The main contributions of our work are as follows:

  • We propose a novel dual-branch vision transformer to extract multi-scale feature representations for image classification. Moreover, we develop a simple yet effective token fusion scheme based on cross-attention, which is linear in both computation and memory to combine features at different scales.

  • Our approach performs better than or on par with several concurrent works based on ViT [12], and demonstrates comparable results with EfficientNet [36] with regards to accuracy, throughput and model parameters.

2 Related Works

Our work relates to three major research directions: convolutional neural networks with attention, vision transformer and multi-scale CNNs. Here, we focus on some representative methods closely related to our work.

CNN with Attention. Attention has been widely used in many different forms to enhance feature representations, e.g., SENet [19] uses channel-attention, CBAM [43] adds the spatial attention and ECANet [39] proposes an efficient channel attention to further improve SENet. There has also been a lot of interest in combining CNNs with different forms of self-attention [3, 34, 50, 32, 2, 18, 33, 41]. SASA [32] and SAN [50] deploy a local-attention layer to replace convolutional layer. Despite promising results, prior approaches limited the attention scope to local region due to its complexity. LambdaNetwork [3] recently introduces an efficient global attention to model both content and position-based interactions that considerably improves the speed-accuracy tradeoff of image classification models. BoTNet [34]

replaced the spatial convolutions with global self-attention in the final three bottleneck blocks of a ResNet resulting in models that achieve a strong performance for image classification on ImageNet benchmark. In contrast to these approaches that mix convolution with self-attention, our work is built on top of pure self-attention network like Vision Transformer 

[12] which has recently shown great promise in several vision applications.

Vision Transformer. Inspired by the success of Transformers [38]

in machine translation, convolution-free models that only rely on transformer layers have gone viral in computer vision. In particular, Vision Transformer (ViT) 

[12] is the first such example of a transformer-based method to match or even surpass CNNs for image classification. Many variants of vision transformers have also been recently proposed that uses distillation for data-efficient training of vision transformer [37], pyramid structure like CNNs [40], or self-attention to improve the efficiency via learning an abstract representation instead of performing all-to-all self-attention [44]. Perceiver [20] leverages an asymmetric attention mechanism to iteratively distill inputs into a tight latent bottleneck, allowing it to scale to handle very large inputs. T2T-ViT [47] introduces a layer-wise Tokens-to-Token (T2T) transformation to encode the important local structure for each token instead of the naive tokenization used in ViT [12]. Unlike these approaches, we propose a dual-path architecture to extract multi-scale features for better visual representation with vision transformers.

Multi-Scale CNNs. Multi-scale feature representations have a long history in computer vision (e.g., image pyramids [1], scale-space representation [30], and coarse-to-fine approaches [29]). In the context of CNNs, multi-scale feature representations have been used for detection and recognition of objects at multiple scales [4, 23, 46, 27], as well as to speed up neural networks in Big-Little Net [5] and OctNet [6]. bLVNet-TAM [13] uses a two-branch multi-resolution architecture while learning temporal dependencies across frames. SlowFast Networks [14] rely on a similar two-branch model, but each branch encodes different frame rates, as opposed to frames with different spatial resolutions. While multi-scale features have shown to benefit CNNs, it’s applicability for vision transformer still remains as a novel and largely under-addressed problem.

3 Method

Our method is built on top of vision transformer [12], so we first present a brief overview of ViT and then describe our proposed method (CrossViT) for learning multi-scale features for image classification.

Figure 2: An illustration of our proposed transformer architecture for learning multi-scale features with cross-attention (CrossViT). Our architecture consists of a stack of multi-scale transformer encoders. Each multi-scale transformer encoder uses two different branches to process image tokens of different sizes ( and , ) and fuse the tokens at the end by an efficient module based on cross attention of the CLS tokens. Our design includes different numbers of regular transformer encoders in the two branches (i.e. N and M) to balance computational costs.

3.1 Overview of Vision Transformer

Vision Transformer (ViT) [12] first converts an image into a sequence of patch tokens by dividing it with a certain patch size and then linearly projecting each patch into tokens. An additional classification token (CLS) is added to the sequence, as in the original BERT [11]. Moreover, since self-attention in the transformer encoder is position-agnostic and vision applications highly need position information, ViT adds position embedding into each token, including the CLS token. Afterwards, all tokens are passed through stacked transformer encoders and finally the CLS token is used for classification. A transformer encoder is composed of a sequence of blocks where each block contains multiheaded self-attention () with a feed-forward network ().

contains two-layer multilayer perceptron with expanding ratio

at the hidden layer, and one GELU non-linearity is applied after the first linear layer. Layer normalization () is applied before every block, and residual shortcuts after every block. The input of ViT, , and the processing of the -th block can be expressed as

(1)

where and are the CLS and patch tokens respectively and is the position embedding. and are the number of patch tokens and dimension of the embedding, respectively.

Figure 3: Multi-scale fusion. (a) All-attention fusion where all tokens are bundled together without considering any characteristic of tokens. (b) Class token fusion, where only CLS tokens are fused as it can be considered as global representation of one branch. (c) Pairwise fusion, where tokens at the corresponding spatial locations are fused together and CLS are fused separately. (d) Cross-attention, where CLS token from one branch and patch tokens from another branch are fused together.

It is worth noting that one very different design of ViT from CNNs is the CLS token. In CNNs, the final embedding is usually obtained by averaging the features over all spatial locations while ViT uses the CLS that interacts with patch tokens at every transformer encoder as the final embedding. Thus, we consider CLS as an agent that summarizes all the patch tokens and hence the proposed module is designed based on CLS to form a dual-path multi-scale ViT.

3.2 Proposed Multi-Scale Vision Transformer

The granularity of the patch size affects the accuracy and complexity of ViT; with fine-grained patch size, ViT can perform better but results in higher FLOPs and memory consumption. For example, the ViT with a patch size of 16 outperforms the ViT with a patch size of 32 by 6% but the former needs 4 more FLOPs. Motivated by this, our proposed approach is trying to leverage the advantages from more fine-grained patch sizes while balancing the complexity. More specifically, we first introduce a dual-branch ViT where each branch operates at a different scale (or patch size in the patch embedding) and then propose a simple yet effective module to fuse information between the branches.

Figure 2 illustrates the network architecture of our proposed Cross-Attention Multi-Scale Vision Transformer (CrossViT). Our model is primarily composed of multi-scale transformer encoders where each encoder consists of two branches: (1) L-Branch: a large (primary) branch that utilizes coarse-grained patch size () with more transformer encoders and wider embedding dimensions, (2) S-Branch: a small (complementary) branch that operates at fine-grained patch size () with fewer encoders and smaller embedding dimensions. Both branches are fused together times and the CLS tokens of the two branches at the end are used for prediction. Note that for each token of both branches, we also add a learnable position embedding before the multi-scale transformer encoder for learning position information as in ViT [12].

Effective feature fusion is the key for learning multi-scale feature representations. We explore four different fusion strategies: three simple heuristic approaches and the proposed cross-attention module as shown in Figure 

3. Below we provide the details on these fusion schemes.

3.3 Multi-Scale Feature Fusion

Let be the token sequence (both patch and CLS tokens) at branch , where can be or for the large (primary) or small (complementary) branch. and represent CLS and patch tokens of branch , respectively.

All-Attention Fusion. A straightforward approach is to simply concatenate all the tokens from both branches without considering the property of each token and then fuse information via the self-attention module, as shown in Fig. 3(a). This approach requires quadratic computation time since all tokens are passed through the self-attention module. The output of the all-attention fusion scheme can be expressed as

(2)

where projects a token from one branch to another while is to back-project it to its own branch.

Class Token Fusion. The CLS token can be considered as an abstract global feature representation of a branch since it is used as the final embedding for prediction. Thus, a simple approach is to sum the CLS tokens of two branches, as shown in Figure 3(b). This approach is very efficient as only one token needs to be processed. Once CLS tokens are fused, the information will be passed back to patch tokens at the later transformer encoder. More formally, the output of this fusion module can be represented as

(3)

where and play the same role as Eq. 2.

Pairwise Fusion. Figure 3

(c) shows how both branches are fused in pairwise fusion. Since patch tokens are located at its own spatial location of an image, a simple heuristic way for fusion is to combine them based on their spatial location. However, the two branches process patches of different sizes, thus having different number of patch tokens. We first perform an interpolation to align the spatial size, and then fuse the patch tokens of both branches in a pair-wise manner. On the other hand, the two

CLS are fused separately. The output of pairwise fusion of branch and can be expressed as

(4)

where and play the same role as Eq. 2.

Figure 4: Cross attention module for Large branch. The CLS token of the large branch (circle) serves as a query token to interact with the patch tokens from the small branch through attention. and are projections to align dimensions. The small branch follows the same procedure but swaps CLS and patch tokens from another branch.

Cross-Attention Fusion. Figure 3(d) shows the basic idea of our proposed cross-attention, where the fusion involves the CLS token of one branch and patch tokens of the other branch. Specifically, in order to fuse multi-scale features more efficiently and effectively, we first utilize the CLS token at each branch as an agent to exchange information among the patch tokens from the other branch and then back project it to its own branch. Since the CLS token already learns abstract information among all patch tokens in its own branch, interacting with the patch tokens at the other branch helps to include information at a different scale. After the fusion with other branch tokens, the CLS token interacts with its own patch tokens again at the next transformer encoder, where it is able to pass the learned information from the other branch to its own patch tokens, to enrich the representation of each patch token. In the following, we describe the cross-attention module for the large branch (L-branch), and the same procedure is performed for the small branch (S-branch) by simply swapping the index and .

An illustration of the cross-attention module for the large branch is shown in Figure 4. Specifically, for branch , it first collects the patch tokens from the S-Branch and concatenates its own CLS tokens to them, as shown in Eq. 5.

(5)

where is the projection function for dimension alignment. The module then performs cross-attention () between and , where CLS token is the only query as the information of patch tokens are fused into CLS token. Mathematically, the can be expressed as

(6)

where , , are learnable parameters, and are the embedding dimension and number of heads. Note that since we only use CLS in the query, the computation and memory complexity of generating the attention map () in cross-attention are linear rather than quadratic as in all-attention, making the entire process more efficient. Moreover, as in self-attention, we also use multiple heads in the and represent it as (). However, we do not apply a feed-forward network after the cross-attention. Specifically, the output of a cross-attention module of a given with layer normalization and residual shortcut is defined as follows.

(7)

where and are the projection and back-projection function for dimension alignment, respectively. We empirically show in Section 4.3 that cross-attention achieves the best accuracy compared to other three simple heuristic approaches while being efficient for mult-scale feature fusion.

4 Experiments

In this section, we conduct extensive experiments to show the effectiveness of our proposed CrossViT over existing methods. First, we check the advantages of our proposed model over the baseline DeiT in Table 2, and then we compare with serveral concurrent ViT variants and CNN-based models in Table 3 and Table 4, respectively. Moreover, we also test the transferability of CrossViT on 5 downstream tasks (Table 5). Finally, we perform ablation studies on different fusion schemes in Table 6 and discuss the effect of different parameters of CrossViT in Table 7.

4.1 Experimental Setup

Dataset. We validate the effectiveness of our proposed approach on the ImageNet1K dataset [9], and use the top-1 accuracy on the validation set as the metrics to evaluate the performance of a model. ImageNet1K contains 1,000 classes and the number of training and validation images are 1.28 millions and 50,000, respectively. We also test the transferability of our approach using several smaller datasets, such as CIFAR10 [21] and CIFAR100 [21].

Training and Evaluation. The original ViT [12] achieves competitive results compared to some of the best CNN models but only when trained on very large-scale datasets (e.g. ImageNet21k [9] and JFT300M [35]). Nevertheless, DeiT [37]

shows that with the help of a rich set of data augmentation techniques, ViT can be trained from ImageNet alone to produce comparable results to CNN models. Therefore, in our experiments, we build our models based on DeiT 

[37], and apply their default hyper-parameters for training. These data augmentation methods include rand augmentation [8], mixup [49] and cutmix [48] as well as random erasing [8]. We also apply drop path [36] for model regularization but instance repetition [17] is only enabled for CrossViT-18 as it does not improve small models.

We train all our models for 300 epochs (30 warm-up epochs) on 32 GPUs with a batch size of 4,096. Other setup includes a cosine linear-rate scheduler with linear warm-up, an initial learning rate of 0.004 and a weight decay of 0.05. During evaluation, we resize the shorter side of an image to 256 and take the center crop 224

224 as the input. Moreover, in our evaluation, we also fine-tuned our models to a larger resolution (384384) for more fair comparison in some cases. In this case, bicubic interpolation was applied to adjust the size of the learnt position embedding, and the finetuning took 30 epochs. More details can be found in the supplemental materials and the attached codes.

Models. Table 1 specifies the architectural configurations of the CrossViT models used in our evaluation. Among these models, CrossViT-Ti, CrossViT-S and CrossViT-B set their large (primary) branches identical to the tiny (DeiT-Ti), small (DeiT-S) and base (DeiT-B) models introduced in DeiT [37], respectively. The other models vary by different expanding ratios in (), depths and embedding dimensions. In particular, the ending number in a model name tells the total number of transformer encoders in the large branch used. For example, CrossViT-15 has 3 multi-scale encoders, each of which includes 5 regular transformers, resulting in a total of 15 transformer encoders.

The original ViT paper shows that a hybrid approach that generates patch tokens from a CNN model such as ResNet-50 can improve the performance of ViT on the ImageNet1K dataset. Here we experiment with a similar idea by substituting the linear patch embedding in ViT by three convolutional layers as the patch tokenizer. These models are differentiated from others by a suffix in Table 1.

max width= Model Patch Patch size Dimension # of heads embedding Small Large Small Large CrossViT-Ti Linear 12 16 96 192 3 4 4 CrossViT-S Linear 12 16 192 384 6 4 4 CrossViT-B Linear 12 16 384 768 12 4 4 CrossViT-9 Linear 12 16 128 256 4 3 3 CrossViT-15 Linear 12 16 192 384 6 5 3 CrossViT-18 Linear 12 16 224 448 7 6 3 CrossViT-9 3 Conv. 12 16 128 256 4 3 3 CrossViT-15 3 Conv. 12 16 192 384 6 5 3 CrossViT-18 3 Conv. 12 16 224 448 7 6 3

Table 1: Model architectures of CrossViT. , , for all models, and number of heads are same for both branches. denotes the number of multi-scale transformer encoders. , and denote the number of transformer encoders of the small and large branches and the cross-attention modules in one multi-scale transformer encoder. is the expanding ratio of feed-forward network () in the transformer encoder.

max width= Model Top-1 Acc. (%) FLOPs (G) Params (M) DeiT-Ti 72.2 1.3 5.7 CrossViT-Ti 73.4 (+1.2) 1.6 6.9 CrossViT-9 73.9 (+0.5) 1.8 8.6 CrossViT-9 77.1 (+3.2) 2.0 8.8 DeiT-S 79.8 4.6 22.1 CrossViT-S 81.0 (+1.2) 5.6 26.7 CrossViT-15 81.5 (+0.5) 5.8 27.4 CrossViT-15 82.3 (+0.8) 6.1 28.2 DeiT-B 81.8 17.6 86.6 CrossViT-B 82.2 (+0.4) 21.2 104.7 CrossViT-18 82.5 (+0.3) 9.0 43.3 CrossViT-18 82.8 (+0.3) 9.5 44.3

Table 2: Comparisons with DeiT baseline on ImageNet1K. The numbers in the bracket show the improvement from each change. See Table 1 for model details.

4.2 Main Results

Comparisons with DeiT. DeiT [37] is a better trained version of ViT, we thus compare our approach with three baseline models introduced in DeiT, i.e. DeiT-Ti,DeiT-S and DeiT-B. It can be seen from Table 2 that CrossViT improves DeiT-Ti, DeiT-S and DeiT-B by 1.2%, 1.2% and 0.4% points respectively when they are used as the primary branch of CrossViT. This clearly demonstrates that our proposed cross attention is effective in learning multi-scale transformer features for image recognition. By making a few architectural changes (see Table 1), CrossViT further raises the accuracy of the baselines by another 0.3-0.5% point, with only a small increase in FLOPs and model parameters. Surprisingly, the convolution-based embedding provides a significant performance boost to CrossViT-9 (+3.2%) and CrossViT-15 (+0.8%). As the number of transformer encoders increases, the effectiveness of convolution layers seems to become weaker, but CrossViT-18 still gains another 0.3% improvement over CrossViT-18. We would like to point out that the work of T2T [47] concurrently proposes a different approach based on token-to-token transformation to address the limitation of linear patch embedding in vision transformer.

Despite the design of CrossViT is intended for accuracy, the efficiency is also considered. E.g., CrossViT-9 and CrossViT-15 incur 30-50% more FLOPs and parameters than the baselines. However, their accuracy is considerably improved by 2.5-5%. On the other hand, CrossViT-18 reduces the FLOPs and parameters almost by half compared to DeiT-B while still being 1.0% more accurate.

max width= Model Top-1 Acc. (%) FLOPs (G) Params (M) Peceiver [20] (arXiv, 2021-03) 76.4 43.9 DeiT-S [37] (arXiv, 2020-12) 79.8 4.6 22.1 CentroidViT-S [44] (arXiv, 2021-02) 80.9 4.7 22.3 PVT-S [40] (arXiv, 2021-02) 79.8 3.8 24.5 PVT-M [40] (arXiv, 2021-02) 81.2 6.7 44.2 T2T-ViT-14 [47] (arXiv, 2021-01) 80.7 6.1 21.5 TNT-S [15] (arXiv, 2021-02) 81.3 5.2 23.8 CrossViT-15 (Ours) 81.5 5.8 27.4 CrossViT-15 (Ours) 82.3 6.1 28.2 ViT-B@384 [12] (ICLR, 2021) 77.9 17.6 86.6 DeiT-B [37] (arXiv, 2020-12) 81.8 17.6 86.6 PVT-L [40] (arXiv, 2021-02) 81.7 9.8 61.4 T2T-ViT-19 [47] (arXiv, 2021-01) 81.4 9.8 39.0 T2T-ViT-24 [47] (arXiv, 2021-01) 82.2 15.0 64.1 TNT-B [15] (arXiv, 2021-02) 82.8 14.1 65.6 CrossViT-18 (Ours) 82.5 9.0 43.3 CrossViT-18 (Ours) 82.8 9.5 44.3 : We recompute the flops by using our tools.

Table 3: Comparisons with recent transformer-based models on ImageNet1K. All models are trained using only ImageNet1K dataset. Numbers are referenced from their recent version as of the submission date.

Comparisons with SOTA Transformers. We further compare our approach with some very recent concurrent works on vision transformers. They all improve the original ViT [12] with respect to efficiency, accuracy or both. Note that all of them are newly arxived and not published yet by the time our paper was submitted. As shown in Table 3, CrossViT-15 outperforms the small models of all the other approaches with comparable FLOPs and parameters. Interestingly when compared with ViT-B, CrossViT-18 significantly outperforms it by 4.9% (77.9% vs 82.8%) in accuracy while requiring 50% less FLOPs and parameters. Furthermore, CrossViT-18 performs as well as TNT-B and better than the others, but also has fewer FLOPs and parameters. Our approach is consistently better than T2T-ViT [47] and PVT [40] in terms of accuracy and FLOPs, showing the efficacy of multi-scale features in vision transformers.

max width= Model Top-1 Acc. FLOPs Throughput Params (%) (G) (images/s) (M) ResNet-101 [16] 76.7 7.80 678.1 44.6 ResNet-152 [16] 77.0 11.5 444.5 60.2 ResNeXt-101-324d [45] 78.8 8.0 477.1 44.2 ResNeXt-101-644d [45] 79.6 15.5 289.1 83.5 SEResNet-101 [19] 77.6 7.8 564.0 49.3 SEResNet-152 [19] 78.4 11.5 391.7 66.8 SENet-154 [19] 81.3 20.7 201.3 115.1 ECA-Net101 [39] 78.7 7.4 591.0 42.5 ECA-Net152 [39] 78.9 10.9 427.5 59.1 RegNetY-8GF [31] 79.9 8.0 557.0 39.2 RegNetY-12GF [31] 80.3 12.1 438.7 51.8 RegNetY-16GF [31] 80.4 15.9 336.0 83.6 RegNetY-32GF [31] 81.0 32.3 208.2 145.0 EfficienetNet-B4@380 [36] 82.9 4.2 356.4 19 EfficienetNet-B5@456 [36] 83.7 9.9 168.7 30 EfficienetNet-B6@528 [36] 84.0 19.0 99.8 43 EfficienetNet-B7@600 [36] 84.3 37.0 55.1 66 CrossViT-15 81.5 5.8 640.3 27.4 CrossViT-15 82.3 6.1 626.3 28.2 CrossViT-15@384 83.5 21.4 158.3 28.5 CrossViT-18 82.5 9.03 429.9 43.3 CrossViT-18 82.8 9.5 417.9 44.3 CrossViT-18@384 83.9 32.4 112.1 44.6 CrossViT-18@480 84.1 56.6 56.8 44.9

Table 4: Comparisons with CNN models on ImageNet1K. Models are evaluated under 224224 if not specified. The inference throughput is measured under a batch size of 64 on a Nvidia Tesla V100 GPU with cudnn 8.0. We report the averaged speed over 100 iterations.

Comparisons with CNN-based Models. CNN-based models are dominant in computer vision applications. In this experiment, we compare our proposed approach with some of the best CNN models including both hand-crafted (e.g., ResNet [16]) and search based ones (e.g., EfficientNet [36]). In addition to accuracy, FLOPs and model parameters, run-time speed is measured for all the models and shown as inference throughput (images/second) in Table 4. First, when compared to the ResNet family, including ResNet [16], ResNeXt [45], SENet [19], ECA-ResNet [39] and RegNet [31], CrossViT-15 outperforms all of them in accuracy while being smaller and running more efficiently (except ResNet-101, which is slightly faster). In addition, our best models such as CrossViT-15 and CrossViT-18, when evaluated at higher image resolution, are encouragingly competitive against EfficientNet [36] with regard to accuracy, throughput and parameters. We expect network architecture search (NAS) [51] to close the performance gap between our approach and EfficientNet.

Transfer Learning. Despite our model achieves better accuracy on ImageNet1K compared to the baselines (Table 2

), it is crucial to check generalization of the models by evaluating transfer performance on tasks with fewer samples. We validate this by performing transfer learning on 5 image classification tasks, including CIFAR10 

[21], CIFAR100 [21], Pet [28], CropDisease [24], and ChestXRay8 [42]. While the first four datasets contains natural images, ChestXRay8 consists of medical images. We finetune the whole pretrained models with 1,000 epochs, batch size 768, learning rate 0.01, SGD optimizer, weight decay 0.0001, and using the same data augmentation in training on ImageNet1K. Table 5 shows the results. While being better in ImageNet1K, our model is on par with DeiT models on all the downstream classification tasks. This result assures that our models still have good generalization ability rather than only fit to ImageNet1K.

max width= Model CIFAR10 CIFAR100 Pet CropDiseases ChestXRay8 DeiT-S [37] 99.15 90.89 94.93 99.96 55.39 DeiT-B [37] 99.10 90.80 94.39 99.96 55.77 CrossViT-15 99.00 90.77 94.55 99.97 55.89 CrossViT-18 99.11 91.36 95.07 99.97 55.94 : numbers reported in the original paper.

Table 5: Transfer learning performance. Our CrossViT models are very competitive with the recent DeiT [37] models on all the downstream classification tasks.

4.3 Ablation Studies

In this section, we first compare the different fusion approaches (Section 3.3), and then analyze the effects of different parameters of our architecture design, including the patch sizes, the channel width and depth of the small branch and number of cross-attention modules. At the end, we also validate that the proposed can cooperate with other concurrent works for better accuracy.

Comparison of Different Fusion Schemes. Table 6 shows the performance of different fusions schemes, including (I) no fusion, (II) all-attention, (III) class token fusion, (IV) pairwise fusion and (V) the proposed cross-attention fusion. Among all the compared strategies, the proposed cross-attention fusion achieves the best accuracy with minor increase in FLOPs and parameters. Surprisingly, despite the use of additional self-attention to combine information between two branches, all-attention fails to achieve better performance compared to the simple class token fusion. While the primary L-branch dominates in accuracy by diminishing the effect of complementary S-branch in other fusion strategies, both of the branches in our proposed cross-attention fusion scheme achieve certain accuracy and their ensemble becomes the best, suggesting that these two branches learn different features for different images.

max width= Top-1 FLOPs Params Single Branch Acc. (%) Fusion Acc. (%) (G) (M) L-Branch S-Branch None 80.2 5.3 23.7 80.2 0.1 All-Attention 80.0 7.6 27.7 79.9 0.5 Class Token 80.3 5.4 24.2 80.6 7.6 Pairwise 80.3 5.5 24.2 80.3 7.3 Cross-Attention 81.0 5.6 26.7 68.1 47.2

Table 6: Ablation study with different fusions on ImageNet1K. All models are based on CrossViT-S. Single branch Acc. is computed using CLS from one branch only.

Effect of Patch Sizes. We perform experiments to understand the effect of patch sizes in our CrossViT by testing two pairs of patch sizes such as (8, 16) and (12, 16), and observe that the one with (12, 16) achieves better accuracy with fewer FLOPs as shown in Table 7 (A). Intuitively, (8, 16) should get better results as patch size of 8 provides more fine-grained features; however, it is not good as (12, 16) because of the large difference in granularity between the two branches, which makes it difficult for smooth learning of the features. For the pair (8, 16), the number of patch tokens are 4 difference while the ratio of patch tokens are only 2 for the model with (12, 16).

Channel Width and Depth in S-branch. Despite our cross-attention is designed to be light-weight, we check the performance by using a more complex S-branch, as shown in Table 7 (B and C). Both models increase FLOPs and parameters without any improvement in accuracy, which we think is due to the fact that L-branch has the main role to extract features while S-branch only provides additional information; thus, a light-weight branch is enough.

Depth of Cross-Attention and Number of Multi-Scale Transformer Encoders. To increase frequency of fusion across two branches, we can either stack more cross-attention modules () or stack more multi-scale transformer encoders () (by reducing to keep the same total depth of a model). Results are shown in Table 7 (D and E). With CrossViT-S as baseline, too frequent fusion of branches does not provide any performance improvement but introduces more FLOPs and parameters. This is because patch token from the other branch is untouched, and the advantages from stacking more than one cross-attention is small as cross-attention is a linear operation without any nonlinearity function. Likewise, using more multi-scale transformer encoders also does not help in performance which is the similar case to increase the capacity of S-branch.

Cooperation with Concurrent Works. Our proposed cross-attention is also capable of cooperating with other concurrent ViT variants. We consider T2T-ViT [47] as a case study and use the T2T module to replace linear projection of patch embedding in both branches on CrossViT-18. CrossViT-18+T2T achieves an top-1 accuracy of 83.0% on ImageNet1K, additional 0.5% improvement over CrossViT-18. This shows that our proposed cross-attention is also capable of learning multi-scale features for other ViT variants.

Additional results and discussions are included in the supplementary material.

max width= Model Patch size Dimension Top-1 FLOPs Params Small Large Small Large Acc. (%) (G) (M) CrossViT-S 12 16 192 384 1 4 1 3 81.0 5.6 26.7 A 8 16 192 384 3 1 4 1 80.8 6.7 26.7 B 12 16 384 384 3 1 4 1 80.1 7.7 31.4 C 12 16 192 384 3 2 4 1 80.7 6.3 28.0 D 12 16 192 384 3 1 4 2 81.0 5.6 28.9 E 12 16 192 384 6 1 2 1 80.9 6.6 31.1

Table 7: Ablation study with different architecture parameters on ImageNet1K. The blue color indicates changes from CrossViT-S.

5 Conclusion

In this paper, we present CrossViT, a dual-branch vision transformer for learning multi-scale features, to improve the recognition accuracy for image classification. To effectively combine image patch tokens of different scales, we further develop a fusion method based on cross-attention to exchange information between two branches efficiently in linear time. With extensive experiments, we demonstrate that our proposed model performs better than or on par with several concurrent works on vision transformer, in addition to efficient CNN models. While our current work scratches the surface on multi-scale vision transformers for image classification, we anticipate that in future there will be more works in developing efficient multi-scale transformers for other vision applications, including object detection, semantic segmentation, and video action recognition.

References

  • [1] E. H. Adelson, C. H. Anderson, J. R. Bergen, P. J. Burt, and J. M. Ogden (1984) Pyramid methods in image processing. RCA engineer 29 (6), pp. 33–41. Cited by: §2.
  • [2] I. Bello, B. Zoph, A. Vaswani, J. Shlens, and Q. V. Le (2019) Attention augmented convolutional networks. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 3286–3295. Cited by: §1, §2.
  • [3] I. Bello (2021) LambdaNetworks: modeling long-range interactions without attention. In International Conference on Learning Representations, External Links: Link Cited by: §2.
  • [4] Z. Cai, Q. Fan, R. S. Feris, and N. Vasconcelos (2016) A unified multi-scale deep convolutional neural network for fast object detection. In European conference on computer vision, pp. 354–370. Cited by: §1, §2.
  • [5] C. Chen, Q. Fan, N. Mallinar, T. Sercu, and R. Feris (2018) Big-little net: an efficient multi-scale feature representation for visual and speech recognition. arXiv preprint arXiv:1807.03848. Cited by: §1, §2.
  • [6] Y. Chen, H. Fan, B. Xu, Z. Yan, Y. Kalantidis, M. Rohrbach, S. Yan, and J. Feng (2019) Drop an octave: reducing spatial redundancy in convolutional neural networks with octave convolution. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 3435–3444. Cited by: §1, §2.
  • [7] B. Cheng, B. Xiao, J. Wang, H. Shi, T. S. Huang, and L. Zhang (2020-06)

    HigherHRNet: scale-aware representation learning for bottom-up human pose estimation

    .
    In

    IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)

    ,
    Cited by: §1.
  • [8] E. D. Cubuk, B. Zoph, J. Shlens, and Q. Le (2020) RandAugment: Practical Automated Data Augmentation with a Reduced Search Space. In Advances in Neural Information Processing Systems, H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin (Eds.), pp. 18613–18624. Cited by: §4.1.
  • [9] J. Deng, W. Dong, R. Socher, L. Li, K. Li, and L. Fei-Fei (2009) Imagenet: a large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pp. 248–255. Cited by: §1, §4.1, §4.1.
  • [10] J. Devlin, M. Chang, K. Lee, and K. Toutanova (2018) Bert: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805. Cited by: §1.
  • [11] J. Devlin, M. Chang, K. Lee, and K. Toutanova (2019-06) BERT: pre-training of deep bidirectional transformers for language understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), Minneapolis, Minnesota, pp. 4171–4186. External Links: Link, Document Cited by: §3.1.
  • [12] A. Dosovitskiy, L. Beyer, A. Kolesnikov, D. Weissenborn, X. Zhai, T. Unterthiner, M. Dehghani, M. Minderer, G. Heigold, S. Gelly, J. Uszkoreit, and N. Houlsby (2021) An image is worth 16x16 words: transformers for image recognition at scale. In International Conference on Learning Representations, External Links: Link Cited by: Figure 1, 2nd item, §1, §2, §2, §3.1, §3.2, §3, §4.1, §4.2, Table 3.
  • [13] Q. Fan, C. R. Chen, H. Kuehne, M. Pistoia, and D. Cox (2019) More is less: learning efficient video representations by big-little network and depthwise temporal aggregation. In Advances in Neural Information Processing Systems, pp. 2261–2270. Cited by: §2.
  • [14] C. Feichtenhofer, H. Fan, J. Malik, and K. He (2019) Slowfast networks for video recognition. In Proceedings of the IEEE International Conference on Computer Vision, pp. 6202–6211. Cited by: §2.
  • [15] K. Han, A. Xiao, E. Wu, J. Guo, C. Xu, and Y. Wang (2021) Transformer in transformer. arXiv preprint arXiv:2103.00112. Cited by: §1, Table 3.
  • [16] K. He, X. Zhang, S. Ren, and J. Sun (2016-06) Deep Residual Learning for Image Recognition. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §1, §4.2, Table 4.
  • [17] E. Hoffer, T. Ben-Nun, I. Hubara, N. Giladi, T. Hoefler, and D. Soudry (2020-06) Augment your batch: improving generalization through instance repetition. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §4.1.
  • [18] H. Hu, Z. Zhang, Z. Xie, and S. Lin (2019) Local relation networks for image recognition. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 3464–3473. Cited by: §2.
  • [19] J. Hu, L. Shen, and G. Sun (2018) Squeeze-and-excitation networks. In 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, Vol. , pp. 7132–7141. External Links: Document Cited by: §2, §4.2, Table 4.
  • [20] A. Jaegle, F. Gimeno, A. Brock, A. Zisserman, O. Vinyals, and J. Carreira (2021) Perceiver: general perception with iterative attention. arXiv preprint arXiv:2103.03206. Cited by: §1, §2, Table 3.
  • [21] A. Krizhevsky, G. Hinton, et al. (2009) Learning multiple layers of features from tiny images. Cited by: §4.1, §4.2.
  • [22] X. Li, W. Wang, X. Hu, and J. Yang (2019-06) Selective kernel networks. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §1.
  • [23] T. Lin, P. Dollár, R. Girshick, K. He, B. Hariharan, and S. Belongie (2017) Feature pyramid networks for object detection. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2117–2125. Cited by: §1, §2.
  • [24] S. P. Mohanty, D. P. Hughes, and M. Salathé (2016)

    Using deep learning for image-based plant disease detection

    .
    Frontiers in plant science 7, pp. 1419. Cited by: §4.2.
  • [25] S. Nah, T. Hyun Kim, and K. Mu Lee (2017-07) Deep multi-scale convolutional neural network for dynamic scene deblurring. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §1.
  • [26] A. Newell, K. Yang, and J. Deng (2016)

    Stacked Hourglass Networks for Human Pose Estimation

    .
    In Computer Vision – ECCV 2016, B. Leibe, J. Matas, N. Sebe, and M. Welling (Eds.), Cham, pp. 483–499. Cited by: §1.
  • [27] A. Newell, K. Yang, and J. Deng (2016) Stacked hourglass networks for human pose estimation. In European conference on computer vision, pp. 483–499. Cited by: §2.
  • [28] O. M. Parkhi, A. Vedaldi, A. Zisserman, and C. V. Jawahar (2012) Cats and dogs. In IEEE Conference on Computer Vision and Pattern Recognition, Cited by: §4.2.
  • [29] M. Pedersoli, A. Vedaldi, J. Gonzalez, and X. Roca (2015) A coarse-to-fine approach for fast deformable object detection. Pattern Recognition 48 (5), pp. 1844–1853. Cited by: §2.
  • [30] P. Perona and J. Malik (1990) Scale-space and edge detection using anisotropic diffusion. IEEE Transactions on pattern analysis and machine intelligence 12 (7), pp. 629–639. Cited by: §2.
  • [31] I. Radosavovic, R. P. Kosaraju, R. Girshick, K. He, and P. Dollar (2020-06) Designing network design spaces. In IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §4.2, Table 4.
  • [32] P. Ramachandran, N. Parmar, A. Vaswani, I. Bello, A. Levskaya, and J. Shlens (2019) Stand-Alone Self-Attention in Vision Models. In Advances in Neural Information Processing Systems, H. Wallach, H. Larochelle, A. Beygelzimer, F. d. A. e. Buc, E. Fox, and R. Garnett (Eds.), Cited by: §1, §2.
  • [33] P. Ramachandran, N. Parmar, A. Vaswani, I. Bello, A. Levskaya, and J. Shlens (2019) Stand-alone self-attention in vision models. arXiv preprint arXiv:1906.05909. Cited by: §2.
  • [34] A. Srinivas, T. Lin, N. Parmar, J. Shlens, P. Abbeel, and A. Vaswani (2021) Bottleneck transformers for visual recognition. arXiv preprint arXiv:2101.11605. Cited by: §1, §2.
  • [35] C. Sun, A. Shrivastava, S. Singh, and A. Gupta (2017) Revisiting unreasonable effectiveness of data in deep learning era. In 2017 IEEE International Conference on Computer Vision (ICCV), Vol. , pp. 843–852. External Links: Document Cited by: §1, §4.1.
  • [36] M. Tan and Q. Le (2019-06) EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. In

    Proceedings of the 36th International Conference on Machine Learning

    , K. Chaudhuri and R. Salakhutdinov (Eds.),
    Long Beach, California, USA, pp. 6105–6114. Cited by: 2nd item, §1, §4.1, §4.2, Table 4.
  • [37] H. Touvron, M. Cord, M. Douze, F. Massa, A. Sablayrolles, and H. Jégou (2020) Training data-efficient image transformers & distillation through attention. arXiv preprint arXiv:2012.12877. Cited by: Figure 1, §1, §1, §2, §4.1, §4.1, §4.2, Table 3, Table 5.
  • [38] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, u. Kaiser, and I. Polosukhin (2017) Attention is All you Need. In Advances in Neural Information Processing Systems, I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett (Eds.), Cited by: §1, §2.
  • [39] Q. Wang, B. Wu, P. Zhu, P. Li, W. Zuo, and Q. Hu (2020) ECA-net: efficient channel attention for deep convolutional neural networks. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §2, §4.2, Table 4.
  • [40] W. Wang, E. Xie, X. Li, D. Fan, K. Song, D. Liang, T. Lu, P. Luo, and L. Shao (2021) Pyramid vision transformer: a versatile backbone for dense prediction without convolutions. External Links: 2102.12122 Cited by: §1, §2, §4.2, Table 3.
  • [41] X. Wang, R. Girshick, A. Gupta, and K. He (2018) Non-local neural networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 7794–7803. Cited by: §2.
  • [42] X. Wang, Y. Peng, L. Lu, Z. Lu, M. Bagheri, and R. M. Summers (2017) Chestx-ray8: hospital-scale chest x-ray database and benchmarks on weakly-supervised classification and localization of common thorax diseases. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2097–2106. Cited by: §4.2.
  • [43] S. Woo, J. Park, J. Lee, and I. S. Kweon (2018-09) CBAM: convolutional block attention module. In Proceedings of the European Conference on Computer Vision (ECCV), Cited by: §2.
  • [44] L. Wu, X. Liu, and Q. Liu (2021) Centroid transformers: learning to abstract with attention. arXiv preprint arXiv:2102.08606. Cited by: §2, Table 3.
  • [45] S. Xie, R. Girshick, P. Dollár, Z. Tu, and K. He (2017-07) Aggregated Residual Transformations for Deep Neural Networks. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §4.2, Table 4.
  • [46] S. Yang and D. Ramanan (2015) Multi-scale recognition with dag-cnns. In Proceedings of the IEEE international conference on computer vision, pp. 1215–1223. Cited by: §2.
  • [47] L. Yuan, Y. Chen, T. Wang, W. Yu, Y. Shi, F. E. Tay, J. Feng, and S. Yan (2021) Tokens-to-token vit: training vision transformers from scratch on imagenet. External Links: 2101.11986 Cited by: Figure 1, §1, §2, §4.2, §4.2, §4.3, Table 3.
  • [48] S. Yun, D. Han, S. J. Oh, S. Chun, J. Choe, and Y. Yoo (2019-10)

    CutMix: Regularization Strategy to Train Strong Classifiers With Localizable Features

    .
    In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), Cited by: §4.1.
  • [49] H. Zhang, M. Cisse, Y. N. Dauphin, and D. Lopez-Paz (2018) Mixup: beyond empirical risk minimization. In International Conference on Learning Representations, External Links: Link Cited by: §4.1.
  • [50] H. Zhao, J. Jia, and V. Koltun (2020-06) Exploring self-attention for image recognition. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §1, §2.
  • [51] B. Zoph, V. Vasudevan, J. Shlens, and Q. V. Le (2018-06) Learning transferable architectures for scalable image recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §4.2.

Appendix A More Comparisons and Analysis

To further check the advantages of the proposed CrossViT, we trained the models whose architecture are identical to the L-branch (primary) of our models. E.g., DeiT-9 is the baseline for CrossViT-9. As shown in Table 8, the proposed cross-attention fusion consistently improves the baseline vision transformers regardless of their primary branches and patch embeddings, suggesting that the proposed multi-scale fusion is effective for different vision transformers.

max width= Model Top-1 Acc. (%) FLOPs (G) Params (M) DeiT-9 72.9 1.4 6.4 CrossViT-9 73.9 1.8 8.6 DeiT-9 75.6 1.5 6.6 CrossViT-9 77.1 2.0 8.8 DeiT-15 80.8 4.9 22.9 CrossViT-15 81.5 5.8 27.4 DeiT-15 81.7 5.1 23.5 CrossViT-15 82.3 6.1 28.2 DeiT-18 81.4 7.8 37.1 CrossViT-18 82.5 9.0 43.3 DeiT-18 81.2 8.1 37.9 CrossViT-18 82.8 9.5 44.3

Table 8: Comparisons with various baselines on ImageNet1K. See Table 1 of the main paper for model details. denotes the models using three convolutional layers for patch embedding instead of linear projection.

max width= Main Results Transfer Batch size 4,096 768 Epochs 300 1,000 Optimizer AdamW SGD Weight Decay 0.05 1e-4 Linear-rate Scheduler Cosine (0.004) Cosine (0.01) (Initial LR) Warmup Epochs 30 5 Warmup linear-rate Linear (1e-6) Scheduler (Initial LR) Data Aug. RandAugment (m=9, n=2) Mixup () 0.8 CutMix () 1.0 Random Erasing 0.25 0.0 Instance 3 Repetition Drop-path 0.1 0.0 Label Smoothing 0.1 : only used for CrossViT-18.

Table 9: Details of training settings.
Figure 5: Feature visualization of CrossViT-S. Features of patch tokens of both branches from the last multi-scale transformer encoder are shown. (36 random channels are selected.)

Figure 5 visualizes the features of both branches from the last multi-scale transformer encoder of CrossViT. The proposed cross-attention learns different features in both branches, where the small branch generates more low-level features because there are only three transformer encoders while the features of the large branch are more abstract. Both branches complement each other and hence the ensemble results are better.