fast-weight-transformers
Official code repository of the paper Linear Transformers Are Secretly Fast Weight Programmers.
view repo
We show the formal equivalence of linearised self-attention mechanisms and fast weight memories from the early '90s. From this observation we infer a memory capacity limitation of recent linearised softmax attention variants. With finite memory, a desirable behaviour of fast weight memory models is to manipulate the contents of memory and dynamically interact with it. Inspired by previous work on fast weights, we propose to replace the update rule with an alternative rule yielding such behaviour. We also propose a new kernel function to linearise attention, balancing simplicity and effectiveness. We conduct experiments on synthetic retrieval problems as well as standard machine translation and language modelling tasks which demonstrate the benefits of our methods.
READ FULL TEXT VIEW PDFOfficial code repository of the paper Linear Transformers Are Secretly Fast Weight Programmers.
Official repository for the paper "Going Beyond Linear Transformers with Recurrent Fast Weight Programmers"
Implementation of Deterministic Parameter-Free Projection (DPFP) from the paper "Linear Transformers Are Secretly Fast Weight Memory Systems"
Transformers (Vaswani et al., 2017) have achieved impressive results in a myriad of sequence processing tasks, including machine translation, language modelling (Al-Rfou et al., 2019; Dai et al., 2019; Baevski & Auli, 2019; Radford et al., 2019), and question answering (Devlin et al., 2019)
, domains previously dominated by recurrent neural networks
(Graves, 2013; Bahdanau et al., 2015).The core component of a Transformer is the self-attention mechanism (Cheng et al., 2016; Parikh et al., 2016; Lin et al., 2017) which was recently connected to the modern Hopfield network (Ramsauer et al., 2021; Krotov & Hopfield, 2016; Demircigil et al., 2017). It extends a form of attention (Bahdanau et al., 2015)
originally introduced to complement recurrent neural networks, e.g.,
(Hochreiter & Schmidhuber, 1997). While relinquishing the recurrence property, all computations across the time axis can be parallelised. However, this comes with drawbacks: self-attention computations scale quadratically with sequence length while the memory of the model grows linearly. Therefore, practitioners are forced to limit the context window to a reasonable size, which in turn makes it impossible to capture longer-term dependencies and limits the size of the memory.Recent work proposed “linear Transformers” with constant size memory and time complexity linear in sequence length (Katharopoulos et al., 2020; Choromanski et al., 2021; Shen et al., 2018). This complexity reduction is mainly due to a linearisation of the softmax (reviewed in Sec. 3.2).
Here we relate this new family of linear Transformers to outer product-based fast weight memories from the 90’s (Schmidhuber, 1991, 1992, 1993). This view allows us to derive a limitation of the memory capacity of such models. When the sequence length exceeds storage capacity, the model may end up in an overcapacity regime (discussed in depth in Sec. 4.1). To properly operate under such a regime, the model should learn to dynamically interact with the memory contents and to selectively decide which key-value associations to keep and which ones to delete. The regular update rule may be inappropriate for this purpose. Therefore, we introduce an improved update rule inspired by recent work on fast weight memories (Schlag et al., 2021).
Furthermore, softmax linearisation techniques for Transformers are still underexplored. The existing techniques are either very simplistic (Katharopoulos et al., 2020) or mathematically well explained but complex (Choromanski et al., 2021). We provide a comprehensive comparison and propose a new method which is both simple and effective.
We demonstrate the benefits of the proposed methods on our own synthetic retrieval dataset (Sec. 6.1), the standard WMT14 English to German machine translation task (Sec. 6.2
), and the Wikitext-103
(Merity et al., 2017) language modelling task (Sec. 6.3)^{2}^{2}2Source code used in this paper is available at github.com/ischlag/fast-weight-transformers..In this section, we review the concepts of fast weights before relating it to linear Transformer variants in Sec. 3.
In standard neural networks, the weights remain fixed after training, unlike the activations, which change depending on the inputs at test time. The general idea of fast weights is to make the weights also variable and input-dependent. This concept was called synaptic modulation (von der Malsburg, 1981), a method for variable binding in neural networks (see e.g. the recent survey by Greff et al. (2020)), or dynamic connections (Feldman, 1982). Von der Malsburg defines the effective weights as a (multiplicative) superposition of conventional, context-independent slow weights, and fast changing, context-dependent fast weights. Hinton & Plaut (1987) studied a net with (additive) superposition of two sets of weights with two different learning rates in a scenario of model retraining.
Context-dependent fast weight generation was introduced in a two-network system of the early 90s (Schmidhuber, 1991, 1992, 1993). A slow net with slow weights continually generates fast weights for a fast net, making the fast weights effectively dependent on the context. Simply put, the slow net learns to program its fast net. Among the proposed weight generation mechanisms, a particularly attractive one makes use of outer products (Schmidhuber, 1991, 1992): for a sequential input , the model outputs the sequence as
(1) | |||||
(2) | |||||
(3) |
where denotes the outer product,
is an activation function,
and are trainable slow weights, while the fast weights are generated at each time step and serve as a short-term memory. This is a key-value associative memory model in which the write operation is based on a summation (Eq. 2) and the retrieval is a matrix-vector multiplication (Eq.
3). Schmidhuber (1993)describes a recurrent version and discusses “internal spotlights of attention”. The use of outer products results in a model of associations similar to tensor product presentations
(Smolensky, 1990). In fact, outer-product based associative memory can be found in numerous works since Hebb’s informal rule (Hebb, 1949) and its more concrete formal variants (Steinbuch, 1961; Steinbuch & Piske, 1963; Kohonen, 1972; Palm, 1980) including Hopfield networks (Hopfield, 1982; Little, 1974) and bi-directional associative nets (Kosko, 1988).The concept of context-dependent or fast weights has been revisited recently (Ba et al., 2016; Schlag & Schmidhuber, 2017), also under different names, e.g., hypernetworks (Ha et al., 2017; Galanti & Wolf, 2020), dynamic plasticity (Miconi et al., 2018, 2019), dynamic convolution (Klein et al., 2015; Noh et al., 2016; Jia et al., 2016), or lambda networks (Bello, 2021) used for applications including meta-learning (Munkhdalai & Yu, 2017; Munkhdalai & Trischler, 2018; Munkhdalai et al., 2019; Kirsch & Schmidhuber, 2020). Fast weights recently also improved memory models through explicit mechanisms for facilitating the replacement of deprecated information and updating associations (Schlag & Schmidhuber, 2018; Schlag et al., 2021).
Ba et al. (2016) have already pointed out a relation between a variant of outer product-based fast weights and attention (Bahdanau et al., 2015). Katharopoulos et al. (2020) have analysed linearised transformers. We review these derivations with a focus on showing the relation between Transformers and the fast weight model of the previous section.
A self-attention layer in auto-regressive Transformers (Vaswani et al., 2017) maps an input sequence to an output sequence as
(4) | |||||
(5) | |||||
(6) | |||||
(7) |
where denotes the concatenation of vector to matrix along the time dimension, is applied along the time dimension, and , , are trainable weight matrices. We omit the scaling by inside the softmax without loss of generality.
Now if we remove the softmax in Eq. 7 we obtain:
(8) | |||||
Denoting by the corresponding weight matrix generated from key and value vectors:
(9) |
we can rewrite Eqs. 4-7 such that they directly relate to Eqs. 1-3 where the activation function is the identity function and without query projection :
(4) | |||||
(10) | |||||
(11) |
Instead of removing the softmax as in Sec. 3.1, prior works have introduced techniques for linearising the softmax (Tsai et al., 2019), which has been shown to improve computational efficiency of self-attention for long sequences (Katharopoulos et al., 2020; Choromanski et al., 2021).
By writing the softmax explicitly, Eq. 7 can be written as:
(12) |
where is the softmax kernel and is the vector dot product.
The general idea is to replace the softmax kernel by another kernel: where is a function . We discuss the necessary properties of in Sec. 5.1. By replacing in Eq. 12 by , we obtain
(13) | ||||
(14) |
Using the outer-product notation, the numerator is analogous to the case without softmax (Sec. 3.1):
By introducing the fast weight matrix and an additional vector for the denominator,
(15) | |||||
(16) |
forward computations of linear Transformers can be written as (Katharopoulos et al., 2020; Choromanski et al., 2021):
(4) | |||||
(17) | |||||
(18) | |||||
(19) |
which is a fast weight model (Sec. 2) with normalisation. Hence, there is a clear connection between these linear Transformer variants and outer-product fast weight systems.
Viewing linear Transformer variants as fast weight systems provides us with two insights which we investigate in this work: their capacity limits as associative memories (Sec. 4.1), and their ineptness to edit previously stored associations (Sec. 4.2).
Endlessly adding new associations to a memory of finite size, as in Eq. 17, inevitably will reach a limit. In linear attention, information is stored in a matrix and is retrieved using matrix multiplication (see Eq. 19). As a consequence, to prevent associations from interfering with each other upon retrieval, the respective keys need to be orthogonal. Otherwise, the dot product will attend to more than one key and return a linear combination of values. With keys embedded in a space, there cannot be more than orthogonal vectors. That is, storing more than associations will result in a retrieval error. In linear Transformers, when the length of the sequence is longer than , the model might be in such an overcapacity regime. While we experimentally demonstrate this effect on toy tasks (Sec. 6.1), prior work on tensor product representations allows for a more formal discussion.
Early work in connectionist research investigated the usage of distributed representations as a means for storing symbolic structures. One highly-influential work is the tensor-product-based variable binding mechanism
(Smolensky, 1990). A tensor product representation (TPR) of a structured symbolic system consisting of a set of variables and values constructed from outer products of the so called role and filler vectors. These terms directly translate into keys and values in our context. The fast weight memories of Eq. 17 are the most basic form of such representations (second order tensors). Therefore, many results discussed in Smolensky’s work transfer to our model. In particular, Theorem 3.3 and 3.1 of Smolensky (1990) discuss more formally the crosstalk and retrieval error intuitively described in the previous paragraph.However, we also note an important difference: with the exception of recent work (Schlag & Schmidhuber, 2018), the classic TPRs of Smolensky (1990) are constructed with a priori knowledge of the symbolic structure. In contrast, our models learn all the vectors involved in constructing such a representation.
Sec. 4.1 argues that the linear Transformers can end up in an overcapacity regime, if the sequence length exceeds the dimension of the keys. Once in overcapacity, an ideal memory model should dynamically interact with the memory contents and selectively determine which associations to remember or to forget. This is in stark contrast to the standard Transformer which stores immutable pairs of key and value vectors by concatenation, thus increasing the storage size. While such models work well in practice, we consider a model’s capability to update previously acquired knowledge to be critical for many problems. Hence, from the perspective of dynamic interaction with the memory, the sum update rule of Eqs. 17 may be sub-optimal. This motivates us to improve the update rule.
Inspired by recent work on fast weight memories (Schlag et al., 2021) we propose the following memory update rule. Given a new input key-value pair , the model first accesses the current state of the memory and retrieves the value currently paired with the key . Then the model stores a convex combination of the retrieved value and the input
using an interpolation weight
also generated by the model. The model thus sequentially transforms an input sequence into an output sequence as:(4) | ||||
(20) | ||||
(21) | ||||
(22) |
where , and
is the sigmoid function. The interpolation weight
is the “write-strength” as it defines to which extent the new value will replace the previous value. We note that while only depends on , in a multi-layer model, has the full context information except in the first layer. We set and . Then the fast weight update rule and the final output are defined as follows (see Appendix A.1 for detailed derivations):(23) | ||||
(24) |
(25) |
In the equations above, no normalisation is applied to the value we retrieve. A straightforward normalisation can be obtained by following the derivation in Sec. 3.2, i.e. by introducing an accumulator:
(26) |
and replacing Eqs. 20 and 25 respectively by:
(27) | ||||
(28) |
where we define . In this approach, the output is a weighted average of for . We refer to this approach as attention normalisation.
This approach, however, has drawbacks. First, the accumulation of positive values in Eq. 26 always grows with the number of steps, and may result in instability. Second, specifically for our update rule, this normalisation is not sufficient to balance the weights between write and remove operations in Eq. 23 (see derivations in Appendix A.2). Here we propose a better approach based on simple normalisation. We divide the effective key and query vectors and by the sum of its components, e.g., for the query:
(29) |
before applying Eqs. 20-25. A general consequence of this normalisation is intuitively understood by noticing that the output of any matrix-vector operations (like Eq. 25) is a weighted sum of columns of the matrix where weights are the components of the vector; thus, if the vector components sum up to one, the operation can be viewed as an attention over the columns of the matrix. We provide further explanations and precise implications for our model in Appendix A.2. We refer to this approach as sum normalisation.
The central component of softmax linearisation (Sec. 3.2) is the function which maps key and query vectors to the space where the dot product is executed: . We first list desirable properties of such a function, and review the existing functions from the perspective of fast weight memories. Finally, we also propose our own function.
For Eq. 13 to define proper attention weights between 0 and 1, the codomain of should be positive. Another property of derives from the discussion of memory capacity in Sec. 4.1. The dimensionality of its codomain defines the model’s capacity. Therefore, by including a transformation which projects the input dimension to a larger dimension , the function can potentially increase the upper bound of the capacity.
Katharopoulos et al. (2020) propose to use the simple element-wise function (Clevert et al., 2016):
(30) |
The choice of over is motivated by non-zero gradients on the negative part. Importantly, as a simple element-wise function, this function preserves the dimension of the input key vector (), without modifying the memory capacity as discussed in Sec. 4.1.
In contrast to Katharopoulos et al. (2020)’s function which merely satisfies positivity (and a good gradient) property, Choromanski et al. (2021) propose a mathematically rigorous method to approximate the softmax with random features. They propose the following function:
(31) | ||||
(32) |
where the concatenation of two vectors and is along the feature dimension, and is a matrix with random features where each row vector is drawn from .
With FAVOR+, the dimension of the codomain is which increases the theoretical capacity of the memory if . At the same time, the model’s capacity is still limited, and equals the infinite capacity of the softmax memory only when goes to infinity, which is never achieved in practice. During training, we redraw these random vectors for each mini-batch. During evaluation, we draw a set of random vectors once, and keep them fixed.
is the only hyperparameter of FAVOR+ and influences the quality of the softmax approximation.
Choromanski et al. (2021) suggest to choose in the order of. This sampling process is the main drawback of FAVOR+ as it introduces variance into the model’s output.
The two previous sub-sections highlight the sub-optimality of the existing functions. Sampling introduces extra complexity to FAVOR+ (Sec. 5.3), while the Linear Transformer (Sec. 5.2) lacks the ability to project up the dot product dimension. Here we propose an alternative approach called deterministic parameter-free projection (DPFP). It is deterministic and easy to compute like Linear Transformers while increasing the dot product dimension without requiring FAVOR+’s random features.
We begin with a low-dimensional example to foster an intuitive understanding before moving on to the general formulation. Consider 4 keys in and where the -th element of is generated by the partial function . We design such that it facilitates orthogonality in the projected space, i.e. for . Towards this end, we construct such that if then for all . Such a constraint can be enforced by limiting the domains of the partial functions to be non-overlapping. With the element-wise rectifier function the partial functions are defined as:
(33) | ||||
(34) | ||||
(35) | ||||
(36) |
Figure 1 illustrates this function. The elements of the 4-dimensional space are displayed as the component of the four coloured surfaces. The figure shows how each vector in the 2d plane will have a single non-zero component in the 4d space and equally splits the input space into four areas which will be orthogonal in the projected space.
We generalise this method to higher dimensional inputs by constructing additional two-factor features. Given an input vector and , the partial function
(37) |
where is a capacity controlling hyperparameter. The codomain dimensionality of is thus . Eq. 37 is highly parallelizable because each partial function can be computed independently. This can be implemented in few lines of code as we show in Appendix B.
Now we present our experimental results on synthetic retrieval problems (Sec. 6.1.1 and 6.1.2), machine translation (Sec. 6.2), and language modelling (Sec. 6.3).
We illustrate the capacity issue (Sec. 4.1) of linear attention, and the effectiveness of our new update rule (Sec. 4.2) on two synthetic problems.
In both settings, our toy problem consists of retrieving the correct value from a sequence of randomly sampled key-value associations when queried with one of the used keys. Crucially, the query is given at the end of the sequence, such that the model is not aware of it while processing the inputs. To succeed, the model has to learn to store the observed associations in its memory without interference.
Let and be the finite and fixed sets of keys and values and . Then, the input to the model is the sequence followed by where every pair is sampled randomly, and is randomly chosen to be one of the keys.
Each value is assigned a fixed one-hot vector . Hence, the set of value vectors is an orthonormal basis. In contrast, the vector embedding of the key symbols is the learned function and where .
Following the write operations, the read function and the query vector are used to retrieve from memory. Finally, the loss is defined as where is the value vector assigned to in the input sequence. Each model is trained in mini-batches using this loss and Adam with default hyperparameters unless stated otherwise. For evaluation, we sample 20 sequences and test all possible queries, e.g., with unique keys, the evaluation batch is of size .
In this setting, we experimentally demonstrate the capacity limit of linear attention (Sec. 4.1). We conduct experiments for the various functions described in Sec. 5. We fix to be , while different functions produce different . We set the sequence length to be equal to the number of unique keys (), and sample the keys and values without replacement to generate the sequences. By varying the sequence length
, our goal is to show that all linear attention models (using the simple sum update rule of Sec.
3.2) fail at retrieving when exceeds .All models are trained with a mini-batch size of until the evaluation loss falls below or until lack of progress for steps. In Figure 2, the best validation set performance for each model and each is displayed (for the learning curves see Appendix C.1). The number of unique keys is initially and is incremented by until . The following models are compared: Softmax, Linear-Attention, FAVOR+ with 64, 128, and 512 random features, DPFP- with .
The results support our theoretical analysis. Linear-Attention has a capacity of due to the choice of . Experimentally, Linear-Attention begins to accumulate errors with or more associations. Similarly, DPFP projections 1, 2 and 3 start to accumulate errors as they approach their respective limits at , , and . FAVOR+, on the other hand, fails to achieve a loss of 0 in any experiment. Finally, as expected, softmax attention is outperforming all functions, although it struggles to fully converge with more than 500 keys.
In the second setting, we compare variations of the update rule. Unlike in setting 1, keys and values will be sampled with replacement and sequence length . As a result, in the same sequence, multiple keys can be re-assigned to a new value more than once. The expected value to retrieve is the most recent one associated with the query. With every new key, the previous value associated with this key deprecates and the model is required to update its finite size memory. The ability to update values associated with keys is essential to bind context-specific values to a key.
We use DPFP-1 as the function. The sequence length is fixed at 40 with 20 unique keys and values. While this setting does not exceed the capacity of DPFP-1, our result is independent of the capacity regime (see results for different and in Appendix C.2).
We compare the proposed fast weight memory update rule with normalisation of Sec. 4.2 (denoted here by ours) to three baselines: the sum update rule of Sec. 3 (sum rule), and two variants of previous update rules (Schlag et al., 2021): Schlag (2021) and Schlag (2021) with DPFP. Schlag (2021) is simply the model from Schlag et al. (2021) ported to this setting (i.e. without the LSTM layer). Schlag (2021) has neither a function, nor the sum normalisation term of Sec. 4.2. Instead it uses a nonlinearity for its key representations. As an ablation we replace it with our DPFP-1 but we don’t use the normalisation term of Sec. 4.2, which we refer to as Schlag (2021) with DPFP.
Figure 3 presents the learning curves. They demonstrate that our new update rule outperforms all other variants. As expected, the baseline sum update rule fails.
Here we compare functions on the standard machine translation task. We compare Linear Transformer (Katharopoulos et al., 2020), Performer (Choromanski et al., 2021) and our function DPFP (Sec. 5.4) to the regular Transformer, complementing prior comparisons, e.g., Tay et al. (2021).
We use the standard WMT14 English to German Translation dataset and standard data setups (Ott et al., 2018; Vaswani et al., 2017). We adapt the recipe of Ott et al. (2019) (see Appendix D) and train Vaswani et al. (2017)’s “big” models for about 4 days on three V100 GPUs. We use the exact same training configurations for all models without model-specific hyper-parameter tuning. We only vary the model hyper-parameters in Performers and in DPFP models.
Table 1 shows the Bleu score (Papineni et al., 2002; Post, 2018) results. The Performer is as good as the basic Transformer when the number of samples is large enough (for , we have ). In fact, with , the recommended value for is . Our DPFP model outperforms the Linear Transformer as well as the Performer when is relatively small; providing a good trade-off between simplicity and performance.
Valid | Test | |||||
---|---|---|---|---|---|---|
64 | 256 | 512 | 64 | 256 | 512 | |
Standard | 26.6 | - | - | 27.7 | - | - |
Linear | 25.5 | - | - | 26.8 | - | - |
Performer | 24.2 | 24.9 | 26.7 | 24.4 | 25.3 | 27.7 |
DPFP (ours) | - | 26.2 | 26.2 | - | 26.9 | 27.1 |
Toy experimental Setting 2 (Sec. 6.1.2) illustrated the effect of our update rule. Now our goal is to confirm its effectiveness on a large-vocabulary word-level language modelling task, and investigate its further potential.
Our update rule should be evaluated on a dataset with sufficiently long contextual dependencies. We use the standard WikiText-103 (Merity et al., 2017) dataset. WikiText-103 consists of long articles from Wikipedia; the training set contains about 28 K articles with a total of 103 M running words. This results in contextual text blocks of about 3600 words. The validation and test sets also contain similarly long dependencies, respectively with 218 K and 246 K running words for 60 articles each. The vocabulary size is about 268 K words.
We split the training data into
-word long segments (which is the backpropagation span). Unless stated otherwise, we treat these segments independently during training. For evaluation, we use a batch size of one, and go through the text with a sliding window of size
, taking into account only the last position for computing perplexity (except in the first segment where all positions are evaluated). This is usually done for Transformers with a limited context (Al-Rfou et al., 2019). Appendix E provides further experimental details.Update | small | medium | |||
---|---|---|---|---|---|
Rule | Valid | Test | Valid | Test | |
Transformer | - | 33.0 | 34.1 | 27.9 | 29.6 |
Linear | sum | 37.1 | 38.3 | 31.1 | 33.0 |
ours | 34.1 | 35.5 | 29.7 | 31.5 | |
Performer | sum | 39.0 | 39.6 | 32.2 | 33.8 |
ours | 36.1 | 37.2 | 30.0 | 31.8 |
We first evaluate our update rule in two configurations. In the small configuration, we set the model dimension (same for key, value, and query) to 128, and the training and evaluation context length to 256. We note that where is the number of heads. is set to 8. The feed-forward layer dimension is 2048. The number of layers is 16 in all configurations. In the medium configuration, we set and . Both configurations represent an overcapacity regime. We evaluate both linear Transformers (Katharopoulos et al., 2020) and Performers (Choromanski et al., 2021). However, to keep the comparison simple, we set the capacity of Performers (Sec. 5.3) equal to the one of linear Transformers, by the right choice of projection dimension ( and , respectively, in small and medium configurations), even though this limits performance. We do not include DPFP here, since in both configurations even the smallest value for provides enough capacity. Here we investigate the effect of the update rule in an overcapacity scenario (see Appendix C.3 for experimental results in a non-overcapacity regime including DPFP). All models can be trained using two V100 GPUs in less than four days. Table 2 shows the perplexity results. In both configurations, our update rule provides convincing improvements over the models with the sum update rule.
We also conduct an ablation study to test the effect of the absolute positional encoding and an extra attention normalization (Sec. 4.2). Table 3 shows the results. The sum normalisation (Sec. 4.2) is used in all cases: the models diverged otherwise. In contrast, better perplexities are obtained when no additional attention normalization is applied. We also observe that the absolute positional encoding is not needed, confirming results of prior work (Irie et al., 2019a).
Position Encoding | Attn. Normalisation | Valid | Test |
---|---|---|---|
Yes | Yes | 30.4 | 32.1 |
No | Yes | 29.2 | 31.2 |
Yes | No | 29.7 | 31.5 |
No | No | 28.1 | 31.1 |
Model | Update Rule | Valid | Test |
---|---|---|---|
Linear Transformer | sum | 1600 | 1600 |
ours | 27.8 | 29.4 | |
Transformer-XL | - | 24.6 | 25.5 |
Given the constant space requirements, we can feed inputs to linear Transformers for an arbitrary number of steps. To properly assess the model’s ability to process arbitrary long sequences, it is crucial to make the training consistent with the evaluation mode (Irie et al., 2019b). During training, we carry over the fast weight memory from one training segment to the following one, while still limiting the backpropagation span to be within the segment. We train a Linear Transformer with our update rule, using neither positional encoding nor attention normalisation (the best setting from Table 3). It was crucial to remove the attention normalization here, since the accumulator blows up as indicated in Sec. 4.2. Table 4 shows the corresponding results. Although our model was not designed for this use case, it does not break (unlike the base model with the naive sum update rule). It even yields a slight improvement over the best model with a limited context window (Table 3). While performance does not yet match the one of prior models (Dai et al., 2019; Rae et al., 2020) specifically designed for this use case (we train the Transformer-XL in our medium configuration), the results are promising for future work on alternative Transformer models which can run for an unlimited number of steps.
We connect linearised self-attention to outer product fast weights. The fast weight perspective allows for discussing the associative memory capacity limitation of linear attention, and for proposing an alternative update rule to dynamically edit the memory. We also propose and discuss a new method for linearising attention. Experiments on synthetic and real language tasks demonstrate the effectiveness of our proposals. The fast weight perspective opens up new avenues for investigating even better update rules and designs for Transformers with finite memory.
We thank Sjoerd van Steenkiste, Hubert Ramsauer and Sepp Hochreiter for valuable comments and suggestions on the first version of the manuscript. This research was partially funded by ERC Advanced grant no: 742870, project AlgoRNN. We thank NVIDIA Corporation for donating several DGX machines, and IBM for donating a Minsky machine. We also thank Katharopoulos et al. (2020) for releasing their cuda implementation of Linear Transformers, which was helpful to implement our models.
Proc. Conference on Artificial Intelligence (AAAI)
, pp. 3159–3166, Honolulu, HI, USA, January 2019.Proc. Conf. on Empirical Methods in Natural Language Processing (EMNLP)
, pp. 551–561, Austin, TX, USA, November 2016.Proc. IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)
, pp. 4840–4848, Boston, MA, USA, June 2015.Image question answering using convolutional neural network with dynamic parameter prediction.
In Proc. IEEE Conf. on Computer Vision and Pattern Recognition (CVPR), pp. 30–38, Las Vegas, NV, USA, 2016.By considering one-hot vectors which form the Cartesian basis of , any matrix can be written as
(41) |
where are the column vectors of . In the context of associative memory, we can interpret this expression as a set of associations with fixed keys and the associated values .
In this view, any update of can be written as updates of each . This perspective allows us to derive the sum normalisation of Sec. 4.2. For that, we start by deriving the update of .
Given an arbitrary weight , we consider updating it to by adding a new association using our update rule of Sec. 4.2 (where we omit ):
(42) | |||||
(43) |
By substituting in Eq. 43 by its expression in the Cartesian basis with , we obtain:
(44) | |||||
(45) |
In this section, we provide additional experimental results which we could not include in the main paper because of space limitation.
Figure 4 shows learning curves for the synthetic setting 1 (without replacement) with 600 unique keys and values. The scripts used to generate such figures can be found in our GitHub repository.
In Sec. 6.3, we evaluated our update rule when the model is under overcapacity regime. Here we present an extra language modelling experiment which evaluate the benefits of our update rule in non-overcapacity scenarios. This also allows us to include DPFP in the evaluation. We train both Performer and DPFP models in the small setting (, ) with and , resulting in for both cases. Table 5 shows the perplexity results. First we observe that the Performer and DPFP baseline models with the sum update rule do not outperform the Linear Transformer baseline from Table 2. In fact, language modelling might be less affected by the capacity issue than the synthetic retrieval task, as it might not require the exact retrieval. Second we observe that our update rule improves both variants of linear attention over the sum update-rule baselines even in this condition. This indicates the general benefits of our update rule in fast weight memories. We note that the improvement is larger for the DPFP model than for the Performer. This is similar to Table 2 where our update rule improves the deterministic Linear Transformers more than the Performers.
Update | small | ||
---|---|---|---|
Rule | Valid | Test | |
Transformer | - | 33.0 | 34.1 |
Performer | sum | 38.0 | 38.8 |
ours | 36.0 | 37.0 | |
DPFP | sum | 37.7 | 38.8 |
ours | 33.9 | 35.0 |
We implemented different functions in the fairseq tookit (Ott et al., 2019). The Transformer architecture used in the experiment is the one referred to as big in the original Transformer paper (Vaswani et al., 2017): the model has 6 layers each in the encoder and the decoder, with a hidden layer size of 1024 with 16 attention heads, 4096-dimensional feed-forward layers, using 32 K byte-pair encoding sub-word units (Sennrich et al., 2016). fairseq provides a training configuration for the corresponding model (Ott et al., 2018)
, which we adapted for our infrastructure. We trained our models on three GPUs using a batch size of up to 3584 tokens per GPU and by accumulating gradients over 16 batches for 45 epochs, and selected the best model based on the validation
Bleu score. In Table 1, we directly reported Bleu for different values of ; Table 6 provides the conversion from hyper-parameters of Performers or in the DPFP to .256 | 384 | 512 | |
---|---|---|---|
Performer | 128 | 192 | 256 |
DPFP | 2 | 3 | 4 |
All our implementations are based on PyTorch (Paszke et al., 2019). Our base language modelling code has been developed by using the public code by Dai et al. (2019) for Transformer-XL as a starting point. For functions, we ported the same implementation we used for our translation experiments. For the implementation of our update rule, we modified the cuda kernel made publicly available by (Katharopoulos et al., 2020). We note that a custom implementation of backward pass for the fast weight is crucial for language modelling. A naive backward computation generated by automatic differentiation would store the fast weights for each time step, which can quickly hit the GPU memory limit. With a custom implement, one can make sure that the space requirement is of order of one set of weights by recomputing the fast weights needed for computing the gradients for each time step (which still remains time-efficient as the operations involved in the computation of our fast weights are rather inexpensive).
Here we provide extra experimental details to complement the descriptions of Sec. 6.3. Respectively for the small and medium configurations, we use the batch size of 96 and 56 sequences, and train for about 120 and 70 epochs. In both settings, we apply 10% dropout (Hanson, 1990; Srivastava et al., 2014), and train using Adam optimiser (Kingma & Ba, 2014) with an initial learning rate of 0.00025 and 2000 learning rate warm-up steps. For models with a large dot product dimension , we experienced some instable training issues, which we will investigate in the future. For experiments with Transformer-XL (Table 4), we train it with the same backpropagation span as our models (i.e. words in the medium configuration), and evaluate it with memory and target sequence lengths of each. For further details, we refer the readers to our code.