Long-Short Transformer: Efficient Transformers for Language and Vision

07/05/2021
by   Chen Zhu, et al.
University of Maryland
Nvidia
12

Transformers have achieved success in both language and vision domains. However, it is prohibitively expensive to scale them to long sequences such as long documents or high-resolution images, because self-attention mechanism has quadratic time and memory complexities with respect to the input sequence length. In this paper, we propose Long-Short Transformer (Transformer-LS), an efficient self-attention mechanism for modeling long sequences with linear complexity for both language and vision tasks. It aggregates a novel long-range attention with dynamic projection to model distant correlations and a short-term attention to capture fine-grained local correlations. We propose a dual normalization strategy to account for the scale mismatch between the two attention mechanisms. Transformer-LS can be applied to both autoregressive and bidirectional models without additional complexity. Our method outperforms the state-of-the-art models on multiple tasks in language and vision domains, including the Long Range Arena benchmark, autoregressive language modeling, and ImageNet classification. For instance, Transformer-LS achieves 0.97 test BPC on enwik8 using half the number of parameters than previous method, while being faster and is able to handle 3x as long sequences compared to its full-attention version on the same hardware. On ImageNet, it can obtain the state-of-the-art results (e.g., a moderate size of 55.8M model solely trained on 224x224 ImageNet-1K can obtain Top-1 accuracy 84.1 scalable on high-resolution images. The source code and models are released at https://github.com/NVIDIA/transformer-ls .

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 15

07/01/2021

Focal Self-attention for Local-Global Interactions in Vision Transformers

Recently, Vision Transformer and its variants have shown great promise o...
09/13/2020

Cluster-Former: Clustering-based Sparse Transformer for Long-Range Dependency Encoding

Transformer has become ubiquitous in the deep learning field. One of the...
12/21/2021

Learned Queries for Efficient Local Attention

Vision Transformers (ViT) serve as powerful vision models. Unlike convol...
11/08/2020

Long Range Arena: A Benchmark for Efficient Transformers

Transformers do not scale very well to long sequence lengths largely bec...
07/25/2021

H-Transformer-1D: Fast One-Dimensional Hierarchical Attention for Sequences

We describe an efficient hierarchical method to compute attention in the...
08/05/2021

FMMformer: Efficient and Flexible Transformer via Decomposed Near-field and Far-field Attention

We propose FMMformers, a class of efficient and flexible transformers in...
12/20/2019

Axial Attention in Multidimensional Transformers

We propose Axial Transformers, a self-attention-based autoregressive mod...

Code Repositories

long-short-transformer

Implementation of Long-Short Transformer, combining local and global inductive biases for attention over long sequences, in Pytorch


view repo

transformer-ls

Official implementation of Long-Short Transformer in PyTorch.


view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Transformer-based models (vaswani2017attention)

have achieved great success in the domains of natural language processing (NLP) 

(devlin2018bert; radford2019language)

and computer vision 

(dosovitskiy2020image; wang2021pyramid; wu2021cvt). These models benefit from the self-attention module, which can capture both adjacent and long-range correlations between tokens while efficiently scaling on modern hardware. However, the time and memory consumed by self-attention scale quadratically with the input length, making it very expensive to process long sequences. Many language and vision tasks benefit from modeling long sequences. In NLP, document-level tasks require processing long articles (e.g., kwiatkowski2019natural; pappagari2019hierarchical)

, and the performance of language models often increases with sequence length 

(e.g., dai2019transformer; rae2019compressive). In computer vision, many tasks involve high-resolution images, which are converted to long sequences of image patches before being processed with Transformer models (dosovitskiy2020image; wu2021cvt; zhang2021visionlongformer). As a result, it is crucial to design an efficient attention mechanism for long sequence modeling that generalizes well across different domains.

Various methods have been proposed to reduce the quadratic cost of full attention, however, an efficient attention mechanism that generalizes well in both language and vision domains is less explored. One family of methods is to sparsify the attention matrix with predefined patterns such as sliding windows (e.g., child2019generating; parmar2018image; beltagy2020longformer; ainslie2020etc) and random sparse patterns (zaheer2020big). These methods leverage strong inductive biases to improve both computational and model performance, but they limit the capacity of a self-attention layer because each specific token can only attend to a subset of tokens. Another family of methods leverages low-rank projections to form a low resolution representation of the input sequence, but the successful application of these methods has been limited to certain NLP tasks  (e.g., wang2020linformer; xiong2021nformer; tay2020synthesizer). Unlike sparse attention, this family of methods allows each token to attend to the entire input sequence. However, due to the loss of high-fidelity token-wise information, their performance sometimes is not as good as full attention or sparse attention on tasks that require fine-grained local information, including standard benchmarks in language (tay2020long) and vision (zhang2021multi).

Despite the rapid progress in efficient Transformers, some proposed architectures can only be applied to bidirectional models (e.g., ainslie2020etc; zaheer2020big; xiong2021nformer)

. Transformer-based autoregressive models have achieved great successes in language modeling 

(brown2020gpt3), image synthesis (razavi2019generating), and text-to-image synthesis (ramesh2021zero), which also involve long texts or high-resolution images. It is desirable to design an efficient transformer that can be applied to both autoregressive and bidirectional models.

In this work, we unify a local window attention and a novel long-range attention into a single efficient attention mechanism. We show that these two kinds of attention have complementary effects that together yield the state-of-the-art results on a range of tasks in language and vision, for both autoregressive and bidirectional models.

Summary of our contributions:

  • [leftmargin=1.2em]

  • We propose Long-Short Transformer (Transformer-LS), an efficient Transformer that integrates a dynamic projection based attention to model long-range correlations, and a local window attention to capture fine-grained correlations. Long-Short Transformer can be applied to both autoregressive and bidirectional models with linear time and memory complexity.

  • We compute a dynamic low-rank projection, which depends on the content of the input sequence. In contrast to previous low-rank projection methods, our dynamic projection method is more flexible and robust to semantic-preserving positional variations (e.g., insertion, paraphrasing). We demonstrate that it outperforms previous low-rank methods (wang2020linformer; xiong2021nformer) on Long Range Arena benchmark (tay2020long).

  • We identify a scale mismatch problem between the embeddings from the long-range and short-term attentions, and design a simple but effective dual normalization strategy, termed DualLN, to account for the mismatch and enhance the effectiveness of the aggregation.

  • We demonstrate that Long-Short Transformer, despite its low memory and runtime complexity, outperforms the state-of-the-art models on a set of tasks from Long Range Arena, and autoregressive language modeling on enwik8 and text8. In addition, the proposed efficient attention mechanism can be easily applied to the most recent vision transformer architectures 

    (wu2021cvt; zhang2021visionlongformer) and provides state-of-the-art results, while being more scalable to high-resolution images. We also investigate the robustness properties of the Transformer-LS on diverse ImageNet datasets.

We organize the rest of the paper as follows. Section 2 discusses related work. We introduce the preliminaries and present our method in Section 3. Section 4 presents experimental results, and we conclude the paper in Section 5.

2 Related Work

2.1 Efficient Transformers

In recent years, many methods have been introduced for dealing with the quadratic cost of full attention. In general, they can be categorized as follows: i) Sparse attention mechanism with predefined patterns (e.g., sliding window), including Sparse Transformer (child2019generating), Image Transformer (parmar2018image), Axial Transformer (ho2019axial) for modeling images, and Longformer (beltagy2020longformer), blockwise self-attention (qiu2019blockwise), ETC (ainslie2020etc), Big Bird (zaheer2020big) for modeling language. ii) Low-rank projection attention, including Linformer (wang2020linformer), Nyströmformer (xiong2021nformer), Synthesizer (tay2020synthesizer). For example, Linformer uses linear layers to project the original high resolution keys () and values () with length to low resolution with size () and allows all query tokens () to attend these compressed representations. iii) Memory-based mechanisms like Compressive Transformer (rae2019compressive) and Set Transformer (lee2019set), which use extra memories for caching global long-range information for use in computing attention between distant tokens. iv) Kernel-based approximation of the attention matrix, including Performer (choromanski2020performer)

, Linear Transformer 

(katharopoulos2020transformers), and Random Feature Attention (peng2021random). vi) Similarity and clustering based methods, including Reformer (kitaev2020reformer), Routing Transformer (roy2021efficient), and Sinkhorn Transformer (tay2020sparse).

Our method seamlessly integrates low-rank projection and local window attentions, and leverages their strengths for modeling long-range and short-term correlations, respectively. In particular, our long-range attention uses a dynamic low-rank projection to encode the input sequence, and outperforms the previous low-rank projection method used by the Linformer (wang2020linformer). In the similar vein, there are a few other methods that try to combine the strengths of different methods. For example, Longformer (beltagy2020longformer) and ETC (ainslie2020etc) augment local window attention with task motivated global tokens. Such global tokens may not be applicable for some tasks (e.g., autoregressive modelling). BigBird (zaheer2020big) further combines local window and global token attention with random sparse attention. It is not applicable in autoregressive tasks because the global token and random sparse pattern are introduced. To compress the model footprint on edge devices, Lite Transformer (wu2020lite) combines convolution and self-attention, but it still has quadratic complexity for long sequences.

2.2 Vision Transformers

Vision Transformer (ViT) (dosovitskiy2020image)

splits images as small patches and treats the patches as the input word tokens. It uses a standard transformer for image classification and has shown to outperform convolution neural networks (e.g., ResNet 

(he2016deep)) with sufficient training data. DeiT (touvron2020training) has applied the teacher-student strategy to alleviate the data efficiency problem of ViT and has shown strong comparable performance using only the standard ImageNet dataset (deng2009imagenet). Instead of applying transformer at a single low resolution of patches (e.g., patches), very recent works, including Pyramid Vision Transformer (PVT) (wang2021pyramid), Swin-Transformer (liu2021swin), T2T-ViT yuan2021t2t, Vision Longformer (ViL) (zhang2021visionlongformer) and Convolutional Vision Transformer (CvT) (wu2021cvt), stack a pyramid of ViTs to form a multi-scale architecture and model long sequences of image patches at much higher resolution (e.g., patches for images with pixels).

Most of previous methods have quadratic complexity of self-attention with respect to the input image size. Swin-Transformer (liu2021swin) achieves linear complexity by limiting the computation of self-attention only within each local window. Perceiver (jaegle2021perceiver) uses cross-attention between data and latent arrays to replace the self-attention on data, thus removing the quadratic complexity bottleneck. HaloNet (vaswani2021scaling) applies local attention on blocked images and only has quadratic complexity with respect to the size of the block. Vision Longformer (ViL) (zhang2021visionlongformer), another concurrent work, achieves linear complexity by adapting the Longformer (beltagy2020longformer) from the NLP domain. The approach augments local window attention with task motivated global tokens, but is not applicable for decoding task (e.g., image synthesis (razavi2019generating; ramesh2021zero)). In contrast, our method reduces the quadratic cost to linear cost by combining local window attention with global dynamic projection attention, which can be applied to any vision transformer architecture for both encoding and decoding tasks.

3 Long-Short Transformer

Figure 1: Long-short term attention of a single attention head. Here, the sequence length , hidden dimension , local window segment size , and rank of dynamic projection . Within the figure, denotes key or value . In the left figure, we virtually replicate or into rows, and highlight the keys and values within the attention span (denoted as ) of all queries for the short-term attention. In the middle figure, all queries attend to the same projected keys and values within the long-term attention. In the right figure, and are first normalized with two sets of Layer Normalizations, and then concatenated so that the queries attend to normalized and within their attention span simultaneously.

Transformer-LS approximates the full attention by aggregating long-range and short-term attentions, while maintaining its ability to capture correlations between all input tokens. In this section, we first introduce the preliminaries of multi-head attention in Transformer. Then, we present the short-term attention via sliding window, and long-range attention via dynamic projection, respectively. After that, we propose the aggregating method and dual normalization (DualLN) strategy. See Figure 1 for an illustration of our long-short term attention.

3.1 Preliminaries and Notations

Multi-head attention is a core component of the Transformer (vaswani2017attention), which computes contextual representations for each token by attending to the whole input sequence at different representation subspaces. It is defined as

(1)

where are the query, key and value embeddings, is the projection matrix for output, the -th head is the scaled dot-product attention, and is the embedding dimension of each head,

(2)

where are learned projection matrices, and denotes the full attention matrix for each attention head. The complexity of computing and storing is , which can be prohibitive when is large. For simplicity, our discussion below is based on the case of 1D input sequences. It is straightforward to extend to the 2D image data given a predetermined order.

3.2 Short-term Attention via Segment-wise Sliding Window

We use the simple yet effective sliding window attention to capture fine-grained local correlations. In other words, each query attends to nearby tokens within a fixed-size neighborhood. Similar techniques have also been adopted in (beltagy2020longformer; zaheer2020big; zhang2021visionlongformer). Specifically, we divide the input sequence into disjoint segments with length for efficiency reason. All tokens within a segment attend to all tokens within its home segment, as well as consecutive tokens on the left and right side of its home segment,111

We use zero padding for the first and last segment.

resulting in an attention span over a total of key-value pairs. See Figure 5 in Appendix for an illustration. For each query at the position within the -th head, we denote the key-value pairs within its window as

. For implementation with PyTorch, this segment-wise sliding window attention is faster than the per-token sliding window attention where each token attends to itself and

tokens to its left and right, and its memory consumption scales linearly with sequence length; see (beltagy2020longformer) and our Figure 3 for more details.

The sliding window attention can be augmented to capture long-range correlations in part, by introducing different dilations to different heads of sliding window attention (beltagy2020longformer). However, the dilation configurations for different heads need further tuning and an efficient implementation of multi-head attention with different dilations is non-trivial. A more efficient alternative is to augment sliding window attention with random sparse attention (zaheer2020big), but this does not guarantee that the long-range correlations are captured in each layer as in full attention. In the following section, we propose our long-range attention to address this issue.

3.3 Long-range Attention via Dynamic Projections

Previous works have shown that the self-attention matrix can be well approximated by the product of low-rank matrices (wang2020linformer). By replacing the full attention with the product of low-rank matrices (katharopoulos2020lineartrans; tay2020synthesizer; xiong2021nformer; peng2021rfa; choromanski2020performer), each query is able to attend to all tokens. Linformer (wang2020linformer) is one of the most representative models in this category. It learns a fixed projection matrix to reduce the length of the keys and values, but the fixed projection is inflexible to semantic-preserving positional variations.

Starting from these observations, we parameterize the dynamic low-rank projection at -th head as , where is the low rank size and depends on all the keys of input sequence. It projects the -dimensional key embeddings and value embeddings into shorter, -dimensional key and value embeddings. Unlike Linformer (wang2020linformer), the low-rank projection matrix is dynamic, which depends on the input sequence and is intended to be more flexible and robust to, e.g., insertion, deletion, paraphrasing, and other operations that change sequence length. Note that, the query embeddings are kept at the same length, and we let each query attend to and . In this way, the full attention matrix can be decomposed into the product of two matrices with columns or rows. Specifically, we define the dynamic projection matrix and the key-value embeddings of low-rank attention as

(3)

where are learnable parameters,222For the CvT-based vision transformer model, we replace with a depth-wise separable convolution, just as its query, key and value projections. and the softmax normalizes the projection weights on the first dimension over all tokens, which stabilizes training in our experiments. Note that in all the experiments we have considered, so remains the same if it depends on . The computational complexity of Eq. 3 is .

To see how the full attention is replaced by the product of low-rank matrices, we compute each head of long-range attention as,

(4)

so the full attention is now replaced with the implicit product of two low-rank matrices and , and the computational complexity is reduced to . Note the effective attention weights of a query on all tokens still sum to 1. Our global attention allows each query to attend to all token embeddings within the same self-attention layer. In contrast, the sparse attention mechanisms (beltagy2020longformer; zaheer2020big) need stack multiple layers to build such correlations.

Application to Autoregressive Models:  In autoregressive models, each token can only attend to the previous tokens, so the long-range attention should have a different range for different tokens. A straightforward way to implement our global attention is to update for each query recurrently, but this requires re-computing the projection in Eq. (3) for every token due to the nonlinearity of softmax, which results in computational complexity. To preserve the linear complexity, for autoregressive models, we first divide the input sequence into equal-length segments with length , and apply our dynamic projection to extract from each segment. Each token can only attend to of segments that do not contain its future tokens. Formally, let be the query at position , be the key-value pairs from the -th segment, and . For autoregressive models, we compute the long-range attention of by attending to , defined as

(5)

In this way, the dynamic low-rank projection is applied to each segment only once in parallel, preserving the linear complexity and the high training speed. By comparison, Random Feature Attention (peng2021random) is slow at training due to the requirement for recurrence.

3.4 Aggregating Long-range and Short-term Attentions

To aggregate the local and long-range attentions, instead of adopting different attention mechanisms for different heads (child2019generating; beltagy2020longformer; wu2020lite), we let each query at -th head attend to the union of keys and values from the local window and global low-rank projections, thus it can learn to select important information from either of them. We find this aggregation strategy works better than separating the heads in our initial trials with the autoregressive language models. Specifically, for the -th head, we denote the global low-rank projected keys and values as , and the local keys and values as within the local window of position for the query . Then the -th attention at position is

(6)

where denotes concatenating the matrices along the first dimension. Furthermore, we find a scale mismatch between the initial norms of and , which biases the attention to the local window at initialization for both language and vision tasks. We introduce a normalization strategy (DualLN) to align the norms and improve the effectiveness of the aggregation in the following.

Figure 2: Left: Ratios of the average norms of the local window to global low-rank key/value embeddings at initialization. Without DualLN, the sparse and low-rank embeddings have a magnitude mismatch. With DualLN, the ratios will be at every layer, which will facilitate optimization. Right: The validation loss of Transformer-LS with and without DualLN on enwik8 and text8.

DualLN: For Transformers with Layer Normalization (LN) (see (xiong2020layer) for an illustration), the

embeddings are the outputs of LN layers, so they have zero mean and unit variance at initialization. The

norm of vectors with zero-mean entries is proportional to their variance in expectation. We note a weighted average will reduce the variance and therefore the norm of such zero-mean vectors. As a result, the embedding vectors from low-rank attention in the weighted average

of Eq. (3) will have smaller norms than the regular key and value embeddings from sliding window attention (see Figure 2 Left for an illustration). This scale mismatch causes two side effects. First, the inner product from local-rank component tends to have smaller magnitude than the local window one, thus the attention scores on long-range attention is systematically smaller. Second, the key-value pairs for the low-rank attention will naturally have less impact on the direction of even when low-rank and local window are assigned with same attention scores, since has smaller norms. Both effects lead to small gradients on the low-rank components and hinders the model from learning to effectively use the long-range correlations.

To avoid such issues, we add two sets of Layer Normalizations after the key and value projections for the local window and global low-rank attentions, so that their scales are aligned at initialization, but the network can still learn to re-weight the norms after training. Specifically, the aggregated attention is now computed as

(7)

where denote the Layer Normalizations for the local and global attentions respectively. As illustrated in Figure 2 Right, the Transformer-LS models trained with DualLN has consistently lower validation loss than the models without DualLN.

4 Experiments

In this section, we demonstrate the effectiveness and efficiency of our method in both language and vision domains. We use PyTorch for implementation and count the FLOPs using fvcore (fvcore).

4.1 Long Range Arena

To evaluate Long-Short Transformer as a bidirectional encoder for long text, we follow the setting of peng2021random and tay2021omninet and select the three NLP tasks from the recently proposed Long Range Arena (LRA) benchmark (tay2020long). As we will evaluate our method on the ImageNet dataset in the following section, we do not evaluate our method on the other two vision tasks from the LRA as they are relatively small and synthetic datasets. We compare our method with the following three tasks:

  • [leftmargin=1.5em]

  • ListOps. ListOps (nangia2018listops) is designed to measure the parsing ability of models through hierarchically structured data. We follow the setting in (tay2020long) in which each instance contains 500-2000 tokens.

  • Text

    . This is a binary sentiment classification task of predicting whether a movie review from IMDb is positive or negative 

    (maas2011imdb). Making correct predictions requires a model to reason with compositional unsegmented char-level long sequences with a maximum length of 4k.

  • Retrieval. This task is based on the ACL Anthology Network dataset (radev2013acl)

    . The model needs to classify whether there is a common citation between a pair of papers, which evaluates the model’s ability to encode long sequences for similarity-based matching. The max sequence length for each byte-level document is 4k and the model processes two documents in parallel each time.

For a fair comparison, we use the PyTorch implementation and the same data preprocessing/split, training hyperparameters and model size from

xiong2021nformer, except for Retrieval where we accidentally used more warmup steps and improved the results for all models. See Appendix B for more details.

Results. Comparisons with other models are given in Table 1. Transformer-LS (best) with the best configurations of for each task are given in Table 5 in Appendix B. We also report the results of using fixed hyperparameter on all tasks. Overall, our Transformer-LS (best) is significantly better than other efficient Transformers, and the model with performs favorably while using only about 50% to 70% computation compared to other efficient Transformers on all three tasks. The advantage of aggregating local and long-range attentions is the most significant on ListOps, which requires the model to understand the tree structures involving both long-term and short-term relations. On Retrieval, where document-level encoding capability is tested, we find our global attention more effective than window attention. The test accuracy of using only dynamic projection is about 10% higher than Linformer on Text

 (i.e., 66.28 vs. 56.12), which has the highest variance in sequence length (i.e. standard deviation 893). This demonstrates the improved flexibility of dynamic projection at learning representations for data with high variance in sequence length, compared to the learned but fixed projection of Linformer. Similarly, Linformer, Nyströmformer and our model outperform full attention on

ListOps, indicating they may have better inductive bias, and efficient Transformers can have better efficacy beyond efficiency.

ListOps Text Retrieval Average
(888 339) (1296 893) (3987 560)
Model Acc. FLOPs Acc. FLOPs Acc. FLOPs Acc.
Full Attention (vaswani2017attention) 37.13 1.21 65.35 4.57 82.30 9.14 61.59
Reformer (kitaev2020reformer) (2) 36.44 0.27 64.88 0.58 78.64 1.15 59.99
Linformer (wang2020linformer) (=256) 37.38 0.41 56.12 0.81 79.37 1.62 57.62
Performer (choromanski2020performer) () 32.78 0.41 65.21 0.82 81.70 1.63 59.90
Nyströmformer (xiong2021nformer) () 37.34 0.61 65.75 1.02 81.29 2.03 61.46
Transformer-LS () 37.50 0.20 66.01 0.40 81.79 0.80 61.77
Dynamic Projection (best) 37.79 0.15 66.28 0.69 81.86 2.17 61.98
Transformer-LS (best) 38.36 0.16 68.40 0.29 81.95 2.17 62.90
Table 1: Accuracy (%) and FLOPs (G) on Long Range Arena (LRA), with the model configs annotated (see Table 5 for more). All results are averages of 4 runs with different random seeds.

4.2 Autoregressive Language Modeling

We compare our method with other efficient transformers on the character-level language modeling where each input token is a character.

Setup.  We train and evaluate our model on enwik8 and text8, each with 100M characters and are divided into 90M, 5M, 5M for train, dev, test, following (Mahoney2009text8). Our smaller 12-layer and larger 30-layer models are Pre-LN Transformers with the same width and depth as Longformer (tay2020long), except that we add relative position encoding to the projected segments in each layer. We adopt the cache mechanism of Transformer-XL (dai2019transformer), setting the cache size to be the same as the input sequence length. We follow similar training schedule as Longformer, and train our model in 3 phases with increasing sequence lengths. The input sequence lengths are 2048, 4096 and 8192 respectively for the 3 phases. By comparison, Longformer trains their model in 5 phases on GPUs with 48GB memory (The maximal of ours is 32GB) where the sequence length is 23,040 in the last phase. The window size of Longformer increases with depth and its average window size is 4352 in phase 5, while our effective number of attended tokens is 1280 on average in the last phase. Each experiment takes around 8 days to finish on 8 V100 GPUs. Detailed hyperparameters are shown in Appendix C. For testing, same as Longformer, we split the dataset into overlapping sequences of length 32K at a step size of 512, and evaluate the BPCs for predicting the next 512 tokens given the previous 32K characters.

Figure 3: Running time and memory consumption of Transformer-XL (full attention) and our Transformer-LS on Char-LM. We increase the sequence length until we use up the 32GB of memory on a V100 GPU. Transformer-LS is the same smaller model in Table 2. We use dashed lines to represent the full attention Transformer and solid lines to represent our model. We use different colors to represent different batch sizes.

Results Table 2 shows comparisons on text8 and enwik8. Our method has achieved state-of-the-art results. On text8, we achieve a test BPC of 1.09 with the smaller model. On enwik8, our smaller model achieves a test BPC of 0.99, and outperforms the state-of-the-art models with comparable number of parameters. Our larger model obtains a test BPC of 0.97, on par with the Compressive Transformer with 2 parameters. Our results are consistently better than Longformer which is trained on longer sequences with 5 stages and 48 GPU memory. In Figure 3, we show our model is much more memory and computational efficient than full attention.

Method #Param text8 enwik8 Dev Test Dev Test T12 al2019character 44M - 1.18 - 1.11 Transformer-XL (dai2019transformer) 41M - - - 1.06 Reformer (kitaev2020reformer) - - - - 1.05 Adaptive (sukhbaatar2019adaptive) 38M 1.05 1.11 1.04 1.02 BP-Transformer (ye2019bp) 38M - 1.11 - 1.02 Longformer (tay2020long) 41M 1.04 1.10 1.02 1.00 Transformer-LS 44M 1.03 1.09 1.01 0.99 Method #Param Test BPC Transformer-XL (dai2019transformer) 88M 1.03 Transformer-XL (dai2019transformer) 277M 0.99 Routing (roy2021efficient) 223M 0.99 Longformer (beltagy2020longformer) 102M 0.99 Sparse (child2019generating) 95M 0.99 Adaptive (sukhbaatar2019adaptive) 209M 0.98 Compressive (rae2019compressive) 227M 0.97 Transformer-LS 110M 0.97
Table 2: BPC () of smaller models on enwik8 and text8 (left), and larger models on enwik8 (right).
Method Model #Param Image FLOPs ImageNet Real V2
(M) Size (G) top-1 (%) top-1 (%) top-1 (%)
CNN ResNet-50 25 224 4.1 76.2 82.5 63.3
ResNet-101 45 224 7.9 77.4 83.7 65.7
ResNet-152 60 224 11 78.3 84.1 67.0
Transformer DeiT-S touvron2020training 22 224 4.6 79.8 85.7 68.5
DeiT-B touvron2020training 86 224 17.6 81.8 86.7 70.9
PVT-Medium wang2021pyramid 44 224 6.7 81.2 - -
PVT-Large wang2021pyramid 61 224 9.8 81.7 - -
Swin-S liu2021swin 50 224 8.7 83.2 - -
Swin-B liu2021swin 88 224 15.4 83.5 - -
PVTv2-B4 wang2021pvtv2 62.6 224 10.1 83.6 - -
PVTv2-B5 wang2021pvtv2 82.0 224 11.8 83.8 - -
ViT-B/16 (dosovitskiy2020image) 86 384 55.5 77.9 - -
ViT-L/16 (dosovitskiy2020image) 307 384 191.1 76.5 - -
DeiT-B (touvron2020training) 86 384 55.5 83.1 - -
Swin-B liu2021swin 88 384 47.1 84.5 - -
CvT-13 wu2021cvt 20 224 6.7 81.6 86.7 70.4
CvT-21 wu2021cvt 32 224 10.1 82.5 87.2 71.3
CvT-LS-13 20.3 224 4.9 81.9 87.0 70.5
CvT-LS-17 23.7 224 9.8 82.5 87.2 71.6
CvT-LS-21 32.1 224 7.9 82.7 87.5 71.9
CvT-LS-21S 30.1 224 11.3 82.9 87.4 71.7
CvT-13 (wu2021cvt) 20 384 31.9 83.0 87.9 71.9
CvT-21 (wu2021cvt) 32 384 45.0 83.3 87.7 71.9
CvT-LS-21 32.1 384 23.9 83.2 88.0 72.5
CvT-LS-21 32.1 448 34.2 83.6 88.2 72.9
ViL-Small beltagy2020longformer 24.6 224 4.9 82.4 - -
ViL-Medium beltagy2020longformer 39.7 224 8.7 83.5 - -
ViL-Base beltagy2020longformer 55.7 224 13.4 83.7 - -
ViL-LS-Medium 39.8 224 8.7 83.8 - -
ViL-LS-Base 55.8 224 13.4 84.1 - -
ViL-LS-Medium 39.9 384 28.7 84.4 - -
Table 3: Test accuracies on ImageNet, ImageNet Real (beyer2020imagenetreal), and ImageNet V2 (recht2019imagenetv2) of models trained on ImageNet-1K. Grey-colored rows are our results. CvT-LS denotes our long-short term attention based on the non-official CvT implementation. ViL models with LS suffixes are our long-short term attention based on the official ViL implementation with relative positional bias.

4.3 ImageNet Classification

We train and evaluate the models on the ImageNet dataset with 1.3M images and 1K classes. We use CvT (wu2021cvt) and ViL (zhang2021visionlongformer), state-of-the art vision transformer architectures, as the backbones and replace their attention mechanisms with our long-short term attention, denoted as CvT-LS and ViL-size-LS in Table 3. CvT uses overlapping convolutions to extract dense patch embeddings from the input images and feature maps, resulting in a long sequence length in the early stages (e.g., patches for images with pixels). For ViL, our sliding window uses the same group size , but each token attends to at most (rounding when necessary) tokens inside the window, instead of as ViL, which allows adding our dynamic projection without increasing the FLOPs. We set for the dynamic projections for both ViL-LS-Medium and ViL-LS-Base. Note that, our efficient attention mechanism does not depend on the particular architecture, and it can be applied to other vision transformers (e.g., dosovitskiy2020image; touvron2020training; wang2021pyramid).

Setup. We implement the CvT model based on a public repository, 333https://github.com/rishikksh20/convolution-vision-transformers because this is a concurrent work with no official implementation when we conduct this work. In Table 3, since our CvT re-implementation gets worse test results than reported ones in their arxiv paper, we still list the best test accuracy from wu2021cvt for fair comparisons. We report the FLOPs of CvT with our implementation for reasonable comparisons, because our CvT-LS implementation is based on that. Same as CvT, all the models have three stages where the first stage downsamples the image by a factor of 4 and each of the following stages downsamples the feature map by a factor of 2. CvT-LS-13 and CvT-LS-21 have the same configuration as CvT-13 and CvT-21. CvT-LS-17 and CvT-LS-21 are our customized models with more layers and higher embedding dimensions in the first two stages (, layers respectively and

dimensions). We train the model for 300 epochs using a peak learning rate of

with the cosine schedule (loshchilov2016sgdr) with 5 epochs of warmup. We use the same set of data augmentations and regularizations as other works including PVT wang2021pyramid and ViL zhang2021visionlongformer. Please refer to Appendix D for more details.

Classification Results. The results are shown in the Table 3, where we also list test accuracies on ImageNet Real and ImageNet V2. Except for CvT, we compare with the original ViT (dosovitskiy2020image) and the enhanced DeiT (touvron2020training), PVT (wang2021pyramid) that also uses multi-scale stragey, ViL (zhang2021visionlongformer) that uses window attention and global tokens to improve the efficiency. Training at high-resolution usually improves the test accuracy of vision transformer. With our long-short term attention, we can easily scale the training to higher resolution, and the performance of CvT-LS and ViL-LS also improves. Our best model with CvT (CvT-LS-21 at ) achieves 0.3% higher accuracy than the best reported result of CvT while using the same amount of parameters and 76% of its FLOPs. In CvT architecture, the spatial dimension of feature maps in earlier stages are large, representing more fine-grained details of the image. Similar to training with high-resolution images, the model should also benefit from denser feature maps. With our efficient long-short term attention, we can better utilize these fine-grained feature maps with less concerns about the computational budget. In this way, our CvT-LS-17 achieves better result than CvT-21 at resolution 224 using fewer parameters and FLOPs, and our CvT-LS-21S model further improves our CvT-LS-21 model.

Our ViL-LS-Medium and ViL-LS-Base with long-short term attention improve the accuracies of ViL-Medium and ViL-Base from 83.5 and 83.7 to 83.8 and 84.1 respectively, without an increase in FLOPs. When increasing the resolution for training ViL-LS-Medium from to , the FLOPs increased (approximately) linearly and the accuracy improved by 0.6%, showing our method still benefits greatly from increased resolution while maintaining the linear complexity in practice.

Short-term Attention Suppresses Oversmoothing. By restricting tokens from different segments to attend to different windows, our short-term sparse local attention encourages diversity of the feature representations and helps to alleviate the over-smoothing problem (gong2021improve) (where all queries extract similar information in deeper layers and the attention mechanism is less important), thus can fully utilize the depth of the network. As in (gong2021improve)

, we provide the cosine similarity of patch embeddings of our CvT

-LS-13 and re-implemented CvT-13 (81.1 accuracy) in Figure 6 within Appendix. This is one of the reasons why our efficient attention mechanism can get even better results than the full attention CvT model in the same setting.

Model Params ImageNet IN-C (hendrycks2018imagenetc) IN-A hendrycks2019natural IN-R hendrycks2020many ImageNet-9 xiao2020noise
(M) Top-1 mCE () Acc. Acc. Mixed-same Mixed-rand
ResNet-50 (he2016deep) 25.6 76.2 78.9 6.2 35.3 87.1 81.6
DeiT-S (touvron2020training) 22.1 79.8 57.1 19.0 41.9 89.1 84.2
CvT-LS-13 20.3 81.9 58.7 27.0 42.6 90.7 85.6
CvT-LS-21 32.1 82.7 55.2 29.3 45.0 91.5 85.8
Table 4: Robustness evaluation on various ImageNet datasets. Top-1/Acc.: Top-1 accuracy. mCE: Mean Corrupution Error. Mixed-same/Mixed-rand: accuracies on Mixed-Same/Mixed-Rand subsets.

Robustness evaluation on Diverse ImageNet Datasets. As vision models have been widely used in safety-critical applications (e.g. autonomous driving), their robustness is vital. In addition to out-of-distribution robustness (ImageNet-Real and Imageet-v2), we further investigate the robustness of our vision transformer against common corruption (ImageNet-C), semantic shifts (ImageNet-R), Background dependence (ImageNet-9) and natural adversarial examples (ImageNet-A). We compare our methods with standard classification methods, including CNN-based model (ResNet he2016deep) and Transformer-based models (DeiT touvron2020training) with similar numbers of parameters.

As shown in Table 4, we observe that our method significantly outperforms the CNN-based method (ResNet-50). Such results further empirically verify the importance of attention mechanisms for the robustness of vision models. Compared to DeiT, our models also achieve favorable improvements. These results indicate that the design of different attention mechanisms plays an important role for model robustness. It sheds new light on the design of robust vision transformers. More details and results can be found in the appendix D.

5 Conclusion

In this paper, we introduced Long-Short Transformer, an efficient transformer for long sequence modeling for both language and vision domain, including both bidirectional and autoregressive models. We design a novel global attention mechanism with linear computational and memory complexity in sequence length based on a dynamic projection. We identify the scale mismatch issue and propose the DualLN technique to eliminate the mismatch at initialization and more effectively aggregate the local and global attentions. We demonstrate that our method obtains the state-of-the-art results on the Long Range Arena, char-level language modeling and ImageNet classification. We look forward to extending our methods to more domains, including document QA, object detection and semantic segmentation on high-resolution images.

References

Appendix A Details of Norm Comparisons

As we have shown in Figure 2, the norms of the key-value embeddings from the long-term and short-term attentions, and , are different at initialization, and the norms of is always larger than on different networks and datasets we have evaluated. Here, we give an explanation.

Intuitively, at initialization, following similar assumptions as (he2015delving; glorot2010understanding), the entries of should have zero mean. Since each entry of is a weighted mean of , they have smaller variance unless one of the weights is 1. Given that are also zero-mean, the norm of their embedding vectors (their rows), which is proportional to the variance, is smaller. For the key-value embeddings from short-term attention, are just a subset of , so their embedding vectors should have the same norm as rows of in expectation. Therefore, the norms of embedding vectors from will be smaller than those from in expectation.

Appendix B Details for Experiments on Long Range Arena

Architecture.

On all tasks, the models have 2 layers, with embedding dimension , head number , FFN hidden dimension 128, smaller than those from (tay2020long). Same as (tay2020long), we add a CLS token as a global token and use its embedding in the last layer for classification. We re-implement the methods evaluated by xiong2021nformer, and report the best results of our re-implementation and those reported by xiong2021nformer. For our method, the results we run a grid search on the window size and the projected dimension , and keep to make the complexity similar to the other methods. The maximum sequence length for ListOps and Text are 2048 and 4096. For Retrieval, we set the max sequence for each of the two documents to 4096.

ListOps (2k) Text (4k) Retrieval (4k)
Dynamic Projection 0 4 0 128 0 256
Transformer-LS 16 2 1 1 1 254
Table 5: Configurations of our method corresponding to the best results (Transformer-LS (best)) in Table 1.

Hyperparameters for Training.

Our hyperparameters are the same as Nyströmformer xiong2021nformer unless otherwise specified. Specifically, we follow xiong2021nformer and use Adam with a fixed learning rate of without weight decay, batch size 32 for all tasks. The number of warmup training steps and total training steps are different due to the difference in numbers of training samples. For Retrieval, we accidentally found using rather than the default of xiong2021nformer improves the results for all models we have evaluated. See Table 6 for the configurations of each task.

lr batch size
ListOps 32 1000 5000
Text 32 8000 20000
Retrieval 32 8000 30000
Table 6: Training Hyperparameters for LRA tasks.

Error bars.

We have already provided the average of 4 runs with different random seeds in Table 1. Here we also provide the standard deviations for these experiments in Table 7.

ListOps Text Retrieval Average
(888 339) (1296 893) (3987 560)
Model Acc. FLOPs Acc. FLOPs Acc. FLOPs Acc.
Full Attention (vaswani2017attention) 37.1 0.4 1.21 65.4 0.3 4.57 82.3 0.4 9.14 61.59
Reformer (kitaev2020reformer) (2) 36.4 0.4 0.27 64.9 0.4 0.58 78.6 0.7 1.15 59.99
Linformer (wang2020linformer) (=256) 37.4 0.4 0.41 56.1 1.5 0.81 79.4 0.9 1.62 57.62
Performer (choromanski2020performer) () 32.8 9.4 0.41 65.2 0.2 0.82 81.7 0.2 1.63 59.90
Nyströmformer (xiong2021nformer) () 37.3 0.2 0.61 65.8 0.2 1.02 81.3 0.3 2.03 61.46
Transformer-LS () 37.5 0.3 0.20 66.0 0.2 0.40 81.8 0.3 0.80 61.77
Dynamic Projection (best) 37.8 0.2 0.15 66.3 0.7 0.69 81.9 0.5 2.17 61.98
Transformer-LS (best) 38.4 0.4 0.16 68.4 0.8 0.29 82.0 0.5 2.17 62.90
Table 7: Accuracy (%) and its standard deviation on Long Range Arena (LRA), with the model configurations and sequence length stats (under the dataset names) annotated. All results are averages of 4 runs with different random seeds. Note that, text has the largest variance of length (i.e., ).

Appendix C Details for Autoregressive Language Modeling

An example of long-short term attention for autoregressive models.

Figure 4: An illustration of effective attention span (colored regions) in Transformer-LS when the segment size for the low-rank attention is , and the segment size for the sliding window attention is . Left: the attention span of only the low-rank attention (segment-wise dynamic projection). Right: the attention span of the aggregated attention.

We give an illustration for the segment-wise dynamic projection for autoregressive models as discussed in Section 3.3. With the segment-wise formulation, we can first compute the low-rank projection for each segment in parallel, and each query will only attend to the tokens from segments that do not contain the future token or the query token itself. The whole process is efficient and maintain the complexity, unlike RFA (peng2021random) which causes a slow-down in training due to the requirement for cumulative sum. However, in this way, some of the most recent tokens are ignored, as shown in Figure 4 (left). The window attention (with segment size ) becomes an indispensable component in this way, since it fills the gap for the missing recent tokens, as shown in Figure 4.

Experimental Setup.

Throughout training, we set the window size , the segment length , and the dimension of the dynamic low-rank projection , which in our initial experiments achieved better efficiency-BPC trade-off than using or . Our small and large models have the same architecture as Longformer beltagy2020longformer, except for the attention mechanisms. We use similar training schedules as Longformer beltagy2020longformer

. Specifically, for all models and both datasets, we train the models for 430k/50k/50k steps with 10k/5k/5k linear learning rate warmup steps, and use input sequence lengths 2048/4096/8192 for the 3 phases. We use constant learning rate after warmup. We compared learning rates from {1.25e-4, 2.5e-4,5e-4,1e-3} for 100k iterations and found 2.5e-4 to work the best for both models on enwik8, and 5e-4 to work the best on text8. The batch sizes for the 3 phases are 32, 32, 16 respectively. Unlike Longformer and Transformer-XL, we remove gradient clipping and found the model to have slightly faster convergence in the beginning while converging reliably. For smaller models, we use dropout rate 0.2 and weight decay 0.01. For the larger model, we use dropout 0.4 and weight decay 0.1.

Appendix D Details for ImageNet Classification

Architecture.

Figure 5: An illustration of our sliding window attention in 1D autoregressive and bidirectional models. Here, we use a group size . Each token inside each group are restricted to attend to at most tokens. In the bidirectional model, they attend to tokens from the home segment, and tokens to the left and right of the home segment respectively. In the autoregressive model, they attend to tokens to the left of the home segment, as well as all tokens within the home segment that is not a future token.

In general, CvT-LS-13 and CvT-LS-21 closely follow the architectural designs of CvT for fair comparisons. Specifically, in CvT-LS, we feed the token embeddings extracted by the depth-wise separable convolution (chollet2017xception) of CvT to our long-short term attention. For dynamic projection, we replace in Eq. (3) with a depth-wise separable convolution to maintain consistency with the patch embeddings, but we change its BN layer into a weight standardization (huang2017centered; qiao2019ws) on the spatial convolution’s weights for simplicity. We do not use position encoding. All of our models have 3 stages, and the feature map size is the same as CvT in each stage when the image resolutions are the same. CvT-LS-13 and CvT-LS-21 follow the same layer configurations as CvT-13 and CvT-21, i.e., the number of heads, the dimension of each head and the number of Transformer blocks are the same as CvT in each stage. For all models on resolution , we set and . For higher resolutions, we scale up and/or to maintain similar effective receptive fields for the attentions. At resolution , we use and for the 3 stages. At resolution , we use and .

Besides maintaining the CvT architectures, we also try other architectures to further explore the advantage of our method. With the efficient long-short term attention, it becomes affordable to stack more layers on higher-resolution feature maps to fully utilize the expressive power of attention mechanisms. Therefore, we have created two new architectures, CvT-LS-17 and CvT-LS-21S, that have more and wider layers in the first two stages, as shown in Table 8. Compared with CvT-21, CvT-LS-17 has 25% fewer parameters, less FLOPs, but obtained the same level of accuracy. CvT-LS-21S has fewer parameters than CvT-LS-21, more FLOPs, and 0.4% higher accuracy, demonstrating the advantage of focusing the computation on higher-resolution feature maps.

Output Size Layer Name CvT-LS-17 CvT-LS-21S
Stage 1 56 56 Conv. Embed. 7

7, 128, stride 4

56 56 Conv. Proj.
LSTA
MLP
Stage 2 28 28 Conv. Embed. 3 3, 256, stride 2
28 28 Conv. Proj.
LSTA
MLP
Stage 3 14 14 Conv. Embed. 3 3, 384, stride 2
14 14 Conv. Proj.
LSTA
MLP
Table 8: Architectures of our CvT-LS-17 and CvT-LS-21S models. LSTA stands for our Long-Short Term Attention.
Figure 6: Pairwise cosine similarity between patch embeddings at different layers of CvT-13 and CvT-LS-13, averaged on 50k images of ImageNet validation set. The larger cosine similarities at deeper layer suggest that the feature representation is less diverse.

The effect of DualLN.

We trained the CvT-LS-13 model without DualLN, which has a test accuracy of 81.3, lower than the 81.9 with DualLN.

Figure 7: Running memory consumption of full self-attention (CvT-13) and Long-Short Transformer on different tasks. We increase the sequence length resolution until the model is out of memory on a V100 GPU with 32GB memory.

More Results about Robustness.

For a fair comparison, we choose models with similar number of parameters. We select two representative models, including the CNN-based model (ResNet) and the transformer-based model (DeiT). We give detailed results on all types of image transforms on ImageNet-C in Table 9. We evaluate our method on various ImageNet robustness benchmarks as follows:

  • [leftmargin=1.5em]

  • ImageNet-C. ImageNet-C refers to the common corruption dataset. It consists of 15 types of algorithmically common corruptions from noise, blur, weather, and digital categories. Each type contains five levels of severity. In Table 4, we report the normalized mean corruption error (mCE) defined in hendrycks2018imagenetc. In Table 9, we report the corruption error among different types. In both tables, the lower value means higher robustness.

  • ImageNet-A

    . ImageNet-A is the natural adversarial example dataset. It contains naturally collected images from online that mislead the ImageNet classifiers. It contains 7,500 adversarially filtered images. We use accuracy as our evaluation metric. The higher accuracy refers to better robustness.

  • ImageNet-R. ImageNet-R (Rendition) aims to evaluate the model generalization performance on out-of-distribution data. It contains renditions of 200 ImageNet classes (e.g. cartoons, graffiti, embroidery). We use accuracy as the evaluation metric.

  • ImageNet-9. ImageNet-9 aims to evaluate the model background robustness. It designs to measure the extent of the model relying on the image background. Following the standard setting xiao2020noise, we evaluate the two categories, including Mixed-Same and Mixed-Rand. Mixed-Same refers to replace the background of the selected image with a random background of the same class by GrabCut xiao2020noise; Mixed-Rand refers to replace the image background with a random background of the random class.

From table 4, we find that our method achieves significant improvement compared to CNN-based network (ResNet). For instance, our method improves the accuracy by 23.6%, 22.1%, 9.7% compared to ResNet on ImageNet-C, ImageNet-A, and ImageNet-R, respectively. For ImageNet-9, our method also achieves favorable improvement by 4.3% on average (Mixed-same and Mixed-rand). It indicates that our method is insensitive to background changes. We guess the potential reasons for these improvements are (1) the attention mechanism and (2) the strong data augmentation strategies during the training for vision transformer dosovitskiy2020image; touvron2020training. The first design helps the model focus more on the global context of the image as each patch could attend to the whole image areas. It reduces the local texture bias of CNN. The latter design increases the diversity of the training data to improve model’s generalization ability. Compared to DeiT, we also surprisingly find that our method achieves slightly better performance. One plausible explanation is that our long-term attention has a favorable smoothing effect on the noisy representations. Such improvements also indicate that different designs of attention and network architecture can be essential to improve the robustness. As the goal of this paper is not to design a robust vision transformer, the robustness is an additional bonus of our method. We believe that our observation opens new directions for designing robust vision Transformers. We leave the in-depth study as an important future work.

The detailed results of ImageNet-C and ImageNet-9 are shown in Table 9 and Table 10 respectively.

Arch. Noise Blur Weather Digital
Gauss. Shot Impulse Defocus Glass Motion Zoom Snow Frost Fog Bright Contrast Elastic Pixel JPEG
ResNet-50 34.24 49.25 55.84 56.24 57.04 63.53 63.68 64.02 64.04 64.89 69.25 70.72 73.14 75.29 75.76
DeiT-S 26.93 36.81 36.89 39.38 40.14 43.32 43.80 44.36 45.71 46.90 47.27 48.57 52.15 57.53 62.91
CvT-LS-13 25.64 36.89 37.06 38.06 43.78 43.78 44.62 45.92 47.77 47.91 49.60 49.66 54.92 57.24 68.72
CvT-LS-17 25.26 35.06 35.48 37.38 41.37 43.95 44.47 46.05 46.17 46.38 49.08 49.37 54.29 54.54 69.54
CvT-LS-21 24.28 34.95 35.03 35.93 39.86 40.71 41.27 41.78 44.72 45.24 45.50 47.19 51.84 53.78 67.05
Table 9: Corruption Error (CE) on ImageNet-C
Model Params (M) ImageNet (%) ImageNet-9 xiao2020noise(%)
Original Mixed-same Mixed-rand
ResNet-50 (he2016deep) 25.6 76.2 94.9 87.1 81.6
DeiT-S (touvron2020training) 22.1 79.8 97.1 89.1 84.2
CvT-LS-13 20.3 81.9 97.0 90.7 85.6
CvT-LS-21 32.1 82.7 97.2 91.5 85.8
Table 10: Robustness evaluation on ImageNet-9. We report Top-1 Accuracy.