fairseq-apollo
FairSeq repo with Apollo optimizer
view repo
The design choices in the Transformer attention mechanism, including weak inductive bias and quadratic computational complexity, have limited its application for modeling long sequences. In this paper, we introduce Mega, a simple, theoretically grounded, single-head gated attention mechanism equipped with (exponential) moving average to incorporate inductive bias of position-aware local dependencies into the position-agnostic attention mechanism. We further propose a variant of Mega that offers linear time and space complexity yet yields only minimal quality loss, by efficiently splitting the whole sequence into multiple chunks with fixed length. Extensive experiments on a wide range of sequence modeling benchmarks, including the Long Range Arena, neural machine translation, auto-regressive language modeling, and image and speech classification, show that Mega achieves significant improvements over other sequence models, including variants of Transformers and recent state space models.
READ FULL TEXT VIEW PDFFairSeq repo with Apollo optimizer
Designing a single unified model to capture long range dependencies in sequential data across a diverge range of modalities, such as language, audio, image and video, is a central and challenging problem in sequence modeling. A number of different archtectures have been developed, including convolutional neural networks (CNNs)
(Kim, 2014; Strubell et al., 2017), recurrent neural networks (RNNs)
(Goller and Kuchler, 1996; Hochreiter and Schmidhuber, 1997; Cho et al., 2014), Transformers (Vaswani et al., 2017) and recent state space models (SSMs) (Gu et al., 2022a; Mehta et al., 2022). Among these models, the Transformer architecture (Vaswani et al., 2017) has stood out for its impressive empirical success on a wide range of language and vision tasks, including machine translation (Vaswani et al., 2017; Ott et al., 2018), language understanding (Devlin et al., 2019; Liu et al., 2019), image recognition (Dosovitskiy et al., 2020; Touvron et al., 2021) and genetic sequence modeling (Madani et al., 2020; Jumper et al., 2021), mainly because of the conceptually attractive attention mechanism (Bahdanau et al., 2015; Luong et al., 2015; Vaswani et al., 2017) which directly models interactions between each pair of input tokens.LRA (Acc. ) | WMT14 (BLEU ) | WT103 (PPL. ) | ImageNet (Acc. ) | SC (Acc. ) | |
---|---|---|---|---|---|
XFM | 59.24 | 27.68 | 18.66 | 81.80 | ✗ |
S4 | 85.86 | – | 20.95 | – | 97.50 |
Mega | 88.21 | 29.01 | 18.07 | 82.35 | 97.30 |
on five sequence modeling benchmarks of different types of data, including long range arena (LRA), machine translation (WMT14 en-de), language modeling (WikiText-103), image classification (ImageNet-1k), raw speech classification (SC-Raw).
Attention provides the key mechanism that captures contextual information from the entire sequence by modeling pairwise interactions between the inputs at every timestep. However, there are two common drawbacks in the design of attention mechanism: i) weak inductive bias; and ii) quadratic computational complexity. First, the attention mechanism does not assume prior knowledge of the patterns of dependencies between tokens (e.g. positional inductive bias), instead learning to predict the pairwise attention weights directly from data. Second, the cost to compute and store the attention weights is quadratic in the length of the input sequences. Recent studies have shown the limitations of applying Transformers to long sequence tasks, w.r.t both accuracy and efficiency (Tay et al., 2020).
In this work, we propose a moving average equipped gated attention mechanism (Mega) to solve the two weaknesses simultaneously. The key idea is to incorporate inductive biases into the attention mechanism across the timestep dimension, by leveraging the classic exponential moving average (EMA) approach (Hunter, 1986). EMA captures local dependencies that exponentially decay over time (see Figure 1), and has been widely used in time series data modeling (§2). We introduce a multi-dimensional damped form of EMA with learnable coefficients (§3.1), and subsequently develop the moving average equipped gated attention mechanism by integrating the EMA with a variant of the single-head gated attention (Hua et al., 2022) (§3.2). Theoretically, we show that the single-head gated attention is as expressive as the most commonly used multi-head attention (§3.3). Benefiting from the incorporated moving average mechanism, we further propose a variant of Mega with linear complexity, named Mega-chunk, which simply chunks input sequences into fixed blocks with minimal loss of contextual information (§3.5).
Experimentally, through five sequence modeling tasks across various data types, including long-context sequence modeling, neural machine translation, auto-regressive language modeling, and image and speech classification, we demonstrate that Mega significantly outperforms a variety of strong baseline models, in terms of both effectiveness and efficiency (§4) (see Table 1). These improvements illustrate the importance of modeling long- and short-term dependencies via different patterns of inductive biases.
In this section, we set up notations, briefly review two widely used approaches for sequence modeling—the self-attention mechanism and exponential moving average (EMA)—and discuss the motivation for combining them.
We use to denote a sequence of input representations with length . Let be the sequence of output representations of each layer with the same length as the input . In this paper, we assume the representations of the input and output sequences have the same dimension .
The traditional self-attention mechanism is a function:
(1) |
where is the self-attention function. are the sequences of queries, keys and values, with learnable parameters , and . is an attention function, e.g. the softmax function (Bahdanau et al., 2015)
, or the recently proposed squared ReLU function
(So et al., 2021; Hua et al., 2022). is a scaling term, which is commonly set to for , or for . The commonly used multi-head variant of attention performs the attention function times in parallel.We can define a matrix following (1), which is called the attention matrix, as it specifies the weight of the dependency strength between every pair of tokens in . Since it models pairwise dependency weights, the matrix in principle delivers a flexible and powerful mechanism to learn long-distance dependencies with minimal inductive biases. However, it is in practice a challenging task to recognize all the dependency patterns in directly from data, particularly when processing long sequences. Moreover, calculating with attention heads takes time and space, and the quadratic dependency on sequence length becomes a significant bottleneck.
The moving average is a classic approach for sequential data modeling, which has been widely used in time series data to smooth out short-term fluctuations and highlight long-term trends or cycles. The Exponential Moving Average (EMA) (Winters, 1960; Hunter, 1986), a special case of moving average, applies weighting factors that decrease exponentially. Formally, an EMA recursively calculates the output sequence :
(2) |
where is the EMA coefficient representing the degree of weighting decrease, and is the element-wise product. A higher discounts older observations faster (see Figure 1).
Using an EMA places a strong inductive bias on the learning of pairwise dependencies: the dependency weight between two tokens decreases exponentially over time with an input-agnostic decay factor . This property favors local dependencies, and limits long-distance dependencies. Despite the recurrent formulation in (2), the computation of EMA can be represented as
individual convolutions, which can be computed efficiently using fast Fourier transforms (FFTs) (see Appendix
A for details).As discussed in Sections 2.1 and 2.2, EMA and attention mechanisms each have their own limitations, despite their wide applications and impressive successes in sequence modeling. By leveraging their properties to complement each other, we propose to embed an EMA into the calculation of the attention matrix . The resulting model enjoys the benefit from strong inductive bias, while maintaining the capacity to learn complex dependency patterns. Moreover, this integration enables the design of a computationally efficient chunk-wise attention mechanism with linear complexity w.r.t sequence length (§3.5).
In this section, we describe in detail our proposed method, moving average equipped gated attention (Mega). We first introduce multi-dimensional damped EMA (§3.1), which is a key component combined with the single-head gated attention in Mega (§3.2), and discuss the relationship between Mega and three closely related models: GRU (Cho et al., 2014), Flash (Hua et al., 2022) and S4 (Gu et al., 2022a). We also provide theoretical justification for the design of single-head gated attention (§3.3). Then, we describe the detailed architecture of each Mega block, including feed-forward and normalization layers (§3.4). At last, we present Mega-chunk, a variant of Mega that simply splits input sequences into fixed chunks, reducing time and space complexity from quadratic to linear (§3.5).
Mega introduces a modification of the standard EMA, named multi-dimensional damped EMA, to improve its flexibility and capacity.
Previous studies (McKenzie and Gardner Jr, 2010; Svetunkov, 2016) have shown that relaxing the coupled weights of the previous and current observations ( vs. in (2)) produces robust dependency modeling. Inspired by this, Mega allows the damping of the influence of the previous time step:
(3) |
where is the damping factor.
To further improve the expressiveness of EMA, we introduce a multi-dimensional variant of EMA. Concretely, we first expand each dimension of the input sequence individually into dimensions via an expansion matrix . Formally, for each dimension :
(4) |
where is the -th row of , is the expanded
-dimensional vector for the
-th dimension at timestep .Correspondingly, we extend the shape of and from a one-dimensional vector to a two-dimensional matrix, i.e. , , where , denote the -th row of and , respectively. Then, for each dimension , the damped EMA is applied to the -dimensional hidden space:
(5) |
where is the EMA hidden state for the -th dimension at timestep . is the projection matrix to map the -dimensional hidden state back to -dimensional output . is the -th row of . The output from (3.1) is denoted as . Because we do not need to explicitly compute to get the output , and the time and space complexity is similar to the standard EMA in (2) (see Appendix A for the details). Experimental improvements demonstrate its effectiveness (§4).
The gated attention mechanism in Mega
adopts the Gated Recurrent Unit (GRU;
Cho et al. (2014)) and Gated Attention Unit (GAU; Hua et al. (2022)) as the backbone architectures, with an EMA-based sub-layer embedded into the calculation of the attention matrix. Formally, we first use the output from (3.1) to compute the shared representation in GAU:(6) | |||||
(7) |
where can be regarded as the updated or contextual input, because it encodes contextual information through EMA. is the shared representation with dimensions, with projection matrix and bias term .
is the self-gated activation function (SiLU)
(Ramachandran et al., 2017; Elfwing et al., 2018). Following GAU, the query and key sequences are computed by applying per-dimension scalars and offsets to , and the value sequence is from the original :(8) | |||||
(9) | |||||
(10) |
where , , , are the learnable scalars and offsets of queries and keys, respectively. is the expanded intermediate dimension for the value sequence. The output of attention is computed as follows:
(11) |
The graphical specification is displayed in Figure 2 (c). is the relative positional bias. We choose from existing approaches, including T5 (Raffel et al., 2020), RoPE (Su et al., 2021), TUPE (Ke et al., 2020) and ALiBi (Press et al., 2021).
Subsequently, Mega introduces the reset gate , the update gate , and computes the candidate activation output :
(12) | |||||
(13) | |||||
(14) |
The final output is computed with the update gate :
(15) |
The graphical architecture of a Mega sub-layer is visualized in Figure 2 (b).
As mentioned in Section 2.1, the softmax function is the most common choice for the attention function . So et al. (2021) recently introduced the squared ReLU function via architecture search techniques, which has shown faster convergence speed and competitive generalization performance on language tasks (Hua et al., 2022). However, one issue of is that neither its range nor its gradient is bounded, leading to unstable model training (see Appendix C.1 for details). To address this issue, we propose a new attention function based on the Laplace function:
(16) |
where is the error function. and are two coefficients that we adjust to approximate , yielding and . The derivations and visualization of the Laplace function are provided in Appendix C.
The computation of the the reset gate , the update gate , and the candidate activation output in (12-14) is reminiscent of GRUs (Cho et al., 2014). The main difference is that in a GRU the two gates are applied between the hidden states of the current and previous timesteps, while in Mega they are applied between the outputs from EMA and gated attention sub-layers. In addition, the output gating mechanism in (15
) is similar to the gated residual connection proposed in
Parisotto et al. (2020); Xu et al. (2020)to reduce the variance of output
.The computation of the shared representation , together with the sequences of queries, keys and values in (7-10) are inspired from GAU in Flash (Hua et al., 2022). Mega integrates EMA into GAU by computing in (7) from the EMA output rather than the original input , and combining the GAU output with for the candidate activation in (14). Experimental gains over Flash demonstrate the effectiveness of this design chice (§4.1).
The multi-dimensional damped EMA can be seen as a simplified variant of a state space model. From this perspective, Mega is also closely related to S4 (Gu et al., 2022a), a state space model with structured state matrices. S4 leverages the HiPPO framework (Gu et al., 2020) to initialize its low-rank structured state matrices, and the computation of the convolutional kernel in S4 requires complex fast Fourier transformers (FFTs). The EMA sub-layer in Mega applies diagonalization on the state matrix and restricts the diagonal elements in the range of . Thus, the convolution kernel would be a Vandermonde product, which can be computed in an efficient and numerically stable way. Moreover, the parameter initialization in Mega does not rely on the HiPPO framework.
Single-head gated attention has been empirically shown as performant as vanilla multi-head attention Liu et al. (2021); Hua et al. (2022), without any discussions on its theoretical insights. In this section, we provide theoretical justifications of the expressiveness of single-head gated attention. To facilitate subsequent analysis, we simplify the notations of the multi-head attention. Specifically, we denote the sequences of queries, keys and values as the outputs of three transformations of the input sequence:
(17) |
where , , are three transformations, such as linear projections. Let be a single query vector (), and denote the corresponding attention weights of , where is the attention transformation, i.e. in (11).
For multi-head attention, a common implementation is to split the query into heads across the model dimension:
(18) |
where is the query of the -th head. and are split in the same way. The attention weight of the -th head is . Then, the outputs of single-head and multi-head attention are, respectively:
(19) |
It is straightforward to see that is more expressive than , because leverages sets of attention weights.
In the single-head gated attention, we introduce a gate vector for each , and the output of single-head gated attention is . The following theorem reveals the equivalence of and w.r.t expressiveness (proof in Appendix B): Suppose the transformation is a universal approximator. Then, for each there exists such that
(20) |
Theorem 3.3 indicates that by simply introducing the gate vector, is as expressive as . In practice, the transformation
is commonly modeled by a (shallow) neural network, whose universality of approximation has been extensively studied
(Hornik et al., 1989; Yarotsky, 2017; Park et al., 2020).The Mega layer (moving average equipped gated attention) is used as a drop-in-replacement for regular attention in Transformer. It is followed by position-wise feed-forward networks (FFNs) and normalization layers to compose one Mega block. As the gated residual connection has already been included in (15), we omit the original residual connection and directly apply a normalization layer to Concretely,
(21) |
where is the output of the Mega block. The overall architecture of a Mega block is shown in Figure 2 (a). In Transformer, the hidden dimension of FFNs is usually set to . To retain a similar model size with each Transformer block, we reduce the hidden dimension of FFN to and set the expanded dimension for the value sequence in (10) throughout this paper, unless specified otherwise.
So far we have only focused on introducing stronger inductive bias into the attention mechanism, which still has quadratic computational complexity. In this section, we propose Mega-chunk, a variant of Mega with linear complexity, which simply applies attention to each local chunk of fixed length.
Specifically, we first split the sequences of queries, keys and values in (8-10) into chunks of length . e.g. , where is the number of chunks.^{2}^{2}2Keys and values are split in the same way. The attention operation in (11) is individually applied to each chunk, yielding linear complexity w.r.t . However, this method suffers from the critical limitation of losing contextual information from other chunks. Fortunately, the EMA sub-layer in Mega mitigates this problem by capturing local contextual information near each token, whose outputs are used as the inputs to the attention sub-layer. As a result, the effective context being exploited by chunk-wise attention can go beyond the chunk boundary. Figure 3 illustrates the largest possible dependency length captured by one Mega-chunk block.
To evaluate Mega, we conduct experiments on five benchmark sequence modeling tasks across various data types, comparing with current state-of-the-art models on each task. All the numbers with indicate results from the baseline models replicated by us. More detailed descriptions, results and analysis are provided in Appendix D.
Models | ListOps | Text | Retrieval | Image | Pathfinder | Path-X | Avg. | Speed | Mem. |
XFM | 36.37 | 64.27 | 57.46 | 42.44 | 71.40 | ✗ | 54.39 | – | – |
XFM | 37.11 | 65.21 | 79.14 | 42.94 | 71.83 | ✗ | 59.24 | 1 | 1 |
Reformer | 37.27 | 56.10 | 53.40 | 38.07 | 68.50 | ✗ | 50.67 | 0.8 | 0.24 |
Linformer | 35.70 | 53.94 | 52.27 | 38.56 | 76.34 | ✗ | 51.36 | 5.5 | 0.10 |
BigBird | 36.05 | 64.02 | 59.29 | 40.83 | 74.87 | ✗ | 55.01 | 1.1 | 0.30 |
Performer | 18.01 | 65.40 | 53.82 | 42.77 | 77.05 | ✗ | 51.41 | 5.7 | 0.11 |
Luna- | 37.98 | 65.78 | 79.56 | 47.86 | 78.55 | ✗ | 61.95 | 4.9 | 0.16 |
S4-v1 | 58.35 | 76.02 | 87.09 | 87.26 | 86.05 | 88.10 | 80.48 | – | – |
S4-v2 | 59.60 | 86.82 | 90.90 | 88.65 | 94.20 | 96.35 | 86.09 | – | – |
S4-v2 | 59.10 | 86.53 | 90.94 | 88.48 | 94.01 | 96.07 | 85.86 | 4.8 | 0.14 |
Mega | 63.14 | 90.43 | 91.25 | 90.44 | 96.01 | 97.98 | 88.21 | 2.9 | 0.31 |
Mega-chunk | 58.76 | 90.19 | 90.97 | 85.80 | 94.41 | 93.81 | 85.66 | 5.5 | 0.13 |
We begin our experiments with an evaluation on the Long Range Arena (LRA) benchmark recently introduced by Tay et al. (2021), which is designed for the purpose of evaluating sequence models under the long-context scenario. They collect six tasks in this benchmark which are ListOps (Nangia and Bowman, 2018), byte-level text classification (Text; Maas et al. (2011)), byte-level document retrieval (Retrieval; Radev et al. (2013)), image classification on sequences of pixels (Image; Krizhevsky and others (2009)), Pathfinder (Linsley et al., 2018) and its extreme long version (Path-X; Tay et al. (2021)). These tasks consist of input sequences ranging from 1K to 16K tokens and span across a variety of data types and modalities.
Table 2 compares Mega against several baselines, including Transformer and its efficient variants, and the state-of-the-art S4 models (both version 1 (Gu et al., 2022a) and version 2 (Gu et al., 2022b)).^{3}^{3}3The S4-v2 used larger model sizes and better-tuned hyper-parameters than S4-v1. Note that our Mega has similar model size with S4-v1 on each task. We have also experimented with SRU++ (Lei, 2021)
on Pathfinder but failed to converge on this dataset after tuning hyperparameters.
To ensure fair comparison, we adjust the number of layers and model dimensions on each task so that Mega has similar number of parameters with S4-v1. For each experiment, we report the average over 5 runs with different random seeds. The tuning information and the model details are provided in the Appendix D.1.On all the six tasks, Mega substantially outperforms all the baselines. We also evaluate Mega-chunk on each task, by setting the chunk size for all the tasks, except Path-X where . We observe that Mega-chunk consistently performs well, particularly on the three language tasks. We also examine the speed and memory efficiency of Mega on the byte-level classification task with the input length of 4K. Mega-chunk is highly efficient, which is about times faster and consumes only % as much memory as the vanilla Transformer. It is interesting to see that Mega with full attention field is also much more efficient than Transformer, benefiting from single-head gated attention.
To demonstrate the effectiveness of the multi-dimensional damped EMA component in Mega, we performs ablation studies on two LRA tasks — byte-level text classification (Text) and image classification on sequences of pixels (Image). We train Mega models with EMA dimension , where indicates removing the EMA component. From the left figure in Figure 4, we see that without the EMA component, model accuracy on both the two tasks declines rapidly. Meanwhile, with a single dimensional EMA (), Mega obtains significant improvements, demonstrating the importance of incorporating inductive bias via EMA.
We further analyze the impact of chunk size on the same two tasks, by varying , where indicates the original Mega without chunking. The right figure in Figure 4 shows that image data is more sensitive to chunk size than text data. On the Text task, Mega-chunk with even a small chunk size is able to achieve around 90% accuracy. On the Image task, Mega-chunk with achieves around 75% accuracy, which is still much better than the vanilla Transformer model.
Finally, we evaluate performance with different attention functions. Table 3 shows the accuracy of the three attention functions on the same two tasks. On text data softmax obtains the best accuracy, while on image data it performs the worst. The laplace function achieves the best accuracy on image data and also competitive result on text data, being consistently better than relu. In the following experiments we use softmax for language tasks and laplace for vision and speech ones.
To evaluate the capability of Mega on the long-range modeling of speech signals, we apply Mega
to classify raw speech (with length 16000), rather than using traditional preprocessing (e.g. convert to MFCC features). Following
Gu et al. (2022a), we perform speech classification on the SC10 subset of the Speech Commands dataset (Warden, 2018). We experiment with the Mega-chunk variant with , since the computation of Mega and Transformer can not fit in GPU memory. As shown in Table 4, our Mega-chunk (base) model with 300K parameters is able to achieve an accuracy of 96.92 that is slightly worse than 97.50 from the state-of-the-art method S4,^{4}^{4}4 Our S4 number is obtained by directly running the official S4 code and is a bit worse than the original reported number (98.32), due to different data splits — the file reading order is not deterministic across machines with os.listdir. while by adding 0.18M parameters our Mega-chunk (big) model performs comparably well with S4.We evaluate Mega on two established language modeling benchmarks — WikiText-103 (Merity et al., 2017)
and enwik8
(Hutter, 2006), which are next-token prediction tasks. WikiText-103 is a word-level language modeling dataset containing 103M training tokens from Wikipedia articles. Following previous work (Baevski and Auli, 2018; Dai et al., 2019), we adopt adaptive softmax and input embeddings and use a vocabulary of 260K tokens. Enwik8 is a character-level language modeling benchmark that has 100M tokens of unprocessed Wikipedia articles and a vocabulary size of about 200. At test time, we split the test data into segments and process each segment sequentially. In Table 5, we compare with previous top-performing models that are designed to take advantage of longer context, including Transformers (Baevski and Auli, 2018; Al-Rfou et al., 2019) (XFM-adaptive), Transformer-XL (Dai et al., 2019) (XFM-XL) and S4 (Gu et al., 2022a). On both WikiText-103 and enwik8, we obtain very competitive results, outperforming baselines by a large margin while enjoying much faster (9) inference speed compared to the Transformer model. Mega can also naturally achieve length extrapolation at inference time to any sequences that are longer than those seen during training due to the recurrent design of the EMA layer. In addition, we can extrapolate to a longer chunk size for Mega attention due to the use of rotary positional embeddings for training (Su et al., 2021). We describe them in details and provide complete results of using various test-time chunk sizes and segment lengths in Appendix D.3.To evaluate Mega on sequence-to-sequence modeling, we conduct experiments on a standard machine translation benchmark, WMT 2014 (Bojar et al., 2014) English-German news translation (WMT’14), consisting of 4.5M sentence pairs of training data. The Mega models closely follow the architecture of Transformer-base: 6 encoder and decoder layers with model dimension .
Table 6 presents the BLEU scores on the test sets of WMT’14 from two directions: ENDE and DEEN. For each experiment, we report the average of both tokenized and SacreBLEU^{5}^{5}5signature: nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:1.5.1 (Post, 2018) scores with 5 different random seeds. Mega-base significantly outperforms Transformer-base by over BLEU. We also report results of Mega with the Laplace attention function, which slightly but consistently underperforms Softmax.
To evaluate Mega on a large-scale image classification task, we conduct experiments on the Imagenet- (Deng et al., 2009) dataset, which consists of 1.28M training images and 50K validation images from 1000 classes. Top-1 accuracy on the validation set is reported in Table 7 to assess various models. Mega obtains about % accuracy improvement over DeiT-B (Touvron et al., 2021). We mostly follow DeiT’s approach of applying several data augmentation and regularization methods that facilitate the training process, including Cutmix (Yun et al., 2019), Mixup (Zhang et al., 2017), stochastic depth (Huang et al., 2016), repeated augmentation (Hoffer et al., 2020), Rand-Augment (Cubuk et al., 2020), and random erasing (Zhong et al., 2020). These methods were highly tuned towards optimizing the performance of DeiT, which might be sub-optimal for Mega. Exploring the optimal data augmentation and regularization methods for Mega is an interesting direction for future work. More training details are presented in the Appendix D.5.
A number of techniques have been recently introduced to address the two issues of Transformer models; we only mention a few here due to space limits.
To incorporate stronger inductive bias into the attention mechanism, one research direction focuses on injecting position information via advanced positional encoding methods, including absolute and relative positional embeddings (Vaswani et al., 2017; Huang et al., 2020; Ke et al., 2020), and relative positional biases (Su et al., 2021; Press et al., 2021). Another line of research combines the attention mechanism with other neural architectures with intrinsic strong inductive bias, such as convolutional (Gehring et al., 2017; Dai et al., 2021) and recurrence (Dai et al., 2019; Rae et al., 2020; Lei, 2021).
Many advanced variants of Transformer models (‘xformers’) (Tay et al., 2020, 2021) have recently emerged to improve the time and memory efficiency. Popular techniques include sparse attention patterns (Parmar et al., 2018; Beltagy et al., 2020; Kitaev et al., 2020), low-rank approximations of the attention matrix (Wang et al., 2020; Ma et al., 2021), and approximations through kernelization (Choromanski et al., 2020; Peng et al., 2021). Although these models demonstrate better asymptotic complexity for long sequences, their efficiency gains are less prominent for moderate length sequences and their performance remains behind that of Transformers with regular attention.
We have introduced Mega, a simple, efficient and effective neural architecture used as a drop-in replacement for regular multi-head attention. By leveraging the classic exponential moving average (EMA) approach, Mega is capable of incorporating stronger inductive biases into the attention mechanism. Moreover, the EMA approach enables the design of Mega-chunk, an efficient variant of Mega with linear complexity. On five sequence modeling tasks across various data types, Mega achieves impressive improvements over a variety of strong baselines, including previous state-of-the-art systems. These improvements lead to a potential direction of future work to apply Mega for multi-modality modeling.
Proceedings of the AAAI conference on artificial intelligence
, Vol. 33, pp. 3159–3166. Cited by: §4.3.Proceedings of the IEEE/CVF conference on computer vision and pattern recognition workshops
, pp. 702–703. Cited by: §4.5.Transformer-xl: attentive language models beyond a fixed-length context
. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pp. 2978–2988. Cited by: §D.3, §4.3, §5.Sigmoid-weighted linear units for neural network function approximation in reinforcement learning
. Neural Networks 107, pp. 3–11. Cited by: §3.2.International conference on machine learning (ICML-2017)
, pp. 1243–1252. Cited by: §5.Learning task-dependent distributed representations by backpropagation through structure
. In Neural Networks, 1996., IEEE International Conference on, Vol. 1, pp. 347–352. Cited by: §1.Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing (EMNLP-2014)
, Doha, Qatar, pp. 1746–1751. Cited by: §1.Learning word vectors for sentiment analysis
. In Proceedings of the 49th annual meeting of the association for computational linguistics: Human language technologies, pp. 142–150. Cited by: §4.1.Exploring the limits of transfer learning with a unified text-to-text transformer.
. J. Mach. Learn. Res. 21 (140), pp. 1–67. Cited by: §3.2.Note that the computation of the multi-dimensional damped EMAs of different dimensions are entirely independent of each other. Without loss of generality, we set and omit the dimension index in the following formulations. We denote the initial hidden state as . The multi-dimensional damped EMA defined in (3.1) can be vectorized into the following formulation:
(22) | ||||
(23) |
where , , and . and is the EMA hidden state at timestep .
Let’s denote . Then, unrolling the above two equations explicitly yields:
This can be written into a vectorized formula:
(24) | ||||
(25) |
where is the convolution transform with kernel :
(26) |
In the proposed multi-dimensional damped EMA, can be efficiently computed by the Vandermonde product. With provided, the output in (25) can be computed efficiently with FFTs.
We split into heads in the same way as , , and :
Then we have
To prove Theorem 3.3, we need to find such that
where is the element-wise divide operation. Since is a universal approximator and , , and are all transformed from , can theoretically recover .
To approximate the squared ReLU function with the Laplace function in (16), we need to select proper coefficients and . We derive the values of and by solving the following two equations at :
(27) | ||||
(28) |
The Eq. (27) delivers and Eq. 28 subsequently provides . Figure 5 visualizes the two functions.
Besides performance improvements, we also investigate the stability of the two attention functions. We conduct experiments on the LRA Pathfinder task with Mega models with the two functions. Figure 5
presents the accuracy on the validation set across training epochs. We observe that Laplace is much more stable than ReLU
.For all tasks, we closely follow Tay et al. (2020) for details such as data preprocessing, data split, etc. The hyper-parameters of Mega models on these tasks are listed in Table 8.
Task | Depth | Attn-FN | Norm | Pre-norm | BSZ | LR | Dropout | WD | Epochs | |||||
ListOps | 6 | 80 | 160 | 64 | 160 | 16 | softmax | LN | False | 64 | 0.001 | 0.1 | 0.01 | 60 |
Text | 4 | 128 | 256 | 64 | 256 | 16 | softmax | SN | False | 50 | 0.004 | 0.1 | 0.01 | 50 |
Retrieval | 6 | 128 | 256 | 64 | 256 | 16 | softmax | SN | False | 64 | 0.003 | 0.1 | 0.04 | 40 |
Image | 8 | 160 | 320 | 96 | 320 | 16 | laplace | BN | True | 50 | 0.01 | 0.0 | 0.02 | 200 |
Pathfinder | 6 | 128 | 256 | 64 | 256 | 16 | laplace | BN | True | 128 | 0.01 | 0.0 | 0.01 | 200 |
Path-X | 4 | 64 | 128 | 32 | 128 | 16 | laplace | BN | True | 128 | 0.01 | 0.0 | 0.01 | 100 |
SC-Raw (base) | 6 | 60 | 120 | 30 | 120 | 16 | laplace | BN | True | 20 | 0.01 | 0.0 | 0.01 | 200 |
SC-Raw (big) | 6 | 72 | 144 | 36 | 144 | 16 | laplace | BN | True | 20 | 0.008 | 0.0 | 0.01 | 200 |
models on LRA and raw speech classification tasks. BSZ is batch size, LR is learning rate and WD is weight decay. BN, LN and SN refer to Batch Normalization, Layer Normalization and Scale Normalization.
We use the data of WikiText-103 and enwik8 and their splits provided by Dai et al. (2019). At training time, we split the training data into segments; each segment contains consecutive chunks, where the chunk size is the effective attention length. is a random integer variable uniformly sampled from . We use for WikiText-103 and for enwik8. Other training hyperparameters including optimizer, learning rate scheduler and architecture are presented in Table 9.
We employ Mega-chunk (§3.5) for training and set the attention chunk size to be 1024 and 2048 for WikiText-103 and enwik8 respectively. To use a longer Mega attention length at inference time than the one used at training time (i.e. 1024 or 2048), we apply rotary positional embedding (Su et al., 2021) to the attention sublayer. At test time, we split the test data into segments and sequentially process each segment by chunks, i.e. the maximum context length of each segment is . In Table 5, we report test results that use longer chunk sizes (attention lengths) of 2048 and 4096 for WikiText-103 and enwik8 respectively. Mega can naturally extrapolate at inference time to sequences longer than those seen during training due to the recurrent design of the EMA layer. That design enables the inputs of each chunk to access the historic context through EMA as illustrated in Figure 3. On the other hand, due to the use of rotary positional embeddings, attention can be performed on longer chunk sizes at test time than those seen during training. We hope these two types of length extrapolation are clear to readers. We provide the ablation studies on these two types of length extrapolation below, i.e. extrapolation to longer context by increasing input sequence lengths and extrapolation to longer attention lengths by increasing the chunk size.
First, we fix the chunk size to be 2048 and vary within corresponding to maximum context tokens of . We plot the test PPL as we increase the context length in the left of Figure 6. Although at training time, the maximum context length the model has seen is 6144, Mega can extrapolate to longer context lengths. The plot shows that PPL decreases as the context length is increased and the improvements saturate when the context length is longer than 25K. This is consistent with the observations in Press et al. (2021).
Next, we fix the context length to be 25K and increase the chunk size from 512 to 3072. As shown in the right side of Figure 6, Mega consistently improves as we increase the attention length although it only uses an attention length of 1024 during training. This contradicts with the findings in Alibi (Press et al., 2021), which finds that rotary embeddings don’t generalize to longer lengths and result in higher PPL.
WikiText-103 | enwik8 | |
Batch Size GPUs | 6144 24 | 8192 8 |
Optimizer | AdamW | AdamW |
Learning Rate | 0.005 | 0.005 |
Adam- | ||
Learning Rate Decay | linear | linear |
Weight Decay | 0.1 | 0.1 |
Dropout | 0.3 | 0.1 |
Attention Dropout | 0.1 | 0.0 |
FFN Hidden Dropout | 0.1 | 0.0 |
Gradient Clipping | 1.0 | 1.0 |
Warmup steps | 24K | 24K |
Total updates | 400K | 400K |
Decoder Layers | 16 | 12 |
Model size | 1024 | 512 |
FFN Hidden size | 1536 | 1024 |
Shared Repr. size () | 256 | 128 |
Value Seq. size () | 1536 | 1024 |
EMA dimension () | 16 | 16 |
Chunk size | 1024 | 2048 |
Total Parameters | 252M | 39M |
The WMT 2014 English-German dataset contains 4.5M parallel sentence pairs for training. We following the standard setting (Vaswani et al., 2017), using Newstest2013 as the validation set and Newstest2014 as the test set. The dataset is pre-processed following (Ma, 2020), using the scripts from FairSeq package (Ott et al., 2019).^{6}^{6}6https://github.com/pytorch/fairseq We share the source and target vocabularies within the language pair, with 37K byte pair encoding (BPE) types (Sennrich et al., 2016). The hyper-parameters of Transformer and Mega models are listed in Table 10.
XFM-Base | Mega-Base | |
Batch Size GPUs | 8192 8 | 8192 8 |
Optimizer | AdamW | AdamW |
Learning Rate | 0.0005 | 0.001 |
Adam- | ||
Learning Rate Decay | inv. sqrt | linear |
Weight Decay | 0.05 | |
Dropout | 0.1 | 0.15 |
Attention Dropout | 0.1 | 0.1 |
FFN Hidden Dropout | 0.1 | 0.1 |
Gradient Clipping | 1.0 | 1.0 |
Label Smoothing | 0.1 | 0.1 |
Warmup steps | 4K | 4K |
Total updates | 500K | 500K |
Encoder Layers | 6 | 6 |
Decoder Layers | 6 | 6 |
Model dimension | 512 | 512 |
FFN Hidden dimension | 2048 | 1024 |
Shared Repr. dimension () | – | 128 |
Value Seq. dimension () | – | 1024 |
EMA dimension () | – | 16 |
Total Parameters | 65M | 67M |
Hyper-parameters are listed in Table 11. We closely follow Touvron et al. (2021) by reusing most of the their hyper-parameters.
DeiT-B | Mega | |
Batch size | 1024 | 1024 |
Optimizer | AdamW | AdamW |
learning rate | 0.001 | 0.001 |
Learning rate decay | cosine | cosine |
Weight decay | 0.05 | 0.05 |
Epochs | 300 | 400 |
Warmup epochs | 5 | 30 |
Label smoothing | 0.1 | 0.1 |
Dropout | ✗ | ✗ |
Stoch. Depth | 0.1 | 0.2 |
Repeated Aug | 3 | 4 |
Gradient Clip. | ✗ | 1.0 |
Rand Augment | 9/0.5 | 9/0.5 |
Mixup prob. | 0.8 | 0.8 |
Cutmix prob. | 1.0 | 1.0 |
Erasing prob. | 0.25 | 0.25 |
Num. Layers | 12 | 12 |
Model size | 768 | 768 |
FFN Hidden size | 3072 | 1536 |
Shared Repr. size () | – | 256 |
Value Seq. size () | – | 1536 |
Total Parameters | 86M | 90M |