ETC: Encoding Long and Structured Data in Transformers

04/17/2020 ∙ by Joshua Ainslie, et al. ∙ Google 46

Transformer-based models have pushed the state of the art in many natural language processing tasks. However, one of their main limitations is the quadratic computational and memory cost of the standard attention mechanism. In this paper, we present a new family of Transformer models, which we call the Extended Transformer Construction (ETC), that allows for significant increases in input sequence length by introducing a new global-local attention mechanism between a global memory and the standard input tokens. We also show that combining global-local attention with relative position encodings allows ETC to handle structured data with ease. Empirical results on the Natural Questions data set show the promise of the approach.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 5

page 6

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Models such as BERT Devlin et al. (2018) based on the Transformer architecture Vaswani et al. (2017) or many of their variants have yielded state-of-the-art performance on an increasing number of natural language processing tasks such as language modeling Rae et al. (2019), question answering Lan et al. (2019), and summarization Zhang et al. (2019). One of the main limitations of the original Transformer architecture is that the computational and memory complexity of its attention mechanism scales quadratically as a function of the input sequence length. Thus most of these models limit their input lengths to 512 tokens. Many extensions of the original attention mechanism have been proposed recently to address this issue, including hierarchical processing of inputs Zhang et al. (2019), sparse attention Child et al. (2019), and segment-level recurrence Dai et al. (2019).

This paper presents the Extended Transformer Construction (ETC) architecture, an extension of the original Transformer architecture with a new attention mechanism that extends the original in two main ways: (1) it allows scaling up the input length from 512 to several thousands; and (2) it allows ingesting structured inputs instead of just linear sequences. Additionally, unlike previous extensions to the Transformer architecture, ETC allows initialization from existing pre-trained standard BERT models (which together with a GPU/TPU-friendly implementation, allow for efficient model training)111An exception to this is the Longformer Beltagy et al. (2020), a new model developed concurrently to ETC, which allows lifting weights from RoBERTa.. Finally, we also show that by using a CPC loss Oord et al. (2018) in addition to the standard MLM loss during pre-training, performance can be improved even further.

The main architectural innovation in ETC is a new flexible attention mechanism called global-local attention, that divides the input into two separate sequences (which we call the global and the long input). This new attention mechanism introduces local sparsity to reduce the quadratic scaling of the attention mechanism. As we show below, when this is coupled with relative position encoding Shaw et al. (2018), it allows for handling structured data in a natural way (take web page data, for example, which has inherent structure in its content, such as headers, blocks, and other markup).

In order to evaluate our model, we report experiments on Google’s Natural Questions (NQ) data set Kwiatkowski et al. (2019), showing significant performance improvements over the base models thanks to the ability to ingest longer inputs. NQ is a challenging question answering data set containing actual user questions issued to Google search, paired with answers found from Wikipedia by annotators.

This preprint presents preliminary results of ETC and is structured as follows. We start by providing some background on existing approaches to extend Transformers to accept long inputs in Section 2. After that, Section 3 presents the technical details of our new model. Section 4 reports experimental results. The paper closes with conclusions and directions for future work.

2 Background

Figure 1: An illustration of various mechanisms used in the literature to scale attention to long inputs.

Since the publication of the original Transformer model Vaswani et al. (2017), and especially after the great success of BERT Devlin et al. (2018), a number of variations of the model have been proposed in the literature. For example, work exists on scaling up the training process (RoBERTa Liu et al. (2019)), scaling the internal representation of the model (ALBERT Lan et al. (2019)), or both (T5 Raffel et al. (2019)), outperforming the original BERT model in tasks such as GLUE Wang et al. (2018), SQuAD Rajpurkar et al. (2016) or RACE Lai et al. (2017). However, these models use input sequences of length up to 512 tokens due to computational and memory constraints: namely, the computational and memory cost of attention in Transformers, where

is the length of the input sequences. This work builds on prior efforts to scale up the attention mechanism and network architecture to accommodate longer input sequences. We classify these approaches into four broad categories: sparse attention, recurrence, hierarchical mechanisms, and compressed attention, which we elaborate on below.

Sparse Attention involves limiting each token to attend only to a subset of the other tokens in the input. For example, the Sparse Transformer Child et al. (2019) used predefined attention patterns with applications to both natural language tasks and image generation. For example, they showed that attending only to previous pixels in the same row or column was enough for generating high quality images. This allows attention to go from to . Another idea is that of Adaptive Attention Span Transformer Sukhbaatar et al. (2019), where each attention head is associated with a decaying masking function, which limits the number of tokens it can attend to. Making those masking functions learnable, they show that lower layers tend to learn to use short attention spans, and it is not until the higher layers of the model, that attention spans are longer. The Reformer Kitaev et al. (2020) model assumes that the result of attention will be dominated by those tokens with higher attention weights, so they reduce computational cost by calculating the nearest neighbors to the attention query (those input tokens that would result in the highest attention weight) using locality sensing hashing Andoni et al. (2015) and use only those for attention. This reduces attention cost to . Another approach is found in the Routing Transformer Roy et al. (2020), which learns dynamic sparse attention patterns using online -means, reducing complexity to . Finally, the most related approach to the work presented in this paper is the Longformer Beltagy et al. (2020), developed concurrently to ETC, and which features a very similar global-local attention mechanism as the one we describe below.

Recurrence

incorporates elements of recurrent neural networks into Transformer models to enlarge their attention span. This was done in the

Transformer-XL Dai et al. (2019) model, where the input sequence is divided into smaller segments of the same length. The input is then processed one segment at a time. At each layer, the model can attend to the embeddings resulting from the layer immediately below both for the current segment and for the previous input segment. The end effect is that layer is influenced by different segments (the current one and the previous segments). This is illustrated in the top-right of Figure 1.

Hierarchical Mechanisms is a common method where the input sequence is first split into either sentences or blocks that are ingested independently to produce single embeddings that represent the whole sequence or block. Then, a separate BERT-style model ingests the concatenation of these embeddings. For example, HIBERT Zhang et al. (2019)

uses this idea at the sentence level for extractive summarization. This is illustrated in the bottom-left of Figure 

1, showing how the number of embeddings that need to be processed by the higher-level BERT model is much smaller than the length of the input tokens. Notice that this idea of processing the input hierarchically is not specific to Transformer models, and it has been applied to recurrent neural network models both at the level of sentences Yang et al. (2016); Miculicich et al. (2018) and blocks Shen et al. (2018). The main downside of this approach is that token-level attention across blocks is not possible, and thus long range attention only happens at the summarized high-level.

Compressed Attention takes the idea of hierarchical attention one step further by selectively compressing certain parts of the input. For example, the BP-Transformer Ye et al. (2019) model builds a binary partitioning tree over the input, and only lets the model attend to the leaves (the raw tokens) for nearby tokens, and higher nodes in the tree (which contain summaries of groups of tokens) as tokens get farther and farther (see Figure 1, middle top). Other ideas include memory compressed attention Liu et al. (2018) where groups of tokens are compressed via a convolution filter before they are attended to, and the Star Transformer Guo et al. (2019) where each token can attend only to its immediate left/right neighbors and to a separate special auxiliary token that represents a summary of the whole input. (See Figure 1, center bottom.) Finally, some models combine these ideas. The Compressive Transformer Rae et al. (2019) compresses tokens in the input sequence that are far away. The idea is that the model benefits more from detailed attention to nearby tokens, but for farther away tokens, summarized/compressed information of the content of those tokens might be enough. As illustrated in Figure 1 (bottom right), the Compressive Transformer integrates ideas from recurrence in the Transformer-XL. However, instead of discarding old segments, those segments are first pushed into a FIFO memory. When they are about to be evicted from this memory, they are moved to a compressed memory (where more than one token from the memory is compressed into a single token of the compressed memory, which is another FIFO queue).

3 Extended Transformer Construction

Our model closely follows the original Transformer architecture Vaswani et al. (2017), with a few key modifications (training losses, relative position encoding, and attention mechanism), which we explain below. In this paper, we consider only the encoder side of the Transformer, and leave the decoder for future work.

3.1 Relative Position Encoding

Inspired by the work of Shaw et al. (2018), ETC uses relative position encodings rather than the absolute position encodings used in the original BERT model. We briefly describe this mechanism here for completeness, but refer the reader to the work of Shaw et al. for full details.

The goal of relative position encodings is to provide the model with information about the relative position of each token in the input sequence with respect to one another. In the most general sense, given the input sequence , we can see it as a labeled fully connected and directed graph, where is the label in the edge that connects to

. The idea is that these labels will translate into learnable weight vectors that modify the way tokens attend to each other.

In the original work of Shaw et al., given a maximum clipping distance , they define relative position labels: , and the label of the edge between two input tokens depends only on their relative position . For all input pairs whose relative position is larger than , label is given, and for those with relative position smaller than , is given. Each different label then turns into a learnable , which modifies the attention mechanism (exact equations in the next section)222In the original work of Shaw et al., a second vector was used per label, but we have not included this second vector in ETC as their ablations showed dropping it may not affect performance..

One advantage of relative position encodings is that they are independent of input length, so it is easy to adapt a model to greater input lengths than seen during pre-training. Additionally, as some other recent work Shaw et al. (2019), ETC’s attention mechanism uses relative position labels not just for encoding relative positions in a sequence but also to express arbitrary token relations. This allows ETC to support structured data in a very flexible way.

3.2 Global-Local Attention

Figure 2: Sparsity diagram showing which attention queries (rows) can attend to which attention keys (columns) a) for standard Transformer attention with input size ; b) for global-local attention with input sizes , , and radius ; c) how l2l attention piece is reshaped into a much smaller attention matrix, limited by local radius.

ETC uses a new attention mechanism called global-local attention. Global-local attention is a generalization of several of the models presented above (such as the standard Transformer, the Star Transformer, and hierarchical mechanisms).

Whereas the input of a standard Transformer is a sequence , we split the input in ETC into two separate input sequences, the global input and the long input , which are treated separately. Typically, the long input contains the regular input a standard Transformer would receive, while the global input contains a much smaller number of auxiliary tokens (). Attention is then split into four separate pieces (illustrated in Figure 2: global-to-global (g2g), global-to-long (g2l), long-to-global (l2g), and global-to-global (l2l).

The key idea is that attention in the l2l piece (the most computationally expensive piece) will be restricted to a fixed size radius . To compensate for this limited attention span, the tokens in the global input have unrestricted attention, and thus tokens that are arbitrarily far apart in the long input can transfer information to each other though the global input tokens. Thus, g2g, g2l, and l2g pieces of attention are unrestricted (any token can attend to any other token), and attention in the l2l piece is restricted to a fixed size radius .

The idea is illustrated in Figure 2, which shows sparsity diagrams where each cell in a row and column is shaded grey if input token can attend to input token . As we can see, in a regular Transformer, any token can attend to any other, thus resulting in full attention. The attention mechanism of ETC, illustrated in Figure 2b, however, restricts the l2l piece to a local radius, significantly reducing computational and memory complexity for very long inputs. Conceptually, the l2l attention piece is reshaped into a matrix as illustrated in Figure 2c.333In practice, however, for GPU/TPU implementation efficiency, a slightly different reshaping occurs while still yielding identical outputs.

Notice that if we take and , we recover exactly the Star Transformer described in Section 2. Similarly, if we place all the input tokens in the global input, and make , we obtain the standard Transformer attention. In general, attention in ETC is thus . If we assume , this results in a complexity of , which is linear in the size of the long input.

Moreover, in order to provide flexibility, per-training instance Boolean attention matrices , , , and are assumed to exist, in order to gain further control on which tokens can attend to each other. For example, will contain zeroes for those pair of input tokens that should not attend to each other.

Specifically, each head of g2g attention works as follows. Given the global input , which is a sequence of token representations , the output of attention is , where each is calculated as follows:

where: is a binary attention mask, , , and are learnable weight matrices, and are are learnable vectors representing the relative position labels between each input positions and . A different vector is learned for each different relative position label. Attention for the other 3 pieces is analogous. To improve performance, attention is actually only split into 2 pieces internally instead of 4, as g2g+g2l can be computed jointly (top half of Figure 2c), and l2g+l2l can also be computed jointly (bottom half of Figure 2c). Thus, a single softmax is used to jointly calculate and , and another for and .

Finally, notice that each of the four attention pieces can, in principle, use a different subset of the different relative position labels, in order to allow for complex attention patterns (for example, to handle structured data), as we will describe below.

3.3 Handling Long Inputs with Global-Local Attention

Figure 3: Some example attention patterns for (a) handling long inputs, (b) handling structured inputs. White background means attention is masked via , and the different colors indicate different relative position labels.

Let us illustrate how this new form of attention mechanism can be used to handle tasks that require very long inputs. Although ETC allows for very flexible configurations to suit specific applications, one general way to handle long inputs is to recreate a similar hierarchical attention pattern as used in models like HIBERT described above. The long sequence of input tokens (e.g., word pieces) can be placed in the long input. Then, assuming some sort of division into segments (for example, diving the long input by sentence, paragraph or some other form of division), we place one segment token in the global input per each segment in the long input. We then use different relative position labels to link the global segment tokens with the word piece tokens depending on whether the word piece token belongs to it or not. Moreover, as we will show in the experiments below we have seen that using the attention masks to perform hard masking in one direction (g2l) can bring performance gains. This last asymmetric hard-masking is illustrated in Figure 3a, where we used different colors to indicate different relative position labels. In this way, although tokens in the long input can only attend to the local neighborhood defined by the radius , they can indirectly attend to all other tokens in the input sequence via the global tokens (which can all attend to each other), reminiscent of some of the hierarchical mechanisms described in Section 2.

Moreover, different relative position labels can be given to the attention edges between global and long input tokens to indicate the fact that these global segment tokens are “summary tokens”. Finally, if the different segments in the long input are not ordered (e.g., they are different retrieved documents), this can also be indicated by setting all the relative position labels between the segment tokens in the global input to be the same, and setting the mask appropriately so that word pieces from one document do not attend to word pieces of the next document. This eliminates the notion of ordering between them, naturally representing a “set of sequences” structure. This highlights the flexibility of the global-local attention mechanism when coupled with relative position encoding.

3.4 Handling Structured Inputs with Global-Local Attention

Transformer models are closely related to graph neural networks Scarselli et al. (2008) and graph attention networks Veličković et al. (2017); see Ye et al. (2019). As a matter of fact, a vanilla Transformer is very much like a graph neural network over a fully connected graph. Thanks to the combination of global-local attention and relative position labels, it is possible to exploit this relation to graph neural networks to ingest structured data in a natural way.

This is useful not just for graph data, but also for tackling some natural language processing tasks. Consider, for example the HotpotQA dataset Yang et al. (2018). In this dataset, the input data consists of a collection of sources, and each source is composed of a sequence of sentences. There is no order between the sources, but there is order between the sentences. This structure can easily be captured with relative position labels as follows: consider sources with sentences each. The global input will contain source summary tokens, and sentence summary tokens. Attention between all source summary tokens will have the same relative position label to indicate there is no order, but attention between sentence summary tokens will use labels depending on the distance between the sentences. Moreover, special labels can be used to mark the attention between the source summary tokens and the sentence summary tokens to indicate part of relations. In this way, all the structure of the source data is preserved, and no arbitrary order is introduced just because the input needs to be serialized to be processed by BERT-style models. An illustration of this type of attention pattern is shown in Figure 3b, where we use different colors to indicate different relative position labels.

This is a relatively simple structure, but the same idea can be generalized to capture more complex hierarchical structures, or more complex graph structures, such as those found when ingesting web page data with arbitrary markup.

3.5 Pre-training Tasks

We pre-train ETC using two tasks. First, we use a standard masked language model (MLM) task, but with whole word masking, so that if one word piece token is masked, then all the other word-piece tokens of the same word are also masked. Next, instead of training with the next sentence prediction (NSP) task used by BERT, we adapt Contrastive Predictive Coding (CPC) Oord et al. (2018) for ETC.

The key idea of CPC is to predict subsequent inputs in latent space, i.e., rather than predicting the actual next tokens, to predict their internal hidden representations. We adapted this idea in ETC by using global input sentence summary tokens. For example, given an input sequence containing

sentences, we mask all the tokens corresponding to some of those sentences (but leave the sentence summary tokens in the global input). We then train the model to try to minimize the difference between the hidden representation of the global input sentence summary tokens for the masked sentences with respect to the hidden representations for encoding the sentences themselves using a Noise Contrastive Estimation (NCE) loss, as in the work of Oord et al. (

2018).

3.6 Lifting Weights from Existing Models

Pre-training for large inputs can be expensive. However, the structural similarities between ETC and BERT are enough so that BERT parameters are useful to perform a warm start. The parameters are compatible because the global-local attention mechanism in ETC includes BERT as a special case if the input is small enough (or the local radius is large enough) to eliminate sparsity completely.

There are two primary differences from BERT. First, ETC allows the addition of new global tokens, so pre-training helps the model learn how to use them if present. Second, ETC uses relative position encodings instead of absolute position encodings, so pre-training is required to learn these parameters. Despite these differences, initializing from BERT weights can result in convergence in less than half of the number of epochs compared to training from scratch.

4 Empirical Evaluation

Model Input length CPC loss g2l mask Long answer F1 Short answer F1
BERT 512 no no 0.634 0.475
ETC 512 no no 0.638 0.476
ETC 4096 no no 0.706 0.502
ETC 4096 yes no 0.711 0.505
ETC 4096 no yes 0.712 0.520
ETC 4096 yes yes 0.717 0.529
ETC (2x pre-train) 4096 yes no 0.716 0.522
ETC (2x pre-train) 4096 yes yes 0.728 0.534
Table 1: Empirical results on the Natural Questions data set.

This section presents a preliminary evaluation of ETC, compared against the original BERT model.

4.1 Data Set

We used Google’s Natural Questions (NQ) data set Kwiatkowski et al. (2019) for these experiments. In this data set, the input consists of a question and a whole Wikipedia article. The task is to identify both a short answer (a few words from the article) and a long answer (e.g., a whole paragraph), if they exist within the article (and other wise, return null answers). Performance is measured based on the F1 score of the model predictions with respect to the human generated answers. The data set contains 307,373 training examples, 7,830 development examples, and 7,842 test examples. To have an idea of the length of the inputs in this dataset, median length of the examples in the development set in word piece tokens is 3258, and the maximum is 77962, which is clearly longer than the usual 512 input length of models like BERT.

4.2 Training Configuration

All the models trained in our experiments were trained using the Base setup of BERT (768 hidden input size and 12 layers). The local radius for global-local attention was set to in all experiments, and the maximum clipping distance for relative position labels set to . We considered two input configurations for the ETC model. One with a long input length of 512 tokens and a global input length of 128 tokens, to make it comparable to BERT, and one with a long input length of 4096 tokens and a global input length of 230 tokens444Although we have not attempted to push the limit, in our internal experimentation we have seen that ETC can scale to 16384 tokens with ease using gradient checkpointing and BERT base settings on TPU v3, but we limit our experiments to 4096 tokens in this paper.. The size of the global input was chosen so that there is enough global sentence summary tokens as for capturing different long answer candidates. For pre-training, all the word piece tokens in the input were placed in the long input, and the global input consists of one sentence summary token per sentence in the long input. We used the same 30k English uncased word piece vocabulary used by BERT.

Models were pre-trained using the original BERT datasets, except documents with fewer than 7 sentences were filtered out. Unless stated otherwise, models were pre-trained for 33 epochs. We used the LAMB optimizer You et al. (2019) with learning rate set to , which corresponds with the recommended scale for an 8x batch size increase relative to BERT, although for 4096-token pre-training we scale the tokens per batch rather than the batch size. Both BERT and ETC models were trained from scratch in these experiments.

When pre-training models with 4096-token windows, we split any input documents that are longer than this. For efficiency, we also concatenate as many shorter documents as will fit into the 4096 window and mask attention to prevent them from influencing each other. This results in a roughly 3x speedup in pre-training time, highlighting once more the advantage of flexible masking.

After pre-training, all models were fine-tuned with a hyperparameter sweep consisting of learning rates in

and number of epochs in with a batch-size of 64 on the NQ training set using the Adam optimizer. All reported results are on the development set.

For NQ instances that are longer than long input size, a sliding window approach is used (with stride 2048 for input lengths of 4096, and 128 for input lengths of 512). Final predictions are then aggregated as in the work of

Alberti et al. (2019).

4.3 Results

Results are shown in Table 1. The first two rows show a comparison of the standard BERT model with ETC when using 512 tokens as the input length. As the table shows, BERT’s performance is comparable to ETC. The smaller local radius of ETC (84) in these experiments puts ETC at a disadvantage with respect to BERT, but other ETC improvements, such as dynamic whole word masking, and relative position encodings seem to compensate, so performance of ETC and BERT in the 512 input length setting is comparable. These results establish a baseline for ETC performance using the same input lengths as BERT.

The rest of rows in Table 1 show the performance of ETC when we increase long input length to 4096 tokens. The first thing we see is that performance jumps significantly, from 0.638/0.476 F1 score for long/short answer respectively to 0.706/0.502 F1. Moreover, we can also see that turning CPC loss on achieves a slight bump in performance (to 0.711/0.505 F1), which we attribute to CPC helping the model to better use the global input summary tokens. Additionally, when turning on hard g2l masking we obtain another performance boost. When g2l masking is off, all global tokens can see all long input tokens, and we just use different relative position labels to differentiate those word piece tokens corresponding to a sentence summary token from those that do not. When g2l masking is on, as shown in Figure 3a, global sentence summary tokens can only attend the word piece tokens associated with them, and are blind to the rest. This pushes the results to 0.717/0.519 F1. This highlights the potential benefits of the combination of CPC loss, relative position labels and flexible masking.

The bottom two rows of Table 1 show results where we increased the amount of pre-training in ETC. As we can see, this results in further gains in performance, reaching 0.728/0.534 F1 when CPC loss and g2l masking are activated.

5 Conclusions

This paper introduced the Extended Transformer Construction, or ETC, a novel extension of the original Transformer model designed specifically to (1) scale up the input length to sequences longer than 512 tokens (scaling linearly in the size of the input), and (2) allow ingesting structured inputs. ETC also allows lifting weights from existing BERT models, saving significant computational resources while training. The key ideas that enable ETC to achieve these are a new global-local attention mechanism, coupled with relative position encodings.

Our experimental results show that significant gains can be obtained thanks to increased input sequence length. Additionally, the combination of relative position encodings, flexible masking and CPC loss present in ETC further improves the model quality. We hypothesize that the gains stemming from CPC help the model train the usage of the higher-level global input summary tokens, as CPC plays a role akin to MLM, but at the global input level. Additionally, flexible masking further seems to help the model focus the attention of the global input tokens to further improve performance.

As part of our future work, we would like to further analyze the ability of ETC to handle structured data, investigate the addition of complementary attention mechanisms like those of Reformer Kitaev et al. (2020) or Routing Transformer Roy et al. (2020), and also explore its scalability limits by incorporating ideas such as those from RevNet Gomez et al. (2017) that can save significant amounts of memory in certain internal operations.

References