Efficient Content-Based Sparse Attention with Routing Transformers

03/12/2020 ∙ by Aurko Roy, et al. ∙ 3

Self-attention has recently been adopted for a wide range of sequence modeling problems. Despite its effectiveness, self-attention suffers from quadratic compute and memory requirements with respect to sequence length. Successful approaches to reduce this complexity focused on attending to local sliding windows or a small set of locations independent of content. Our work proposes to learn dynamic sparse attention patterns that avoid allocating computation and memory to attend to content unrelated to the query of interest. This work builds upon two lines of research: it combines the modeling flexibility of prior work on content-based sparse attention with the efficiency gains from approaches based on local, temporal sparse attention. Our model, the Routing Transformer, endows self-attention with a sparse routing module based on online k-means while reducing the overall complexity of attention to O(n^1.5d) from O(n^2d) for sequence length n and hidden dimension d. We show that our model outperforms comparable sparse attention models on language modeling on Wikitext-103 (15.8 vs 18.3 perplexity) as well as on image generation on ImageNet-64 (3.43 vs 3.44 bits/dim) while using fewer self-attention layers.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Generative models of sequences have witnessed rapid progress driven by the application of attention to neural networks. In particular,

(Bahdanau et al., 2014; Cho et al., 2014; Vaswani et al., 2017) relied on attention to drastically improve the state-of-the art in machine translation. Subsequent research (Radford et al., 2018; Devlin et al., 2018; Liu et al., 2019; Yang et al., 2019)

demonstrated the power of self-attention in learning powerful representations of language to address several natural language processing tasks. Self-attention also brought impressive progress for generative modeling outside of language, e.g. image

(Parmar et al., 2018; Menick and Kalchbrenner, 2018; Child et al., 2019) and music generation (Huang et al., 2018; Child et al., 2019).

Self-attention operates over sequences in a step-wise manner: at every time-step, attention assigns an attention weight to each previous input element (representation of past time-steps) and uses these weights to compute the representation of the current time-step as a weighted sum of the past input elements (Vaswani et al., 2017). Self-attention (Shaw et al., 2018) is a particular case of attention (Bahdanau et al., 2014; Chorowski et al., 2015; Luong et al., 2015).

Self-attention is commonly used in auto-regressive generative models. These models generate observations step-by-step, modeling the probability of the next symbol given the previously generated ones. At every time step, self-attentive generative models can directly focus on any part of the previous context. In contrast, recurrent neural networks (RNNs) and convolutional neural networks (CNNs) have direct interactions with only a local neighborhood of context around the current time step.

This advantage however comes at a price: unlike recurrent networks or convolution networks, the time and space complexity of self-attention is quadratic in , the length of the sequence. Specifically, for every position , self-attention computes weights for its whole context of length , which induces a complexity of . This makes it difficult to scale attention based models to modeling long sequences. However, long sequences are the norm in many domains, including music, image, speech, video generation and document level machine translation.

Therefore, an important research direction is to investigate sparse and memory efficient forms of attention in order to scale to tasks with long sequence lengths. Previous work has proposed data independent

or fixed sparsity patterns bounding temporal dependencies, such as local or strided attention. At each time step, the model attends only to a fix number of time steps in the past

(Child et al., 2019). Extensions to local attention have suggested learning the length of the temporal sparsity for each attention module in the network (Sukhbaatar et al., 2019). These strategies draw their inspiration from RNNs and CNNs and bound their complexity by attending only to representations summarizing a local neighborhood of the current time step. Their attention matrices (matrices containing the attention weights for every pair of previous, current time-step) are natively sparse and requires instantiating only non-zero entries. While these approaches have achieved good results, fixing the sparsity pattern of a content based mechanism such as self-attention can limit its ability to pool in information from large contexts.

As an alternative to local attention, (Correia et al., 2019) considers content-based sparsity, an approach allowing for arbitrary sparsity patterns. This formulation however does require instantiating a full dense attention matrix prior to sparsification through variants of -sparsity or sparsemax approximations (Blondel et al., 2019).

The present work builds upon these two lines of research and proposes to retain the modeling flexibility of content-based sparse attention while leveraging the efficiency of natively sparse attention matrices. Our formulation avoids sparsemax variants and relies on clustering of attention instead. Each attention module considers a clustering of the space: the current time-step only attends to context belonging to the same cluster. In other word, the current time-step query is routed to a limited number of context through its cluster assignment. This strategy draws inspiration from the application of -means clustering to Non-negative Matrix Factorization (NMF) (Lee and Seung, 2001; Ding et al., 2005; Kim and Park, 2008), which is relevant to the sparsification of non-negative matrices like attention matrices.

Our proposed model, Routing Transformer, combines our efficient clustered-based sparse attention with classical local attention to reach excellent performance both for language and image generation. These results are obtained without the need to maintain attention matrices larger than batch length which is the case with the segment level recurrence mechanism used in (Dai et al., 2019; Sukhbaatar et al., 2019). We present experimental results on language modeling (Wikitext-103 and enwik-8) and unconditional image generation (ImageNet-64). Routing Transformer sets new state-of-the-art while having comparable or fewer number of self-attention layers and heads, both on Wikitext-103 ( vs perplexity) and on ImageNet-64 ( vs bits/dim). We also report competitive results on enwik-8 ( vs perplexity).

2 Related Work

Attention with Temporal Sparsity: Research on efficient attention neural models parallels the advent of attention-based architectures. In the context of speech recognition, (Jaitly et al., 2015) proposed the Neural Transducer which segments sequences in non-overlapping chunks and attention is performed in each chunk independently. Limiting attention to a fixed temporal context around the current prediction has also been explored in (Chorowski et al., 2015), while (Chiu and Raffel, 2017) dynamically segment the sequence into variable sized-chunks.

Hierarchical attention strategies have also been explored: the model first considers which part of the inputs should be attended to before computing full attention in a contiguous neighborhood of the selected area (Gregor et al., 2015; Xu et al., 2015; Luong et al., 2015). Later, hierarchical attention has been simplified by (Liu et al., 2018) that alternates coarse layers (attending to the whole sequence at a lower temporal resolution) with local layers (attending to a neighborhood of the current prediction).

This alternating strategy is also employed by (Child et al., 2019), which introduces bounded and strided attention, i.e. attending to a fixed context in the past at a subsampled temporal resolution. This work formalizes such a strategy using a sparse attention formalism, showing how it relates to full attention with a specific sparsity pattern in the attention matrix. It shows that sparse attention is sufficient to get state-of-the-art results in modeling long sequences over language modeling, image generation and music generation. (Sukhbaatar et al., 2019) builds upon this work and shows that is it is possible to obtain further sparsity by letting the model learn the length of the temporal context for each attention module. This work also makes use of the attention cache introduced in (Dai et al., 2019), a memory mechanism to train models over temporal contexts which extend beyond the length of the training batches.

Attention with Content-Based Sparsity: The above work mainly relies on two efficient ideas: attending to less elements by only considering a fixed bounded local context in the past, and attending to less elements by decreasing the temporal resolution of context. These ideas do not allow arbitrary sparsity patterns in attention matrices. Content-based sparse attention has been introduced to allow for richer patterns and more expressive models. (Martins and Kreutzer, 2017; Malaviya et al., 2018) propose to compute attention weights with variants of sparsemax. (Correia et al., 2019) generalizes this approach to every layer in a Transformer using entmax which allows for more efficient inference. This line of work allows for learning arbitrary sparsity attention patterns from data, based on the content of the current query and past context. However, sparsity here cannot be leveraged to improve space and time complexity since sparsemax/entmax formulations require instantiating the full attention matrix prior to sparsification. This is a drawback compared to temporal sparsity approaches. Our work is motivated by bridging this gap and allows for arbitrary sparsity patterns while avoiding to instantiate non-zero entries of attention matrices.

Sparse Computation beyond Attention: Learning models with sparse representations/activations for saving time and computation has addressed in the past in various context. Previous work often refers to this goal as gating

for conditional computation. Gating techniques relying on sampling and straight-through gradient estimators are common 

(Bengio et al., 2013; Eigen et al., 2013; Cho and Bengio, 2014)

. Conditional computation can also be addressed with reinforcement learning

(Denoyer and Gallinari, 2014; Indurthi et al., 2019). Memory augmented neural networks with sparse reads and writes have also been proposed in (Rae et al., 2016)

as a way to scale Neural Turing Machines 

(Graves et al., 2014). In the domain of language modeling, a related work is the sparsely gated Mixture-of-experts (MOE) (Shazeer et al., 2017) where sparsity is induced by experts and a trainable gating network controls the routing strategy to each sub-network. Another related work is (Lample et al., 2019) who use product quantization based key-value lookups to replace the feed forward network in the Transformer. Our work differs from theirs in that we make use of dynamic key-value pairs to infer sparsity patterns, while their key-value pairs are the same across examples.

3 Self-Attentive Auto-regressive Sequence Modeling

Auto-regressive sequence models decompose the probability of a sequence as


In neural models, the conditional distribution is modeled by a neural network with learned parameters and these parameters are typically learned to maximize the likelihood of the training data. In particular, Transformer architectures have shown to reach state-of-the-art accuracy in several domains, including language modeling (Vaswani et al., 2017; Radford et al., 2018), image generation (Parmar et al., 2018) and music generation (Huang et al., 2018). Transformer models compose a series of attention modules. Each module refines the input representation by taking a weighted average of the representations from the previous modules.

For every module, the input representation is a sequence of vectors from a continuous space of dimension . Thus one may actually treat the input sequence as a matrix . A self-attention layer operates on this representation. It first applies three linear projections,


where and are referred to as keys, queries and values, while are learned projection matrices.

The key and the query matrices determine the attention matrix , where the softmax operator over matrices denotes that the softmax function has been applied to each row. may be interpreted as a matrix of weights in where denotes how much query position at the next layer must pay attention to key position at the previous layer. In the case of self-attention for auto-regressive models, queries attend only over keys from previous time-steps, i.e.


where denotes the lower triangular operator. Given the attention matrix , the next layer representation is computed simply as . In summary,


In practice, Transformer (Vaswani et al., 2017) adds several extensions to this basic self-attention mechanism. In particular, the result of performing self-attention is scaled by . Moreover, each layer relies on multiple attention heads

, i.e. each layer performs multiple projections onto triplet (queries, keys, values) and attention is performed for each head. The attention results from all heads are then concatenated. This strategy allows each head to specialize on different aspects of the input sequence. In addition, Transformer further processes the result of attention through a learnable non-linear transformation (multi-layer perceptron,

) followed by a residual connection and a normalization step, i.e.


where denotes the parameterized normalization step from (Ba et al., 2016). A full Transformer model is therefore a chain of attention modules (Eq. 6

) preceded by an embedding module (learnable representation for symbols and their positions) and followed by a logistic classification module (learnable linear classifier to predict the next symbol).

Our work is interested in the application of the Transformer to long sequences, a challenging problem since space and time complexity of attention is quadratic in sequence length . We describe various approaches to sparse attention including ours in the next section.

4 Efficient Content-Dependent Sparse Attention

Attention-based models can be problematic for long sequences. For a sequence of length , the full attention matrix , as introduced in Section 3, is -dimensional and can be prohibitive to instantiate. This motivates sparse attention models, i.e. models relying on attention matrices which have a majority of zero entries.

For each query, a sparse attention model defines a set of keys which can be attended to. In the following, we introduce the set as the set of key positions that the query at position can attend to, i.e.


For example, classical causal self attention can attend to every key prior to the current query, which translates to . Most previous work on attention sparsity defined such sets purely based on positions, independently of actual query and key vectors. For example, local attention (Luong et al., 2015) considers attending only to a -long time window prior to the current query, . (Child et al., 2019) propose block sparse attention where half the heads perform local attention, and half the heads perform strided attention given by . (Sukhbaatar et al., 2019) is also a variant of local attention where the cardinality of is learned from data with an penalty to trade-off sparsity with modeling accuracy.

These local attention sparsity variants are effective in practice since correlation between observations naturally decrease with time for many problems. In our experiments, we actually find that local attention is a surprisingly strong baseline in both image generation and language modeling: for e.g., a scaled up ImageTransformer (Parmar et al., 2018) gets bits/dim compared to the bits/dim reported in (Child et al., 2019). Similarly, scaled up versions of Transformer with local attention and the relative positional encoding scheme of (Shaw et al., 2018) are able to get perplexity on Wikitext-103 and bits per byte on enwik-8, while the state-of-the-art results using Transformer-XL (Dai et al., 2019) are and respectively. From an efficiency perspective, local attention is also interesting since sparsity patterns are regular, contiguous in memory and known in advance.

In this work, however, we are interested in a more generic formulation of attention sparsity and would like the sparsity pattern to be informed by the data, i.e., . This approach has several modeling advantages: it can accommodate data without a clear ordering over observations. For temporal data, it can also discover patterns with greater sparsity if some types of queries have a longer lasting effect on future observations than others. Content-based sparse attention should however be carefully implemented if we need to avoid instantiating full attention matrices at any point in time. For instance, (Correia et al., 2019) infer sparsity from data but their formulation instantiates a full attention matrix before finding its sparse counterpart. Next section explains how a natively sparse approach can actually be devised inspired by non-negative matrix factorization (NMF).

4.1 Routing Attention with Clustering

Our strategy follows the motivation we delineated in the previous section: we model sparse attention matrices with a low rank sparsity patterns relying on -means clustering. Our strategy first assigns queries and keys to clusters. Then only queries and keys from the same cluster are considered for attention.

Precisely, our model projects keys and queries into a routing matrix as follows


where is a fixed random orthonormal routing projection matrix. The vectors of undergo -means clustering in order to factorize the full attention matrix. The clustering parameters are the centroid vectors . These parameters are model parameters shared across sequences. There are learned online along with the rest of the parameters, as delineated in (Bottou and Bengio, 1995). Once cluster membership for each position in the sequence is determined, we denote with the cluster corresponding to the routing vector . This allows us to define our sparse attention strategy as


where denotes the cluster of the vector . In summary, queries are routed to keys belonging to the same cluster. Therefore, our attention sparsity pattern is of rank , i.e. where and are binary matrices denoting cluster memberships of queries and keys respectively. Note that since we route both queries and keys via the routing matrix , it follows that . It is important to note that this low rank property only concerns the sparsity pattern, while the resulting attention matrix can however be of higher rank ( denotes element-wise product).

We work with keys and values which are unitary vectors, projecting them onto the unit ball, immediately before computing them. Note that performing -means on unitary vectors is equivalent to the spherical -means algorithm. This differentiable normalization (Ba et al., 2016) is useful to link cluster memberships with proximity of queries and keys, as outlined below. We also assume that the projection matrices and used to infer queries and keys are close to each other in max norm. More precisely, we assume the existence of a such that . This can be enforced by adding an auxiliary loss or by explicitly setting . This assumption implies that for any vector it holds that:


where the inequality is entry-wise and is the vector in with all ’s. In this case we first show that for any pair the queries and keys satisfy the following:

Therefore, for small enough , we get that and so we get


Note that Equation 14 follows since is a distance preserving transform. Thus, we have the following implication: This means that, Therefore, when two time steps are assigned the same cluster due to a small distance, it also means that their attention weight is high. This analysis shows that our clustering routing strategy preserves large attention weights as non-zero entries.

Since, we route attention via the matrix we dub our model Routing Transformer. A visualization of the attention scheme and its comparison to local and strided attention is given in Figure 1. The computational complexity of this variant of sparse attention is . Cluster assignments correspond to the first term, i.e. it compares routing vectors to all centroids in a space of size . Query/key dot products corresponds to the second term, i.e. assuming balanced clusters, each of the queries is compared to in its cluster through a dot product of dimension . Therefore the optimal choice of is as in (Child et al., 2019), thereby reducing overall memory and computational cost to instead of (Vaswani et al., 2017).

In practice, we apply regular online -means to train the cluster centroids. However, in order to infer balanced routing patterns, we define the sets to be of equal size roughly , i.e. for every centroid we sort tokens by distance to and cluster membership is determined by this threshold (). This strategy is simple and efficient. In particular, it guarantees that all clusters have the same size, which is extremely interesting in terms of computational efficiency on parallel hardware like graphic cards. As a downside, this assignment does not guarantee that each point belongs to a single cluster. In the future, we want to investigate using balanced variants of -means (Banerjee and Ghosh, 2004; Malinen and Fränti, 2014) which is not common in an online setting.

5 Experiments

We evaluate our sparse attention model on various generative modeling tasks including text and image generation. The following sections report our results on Wikitext-103 (Merity et al., 2016), enwik-8 (Mahoney, 2011), as well as ImageNet-64. We find that local attention is a surprisingly strong baseline and that our Routing Transformer outperforms Transformer-XL (Dai et al., 2019) and the Sparse Transformer model of (Child et al., 2019) on all tasks. In all our models, we allocate half the heads to do local attention and the other half to route attention as in Equation 9. We use the Adam optimizer (Kingma and Ba, 2014) with learning rate with and following the learning rate schedule described in (Vaswani et al., 2017). We present unconditional samples from our model as a part of the supplementary material.

5.1 Wikitext-103

Wikitext-103 (Merity et al., 2016) is a large public benchmark data-set for testing long term dependencies in word-level language models. It contains over million tokens from 28K articles extracted from Wikipedia with an average of 3.6K tokens per article, which makes it a reference data-set to model long-term textual dependencies. We train a layer Routing Transformer with heads using the relative position encoding of (Shaw et al., 2018)

and with attention and ReLU dropout rate of

each. For routing attention as in Section 4.1 we choose and attention window to be during both training and evaluation. We describe our results in Table 2 and compare it to other recent work on sparse or recurrent attention such as Adaptive Inputs (Baevski and Auli, 2018) and TransformerXL (Dai et al., 2019) as well as a local attention with relative position encoding baseline (Huang et al., 2018). We find that local attention is a great inductive bias for sparse attention and is better than the adaptive methods proposed in (Baevski and Auli, 2018; Sukhbaatar et al., 2019). Moreover, our Routing Transformer model is able to get a test perplexity of improving on the 18.3 obtained by TransformerXL (Dai et al., 2019) while having fewer self-attention layers, and without the need for segment level recurrence.

5.2 enwik-8

The enwik-8 (Mahoney, 2011) is a data-set to benchmark text compression algorithms in the context of the Hutter prize. This data-set consists of the first 100M bytes of unprocessed Wikipedia. It is typically used to evaluate character-level language models. Similar to the prior work of (Dai et al., 2019; Child et al., 2019) we use a sequence length and benchmark our results against various baselines including local attention. We train a layer model with attention heads with an attention and ReLU dropout rate of each and using the relative position encoding of (Shaw et al., 2018). For routing attention as in Section 4.1 we set and attention window . We report perplexity of like TransformerXL and Sparse Transformer, slightly under from Adaptive Transformer.

5.3 ImageNet

In order to evaluate the ability of our model to capture long term dependencies on a modality other than text, we report results on the ImageNet data-set as used in (Child et al., 2019). For auto-regressive image generation, this data-set consists of images of bytes represented as long sequences of length presented in raster scan, red-green-blue order. We train a layer model with attention heads, with half the heads performing local attention, and the other half routing attention as in Section 3. For routing attention we set , attention window , batch size and train our model for roughly epochs as in (Child et al., 2019). We compare our model to a scaled-up ImageTransformer model with local attention (Parmar et al., 2018) and the SparseTransformer model of (Child et al., 2019).

We find that local attention (Parmar et al., 2018) is a strong baseline for image generation, obtaining bits/dim when scaled up to layers and heads, compared to later work like Sub-scale Pixel Networks (SPN) (Menick and Kalchbrenner, 2018). Our Routing Transformer model achieves a performance of bits/dim (see Table 1) compared to the previous state-of-the-art of bits/dim (Child et al., 2019), thereby showing the advantage of the content based sparsity formulation of Section 4.1.

Model   Layers   Heads   Bits/dim
Glow (Kingma and Dhariwal, 2018)   -   -   3.81
PixelCNN (Van den Oord et al., 2016)   -   -   3.57
PixelSNAIL (Chen et al., 2017)   -   -   3.52
SPN (Menick and Kalchbrenner, 2018)   -   -   3.52
ImageTransformer (Parmar et al., 2018)   24   16   3.48
Sparse Transformer (Child et al., 2019)   48   16   3.44
Routing Transformer   24   16   3.43
Table 1: Results on image generation on ImageNet in bits/dim.
Model   Layers   Heads   Perplexity
LSTMs (Grave et al., 2016)   -   -   40.8
QRNNs (Merity et al., 2018)   -   -   33.0
Adaptive Transformer (Sukhbaatar et al., 2019)   36   8   20.6
Local Transformer   16   16   19.8
Adaptive Input (Baevski and Auli, 2018)   16   16   18.7
TransformerXL (Dai et al., 2019)   18   16   18.3
Routing Transformer   10   16   15.8
Table 2: Results on language modeling on Wikitext-103 data-set. Local Transformer refers to Transformer (Vaswani et al., 2017) with relative position encoding (Shaw et al., 2018) together with local attention. Perplexity is reported on the test set.
Model   Layers   Heads   Bits per byte
T64 (Al-Rfou et al., 2019)   64   2   1.13
Local Transformer   24   8   1.10
TransformerXL (Dai et al., 2019)   24   8   0.99
Sparse Transformer (Child et al., 2019)   30   8   0.99
Adaptive Transformer (Sukhbaatar et al., 2019)   24   8   0.98
Routing Transformer   12   8   0.99
Table 3: Results on language modeling on enwik-8 data-set. Local Transformer refers to Transformer (Vaswani et al., 2017) with relative position encoding (Shaw et al., 2018) together with local attention. Bits per byte (bpc) is reported on the test set.

6 Analysis

We evaluate the difference in attention patterns between local and routed attention and compute the Jensen-Shannon divergence between local attention and routed attention for a random subset of heads in our network on the Wikitext-103 data-set. The divergence is computed over the entire sequence length of . We average over

runs and report means and standard deviations of the

in Table 4. Note that the is always non-negative and is upper-bounded by when computed using the natural logarithm. We observe that the divergence between the different local heads is always very low compared to the divergence between local and routing attention heads, which is almost always very close to the upper-bound of . Divergence between different routing attention heads falls somewhere in between, being closer to the upper-bound. This shows that the attention distribution inferred by the routing attention of Section 4.1 is highly non-local in nature and different heads specialize in attending to very different parts of the input.

layer 0      
layer 1      
layer 2      
layer 3      
layer 4      
layer 5      
layer 6      
layer 7      
layer 8      
layer 9      
Table 4: Jensen-Shannon divergence between the attention distributions of a random local attention head and a random head that routes attention as in Section 4.1 per layer on the Wikitext-103 data-set. We report means and standard deviations computed over runs and use the natural logarithm so that divergences are upper-bounded by .
(a) Local attention
(b) Strided attention
(c) Routing attention
Figure 1: Figures showing 2-D attention schemes for the Routing Transformer compared to local attention and strided attention of (Child et al., 2019). The rows represent the outputs while the columns represent the inputs. For local and strided attention, the colored squares represent the elements every output row attends to. For attention routed as in Section 4.1, the different colors represent cluster memberships for the output token.

7 Conclusion

Transformer models constitutes the state-of-the-art in auto-regressive generative models for sequential data. Their space-time complexity is however quadratic in sequence length, due to their attention modules. Our work proposes a sparse attention model, the Routing Transformer. It relies on content-based sparse attention motivated by non-negative matrix factorization. Compared with local attention models, it does not require fixed attention patterns but enjoys similar space-time complexity. In contrast with prior work on content-based sparse attention, it does not require computing a full attention matrix but still selects sparsity patterns based on content similarity.

Our experiments over text and image generation draw two main conclusions. First, we show that a carefully tuned local attention model establishes a strong baseline on modern benchmark, even compared to recent state-of-the-art models. Second, we show that the Routing Transformer redefines the state-of-the-art in large long sequence benchmarks of Wikitext-103 and ImageNet-64, while being very close to do so on enwik-8 as well. Our analysis also shows that routed attention modules offer complementary attention patterns when compared to local attention.

Overall, our work contributes an efficient attention mechanism that applies to the modeling of long sequences and redefines the state of the art for auto-regressive generative modeling. Our approach could prove useful in domains where the inputs are naturally sparse, such as 3D point clouds, social networks or protein interactions.