Visual Transformers: Token-based Image Representation and Processing for Computer Vision

06/05/2020 ∙ by Bichen Wu, et al. ∙ Facebook berkeley college 23

Computer vision has achieved great success using standardized image representations – pixel arrays, and the corresponding deep learning operators – convolutions. In this work, we challenge this paradigm: we instead (a) represent images as a set of visual tokens and (b) apply visual transformers to find relationships between visual semantic concepts. Given an input image, we dynamically extract a set of visual tokens from the image to obtain a compact representation for high-level semantics. We then use visual transformers to operate over the visual tokens to densely model relationships between them. We find that this paradigm of token-based image representation and processing drastically outperforms its convolutional counterparts on image classification and semantic segmentation. To demonstrate the power of this approach on ImageNet classification, we use ResNet as a convenient baseline and use visual transformers to replace the last stage of convolutions. This reduces the stage's MACs by up to 6.9x, while attaining up to 4.53 points higher top-1 accuracy. For semantic segmentation, we use a visual-transformer-based FPN (VT-FPN) module to replace a convolution-based FPN, saving 6.5x fewer MACs while achieving up to 0.35 points higher mIoU on LIP and COCO-stuff.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 8

page 13

page 14

page 16

Code Repositories

VisualTransformers

A Pytorch Implementation of the following paper "Visual Transformers: Token-based Image Representation and Processing for Computer Vision"


view repo

YOGO

Official implementation of YOGO for point-cloud processing


view repo

amazon-sagemaker-visual-transformer

Implementation of Image Classification using Visual Transformers in Amazon SageMaker based on the ideas from research paper - Visual Transformers: Token-based Image Representation and Processing for Computer Vision.


view repo

Transformer-Architectures-From-Scratch

Implementation of transformers based architecture in PyTorch.


view repo

Visual-Transformers

Unofficial implimentation of Visual Transformers: Token-based Image Representation and Processing for Computer Vision


view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In modern digital imaging and computer vision, visual information is captured and processed as arrays of pixels: a 2D image is represented as a 3D array. The de facto deep learning operator for computer vision, convolutional filters, accepts and outputs 3D pixel arrays. Although convolving over pixel arrays has achieved great success, there are critical limitations:

1) Not all pixels are created equal

: For the target task, each pixel has varying importance. For example, image classification models should prioritize foreground objects over background. Additionally, semantic segmentation models should prioritize miniscule pedestrians over disproportionately large swaths of sky, road, vegetation etc. Despite this, convolutions uniformly distribute computation over all pixels, regardless of importance. This leads to redundancy in both computation and representation.

2) Convolutions struggle with capturing long-range interactions: Each convolutional filter is constrained to operate on a small (e.g., ) region due to computational cost quadratic in kernel size. However, long-range interactions between semantic concepts is vital, as evidenced by large numbers of context-driven work using operations like dilations and global average pooling. However, these are only temporary fixes for the underlying pixel-convolution issue.

3) Convolutions are not efficient for sparse, high-level semantic concepts

: Low-level visual features such as corners and edges are densely distributed across images; as a result, convolutions are apt for image processing early in a neural network. However, as the receptive field increases deeper in the network, the number of potential patterns grows exponentially large. As a result, semantic concepts become increasingly sparse: each concept appears in small regions of a few images. Later convolutions become correspondingly inefficient.

We propose a new paradigm that overcomes the above limitations and achieves better performance (accuracy) with lower computational cost on vision tasks. More specifically, our intuition is to replace pixel arrays with language-like descriptors: A sentence with several words (i.e., tokens) suffices to describe an image. Unlike pixel arrays, such token-based representations are compact. We thus propose visual transformers to: (a) re-represent images with visual tokens and (b) process tokens with transformers (i.e., self-attention). The visual transformer is illustrated in Figure 1

: We first convolve over the input image, then re-represent the feature map as a small number of visual, semantic tokens. We then apply a cascaded set of transformers, a self-attention module widely used in natural language processing

vaswani2017attention to process visual tokens, capture interactions, and compute output tokens. These tokens are directly used for image-level prediction tasks such as image classification. For pixel-level prediction tasks like semantic segmentation, we project the output tokens back to the feature map and use the aggregated feature map.

To demonstrate this idea we use visual transformers to replace convolution-based modules in baseline vision models. Experiments show that using visual transformers achieve higher accuracy with lower computational cost on classification tasks on ImageNet deng2009imagenet and show similar advantages on the semantic segmenation task on COCO-Stuff caesar2018cvpr and Look-Into-Person liang2018look datasets. On ImageNet, we use ResNethe2016deep as a convenient baseline to enable comparison and replace the last stage with visual transformers, saving up to 6.9x MACs for the stage and achieving up to 4.53 points higher top-1 accuracy. ResNet34 using visual transformers can outperform ResNet101 using 2.5x fewer MACs. For semantic segmentation, we use visual-transformers to re-design the feature-pyramid networks (FPN) and save 6.4x MACs while achieving higher mIoU on LIP and COCO-Stuff datasets.

2 Related work

Attention in vision models: Attention is an effective mechanism recently added to the deep learning toolbox. It has been widely used in natural language processing vaswani2017attention ; devlin2018bert and is gaining popularity in computer vision hu2018squeeze ; hu2018gather ; woo2018cbam ; wu2018squeezesegv2 ; xu2020squeezesegv3 ; wang2018non ; parmar2018image ; hu2018relation ; hu2019local ; bello2019attention ; zhao2020exploring ; ramachandran2019stand . Attention was first used to modulate the feature map: some attention values are computed from the input and multiplied to the feature map as in woo2018cbam ; hu2018squeeze ; hu2018gather ; wu2018squeezesegv2 . Later work xu2020squeezesegv3 ; su2019pixel ; wang2018depth interpret this “modulation” as a way to make convolution spatially adaptive and content-aware. In wang2018non , Wang et al. introduced self-attention, also called non-local operators, to video understanding and use it to capture the long-range interactions. However, self-attention in computer vision is computationally expensive since the complexity grows quadratically with the number of pixels. bello2019attention use self-attention to augment convolutions and and reduce the compute cost by using small channel sizes for attention. ramachandran2019stand ; parmar2018image ; cordonnier2019relationship ; zhao2020exploring ; hu2019local on the other hand restrict receptive field of self-attention and use it in a convolutional manner. Since ramachandran2019stand , self-attentions are used as a stand-alone building block for vision models. Our work is different from all above since we are the first to challenge the paradigm of pixel array-based image representation.

Efficient vision models: Many recent research efforts have been focusing on building vision models to achieve better performance using lower computational cost. Early work iandola2016squeezenet ; gholami2018squeezenext achieves this by carefully choosing convolutional layers with smaller kernel and channel size, the following work howard2017mobilenets ; sandler2018mobilenetv2 ; howard2019searching ; zhang2018shufflenet ; ma2018shufflenet ; wu2018shift explores new variations of convolution operators. Recently, people begin to use neural architecture search wu2019fbnet ; dai2019chamnet ; wan2020fbnetv2 ; tan2019efficientnet ; tan2019mnasnet to optimize network’s performance and efficiency within a predefined search space that consists of existing convolution-based operators. After years’ of progress, it is becoming increasingly hard to further improve vision models under the current paradigm. Our work seek a new angle to this problem and aim to eliminate the redundancy in the vision representation and processing based on pixel-arrays and convolutions.

3 Visual transformer

In this section, we introduce the details of the visual transformer. The overall diagram of a visual transformer is illustrated in Figure 1. For a given image, we first use convolutions to extract a feature map representation, since convolutions are very effective processing low-level visual features. We feed the feature maps into stacked visual transformers. Each visual transformer consists of three major components: a tokenizer, a transformer, and a projector. The tokenizer extracts a small number of visual tokens from the feature map, the transformer captures the interaction between the visual tokens and computes the output tokens, and the projector fuses the output tokens back to the feature map. Through out the visual transformers, we keep both visual tokens and feature maps as output since visual tokens capture high-level semantics in the image while the feature map preserves the pixel-level details. Since the computations are performed only on a small number of visual tokens, the computational cost for the visual transformer is very low. At the end of the network, we use visual tokens for image-level prediction tasks, such as image classification, and use the accumulated feature map for pixel-level prediction tasks, such as semantic segmentation. The pseudo-code implementation for key components of visual transformer can be found in Appendix A.

Figure 1: Diagram of a visual transformer.

3.1 Extracting visual tokens

3.1.1 Static tokenization

First we discuss how do we extract visual tokens from a feature map, as illustrated in Figure 1(a). Each visual token represents a semantic concept in the image, and we compute visual tokens by spatially aggregating the feature map. Formally, a feature map can be represented by , where are the height and width of the feature map, is the channel size. For notation convenience, we denote as a reshaped matrix of by merging the two spatial dimensions into one. Visual tokens can be represented by where is the number of visual tokens, . is the channel size of a visual token. We compute visual tokens as

(1)

is a learnable weight to convert the feature map into . This step can be implemented as a point-wise 2D convolution. Visual tokens can be seen as a weighted average of the feature map. is a normalized token coefficient, and the value of determines the contribution of the -th pixel to the -th token . is the weight we use to compute . It can be seen as a convolutional filter that divides the feature map into various regions that corresponds to different semantic concepts. This is illustrated in Figure 1(a). The visualization of the token coefficients on real images can be found in Figure 5.

3.1.2 Position encoding

Visual tokens are computed as a weighted average from the original feature map, but equation (1) does not encode the contribution from individual pixels so the position of each visual token is lost. Intuitively, position information is important to understand the image semantics. This information is recorded in the token coefficient . To encode positions, we compute

(2)

We first downsample spatially to the size of , and multiply it with a learnable weight to compress the position information in

to a smaller encoding vector with dimension

. Then, we augment the visual tokens by concatenating position encoding along the channel dimension to include the positional information. We can optionally use another 1D convolution to adjust the channel size of the concatenated visual tokens.

(a) Diagram of a tokenizer. It uses a static or a dynamic weight to group the pixels into tokens with different semantic concepts. Position encodings are computed from the token coefficients.
(b) Diagram of a projector that fuses visual tokens back to the feature map. Through attention, each pixel decides how to combine information from visual tokens.
Figure 2: Diagram of a tokenizer and a projector.

3.2 Transformer

Visual tokens extracted from the feature map do not have intrinsic orders between them, so unlike pixel-arrays, we cannot use convolutions (except point-wise convolutions) to process them. So instead we use transformers, which can process sets of elements. Formally, a transformer can be described as

(3)

where is the input and output tokens. are learnable weights used to compute keys, queries, values, and output tokens. The computational cost for equation (3) is , where is to compute keys, queries, values, output tokens, is for multiplying keys, queries, and values.

Note that transformer (self-attention) has been used in computer vision and is also referred as the “non-local” operation. A non-local operation can also be described by equation (3). However, for non-local operations, instead of operating on visual tokens, they directly operate on the entire feature map . As a consequence, the computational cost under the same configuration is . To provide a concrete example, assume . A non-local operator will require 40.3 million MACs. In our experiment, a visual transformer with tokens are sufficient to achieve better results, but the computational cost for the transformer is 6.3 million, or 6.4x smaller than the non-local operation on the pixel array.

3.3 Fusing visual tokens with the feature map

Many vision tasks require pixel-level details, but such details are not preserved in the visual tokens. Therefore, we fuse the output of transformer back to the feature map to augment the pixel-level representation. To achieve this, we compute

(4)

Here is a learnable weight used to compute queries from the input feature map. The queries decide the information a pixel needs from the visual tokens. is a learnable weight to compute keys from the visual tokens, an the keys encode the information each visual token contains. is a learnable weight to compute the values from the visual token that will be fused with the feature map. The intuition is illustrated in Figure 1(b).

3.4 Dynamic tokenization

Same as convolution, we cascade layers of visual transformers together so each layer takes the previous layer’s visual tokens and feature maps as input. Within each layer, in order to capture more semantic concepts from the feature map, we use equation (1) to extract new visual tokens and combine with the existing ones. However, instead of using as a static weight, we dynamically compute from the input visual tokens. With a static weight, we extract a fixed set of semantic concepts regardless of the input image. However, high-level semantics in images are very sparse and diverse, exceeding the capacity of a small set of fixed filters. So instead, we compute based on the previous visual tokens. With the input token , we first split it along the token dimension to obtain , and we update the visual tokens as

(5)

is a function defined by equation (1) that extract visual tokens from the feature map using a point-wise 2D convolutional filter . is a learnable weight that generates the filter based on the input visual token . Then, we concatenate the new tokens with to get and feed it to the transformer, as shown in Figure 1.

4 Using visual transformers to build vision models

Now that we have discussed the basics of visual transformers, we discuss how to use visual transformers as building blocks for vision models. We use three hyper-parameters to control a block of visual transformer: channel size of the output feature map , channel size of visual token , and the number of visual tokens . In a series of cascaded visual transformers, the first block accepts a feature map as input. If the channel size of the input feature map is not the same as visual transformer, we simply use a point-wise 2D convolution on the input feature map to adjust the channel size. For the first visual transformer block, we use static tokenization and in later blocks, we use dynamic tokenization. If the visual token’s channel size is not the same as the input feature’s channel size , we simply apply a point-wise 1D convolution to adjust the channel size. Following previous work vaswani2017attention ; ramachandran2019stand , we use multi-head attention in the tokenizer, transformer, and projector. To further reduce the computational complexity, we use group convolutions to reduce the MACs and parameter sizes. To provide more details, We provide pseudo code for the visual transformer Appendix A

. We will also open-source the model after the paper publication.

4.1 Visual transformers for image classification

For image classification, we build our networks with backbones inherited from ResNet he2016deep , since its simple and elegant architecture makes it an ideal baseline to demonstrate new ideas. Based on ResNet-{18, 34, 50, 101}, we build corresponding visual-transformer-ResNets (VT-ResNet). For all the models, we keep stage-1 through stage-4 from the original ResNet the same and use them to extract feature maps. We replace the last stage of the network with visual transformer blocks. On the ImageNet dataset, all the models accept input images with a resolution of 224224 so the output of stage-4 is a feature map with a resolution of 1414. Visual transformers take this as the input, and extract L=8 visual tokens with channel size . We set the output feature map’s channel size as 512 for VT-ResNet-{18, 34} and 1024 for VT-ResNet-{50, 101}. At the end of the network, we output 8 visual tokens and feed it to the classification head. A table summarizing the stage-wise description of the model is provided in Appendix B.

4.2 Visual transformers for semantic segmentation

We hypothesize that using visual transformers as a building block for semantic segmentation can offer several challenges over convolutions. First, the computational complexity of convolution grows quadratically with the image resolution. Second, convolutions struggles to capture long-term interactions between pixels. Visual transformers, on the other hand, operate on a small number of visual tokens regardless of the image resolution, and since it computes interactions between semantic concepts, it bypasses the “long-range” challenge with pixel-arrays.

Figure 3: Feature Pyramid Networks (FPN) (left) vs visual-transformer-FPN (VT-FPN) (right) for semantic segmentation.

To verify this, we use panoptic feature pyramid networks (FPN) kirillov2019panoptic as a baseline and use visual transformers to improve the network. Panoptic FPN offers a simple, strong, and open-source baseline to demonstrate new ideas. Panoptic FPNs use ResNet as backbones to extract feature maps from different stages with various resolutions. These feature maps are then fused by a feature pyramid network in a top-down manner to generate a multi-scale and detail preserving feature map with rich semantics for segmentation. This is illustrated in Figure 3. FPN is computationally expensive since it heavily relies on spatial convolutions operating on high resolution feature maps with large channel sizes. Therefore, we use visual transformers to replace the convolutions in FPN. We name the new module as VT-FPN, which is shown in Figure 3

. From each feature map, VT-FPN extract L=8 visual tokens with a channel size CT=1024. The visual tokens are combined and fed into one transformer to compute to capture interactions between visual tokens from different resolutions. The output tokens are then projected back to the original feature maps, which are then used to perform pixel-level prediction. Compared with the original FPN based on convolution, the computational cost for VT-FPN is much smaller since we only operate on a very small number of visual tokens rather than all the pixels. Instead of fusing feature maps pixels by pixels via convolution and interpolation, we use a transformer to fuse feature maps globally. Our experiment shows VT-FPN uses 6.5x fewer MACs than FPN while preserving or surpassing its performance.

5 Experiments

5.1 ImageNet classification

For image classification, we conduct experiments on the ImageNet dataset deng2009imagenet

. We implement VT-ResNets in Pytorch and train them on the ImageNet training set with around 1.3 million images and evaluate on the validation set. We train the models for 400 epochs with RMSProp. We use an initial learning rate of 0.01 and increase it to 0.16 in 5 warmup epochs, then reduce the learning rate by a factor of 0.9875 per epoch. We use synchronized batch normalization and distributed training with a batch size of 2048. We use label smoothing and AutoAugment

cubuk2019autoaugment

in the training. Additionally, we set the stochastic depth survival probability 

huang2016deep and dropout ratio to be 0.9 and 0.2, respectively. We use exponential moving average (EMA) on the model weights with a decay of 0.99985.

Our results are reported in Table 1 and Figure 4. The baseline ResNet models are based on the references from Pytorch: https://pytorch.org/hub/pytorch_vision_resnet/. Compared with the baseline ResNet models, VT-ResNet models achieved up to 4.84 points higher accuracy with fewer MACs and parameters. Note that the parameter and MAC reduction is the result of only replacing one stage of the baseline model. If we consider the reduction over the original stage, the MAC reduction is up to 6.9x, as shown in Table 2. It is our future work to study how to replace more stages of convolutions. We compare our results with other previous methods hu2018squeeze ; woo2018cbam ; bello2019attention ; chen20182 ; hu2019local ; ramachandran2019stand ; zhao2020exploring that adopt attention in different ways on ResNets. Figure 4 shows our models significantly outperform all previous ones with higher accuracy and lower MACs and parameter sizes.

Figure 4: Accuracy vs. MACs and parameters for visual transformer and related work on ImageNet.
Models
Top-1
Acc (%)
Top-1
Acc (%)
MACs
(G)
MACs
(M)
Params
(M)
Params
(M)
R-18 he2016deep 69.76 - 1.814 - 11.67 -
R-18 + SE hu2018squeeze ; woo2018cbam 70.59 +0.83 1.814 +0 11.78 +0.11
R-18 + CBAM woo2018cbam 70.73 +0.97 1.815 +1 11.78 +0.11
LR-R-18 hu2019local 74.6 +4.84 2.5 +686 14.4 +2.62
VT-R-18 (ours) 74.29 +4.53 1.587 -227 12.76 +1.09
R-34 he2016deep 73.32 3.664 - 21.78 -
R-34 + SE hu2018squeeze ; woo2018cbam 73.87 +0.55 3.664 0 21.96 +0.18
R-34 + CBAM woo2018cbam 74.01 +0.69 3.664 0 21.96 +0.18
AA-R-34 bello2019attention 74.7 +1.38 3.55 -224 20.7 -1.08
VT-R-34 (ours) 77.46 +4.14 3.151 -513 18.02 -3.76
R-50 he2016deep 76.0 - 4.089 - 25.5 -
R-50 + SE hu2018squeeze ; woo2018cbam 76.86 +0.86 3.860* +2* 28.1 +2.6
R-50 + CBAM woo2018cbam 77.34 +1.34 3.864* +6* 28.1 +2.6
LR-R-50 hu2019local 77.3 +1.3 4.3 +211 23.3 -2.2
Stand-Alone ramachandran2019stand 77.6 +1.6 3.6 -489 18.0 -7.5
AA-R-50 bello2019attention 77.7 +1.7 4.1 +11 25.6 +0.1
-R-50 chen20182 77.0 +1.0 - - - -
SAN19 zhao2020exploring 78.2 +2.2 3.3 -789 20.5 -5
VT-R-50 (ours) 78.88 +2.88 3.412 -677 21.4 -4.1
R-101 he2016deep 77.39 - 7.802 - 44.44 -
R-101 + SE hu2018squeeze ; woo2018cbam 77.65 +0.26 7.575* +5* 49.33 +4.89
R-101 + CBAM woo2018cbam 78.49 +1.10 7.581* +11* 49.33 +4.89
LR-R-101 hu2019local 78.5 +1.11 7.79 -12 42.0 -2.44
AA-R-101 bello2019attention 78.7 +1.31 8.05 +248 45.4 +0.96
VT-R-101 (ours) 81.09 +3.70 7.127 -675 40.33 -4.11
Table 1: Experiments of visual-transformer-based ResNets on the ImageNet dataset. R-18 denotes ResNet-18. VT-R-18 denotes visual-transformer-ResNet-18. *The baseline MACs reported in woo2018cbam is lower than this work. MACs calculated based on the baseline in woo2018cbam .
R18 R34 R50 R101
MACs Total 1.14x 1.16x 1.20x 1.09x
Stage-5 2.4x 5.0x 6.1x 6.9x
Params Total 0.91x 1.21x 1.19x 1.19x
Stage-5 0.9x 1.5x 1.26x 1.26x
Table 2: MAC and parameter size reduction of visual transformers on ResNets.

5.2 Ablation study

We introduce many new components in a visual transformer. We conduct ablation studies to examine the role of several key components. To accelerate the experiment, we only train VT-ResNet-{18, 34} for 90 epochs. We verify the role of the following components in a visual transformer. 1) Transformer: we use transformers to capture interaction between visual tokens, which we hypothesize is important for understanding the image semantics. To verify this, we replace the transformer with a point-wise 1D convolution. It shows the accuracy drops by 1.96% and 0.32% without transformers. 2) We hypothesize dynamic tokenization can better deal with the sparse semantic concepts. To verify this, we replaced all the dynamic tokenization with static ones, and our experiment shows the accuracy drops by 1.19% and 1.25%. 3) We hypothesize that position encoding is important for understanding semantic concepts. After removing the position encoding, the accuracy drops by 0.02% and 0.15%.

Model Ablation Top-1 Acc (%) MACs (G) Params (M)
VT-R-18 Base 71.45 1.587 12.76
- TF 69.49 (-1.96) 1.519 8.57
- PE 71.43 (-0.02) 1.586 12.69
- DT 70.26 (-1.19) 1.608 12.22
VT-R-34 Base 74.65 3.217 18.02
- TF 74.33 (-0.32) 3.204 17.24
- PE 74.50 (-0.15) 3.216 17.90
- DT 73.40 (-1.25) 3.246 16.92
Table 3: Ablation studies on VT-ResNet. “-TF” denotes removing transformers,“-PE”: removing position encoding, “-DT”: removing dynamic tokenization, “Base”: baseline.

5.3 Visual transformers for semantic segmentation

We conduct experiments to test the effectiveness of visual transformer for semantic segmentation on the COCO-stuff dataset caesar2018cvpr and the LIP dataset liang2018look . COCO-stuff dataset contains annotations for 91 stuff classes with 118K training images and 5K validation images. LIP dataset is a collection of images containing human appearing with challenging poses and views. For COCO-stuff dataset, we train a VT-FPN model with ResNet-{50, 101} backbones. Our implementation is based on Detectron2 wu2019detectron2 . Our training recipe is based on the semantic segmentation FPN recipe with 1x training steps, except that we use synchronized batch normalization in the VT-FPN, change the batch size to 32, and use a base learning rate of 0.04. For the LIP dataset, we also use synchronized batch normalization with a batch size of 96. We optimize the model via SGD with weight decay of 0.0005 and learning rate of 0.01. Our models and the training recipes will be open-sourced after the paper publication.

As we can see in Table 4, after replacing FPN with VT-FPN, both ResNet-50 and ResNet-101 based models achieve slightly higher mIoU, but VT-FPN requires 6.5x fewer MACs than FPNs, demonstrating the effectiveness of visual transformer.

mIoU (%) Total MACs (G) FPN MACs (G) Dataset
R-50 FPN 40.78 159 55.1 COCO-stuff
VT-FPN 41.00 113 (1.41x) 8.5 (6.48x)
FPN 47.04 37.1 12.8 LIP
VT-FPN 47.39 26.4 (1.41x) 2.0 (6.40x)
R-101 FPN 41.51 231 55.1 COCO-stuff
VT-FPN 41.50 185 (1.25x) 8.5 (6.48x)
FPN 47.35 54.4 12.8 LIP
VT-FPN 47.58 43.6 (1.25x) 2.0 (6.40x)
Table 4: Semantic segmentation results on the COCO-stuff validation set and the LIP validation set. The MACs for COCO-stuff is calculated with a typical input resolution of 8001216 and the MACs for LIP is calculated with an input resolution of 473473.

5.4 Visualizations of visual tokens

We hypothesize that visual tokens extracted in the visual transformer correspond to different high-level semantics in the image. To better understand this, we visualize the token coefficients , where each is an attention map to show how each part of the image contributes to token-. We plot the attention map in Figure 5, and we can see that without any supervision, different visual tokens are attending to different semantic concepts in the image corresponding to parts of the background or foreground objects. More visualization results are provided in Appendix C.

Figure 5: Token coefficient visualization from the LIP dataset. Red pixels have larger token coefficients and blue ones have smaller. Although without supervision, different visual tokens selectively focus on different regions such as snow grounds, forest, sheep, and person.

6 Conclusions

In this paper, we propose visual transformers to challenge the existing computer vision paradigm that represents and processes images with pixel arrays and convolutions. Visual transformers overcome the limitations of the existing paradigm by representing an image with a compact set of dynamically extracted visual tokens and densely modeling their interactions with transformers. We use visual transformers to replace convolution-based modules in vision models and we observe significant accuracy and efficiency improvement in image classification and semantic segmentation. For example, ResNet34 using visual transformers can outperform ResNet101 while using 2.5x fewer MACs, visual transformer-based FPN uses 6.5x fewer MACs than regular FPNs while achieving similar accuracy.

7 Broader impact discussion

Our primary goal for this work is to advance the scientific understanding of vision systems. The scientific questions we hope to answer through this paper include 1) How to build vision systems to efficiently represent and process semantic concepts; 2) How to bridge the low-level perception to high-level reasoning vision systems. We answer these questions by building and experimenting a new vision paradigm and testing it on several vision tasks. Our experiment show that our proposed paradigm can achieve significantly better performance on several existing vision tasks while being more computationally efficient. This provides us with promising results to help us better answer the scientific questions above. In addition, the practical impact of our work is that it can help us build better and more efficient vision systems for many applications including AR/VR, autonomous driving, medical imaging, robotics, and so on. Our method is generic and we sincerely hope it will be used for purposes that are beneficial to our society. However, like any many other technologies, there are risks that the method proposed in this paper can be used for no-so-beneficial purposes. Given the generic nature of this paper, it is hard for us to predict specific bad applications using this method. However, as researchers, we hope that by publishing this paper so everyone can access it, we can increase the chance that the research community as a whole will collectively supervise its applications and make sure it is used for the common good.

References

Appendix A Pseudo-code implementation of visual transformers

In this section, we provide the pseudo-code for key components of visual transformer. Note that the pseudo-code is just to help readers understand visual transformer. To make this easier to understand, we remove many important implementation details, such as normalization and initialization, so this is not a rigorous implementation that guarantees reproducibility. We will open-source our model and code implementation afterwards.

The pseudo-code for the static and dynamic tokenizer described in Section 3.1.1 and 3.4.

1class Tokenizer(nn.Module):
2    def __init__(self, L, CT, C, head=16, groups=16, dynamic=False):
3        super(Tokenizer, self).__init__()
4        if not dynamic
5            # use static weights to compute token coefficients.
6            self.conv_token_coef = conv1x1_2d(C, L)
7        else:
8            # use previous tokens to compute a query weight, which is
9            # then used to compute token coefficients.
10            self.conv_query = conv1x1_1d(CT, C)
11            self.conv_key = conv1x1_2d(C, C, groups=groups)
12        self.conv_value = conv1x1_2d(C, C, groups=groups)
13        self.pos_encoding = PosEncoder()
14        self.conv_token = conv1x1_1d(
15            C + self.pos_encoding.pos_dim, CT)
16        self.head = head
17        self.dynamic = dynamic
18
19    def forward(self, feature, token):
20        # compute token coefficients
21        #feature: N, C, H, W, token: N, CT, L
22        if not self.dynamic
23            token_coef = self.conv_token_coef(feature)
24            N, L, H, W = token_coef.shape
25            token_coef = token_coef.view(N, 1, L, H * W)
26            token_coef = token_coef.permute(0, 1, 3, 2)  # N, 1, HW, L
27            token_coef = token_coef / np.sqrt(feature.shape[1])
28        else:
29            L = token.shape[2]
30            # Split input tokens
31            # T_a, T_b: N, CT, L//2
32            T_a, T_b = tokens[:, :, :L // 2], tokens[:, :, L // 2:]
33            query = self.conv_query(T_a)
34            N, C, L_a = query.shape
35            # N, h, C//h, L_a
36            query = query.view(N, self.head, C // self.head, L_a)
37            N, C, H, W = feature.shape
38            key = self.conv_key(feature).view(
39                N, self.head, C // self.head, H * W) # N, h, C//h, HW
40            # Compute token coefficients.
41            # N, h, HW, L_a
42            token_coef = torch.matmul(key.permute(0, 1, 3, 2), query)
43            token_coef = token_coef / np.sqrt(C / self.head)
44        N, C, H, W = feature.shape
45        token_coef = F.softmax(token_coef, dim=2)
46        value = self.conv_value(feature).view(
47            N, self.head, C // self.head, H * W) # N, h, C//h, HW
48        # extract tokens from the feature map
49        # static tokens: N, C, L. dynamic tokens: N, C, L_a
50        tokens = torch.matmul(value, token_coef).view(N, C, -1)
51        # compute position encoding
52        # if static: pos_encoding: N, Cp, L  else: N, Cp, L_a
53        pos_encoding = self.pos_encoding(token_coef, (H, W))
54        tokens = torch.cat((tokens, pos_encoding), dim=1)
55        if not self.dynamic
56        # N, C+Cp, L -> N, CT, L
57            tokens = self.conv_token(tokens)
58        else:
59        # N, C+Cp, L_a -> N, CT, L_a, then cat to N, CT, (L_a + L_b)
60            tokens = torch.cat((T_b, self.conv_token(tokens)), dim=2)
61        return tokens
Listing 1: Pseudo-code of the tokenizer.

The pseudo-code implementation for the position encoder, described in Section 3.1.2.

1class PosEncoder(nn.Module):
2    def __init__(self, C, size, num_downsample):
3        super(PosEncoder, self).__init__()
4        self.size = size
5        ds_conv = []
6        for _ in range(num_downsample):
7            ds_conv.append(conv3x3_2d(dim, dim, stride=2))
8        ds_conv.append(conv1x1_2d(dim, 1))
9        self.ds_conv = nn.Sequential(*ds_conv)
10        pos_dim = size**2 // (4**num_downsample)
11        self.pos_conv = conv1x1_1d(pos_dim, pos_dim)
12        self.pos_dim = pos_dim
13
14    def forward(self, token_coef, input_size):
15        H, W = input_size
16        N, h, HW, L = token_coef.shape
17        token_coef = token_coef.view(N * L, h, H, W)
18        # interpolation to deal with input with varying sizes
19        token_coef = F.interpolate(
20            token_coef, size=(self.size, self.size))
21        # downsampling
22        token_coef = self.ds_conv(token_coef)
23        token_coef = token_coef.view(N, L, -1).permute(0, 2, 1)
24        # compress and compute the position encoding.
25        return self.pos_conv(token_coef) # N, Cp, L
Listing 2: Pseudo-code of the potision encoder.

The pseudo-code implementation for the transformer, described in Section 3.2.

1class Transformer(nn.Module):
2    def __init__(self, CT, head=16, kqv_groups=8):
3        super(Transformer, self).__init__()
4        self.k_conv = conv1x1_1d(CT, CT // 2, groups=kqv_groups)
5        self.q_conv = conv1x1_1d(CT, CT // 2, groups=kqv_groups)
6        self.v_conv = conv1x1_1d(CT, CT, groups=kqv_groups)
7        self.ff_conv = conv1x1_1d(CT, CT)
8        self.head = head
9        self.CT = CT
10
11    def forward(self, tokens):
12        N = tokens.shape[0]
13        # k: N, h, CT//2//h, L
14        k = self.k_conv(tokens).view(
15            N, self.head, self.CT // 2 // self.head, -1)
16        # q: N, h, CT//2//h, L
17        q = self.q_conv(tokens).view(
18            N, self.head, self.CT // 2 // self.head, -1)
19        # v: N, h, CT//h, L
20        v = self.v_conv(tokens).view(
21            N, self.head, self.CT // self.head, -1)
22        # kq: N, h, L, L
23        kq = torch.matmul(k.permute(0, 1, 3, 2), q)
24        kq = F.softmax(kq / np.sqrt(kq.shape[2]), dim=2)
25        # kqv: N, CT, L
26        kqv = torch.matmul(v, kq).view(N, self.CT, -1)
27        tokens = tokens + kqv
28        tokens = tokens + self.ff_conv(tokens)
29        return tokens
Listing 3: Pseudo-code of the transformer.

The pseudo-code implementation for the projector, described in Section 3.3.

1class Projector(nn.Module):
2    def __init__(self, CT, C, head=16, groups=16):
3        super(Projector, self).__init__()
4        self.proj_value_conv = conv1x1_1d(CT, C)
5        self.proj_key_conv = conv1x1_1d(CT, C),
6        self.proj_query_conv = conv1x1_2d(C, CT, groups=groups)
7        self.head = head
8
9    def forward(self, feature, token):
10        N, _, L = token.shape
11        h = self.head
12        proj_v = self.proj_value_conv(token).view(N, h, -1, L)
13        proj_k = self.proj_key_conv(token).view(N, h, -1, L)
14        proj_q = self.proj_query_conv(feature)
15        N, C, H, W = proj_q.shape
16        proj_q = proj_q.view(N, h, C // h, H * W).permute(0, 1, 3, 2)
17        # proj_coef: N, h, HW, L
18        proj_coef = F.softmax(
19            torch.matmul(proj_q, proj_k) / np.sqrt(C / h), dim=3)
20        # proj: N, h, C//h, HW
21        proj = torch.matmul(proj_v, proj_coef.permute(0, 1, 3, 2))
22        _, _, H, W = x.shape
23        proj = .view(N, -1, H, W))
24        return feature + proj.view(N, -1, H, W)
Listing 4: Pseudo-code of the projector.

Appendix B Stage-wise model description of VT-ResNet

In this section, we provide a detailed description of the model configurations for visual-transformer-based ResNet (VT-ResNet). We use three hyper-parameters to control a block of visual transformer: channel size of the output feature map C, channel size of visual token CT, and the number of visual tokens L. The model configurations of VT-ResNets are described in Table 5.

Stage Resolution VT-R-18 VT-R-34 VT-R-50 VT-R-101
1 5656 conv77-C64-S2 maxpool33-S2
2 5656 BB-C64 BB-C64 BN-C256 BN-C256
3 2828 BB-C128 BB-C128 BN-C512 BN-C512
4 1414 BB-C256 BB-C256 BN-C1024 BN-C1024
5 1414 || 8
VT-C512-L8
-CT1024
VT-C512-L8
-CT1024
VT-C1024-L8
-CT1024
VT-C1024-L8
-CT1024
head 8 avgpool-fc1000
Table 5: Model descriptions for VT-ResNets. VT-R-18 denotes visual-transformer-ResNet-18. “conv77-C64-S2” denotes a 7-by-7 convolution with an output channel size of 64 and a stride of 2. “BB-C642” denotes a basic block [5] with an output channel size of 64 and it is repeated twice. “BN-C2563” denotes a bottleneck block [5] with an output channel size of 256 and it is repeated by three times. “VT-C512-L8-CT1024 2” denotes a visual transformer block with a channel size for the output feature map as 512, channel size for visual tokens as 1024, and the number of tokens as 8.

Appendix C More visualization results

We provide more visualization of the token coefficients in Figure 6.

Figure 6: Token coefficient visualization. Color red denotes higher coefficient values and color blue denotes lower coefficient values. From the visualization, we can see that without any supervision, visual tokens automatically focus on different areas of the image that correspond to different semantic concepts. This supports our intuition of how visual tokens work.