1 Introduction
Neural Attention was introduced by [Bahdanau et al., 2014] as a way of extracting information from variablelength representations. The Transformer model [Vaswani et al., 2017] uses "multihead" attention, consisting of multiple attention layers ("heads") in parallel, each with different projections on its inputs and outputs. By using a dimensionality reduction in the input projections, the computational cost is kept similar to that of basic attention. Quality is improved, presumably due to the ability to attend to multiple positions simultaneously based on multiple different types of relationships.
As noted in [Vaswani et al., 2017]^{1}^{1}1Section (A) of table 3 in [Vaswani et al., 2017]. Also the first sections of tables 1 and 5 of this paper.
, taking this process to the extreme (more attention heads projected to lower dimensionality) becomes counterproductive. We believe that this is due to the fact that the queryvectors and keyvectors become so lowdimensional that their dot product can no longer constitute an informative matching function.
In this paper, we introduce a new variant, "talkingheads attention", that addresses this problem by inserting a learned linear projection across the attentionheads dimension of the attentionlogits tensor. This allows each attention function to depend on all of the keys and queries. We also insert a second such projection immediately following the softmax.
We show experimentally that inserting these "talkingheads" projections leads to better perplexities on masked language modeling tasks, as well as better quality when transferlearning to language comprehension and question answering tasks.
2 Notation
In our pseudocode, we use capital letters to represent tensors and lowercase letters to represent their dimensions. Each tensor is followed by a dimension list in brackets. For example, a 4dimensional imagetensor with (batch, height, width, channels) dimensions would be written as:
We use einsum
notation for generalized contractions between tensors of arbitrary dimension. The computation is numerically equivalent to broadcasting each input to have the union of all dimensions, multiplying componentwise, and summing across all dimensions not in the output. Rather than identifying the dimensions by an equation, as in TensorFlow and numpy, the dimensions are indentified by the dimensionlist annotations on the arguments and on the result. For example, multiplying two matrices would be expressed as:
3 Review of Attention Algorithms
3.1 DotProduct Attention
Simple dotproduct attention can be described by the pseudocode below. The logits L are computed as the dotproducts of the queryvectors and the memoryvectors. For each query, the logits are passed through a softmax function to produce weights, and the different memoryvectors are averaged together, weighted by those weights. In this code, we show the case where there are different queries all attending to the same memoryvectors. If there is only one query, the code is identical except that the "n" dimension is removed from all tensors.
⬇
def DotProductAttention(
X[n, d], # n queryvectors with dimensionality d
M[m, d]): # m memoryvectors with dimensionality d
L[n, m] = einsum(X[n, d], M[m, d]) # Attention logits
W[n, m] = softmax(L[n, m], reduced_dim=m) # Attention weights
Y[n, d] = einsum(W[n, m], M[m, d])
return Y[n, d]
3.2 DotProduct Attention With Projections
[Vaswani et al., 2017] propose a dimensionalityreduction to reduce the computational complexity of the attention algorithm. In this version, instead of computing the attention algorithm directly on the inputs and , we first project the inputs using the learned linear projections , and , to produce lowerdimensional queryvectors, keyvectors and valuevectors , and . We use a fourth learned linear projection, , to produce the output.
3.3 MultiHead Attention
The multihead attention described in [Vaswani et al., 2017] consists of the sum of multiple parallel attention layers. This can be represented by adding a "heads" dimension h to the above computation.
⬇
def MultiHeadAttention(
X[n, d_X], # n vectors with dimensionality d_X
M[m, d_M], # m vectors with dimensionality d_M
P_q[d_X, d_k, h], # learned linear projection to produce queries
P_k[d_M, d_k, h], # learned linear projection to produce keys
P_v[d_M, d_v, h], # learned linear projection to produce values
P_o[d_Y, d_v, h]): # learned linear projection of output
Q[n, d_k, h] = einsum(X[n, d_X], P_q[d_X, d_k, h]) # queries h*n*d_X*d_k
K[m, d_k, h] = einsum(M[m, d_M], P_k[d_M, d_k, h]) # keys h*m*d_M*d_k
V[m, d_v, h] = einsum(M[m, d_M], P_v[d_M, d_v, h]) # values h*m*d_M*d_v
L[n, m, h] = einsum(Q[n, d_k, h], K[m, d_k, h]) # logits h*n*m*d_k
W[n, m, h] = softmax(L[n, m, h], reduced_dim=m) # weights
O[n, d_v, h] = einsum(W[n, m, h], V[m, d_v, h]) # h*n*m*d_v
Y[n, d_Y] = einsum(O[n, d_v, h], P_o[d_Y, d_v, h]) # output h*n*d_Y*d_v
return Y[n, d_Y]
The pseudocode above illustrates the practical stepbystep computation of multihead attention. The costs of the einsum operations (the number of multiplications in a naive implementation) are shown in the comments. The equivalent pseudocode below uses multiway einsums and is more concise:
⬇
def MultiHeadAttentionConcise(X, M, P_q, P_k, P_v, P_o):
L[n, m, h] = einsum(X[n, d_X],
M[m, d_M],
P_q[d_X, d_k, h],
P_k[d_M, d_k, h])
W[n, m, h] = softmax(L[n, m, h], reduced_dim=m)
Y[n, d] = einsum(W[n, m, h],
M[m, d_M],
P_v[d_M, d_v, h],
P_o[d_Y, d_v, h])
return Y[n, d_Y]
Note: [Vaswani et al., 2017] include a constant scaling factor on the logits. We omit this in our code, as it can be folded into the linear projections or .
4 TalkingHeads Attention
In multihead attention, the different attention heads perform separate computations, which are then summed at the end. Our new variation, which we call "TalkingHeads Attention" breaks that separation. We insert two additional learned linear projections, and , which transform the attentionlogits and the attentionweights respectively, moving information across attention heads. ^{2}^{2}2Appendix A presents a variation on this, where the projection matrices themselves are inputdependent. Instead of one "heads" dimension across the whole computation, we now have three separate heads dimensions: , , and , which can optionally differ in size (number of "heads"). refers to the number of attention heads for the keys and the queries. refers to the number of attention heads for the logits and the weights, and refers to the number of attention heads for the values. The algorithm is shown by the pseudocode below. The costs of the einsum operations are shown in the comments.
⬇
def TalkingHeadsAttention(
X[n, d_X], # n vectors with dimensionality d_X
M[m, d_M], # m vectors with dimensionality d_M
P_q[d_X, d_k, h_k], # learned linear projection to produce queries
P_k[d_M, d_k, h_k], # learned linear projection to produce keys
P_v[d_M, d_v, h_v], # learned linear projection to produce values
P_o[d_Y, d_v, h_v], # learned linear projection of output
P_l[h_k, h], # talkingheads projection for logits
P_w[h, h_v]): # talkingheads projection for weights
Q[n, d_k, h_k] = einsum(X[n, d_X], P_q[d_X, d_k, h_k]) # queries n*d_X*d_k*h_k
K[m, d_k, h_k] = einsum(M[m, d_M], P_k[d_M, d_k, h_k]) # keys m*d_M*d_k*h_k
V[m, d_v, h_v] = einsum(M[m, d_M], P_v[d_M, d_v, h_v]) # values m*d_M*d_v*h_v
J[n, m, h_k] = einsum(Q[n, d_k, h_k], K[m, d_k, h_k]) # dot prod. n*m*d_k*h_k
L[n, m, h] = einsum(J[n, m, h_k], P_l[h_k, h]) # Talkingheads proj. n*m*h*h_k
W[n, m, h] = softmax(L[n, m, h], reduced_dim=m) # Attention weights
U[n, m, h_v] = einsum(W[n, m, h], P_w[h, h_v]) # Talkingheads proj. n*m*h*h_v
O[n, d_v, h_v] = einsum(U[n, m, h_v], V[m, d_v, h_v]) # n*m*d_v*h_v
Y[n, d_Y] = einsum(O[n, d_v, h_v], P_o[d_Y, d_v, h_v]) # n*d_Y*d_v*h_v
return Y[n, d_Y]
Again, we can write this more concisely using multiway einsum operations:
⬇
def TalkingHeadsAttentionConcise(X, M, P_q, P_k, P_v, P_o, P_l, P_w):
L[n, m, h] = einsum(X[n, d_X],
M[m, d_M],
P_q[d_X, d_k, h_k],
P_k[d_M, d_k, h_k],
P_l[h_k, h])
W[n, m, h] = softmax(L[n, m, h], reduced_dim=m)
Y[n, d_Y] = einsum(W[n, m, h],
M[m, d_M],
P_v[d_M, d_v, h_v],
P_o[d_Y, d_v, h_v],
P_w[h, h_v])
return Y[n, d_Y]
5 Complexity Analysis
If we assume that , then the number of scalar multiplications in multihead attention is:
The number of scalar multiplications in talkingheads attention is:
The first term in this expression matches up with the cost of multihead attention. The second term is due to the talkingheads projections. If and , the the costs of the new talkingheads projections, and are less than the existing terms and , respectively.
In practice, the talkingheads projections may be expensive on some neuralnetwork accelerators due to the small dimension sizes involved.
6 One More Way To Look At It
Mathematically, one can view multihead attention and talkingheads attention as two special cases of the same general function, which we will call "general bilinear multihead attention" (GBMA). GBMA uses two threedimensional parameter tensors, as defined in the pseudocode below. Due to its high computational cost, GBMA may have no practical use. Multihead attention is mathematically equivalent to a version of GBMA where each of the two parameter tensors is expressed as the product of two factors, as shown below. Talkingheads attention is mathematically equivalent to a version of GBMA where each of the two parameter tensors is expressed as the product of three factors, as shown below.
⬇
def GeneralBilinearMultiheadAttention(
X[n, d_X], # n vectors with dimensionality d_X
M[m, d_M], # m vectors with dimensionality d_M
P[d_X, d_M, h], # learned parameters
Q[d_M, d_Y, h]): # learned parameters
L[n, m, h] = einsum(X[n, d_X], M[m, d_M], P[d_X, d_M, h])
W[n, m, h] = softmax(L[n, m, h], reduced_dim=m)
Y[n, d_Y] = einsum(W[n, m, h], M[m, d_M], Q[d_M, d_Y, h])
return Y[n, d_Y]
def MultiHeadAttentionInefficient(X, M, P_q, P_k, P_v, P_o):
P[d_X, d_M, h] = einsum(P_q[d_X, d_k, h], P_k[d_M, d_k, h])
Q[d_M, d_Y, h] = einsum(P_v[d_M, d_v, h], P_o[d_Y, d_v, h])
return GeneralBilinearMultiheadAttention(X, M, P, Q)
def TalkingHeadsAttentionInefficient(X, M, P_q, P_k, P_v, P_o, P_l, P_w):
P[d_X, d_M, h] = einsum(P_q[d_X, d_k, h_k], P_k[d_M, d_k, h_k], P_l[h_k, h])
Q[d_M, d_Y, h] = einsum(P_v[d_M, d_v, h_v], P_o[d_Y, d_v, h_v], P_w[h, h_v])
return GeneralBilinearMultiheadAttention(X, M, P, Q)
7 Experiments
7.1 TexttoText Transfer Transformer (T5)
We test various configurations of multihead attention and talkingheads attention on the transferlearning setup from [Raffel et al., 2019]. An encoderdecoder transformer model [Vaswani et al., 2017] is pretrained on a denoising objective of predicting missing text segments (average span length 3) from the C4 dataset [Raffel et al., 2019] ^{3}^{3}3This is identical to one of the training objecives described in [Raffel et al., 2019], and subsequently finetuned on various language understanding tasks. We use the same code base and model architecture as the base model from [Raffel et al., 2019]. The encoder and decoder each consist of 12 layers, with and . Each encoder layer contains a multihead selfattention layer, and each decoder layer contains a multihead selfattention layer and a multihead attentionoverencoder layer. For their base model, [Raffel et al., 2019] follow [Devlin et al., 2018] and others, using and for all of these attention layers. We compare this setting to a variety of other configurations of multihead and talkingheads attention, as detailed in table 1.
Similar to [Raffel et al., 2019], we pretrain our models for 524288 steps. Each training batch consists of 128 examples, each of which has an input of 512 tokens and an output of 114 tokens, the output containing multiple spans of tokens which were deleted from the input. Similarly to [Raffel et al., 2019], we use the Adafactor optimizer [Shazeer and Stern, 2018] and an inversesquareroot learningrate schedule. We also decay the learning rate linearly for the final 10 percent of the training steps. Our main departure from [Raffel et al., 2019] is that we, as suggested by [Lan et al., 2019], use no dropout during pretraining. We find this to produce superior results. We compute the logperplexity on the training objective on a heldout shard of C4, which we believe to be a good indicator of model quality. For each configuration, we train one model for the "full" 524288 steps and four models for a shorter time (65536 steps) to measure interrun variability. The results are listed in table 1.
We then finetune each of the models on an examplesproportional mixture of SQUAD [Rajpurkar et al., 2016], GLUE [Wang et al., 2018] and SuperGlue [Wang et al., 2019]. Finetuning consists of 131072 additional steps with a learning rate of . Following [Raffel et al., 2019], we use a dropout rate
on the layer outputs, feedforward hiddenlayers and attention weights. The embedding matrix (also used as the projection in the final classifier layer) is fixed during finetuning. Tables
1, 2, 3 and 4 include results for SQUAD and MNLIm. Results for all other tasks are listed in the appendix.7.1.1 MultiHead vs TalkingHeads Attention
ln(PPL)  ln(PPL)  SQUAD  step  parameters  multiplies  
65536  524288  v1.1  MNLIm  time  per  per att. layer  
steps  steps  devf1  dev  (s)  att. layer  (n=m=512)  
multihead  6  128  128  2.010 (0.005)  1.695  89.88  85.34  0.14  2359296  
multihead  12  64  64  1.982 (0.003)  1.678  90.87  86.20  0.15  2359296  
multihead  24  32  32  1.989 (0.009)  1.669  91.04  86.41  0.17  2359296  
multihead  48  16  16  2.011 (0.004)  1.682  90.35  85.32  0.21  2359296  
talkingheads  6  6  6  128  128  1.965 (0.009)  1.659  90.51  85.99  0.16  2359368  
talkingheads  12  12  12  64  64  1.932 (0.004)  1.641  91.38  86.19  0.18  2359584  
talkingheads  24  24  24  32  32  1.910 (0.001)  1.624  91.83  87.42  0.22  2360448  
talkingheads  48  48  48  16  16  1.903 (0.006)  1.603  91.90  87.50  0.32  2363904  
multihead  24  64  64  1.950 (0.005)  1.625  91.46  86.58  0.22  4718592  
general bilinear  12  768  768  1.921 (0.011)  1.586  90.83  86.50  0.47  14155776  
[Raffel et al., 2019]  12  64  64  89.66  84.85  2359296 
In table 1, we compare multihead attention to talkingheads attention. For each of the two algorithms, we test versions with 6, 12, 24 and 48 heads. Following [Vaswani et al., 2017], as we increase the number of heads, we decrease the key/value dimensionality and , so as to keep the number of parameters constant. For each number of heads, talkingheads attention improves over multihead attention on all quality metrics.
Additionally, multihead attention gets worse as we increase the number of heads from 24 to 48 and decrease the key and value dimensionalty from 32 to 16, while talkingheads attention gets better. We presume that this is due to the keys being too short to produce a good matching signal.
For additional comparison, we include in table 1 two models with significantly more parameters and computation in the attention layers. In the first, we double the number of heads in our baseline model from to without reducing and , resulting in a multihead attention layer with double the parameters and double the computation. In the second, we use "general bilinear multihead attention", as described in section 6.
We also list the results from [Raffel et al., 2019]. We believe that their results are worse due to their use of dropout during pretraining.
7.1.2 Varying the HeadsDimensions Separately
ln(PPL)  ln(PPL)  SQUAD  step  parameters  multiplies  

65536  524288  v1.1  MNLIm  time  per  per att. layer  
steps  steps  devf1  dev  (s)  att. layer  (n=m=512)  
talkingheads  6  6  6  128  128  1.965 (0.009)  1.659  90.51  85.99  0.16  2359368  
talkingheads  6  24  6  128  128  1.941 (0.009)  1.641  90.91  86.29  0.18  2359584  
talkingheads  24  6  24  32  32  1.959 (0.008)  1.667  90.77  86.15  0.20  2359584  
talkingheads  6  24  24  128  32  1.939 (0.011)  1.633  91.06  86.31  0.20  2360016  
talkingheads  24  24  6  32  128  1.931 (0.013)  1.628  90.98  86.81  0.21  2360016  
talkingheads  24  24  24  32  32  1.910 (0.001)  1.624  91.83  87.42  0.22  2360448 
In table 2, we experiment with independently varying the sizes of the three headsdimensions. From the results, it appears that all three are good to increase, but that the softmaxheads dimension is particularly important.
7.1.3 LogitsProjection Only and WeightsProjection Only
ln(PPL)  ln(PPL)  SQUAD  step  parameters  multiplies  

65536  524288  v1.1  MNLIm  time  per  per att. layer  
steps  steps  devf1  dev  (s)  att. layer  (n=m=512)  
multihead  24  32  32  1.989 (0.009)  1.669  91.04  86.41  0.17  2359296  
project logits  24  24  32  32  1.969 (0.004)  1.652  91.29  85.86  0.23  2359872  
project weights  24  24  32  32  1.951 (0.009)  1.636  91.03  86.12  0.23  2359872  
talkingheads  24  24  24  32  32  1.910 (0.001)  1.624  91.83  87.42  0.22  2360448 
In the middle two experiments of table 3, we examine hybrids of multihead attention and talkingheads attention, where there is a projection on one but not both of the logits and the weights.
7.1.4 Encoder vs. Decoder
ln(PPL)  ln(PPL)  SQUAD  step  parameters  multiplies  
65536  524288  v1.1  MNLIm  time  per  per att. layer  
steps  steps  devf1  dev  (s)  att. layer  (n=m=512)  
multihead  24  32  32  1.989 (0.009)  1.669  91.04  86.41  0.17  2359296  
THencself  24*  24  24*  32  32  1.969 (0.002)  1.655  91.63  87.00  0.21  various  various 
THdecself  24*  24  24*  32  32  1.981 (0.005)  1.671  90.56  85.56  0.17  various  various 
THencdec  24*  24  24*  32  32  1.942 (0.003)  1.646  90.86  86.07  0.18  various  various 
talkingheads  24  24  24  32  32  1.910 (0.001)  1.624  91.83  87.42  0.22  2360448 
The transformer model contains three types of attention layers  selfattention in the encoder, selfattention in the decoder, and attentionoverencoder in the decoder. In each of the middle three experiments of table 4, we employ talkingheads attention in only one of these types of attention layers, and multihead attention in the others. We find that modifying the encoderselfattention layers has the biggest effect on the downstream languageunderstanding tasks. This is unsurprising, given that these tasks have more to do with analyzing the input than with generating output.
7.2 Albert
[Lan et al., 2019] introduce ALBERT, a variation on BERT [Devlin et al., 2018]. The main difference between the ALBERT and BERT architectures is that ALBERT shares layer parameters among all layers, significantly reducing the number of parameters. For example, a 12layer ALBERT model has about 1/12 the number of parameters in the attention and feedforward layers as a similar BERT model. Another difference is that the ALBERT model factorizes the word embedding as the product of two matrices with smaller bases, again significantly reducing the parameter count. This makes ALBERT appealing for memory limited devices such as mobiles. Besides above architecture differences, ALBERT also uses sentence order prediction (SOP) to replace next sentence prediction (NSP) in BERT.
We report here experiments done with the base setting for ALBERT: a Transformer network with 12 layers of attention, the hidden and embedding size set to 768. The pretraining and finetuning hyperparameter settings are also exactly the same as in
[Lan et al., 2019]. We use the English Wikipedia and book corpus datasets [Devlin et al., 2018] to pretrain various models with different head sizes and talkingheads configurations. We evaluate the resulting representations by using them as a starting point to finetune for the SQuAD task (SQuAD1.1, SQuAD2.0 dev set) and various tasks (MNLI, SST2, RACE) from the GLUE benchmark. Results are in Table 5.heads  SQuAD1.1 (f1)  SQuAD2.0 (f1)  MNLI  SST2  RACE  MLM  SOP  Average  

Multihead  6  128  88.5  78.8  79.9  88.6  62.7  54.3  85.9  79.7 
Multihead  12  64  88.8  79.3  80.2  89.9  63.4  54.5  86.2  80.32 
Multihead  24  32  88.8  79.1  79.9  87.7  62.1  54.4  85.9  79.52 
Multihead  48  16  87.9  78.8  79.6  88.4  61.8  53.8  85.3  79.3 
Talkingheads  6  128  88.7  78  80  88.5  62  54.1  85.2  79.44 
Talkingheads  12  64  89.2  79.9  80.5  89  65.3  54.9  87.6  80.78 
Talkingheads  24  32  89.3  80.5  80.5  87.6  65.6  55.3  86.3  80.7 
Talkingheads  48  16  89.6  80.9  80.9  89.3  66.5  55.7  86.5  81.44 
We find that as the number of heads increases beyond 12 and the dimensionality of the attentionkeys and attentionvalues decreases below 64, the performance of multihead attention decays. On the other hand, the performance of talkinghead attention keeps improving.
In addition, we also compare the logits projection and the weight projection separately with multihead and talkingheads attention. The results are shown in Table 6. Similar to our observation in T5 experiments, only applying either the logits projection or the weight projection does not result in significant improvement compared to without them. These results again confirm the importance of having both projections.
heads  SQuAD1.1 (f1)  SQuAD2.0 (f1)  MNLI  SST2  RACE  MLM  SOP  Average  

Multihead  12  64  88.8  79.3  80.2  89.9  63.4  54.5  86.2  80.32 
Logitprojectonly  12  64  88.5  78.8  79.8  89.3  63  54.6  85.8  79.88 
Weightprojectonly  12  64  88.9  79.6  80.3  89  64  54.7  85.8  80.36 
Talkingheads  12  64  89.2  79.9  80.5  89  65.3  54.9  87.6  80.78 
7.3 Bert
We test various configurations of talkingheads attention based on [Devlin et al., 2018]. All of our experiments use the simplified relative position embeddings [Raffel et al., 2019] instead of fixed position embedding. We first pretrain a 12 Transformer layers using the same dataset as [Devlin et al., 2018]. And then we finetune for the SQuAD1.1 task and the MNLI from the GLUE dataset. Our experiments show that quality continues to improve when we grow the number of heads up to 768 and decrease the key and value dimensionality down to 1 ^{4}^{4}4These extreme hyperparameter settings likely have no practical use, due to the massive amount of computation.
heads  SQuAD1.1 (f1)  MNLI  

Multihead  12  64  88.51  82.6 
Talkingheads  6  128  88.8  83.4 
Talkingheads  12  64  89.2  83.6 
Talkingheads  24  32  89.4  83.6 
Talkingheads  48  16  89.5  83.4 
Talkingheads  64  12  89.9  83.8 
Talkingheads  96  8  89.3  83.6 
Talkingheads  192  4  89.8  83.9 
Talkingheads  384  2  90.5  83.9 
Talkingheads  768  1  90.5  84.2 
7.4 Visualizing the Projection Matrices of TalkingHeads
To illustrate how different heads exchange information with each other, we visualize the projection matrices ( and ) of a 12 layer BERT with 12 talkingheads in figure 1. Since is applied after (although there is a softmax nonlinearity in between), we also visualize the combined transformation in figure 1. As can be observed, the main diagonals of the projection matrices do not have significant greater values than other entries. This is expected because with talkingheads, a pair of query and key do not corresponds to any specific valuevector. All keys and queries jointly decide how the values in each head interchange data. Additionally, all projection matrices are well conditioned (magnitude of determinant above
with smallest eigenvalue above
), indicating that no significant approximation can be achieved.8 Conclusions and Future Work
We have proposed talkingheads attention and shown some promising results. One potential challenge is speed on modern deeplearning accelerators, which are optimized for largedimension matrix multiplications. We imagine that this will be an area of future work. One approach is to build hardware which is better at smalldimension matrixmultiplication. Another potential approach is to decrease the number of memorypositions considered for each queryposition  for example, by using the localattention and memorycompressedattention approaches described in
[Liu et al., 2018]. We look forward to more applications of talkingheads attention, as well as to further architectural improvements.References
 Bahdanau et al. [2014] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate, 2014.
 Devlin et al. [2018] Jacob Devlin, MingWei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pretraining of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.

Lan et al. [2019]
Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, and
Radu Soricut.
Albert: A lite bert for selfsupervised learning of language representations, 2019.
 Liu et al. [2018] Peter J Liu, Mohammad Saleh, Etienne Pot, Ben Goodrich, Ryan Sepassi, Lukasz Kaiser, and Noam Shazeer. Generating wikipedia by summarizing long sequences. In Proceedings of the International Conference on Learning Representations, 2018.
 Raffel et al. [2019] Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, and Peter Liu. Exploring the limits of transfer learning with a unified texttotext transformer. arXiv eprints, 2019.
 Rajpurkar et al. [2016] Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, and Percy Liang. Squad: 100,000+ questions for machine comprehension of text. arXiv preprint arXiv:1606.05250, 2016.
 Shazeer and Stern [2018] Noam Shazeer and Mitchell Stern. Adafactor: Adaptive learning rates with sublinear memory cost. arXiv preprint arXiv:1804.04235, 2018.
 Vaswani et al. [2017] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017.
 Wang et al. [2018] Alex Wang, Amapreet Singh, Julian Michael, Felix Hill, Omer Levy, and Samuel R. Bowman. GLUE: A multitask benchmark and analysis platform for natural language understanding. arXiv preprint arXiv:1804.07461, 2018.
 Wang et al. [2019] Alex Wang, Yada Pruksachatkun, Nikita Nangia, Amanpreet Singh, Julian Michael, Felix Hill, Omer Levy, and Samuel R. Bowman. Superglue: A stickier benchmark for generalpurpose language understanding systems. arXiv preprint arXiv:1905.00537, 2019.
Appendix A Variation: Dynamic Projections
In the basic talkingheads attention algorithm described in section 4, the talkingheads projections are represented by two learned weight matrices and . In an additional wrinkle, we can make these projections matrices themselves inputdependent, adding terms to the projection matrices that are themselves learned linear projections of the inputs and . The algorithm is described by the pseudocode below.
⬇
def TalkingHeadsAttentionWithDynamicProjections(
X[n, d_X], # n vectors with dimensionality d_X
M[m, d_M], # m vectors with dimensionality d_M
P_q[d_X, d_k, h_k], # learned linear projection to produce queries
P_k[d_M, d_k, h_k], # learned linear projection to produce keys
P_v[d_M, d_v, h_v], # learned linear projection to produce values
P_o[d_Y, d_v, h_v], # learned linear projection of output
P_l[h_k, h], # learned static talkingheads proj. on logits
P_Xl[d_X, h_k, h], # learned projection to generate dynamic talkingheads projection
P_Ml[d_M, h_k, h], # learned projection to generate dynamic talkingheads projection
P_w[h, h_v], # learned static talkingheads proj. on weights
P_Xw[d_X, h, h_v], # learned projection to generate dynamic talkingheads projection
P_Mw[d_X, h, h_v]) # learned projection to generate dynamic talkingheads projection
Q[n, d_k, h_k] = einsum(X[n, d_X], P_q[d_X, d_k, h_k]) # queries n*d_X*d_k*h_k
K[m, d_k, h_k] = einsum(M[m, d_M], P_k[d_M, d_k, h_k]) # keys m*d_M*d_k*h_k
V[m, d_v, h_v] = einsum(M[m, d_M], P_v[d_M, d_v, h_v]) # values m*d_M*d_v*h_v
J[n, m, h_k] = einsum(Q[n, d_k, h_k], K[m, d_k, h_k]) # dot prod. n*m*d_k*h_k
R_Xl[n, h_k, h] = einsum(X[n, d_X], P_Xl[d_X, h_k, h]) # dynamic proj. n*d_X*h_k*h
R_Ml[n, h_k, h] = einsum(M[m, d_M], P_Ml[d_M, h_k, h]) # dynamic proj. n*d_M*h_k*h
L[n, m, h] = (
einsum(J[n, m, h_k], P_l[h_k, h]) + # Static talkingheads proj. n*m*h*h_k
einsum(J[n, m, h_k], R_Xl[n, h_k, h]) + # Dynamic talkingheads proj. n*m*h*h_k
einsum(J[n, m, h_k], R_Ml[m, h_k, h])) # Dynamic talkingheads proj. n*m*h*h_k
W[n, m, h] = softmax(L[n, m, h], reduced_dim=m) # Attention weights
R_Xw[n, h, h_v] = einsum(X[n, d_X], P_Xw[d_X, h, h_v]) # dynamic proj. n*d_X*h*h_v
R_Mw[n, h, h_v] = einsum(M[m, d_M], P_Mw[d_M, h, h_v]) # dynamic proj. n*d_M*h*h_v
U[n, m, h_v] = (
einsum(W[n, m, h], P_w[h, h_v]) + # Static Talkingheads proj. n*m*h*h_v
einsum(W[n, m, h], R_Xw[n, h, h_v]) + # Dynamic talkingheads proj. n*m*h*h_v
einsum(W[n, m, h], R_Mw[m, h, h_v])) # Dynamic talkingheads proj. n*m*h*h_v
O[n, d_v, h_v] = einsum(U[n, m, h_v], V[m, d_v, h_v]) # n*m*d_v*h_v
Y[n, d_Y] = einsum(O[n, d_v, h_v], P_o[d_Y, d_v, h_v]) # n*d_Y*d_v*h_v
return Y[n, d_Y]
We observed that the model only trained well if we initialized the projectiongenerating parameter matrices (, , ,
) to contain small enough values. We used normal initializers with standard deviations of
, , , and , respectively.ln(PPL)  ln(PPL)  SQUAD  step  parameters  multiplies  

65536  524288  v1.1  MNLIm  time  per  per att. layer  
steps  steps  devf1  dev  (s)  att. layer  (n=m=512)  
multihead  12  64  64  1.982 (0.003)  1.678  90.87  86.20  0.15  2359296  
talkingheads  12  12  12  64  64  1.932 (0.004)  1.641  91.38  86.19  0.18  2359584  
dyn. proj.  12  12  12  64  64  1.897 (0.007)  1.595  90.17  86.18  0.36  2801952  
multihead  24  32  32  1.989 (0.009)  1.669  91.04  86.41  0.17  2359296  
talkingheads  24  24  24  32  32  1.910 (0.001)  1.624  91.83  87.42  0.22  2360448  
dynamic proj.  24  24  24  32  32  1.873 (0.008)  1.587  90.17  85.94  0.53  4129920 
a.1 Experiments
We evaluate talkingheads attention with dynamic projections on T5 [Raffel et al., 2019] in a set of experiments similar to those described in section 7.1.
Table 8 compares multihead attention, takingheads attention with static projections, and talkingheads attention with dynamic projections. The dynamic projections reduce perplexity on the pretraining task. However, in our experiments, we did not see an improvement on the downstream tasks.
ln(PPL)  ln(PPL)  SQUAD  step  parameters  multiplies  

65536  524288  v1.1  MNLIm  time  per  per att. layer  
steps  steps  devf1  dev  (s)  att. layer  (n=m=512)  
talkingheads  12  12  12  64  64  1.932 (0.004)  1.641  91.38  86.19  0.18  2359584  
dyn. proj.  12  12  12  64  64  1.932 (0.011)  1.634  91.34  86.32  0.19  2470176  
dyn. proj.  12  12  12  64  64  1.914 (0.005)  1.619  90.70  86.43  0.19  2470176  
dyn. proj.  12  12  12  64  64  1.930 (0.010)  1.624  91.14  86.63  0.24  2470176  
dyn. proj.  12  12  12  64  64  1.917 (0.003)  1.624  90.54  86.45  0.25  2470176  
dyn. proj.  12  12  12  64  64  1.897 (0.007)  1.595  90.17  86.18  0.36  2801952 
a.1.1 Comparing the Four Dynamic Projections
Table 9 examines the effects of the four dynamic projections employed individually. The middle four rows represent experiments where only one of the four dynamic projections were employed. These are compared to static projections (top row) and all four dynamic projections together (bottom row).
ln(PPL)  ln(PPL)  SQUAD  step  parameters  multiplies  

65536  524288  v1.1  MNLIm  time  per  per att. layer  
steps  steps  devf1  dev  (s)  att. layer  (n=m=512)  
multihead  24  32  32  1.989 (0.009)  1.669  91.04  86.41  0.17  2359296  
THencself  24*  24  24*  32  32  1.969 (0.002)  1.655  91.63  87.00  0.21  various  various 
DPencself  24*  24  24*  32  32  1.953 (0.006)  1.639  91.99  86.97  0.42  various  various 
a.1.2 TalkingHeads in Encoder Only
In section 7.1.4 we saw that talking heads were particularly useful in the encoder part of the model. Table 10 presents a set of experiments where the decoder uses only multihead attention, while the encoder uses either multihead attention (top row), talkingheads attention with static projections (middle row), or talkingheads attention with dynamic projections (bottom row). We observe that in this case, the dynamic projections do not appear to degrade performance on the downstream tasks.
Appendix B T5 FineTuning Full Results
Tables 11, 12 and 13 present the results of finetuning the models in section 7.1 and appendix A on the GLUE [Wang et al., 2018] and SuperGlue [Wang et al., 2019], and Stanford QuestionAnswering Dataset (SQuAD) [Rajpurkar et al., 2016] benchmarks.
Score  CoLA  SST2  MRPC  MRPC  STSB  STSB  QQP  QQP  MNLIm  MNLImm  QNLI  RTE  
Average  MCC  Acc  F1  Acc  PCC  SCC  F1  Acc  Acc  Acc  Acc  Acc  
multihead  6  128  128  
multihead  12  64  64  
multihead  24  32  32  93.59  91.18  
multihead  48  16  16  93.59  91.18  
talkingheads  6  6  6  128  128  
talkingheads  12  12  12  64  64  
talkingheads  24  24  24  32  32  89.19  87.42  
talkingheads  48  48  48  16  16  94.61  91.16  91.00  
multihead  24  64  64  55.99  
general bilinear  12  768  768  91.98  
talkingheads  6  6  6  128  128  
talkingheads  6  24  6  128  128  
talkingheads  24  6  24  32  32  
talkingheads  6  24  24  128  32  
talkingheads  24  24  6  32  128  84.85  84.84  
talkingheads  24  24  24  32  32  89.19  87.42  
multihead  24  32  32  93.59  91.18  
project logits  24  24  32  32  
project weights  24  24  32  32  
talkingheads  24  24  24  32  32  89.19  87.42  
multihead  24  32  32  93.59  91.18  
THencself  24*  24  24*  32  32  
THdecself  24*  24  24*  32  32  
THencdec  24*  24  24*  32  32  
talkingheads  24  24  24  32  32  89.19  87.42  
multihead  12  64  64  
talkingheads  12  12  12  64  64  
dyn. proj.  12  12  12  64  64  
multihead  24  32  32  93.59  91.18  
talkingheads  24  24  24  32  32  89.19  87.42  
dyn. proj.  24  24  24  32  32  
talkingheads  12  12  12  64  64  
dyn. proj.  12  12  12  64  64  
dyn. proj.  12  12  12  64  64  
dyn. proj.  12  12  12  64  64  
dyn. proj.  12  12  12  64  64  
dyn. proj.  12  12  12  64  64  
multihead  24  32  32  93.59  91.18  
THencself  24*  24  24*  32  32  
DPencself  24*  24  24*  32  32  87.52  93.48  
[Raffel et al., 2019]  12  64  64  
ibid. stddev.  

Score  BoolQ  CB  CB  CoPA  MultiRC  MultiRC  ReCoRD  ReCoRD  RTE  WiC  WSC  
Average  Acc  F1  Acc  Acc  F1  EM  F1  EM  Acc  Acc  Acc  
multihead  6  128  128  
multihead  12  64  64  
multihead  24  32  32  
multihead  48  16  16  
talkingheads  6  6  6  128  128  
talkingheads  12  12  12  64  64  
talkingheads  24  24  24  32  32  90.60  94.64  
talkingheads  48  48  48  16  16  76.39  82.94  87.50  
multihead  24  64  64  
general bilinear  12  768  768  
talkingheads  6  6  6  128  128  
talkingheads  6  24  6  128  128  
talkingheads  24  6  24  32  32  71.32  
talkingheads  6  24  24  128  32  
talkingheads  24  24  6  32  128  76.00  86.28  87.50  
talkingheads  24  24  24  32  32  90.60  94.64  
multihead  24  32  32  
project logits  24  24  32  32  
project weights  24  24  32  32  
talkingheads  24  24  24  32  32  90.60  94.64  
multihead  24  32  32  
THencself  24*  24  24*  32  32  78.18  43.55  
THdecself  24*  24  24*  32  32  
THencdec  24*  24  24*  32  32  
talkingheads  24  24  24  32  32  90.60  94.64  
multihead  12  64  64  
talkingheads  12  12  12  64  64  
dyn. proj.  12  12  12  64  64  
multihead  24  32  32  
talkingheads  24  24  24  32  32  90.60  94.64  
dyn. proj.  24  24  24  32  32  
talkingheads  12  12  12  64  64  
dyn. proj.  12  12  12  64  64  
dyn. proj.  12  12  12  64  64  
dyn. proj.  12  12  12  64  64  
dyn. proj.  12  12  12  64  64  
dyn. proj.  12  12  12  64  64  
multihead  24  32  32  
THencself  24*  24  24*  32  32  78.18  43.55  
DPencself  24*  24  24*  32  32  77.88  76.99  
[Raffel et al., 2019]  12  64  64  
ibid. stddev. 