Talking-Heads Attention

03/05/2020 ∙ by Noam Shazeer, et al. ∙ Google 0

We introduce "talking-heads attention" - a variation on multi-head attention which includes linearprojections across the attention-heads dimension, immediately before and after the softmax operation.While inserting only a small number of additional parameters and a moderate amount of additionalcomputation, talking-heads attention leads to better perplexities on masked language modeling tasks, aswell as better quality when transfer-learning to language comprehension and question answering tasks.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Neural Attention was introduced by [Bahdanau et al., 2014] as a way of extracting information from variable-length representations. The Transformer model [Vaswani et al., 2017] uses "multi-head" attention, consisting of multiple attention layers ("heads") in parallel, each with different projections on its inputs and outputs. By using a dimensionality reduction in the input projections, the computational cost is kept similar to that of basic attention. Quality is improved, presumably due to the ability to attend to multiple positions simultaneously based on multiple different types of relationships.

As noted in [Vaswani et al., 2017]111Section (A) of table 3 in [Vaswani et al., 2017]. Also the first sections of tables 1 and 5 of this paper.

, taking this process to the extreme (more attention heads projected to lower dimensionality) becomes counterproductive. We believe that this is due to the fact that the query-vectors and key-vectors become so low-dimensional that their dot product can no longer constitute an informative matching function.

In this paper, we introduce a new variant, "talking-heads attention", that addresses this problem by inserting a learned linear projection across the attention-heads dimension of the attention-logits tensor. This allows each attention function to depend on all of the keys and queries. We also insert a second such projection immediately following the softmax.

We show experimentally that inserting these "talking-heads" projections leads to better perplexities on masked language modeling tasks, as well as better quality when transfer-learning to language comprehension and question answering tasks.

2 Notation

In our pseudocode, we use capital letters to represent tensors and lower-case letters to represent their dimensions. Each tensor is followed by a dimension list in brackets. For example, a 4-dimensional image-tensor with (batch, height, width, channels) dimensions would be written as:

  X[b, h, w, c]

We use einsum

notation for generalized contractions between tensors of arbitrary dimension. The computation is numerically equivalent to broadcasting each input to have the union of all dimensions, multiplying component-wise, and summing across all dimensions not in the output. Rather than identifying the dimensions by an equation, as in TensorFlow and numpy, the dimensions are indentified by the dimension-list annotations on the arguments and on the result. For example, multiplying two matrices would be expressed as:


  Y[a, c] = einsum(X[a, b], W[b, c])

3 Review of Attention Algorithms

3.1 Dot-Product Attention

Simple dot-product attention can be described by the pseudocode below. The logits L are computed as the dot-products of the query-vectors and the memory-vectors. For each query, the logits are passed through a softmax function to produce weights, and the different memory-vectors are averaged together, weighted by those weights. In this code, we show the case where there are different queries all attending to the same memory-vectors. If there is only one query, the code is identical except that the "n" dimension is removed from all tensors.
def DotProductAttention(     X[n, d],   # n query-vectors with dimensionality d     M[m, d]):  # m memory-vectors with dimensionality d   L[n, m] = einsum(X[n, d], M[m, d])           # Attention logits   W[n, m] = softmax(L[n, m], reduced_dim=m)    # Attention weights   Y[n, d] = einsum(W[n, m], M[m, d])   return Y[n, d]

3.2 Dot-Product Attention With Projections

[Vaswani et al., 2017] propose a dimensionality-reduction to reduce the computational complexity of the attention algorithm. In this version, instead of computing the attention algorithm directly on the inputs and , we first project the inputs using the learned linear projections , and , to produce lower-dimensional query-vectors, key-vectors and value-vectors , and . We use a fourth learned linear projection, , to produce the output.

def DotProductAttentionWithProjections(
    X[n, d_X],      # n vectors with dimensionality d_X
    M[m, d_M],      # m vectors with dimensionality d_M
    P_q[d_X, d_k],  # learned linear projection to produce queries
    P_k[d_M, d_k],  # learned linear projection to produce keys
    P_v[d_M, d_v],  # learned linear projection to produce values
    P_o[d_Y, d_v]): # learned linear projection of output
  Q[n, d_k] = einsum(X[n, d_X], P_q[d_X, d_k])  # queries
  K[m, d_k] = einsum(M[m, d_M], P_k[d_M, d_k])  # keys
  V[m, d_v] = einsum(M[m, d_M], P_v[d_M, d_v])  # values
  L[n, m] = einsum(Q[n, d_k], K[m, d_k])     # Attention logits
  W[n, m] = softmax(L[n, m], reduced_dim=m)  # Attention weights
  O[n, d_v] = einsum(W[n, m], V[m, d_v])
  Y[n, d_Y] = einsum(O[n, d_v], P_o[d_Y, d_v])
  return Y[n, d_Y]

3.3 Multi-Head Attention

The multi-head attention described in [Vaswani et al., 2017] consists of the sum of multiple parallel attention layers. This can be represented by adding a "heads" dimension h to the above computation.
def MultiHeadAttention(     X[n, d_X],         # n vectors with dimensionality d_X     M[m, d_M],         # m vectors with dimensionality d_M     P_q[d_X, d_k, h],  # learned linear projection to produce queries     P_k[d_M, d_k, h],  # learned linear projection to produce keys     P_v[d_M, d_v, h],  # learned linear projection to produce values     P_o[d_Y, d_v, h]): # learned linear projection of output   Q[n, d_k, h] = einsum(X[n, d_X], P_q[d_X, d_k, h])  # queries   h*n*d_X*d_k   K[m, d_k, h] = einsum(M[m, d_M], P_k[d_M, d_k, h])  # keys      h*m*d_M*d_k   V[m, d_v, h] = einsum(M[m, d_M], P_v[d_M, d_v, h])  # values    h*m*d_M*d_v   L[n, m, h] = einsum(Q[n, d_k, h], K[m, d_k, h])     # logits    h*n*m*d_k   W[n, m, h] = softmax(L[n, m, h], reduced_dim=m)     # weights   O[n, d_v, h] = einsum(W[n, m, h], V[m, d_v, h])     #           h*n*m*d_v   Y[n, d_Y] = einsum(O[n, d_v, h], P_o[d_Y, d_v, h])  # output    h*n*d_Y*d_v   return Y[n, d_Y]

The pseudo-code above illustrates the practical step-by-step computation of multi-head attention. The costs of the einsum operations (the number of multiplications in a naive implementation) are shown in the comments. The equivalent pseudo-code below uses multi-way einsums and is more concise:
def MultiHeadAttentionConcise(X, M, P_q, P_k, P_v, P_o):   L[n, m, h] = einsum(X[n, d_X],                       M[m, d_M],                       P_q[d_X, d_k, h],                       P_k[d_M, d_k, h])   W[n, m, h] = softmax(L[n, m, h], reduced_dim=m)   Y[n, d] = einsum(W[n, m, h],                    M[m, d_M],                    P_v[d_M, d_v, h],                    P_o[d_Y, d_v, h])   return Y[n, d_Y]

Note: [Vaswani et al., 2017] include a constant scaling factor on the logits. We omit this in our code, as it can be folded into the linear projections or .

4 Talking-Heads Attention

In multi-head attention, the different attention heads perform separate computations, which are then summed at the end. Our new variation, which we call "Talking-Heads Attention" breaks that separation. We insert two additional learned linear projections, and , which transform the attention-logits and the attention-weights respectively, moving information across attention heads. 222Appendix A presents a variation on this, where the projection matrices themselves are input-dependent. Instead of one "heads" dimension across the whole computation, we now have three separate heads dimensions: , , and , which can optionally differ in size (number of "heads"). refers to the number of attention heads for the keys and the queries. refers to the number of attention heads for the logits and the weights, and refers to the number of attention heads for the values. The algorithm is shown by the pseudo-code below. The costs of the einsum operations are shown in the comments.
def TalkingHeadsAttention(     X[n, d_X],           # n vectors with dimensionality d_X     M[m, d_M],           # m vectors with dimensionality d_M     P_q[d_X, d_k, h_k],  # learned linear projection to produce queries     P_k[d_M, d_k, h_k],  # learned linear projection to produce keys     P_v[d_M, d_v, h_v],  # learned linear projection to produce values     P_o[d_Y, d_v, h_v],  # learned linear projection of output     P_l[h_k, h],         # talking-heads projection for logits     P_w[h, h_v]):        # talking-heads projection for weights   Q[n, d_k, h_k] = einsum(X[n, d_X], P_q[d_X, d_k, h_k])  # queries        n*d_X*d_k*h_k   K[m, d_k, h_k] = einsum(M[m, d_M], P_k[d_M, d_k, h_k])  # keys           m*d_M*d_k*h_k   V[m, d_v, h_v] = einsum(M[m, d_M], P_v[d_M, d_v, h_v])  # values         m*d_M*d_v*h_v   J[n, m, h_k] = einsum(Q[n, d_k, h_k], K[m, d_k, h_k])   # dot prod.      n*m*d_k*h_k   L[n, m, h] = einsum(J[n, m, h_k], P_l[h_k, h])    # Talking-heads proj.  n*m*h*h_k   W[n, m, h] = softmax(L[n, m, h], reduced_dim=m)   # Attention weights   U[n, m, h_v] = einsum(W[n, m, h], P_w[h, h_v])    # Talking-heads proj.  n*m*h*h_v   O[n, d_v, h_v] = einsum(U[n, m, h_v], V[m, d_v, h_v])   #                n*m*d_v*h_v   Y[n, d_Y] = einsum(O[n, d_v, h_v], P_o[d_Y, d_v, h_v])  #                n*d_Y*d_v*h_v   return Y[n, d_Y] Again, we can write this more concisely using multi-way einsum operations:
def TalkingHeadsAttentionConcise(X, M, P_q, P_k, P_v, P_o, P_l, P_w):   L[n, m, h] = einsum(X[n, d_X],                       M[m, d_M],                       P_q[d_X, d_k, h_k],                       P_k[d_M, d_k, h_k],                       P_l[h_k, h])   W[n, m, h] = softmax(L[n, m, h], reduced_dim=m)   Y[n, d_Y] = einsum(W[n, m, h],                      M[m, d_M],                      P_v[d_M, d_v, h_v],                      P_o[d_Y, d_v, h_v],                      P_w[h, h_v])   return Y[n, d_Y]

5 Complexity Analysis

If we assume that , then the number of scalar multiplications in multi-head attention is:

The number of scalar multiplications in talking-heads attention is:

The first term in this expression matches up with the cost of multi-head attention. The second term is due to the talking-heads projections. If and , the the costs of the new talking-heads projections, and are less than the existing terms and , respectively.

In practice, the talking-heads projections may be expensive on some neural-network accelerators due to the small dimension sizes involved.

6 One More Way To Look At It

Mathematically, one can view multi-head attention and talking-heads attention as two special cases of the same general function, which we will call "general bilinear multihead attention" (GBMA). GBMA uses two three-dimensional parameter tensors, as defined in the pseudocode below. Due to its high computational cost, GBMA may have no practical use. Multi-head attention is mathematically equivalent to a version of GBMA where each of the two parameter tensors is expressed as the product of two factors, as shown below. Talking-heads attention is mathematically equivalent to a version of GBMA where each of the two parameter tensors is expressed as the product of three factors, as shown below.
def GeneralBilinearMultiheadAttention(     X[n, d_X],   # n vectors with dimensionality d_X     M[m, d_M],   # m vectors with dimensionality d_M     P[d_X, d_M, h],  # learned parameters     Q[d_M, d_Y, h]): # learned parameters   L[n, m, h] = einsum(X[n, d_X],  M[m, d_M], P[d_X, d_M, h])   W[n, m, h] = softmax(L[n, m, h], reduced_dim=m)   Y[n, d_Y] = einsum(W[n, m, h], M[m, d_M], Q[d_M, d_Y, h])   return Y[n, d_Y] def MultiHeadAttentionInefficient(X, M, P_q, P_k, P_v, P_o):   P[d_X, d_M, h] = einsum(P_q[d_X, d_k, h], P_k[d_M, d_k, h])   Q[d_M, d_Y, h] = einsum(P_v[d_M, d_v, h], P_o[d_Y, d_v, h])   return GeneralBilinearMultiheadAttention(X, M, P, Q) def TalkingHeadsAttentionInefficient(X, M, P_q, P_k, P_v, P_o, P_l, P_w):   P[d_X, d_M, h] = einsum(P_q[d_X, d_k, h_k], P_k[d_M, d_k, h_k], P_l[h_k, h])   Q[d_M, d_Y, h] = einsum(P_v[d_M, d_v, h_v], P_o[d_Y, d_v, h_v], P_w[h, h_v])   return GeneralBilinearMultiheadAttention(X, M, P, Q)

7 Experiments

7.1 Text-to-Text Transfer Transformer (T5)

We test various configurations of multi-head attention and talking-heads attention on the transfer-learning setup from [Raffel et al., 2019]. An encoder-decoder transformer model [Vaswani et al., 2017] is pre-trained on a denoising objective of predicting missing text segments (average span length 3) from the C4 dataset [Raffel et al., 2019] 333This is identical to one of the training objecives described in [Raffel et al., 2019], and subsequently fine-tuned on various language understanding tasks. We use the same code base and model architecture as the base model from [Raffel et al., 2019]. The encoder and decoder each consist of 12 layers, with and . Each encoder layer contains a multi-head self-attention layer, and each decoder layer contains a multi-head self-attention layer and a multi-head attention-over-encoder layer. For their base model, [Raffel et al., 2019] follow [Devlin et al., 2018] and others, using and for all of these attention layers. We compare this setting to a variety of other configurations of multi-head and talking-heads attention, as detailed in table 1.

Similar to [Raffel et al., 2019], we pre-train our models for 524288 steps. Each training batch consists of 128 examples, each of which has an input of 512 tokens and an output of 114 tokens, the output containing multiple spans of tokens which were deleted from the input. Similarly to [Raffel et al., 2019], we use the Adafactor optimizer [Shazeer and Stern, 2018] and an inverse-square-root learning-rate schedule. We also decay the learning rate linearly for the final 10 percent of the training steps. Our main departure from [Raffel et al., 2019] is that we, as suggested by [Lan et al., 2019], use no dropout during pre-training. We find this to produce superior results. We compute the log-perplexity on the training objective on a held-out shard of C4, which we believe to be a good indicator of model quality. For each configuration, we train one model for the "full" 524288 steps and four models for a shorter time (65536 steps) to measure inter-run variability. The results are listed in table 1.

We then fine-tune each of the models on an examples-proportional mixture of SQUAD [Rajpurkar et al., 2016], GLUE [Wang et al., 2018] and SuperGlue [Wang et al., 2019]. Fine-tuning consists of 131072 additional steps with a learning rate of . Following [Raffel et al., 2019], we use a dropout rate

on the layer outputs, feed-forward hidden-layers and attention weights. The embedding matrix (also used as the projection in the final classifier layer) is fixed during fine-tuning. Tables

1, 2, 3 and 4 include results for SQUAD and MNLI-m. Results for all other tasks are listed in the appendix.

7.1.1 Multi-Head vs Talking-Heads Attention

ln(PPL) ln(PPL) SQUAD step parameters multiplies
65536 524288 v1.1 MNLI-m time per per att. layer
steps steps dev-f1 dev (s) att. layer (n=m=512)
multi-head 6 128 128 2.010 (0.005) 1.695 89.88 85.34 0.14 2359296
multi-head 12 64 64 1.982 (0.003) 1.678 90.87 86.20 0.15 2359296
multi-head 24 32 32 1.989 (0.009) 1.669 91.04 86.41 0.17 2359296
multi-head 48 16 16 2.011 (0.004) 1.682 90.35 85.32 0.21 2359296
talking-heads 6 6 6 128 128 1.965 (0.009) 1.659 90.51 85.99 0.16 2359368
talking-heads 12 12 12 64 64 1.932 (0.004) 1.641 91.38 86.19 0.18 2359584
talking-heads 24 24 24 32 32 1.910 (0.001) 1.624 91.83 87.42 0.22 2360448
talking-heads 48 48 48 16 16 1.903 (0.006) 1.603 91.90 87.50 0.32 2363904
multi-head 24 64 64 1.950 (0.005) 1.625 91.46 86.58 0.22 4718592
general bilinear 12 768 768 1.921 (0.011) 1.586 90.83 86.50 0.47 14155776
[Raffel et al., 2019] 12 64 64 89.66 84.85 2359296
Table 1: Multi-Head vs. Talking-Heads attention on T5

In table 1, we compare multi-head attention to talking-heads attention. For each of the two algorithms, we test versions with 6, 12, 24 and 48 heads. Following [Vaswani et al., 2017], as we increase the number of heads, we decrease the key/value dimensionality and , so as to keep the number of parameters constant. For each number of heads, talking-heads attention improves over multi-head attention on all quality metrics.

Additionally, multi-head attention gets worse as we increase the number of heads from 24 to 48 and decrease the key and value dimensionalty from 32 to 16, while talking-heads attention gets better. We presume that this is due to the keys being too short to produce a good matching signal.

For additional comparison, we include in table 1 two models with significantly more parameters and computation in the attention layers. In the first, we double the number of heads in our baseline model from to without reducing and , resulting in a multi-head attention layer with double the parameters and double the computation. In the second, we use "general bilinear multihead attention", as described in section 6.

We also list the results from [Raffel et al., 2019]. We believe that their results are worse due to their use of dropout during pre-training.

7.1.2 Varying the Heads-Dimensions Separately

ln(PPL) ln(PPL) SQUAD step parameters multiplies
65536 524288 v1.1 MNLI-m time per per att. layer
steps steps dev-f1 dev (s) att. layer (n=m=512)
talking-heads 6 6 6 128 128 1.965 (0.009) 1.659 90.51 85.99 0.16 2359368
talking-heads 6 24 6 128 128 1.941 (0.009) 1.641 90.91 86.29 0.18 2359584
talking-heads 24 6 24 32 32 1.959 (0.008) 1.667 90.77 86.15 0.20 2359584
talking-heads 6 24 24 128 32 1.939 (0.011) 1.633 91.06 86.31 0.20 2360016
talking-heads 24 24 6 32 128 1.931 (0.013) 1.628 90.98 86.81 0.21 2360016
talking-heads 24 24 24 32 32 1.910 (0.001) 1.624 91.83 87.42 0.22 2360448
Table 2: Talking-heads attention has three "heads" dimensions that can vary independently.

In table 2, we experiment with independently varying the sizes of the three heads-dimensions. From the results, it appears that all three are good to increase, but that the softmax-heads dimension is particularly important.

7.1.3 Logits-Projection Only and Weights-Projection Only

ln(PPL) ln(PPL) SQUAD step parameters multiplies
65536 524288 v1.1 MNLI-m time per per att. layer
steps steps dev-f1 dev (s) att. layer (n=m=512)
multi-head 24 32 32 1.989 (0.009) 1.669 91.04 86.41 0.17 2359296
project logits 24 24 32 32 1.969 (0.004) 1.652 91.29 85.86 0.23 2359872
project weights 24 24 32 32 1.951 (0.009) 1.636 91.03 86.12 0.23 2359872
talking-heads 24 24 24 32 32 1.910 (0.001) 1.624 91.83 87.42 0.22 2360448
Table 3: The logits-projection and the weights-projection can be employed separately.

In the middle two experiments of table 3, we examine hybrids of multi-head attention and talking-heads attention, where there is a projection on one but not both of the logits and the weights.

7.1.4 Encoder vs. Decoder

ln(PPL) ln(PPL) SQUAD step parameters multiplies
65536 524288 v1.1 MNLI-m time per per att. layer
steps steps dev-f1 dev (s) att. layer (n=m=512)
multi-head 24 32 32 1.989 (0.009) 1.669 91.04 86.41 0.17 2359296
TH-enc-self 24* 24 24* 32 32 1.969 (0.002) 1.655 91.63 87.00 0.21 various various
TH-dec-self 24* 24 24* 32 32 1.981 (0.005) 1.671 90.56 85.56 0.17 various various
TH-encdec 24* 24 24* 32 32 1.942 (0.003) 1.646 90.86 86.07 0.18 various various
talking-heads 24 24 24 32 32 1.910 (0.001) 1.624 91.83 87.42 0.22 2360448
Table 4: In each of the middle three experiments, talking-heads attention is employed in only one of the three types of attention layers in the model.

The transformer model contains three types of attention layers - self-attention in the encoder, self-attention in the decoder, and attention-over-encoder in the decoder. In each of the middle three experiments of table 4, we employ talking-heads attention in only one of these types of attention layers, and multi-head attention in the others. We find that modifying the encoder-self-attention layers has the biggest effect on the downstream language-understanding tasks. This is unsurprising, given that these tasks have more to do with analyzing the input than with generating output.

7.2 Albert

[Lan et al., 2019] introduce ALBERT, a variation on BERT [Devlin et al., 2018]. The main difference between the ALBERT and BERT architectures is that ALBERT shares layer parameters among all layers, significantly reducing the number of parameters. For example, a 12-layer ALBERT model has about 1/12 the number of parameters in the attention and feed-forward layers as a similar BERT model. Another difference is that the ALBERT model factorizes the word embedding as the product of two matrices with smaller bases, again significantly reducing the parameter count. This makes ALBERT appealing for memory limited devices such as mobiles. Besides above architecture differences, ALBERT also uses sentence order prediction (SOP) to replace next sentence prediction (NSP) in BERT.

We report here experiments done with the base setting for ALBERT: a Transformer network with 12 layers of attention, the hidden and embedding size set to 768. The pre-training and fine-tuning hyperparameter settings are also exactly the same as in

[Lan et al., 2019]. We use the English Wikipedia and book corpus datasets [Devlin et al., 2018] to pre-train various models with different head sizes and talking-heads configurations. We evaluate the resulting representations by using them as a starting point to finetune for the SQuAD task (SQuAD1.1, SQuAD2.0 dev set) and various tasks (MNLI, SST-2, RACE) from the GLUE benchmark. Results are in Table 5.

heads SQuAD1.1 (f1) SQuAD2.0 (f1) MNLI SST-2 RACE MLM SOP Average
Multi-head 6 128 88.5 78.8 79.9 88.6 62.7 54.3 85.9 79.7
Multi-head 12 64 88.8 79.3 80.2 89.9 63.4 54.5 86.2 80.32
Multi-head 24 32 88.8 79.1 79.9 87.7 62.1 54.4 85.9 79.52
Multi-head 48 16 87.9 78.8 79.6 88.4 61.8 53.8 85.3 79.3
Talking-heads 6 128 88.7 78 80 88.5 62 54.1 85.2 79.44
Talking-heads 12 64 89.2 79.9 80.5 89 65.3 54.9 87.6 80.78
Talking-heads 24 32 89.3 80.5 80.5 87.6 65.6 55.3 86.3 80.7
Talking-heads 48 16 89.6 80.9 80.9 89.3 66.5 55.7 86.5 81.44
Table 5: Multi-Head vs. Talking-Heads attention on ALBERT.

We find that as the number of heads increases beyond 12 and the dimensionality of the attention-keys and attention-values decreases below 64, the performance of multi-head attention decays. On the other hand, the performance of talking-head attention keeps improving.

In addition, we also compare the logits projection and the weight projection separately with multi-head and talking-heads attention. The results are shown in Table 6. Similar to our observation in T5 experiments, only applying either the logits projection or the weight projection does not result in significant improvement compared to without them. These results again confirm the importance of having both projections.

heads SQuAD1.1 (f1) SQuAD2.0 (f1) MNLI SST-2 RACE MLM SOP Average
Multi-head 12 64 88.8 79.3 80.2 89.9 63.4 54.5 86.2 80.32
Logit-project-only 12 64 88.5 78.8 79.8 89.3 63 54.6 85.8 79.88
Weight-project-only 12 64 88.9 79.6 80.3 89 64 54.7 85.8 80.36
Talking-heads 12 64 89.2 79.9 80.5 89 65.3 54.9 87.6 80.78
Table 6: The logits-projection and the weights-projection can be employed separately.

7.3 Bert

We test various configurations of talking-heads attention based on [Devlin et al., 2018]. All of our experiments use the simplified relative position embeddings [Raffel et al., 2019] instead of fixed position embedding. We first pre-train a 12 Transformer layers using the same dataset as [Devlin et al., 2018]. And then we finetune for the SQuAD1.1 task and the MNLI from the GLUE dataset. Our experiments show that quality continues to improve when we grow the number of heads up to 768 and decrease the key and value dimensionality down to 1 444These extreme hyperparameter settings likely have no practical use, due to the massive amount of computation.

heads SQuAD1.1 (f1) MNLI
Multi-head 12 64 88.51 82.6
Talking-heads 6 128 88.8 83.4
Talking-heads 12 64 89.2 83.6
Talking-heads 24 32 89.4 83.6
Talking-heads 48 16 89.5 83.4
Talking-heads 64 12 89.9 83.8
Talking-heads 96 8 89.3 83.6
Talking-heads 192 4 89.8 83.9
Talking-heads 384 2 90.5 83.9
Talking-heads 768 1 90.5 84.2
Table 7: Talking-Heads attention on BERT.

7.4 Visualizing the Projection Matrices of Talking-Heads

To illustrate how different heads exchange information with each other, we visualize the projection matrices ( and ) of a 12 layer BERT with 12 talking-heads in figure 1. Since is applied after (although there is a softmax non-linearity in between), we also visualize the combined transformation in figure 1. As can be observed, the main diagonals of the projection matrices do not have significant greater values than other entries. This is expected because with talking-heads, a pair of query and key do not corresponds to any specific value-vector. All keys and queries jointly decide how the values in each head interchange data. Additionally, all projection matrices are well conditioned (magnitude of determinant above

with smallest eigenvalue above

), indicating that no significant approximation can be achieved.

Figure 1: Visualization of the learned projection matrices , , and the multiplication . All entries in each matrix are normalized to [-1, 1]. The fact that these matrices are not diagonal or sparse at all, shows that there are significant data exchange across different attention heads.

8 Conclusions and Future Work

We have proposed talking-heads attention and shown some promising results. One potential challenge is speed on modern deep-learning accelerators, which are optimized for large-dimension matrix multiplications. We imagine that this will be an area of future work. One approach is to build hardware which is better at small-dimension matrix-multiplication. Another potential approach is to decrease the number of memory-positions considered for each query-position - for example, by using the local-attention and memory-compressed-attention approaches described in

[Liu et al., 2018]. We look forward to more applications of talking-heads attention, as well as to further architectural improvements.

References

Appendix A Variation: Dynamic Projections

In the basic talking-heads attention algorithm described in section 4, the talking-heads projections are represented by two learned weight matrices and . In an additional wrinkle, we can make these projections matrices themselves input-dependent, adding terms to the projection matrices that are themselves learned linear projections of the inputs and . The algorithm is described by the pseudo-code below.
def TalkingHeadsAttentionWithDynamicProjections(     X[n, d_X],   # n vectors with dimensionality d_X     M[m, d_M],   # m vectors with dimensionality d_M     P_q[d_X, d_k, h_k],  # learned linear projection to produce queries     P_k[d_M, d_k, h_k],  # learned linear projection to produce keys     P_v[d_M, d_v, h_v],  # learned linear projection to produce values     P_o[d_Y, d_v, h_v],  # learned linear projection of output     P_l[h_k, h],       # learned static talking-heads proj. on logits     P_Xl[d_X, h_k, h], # learned projection to generate dynamic talking-heads projection     P_Ml[d_M, h_k, h], # learned projection to generate dynamic talking-heads projection     P_w[h, h_v],       # learned static talking-heads proj. on  weights     P_Xw[d_X, h, h_v], # learned projection to generate dynamic talking-heads projection     P_Mw[d_X, h, h_v]) # learned projection to generate dynamic talking-heads projection   Q[n, d_k, h_k] = einsum(X[n, d_X], P_q[d_X, d_k, h_k])  # queries         n*d_X*d_k*h_k   K[m, d_k, h_k] = einsum(M[m, d_M], P_k[d_M, d_k, h_k])  # keys            m*d_M*d_k*h_k   V[m, d_v, h_v] = einsum(M[m, d_M], P_v[d_M, d_v, h_v])  # values          m*d_M*d_v*h_v   J[n, m, h_k] = einsum(Q[n, d_k, h_k], K[m, d_k, h_k])   # dot prod.       n*m*d_k*h_k   R_Xl[n, h_k, h] = einsum(X[n, d_X], P_Xl[d_X, h_k, h])  # dynamic proj.   n*d_X*h_k*h   R_Ml[n, h_k, h] = einsum(M[m, d_M], P_Ml[d_M, h_k, h])  # dynamic proj.   n*d_M*h_k*h   L[n, m, h] = (     einsum(J[n, m, h_k], P_l[h_k, h]) +       # Static talking-heads proj.  n*m*h*h_k     einsum(J[n, m, h_k], R_Xl[n, h_k, h]) +   # Dynamic talking-heads proj. n*m*h*h_k     einsum(J[n, m, h_k], R_Ml[m, h_k, h]))    # Dynamic talking-heads proj. n*m*h*h_k   W[n, m, h] = softmax(L[n, m, h], reduced_dim=m)   # Attention weights   R_Xw[n, h, h_v] = einsum(X[n, d_X], P_Xw[d_X, h, h_v])  # dynamic proj.   n*d_X*h*h_v   R_Mw[n, h, h_v] = einsum(M[m, d_M], P_Mw[d_M, h, h_v])  # dynamic proj.   n*d_M*h*h_v   U[n, m, h_v] = (     einsum(W[n, m, h], P_w[h, h_v]) +         # Static Talking-heads proj.  n*m*h*h_v     einsum(W[n, m, h], R_Xw[n, h, h_v]) +     # Dynamic talking-heads proj. n*m*h*h_v     einsum(W[n, m, h], R_Mw[m, h, h_v]))      # Dynamic talking-heads proj. n*m*h*h_v   O[n, d_v, h_v] = einsum(U[n, m, h_v], V[m, d_v, h_v])   #                 n*m*d_v*h_v   Y[n, d_Y] = einsum(O[n, d_v, h_v], P_o[d_Y, d_v, h_v])  #                 n*d_Y*d_v*h_v   return Y[n, d_Y]

We observed that the model only trained well if we initialized the projection-generating parameter matrices (, , ,

) to contain small enough values. We used normal initializers with standard deviations of

, , , and , respectively.

ln(PPL) ln(PPL) SQUAD step parameters multiplies
65536 524288 v1.1 MNLI-m time per per att. layer
steps steps dev-f1 dev (s) att. layer (n=m=512)
multi-head 12 64 64 1.982 (0.003) 1.678 90.87 86.20 0.15 2359296
talking-heads 12 12 12 64 64 1.932 (0.004) 1.641 91.38 86.19 0.18 2359584
dyn. proj. 12 12 12 64 64 1.897 (0.007) 1.595 90.17 86.18 0.36 2801952
multi-head 24 32 32 1.989 (0.009) 1.669 91.04 86.41 0.17 2359296
talking-heads 24 24 24 32 32 1.910 (0.001) 1.624 91.83 87.42 0.22 2360448
dynamic proj. 24 24 24 32 32 1.873 (0.008) 1.587 90.17 85.94 0.53 4129920
Table 8: Dynamic-Projections Results on T5

a.1 Experiments

We evaluate talking-heads attention with dynamic projections on T5 [Raffel et al., 2019] in a set of experiments similar to those described in section 7.1.

Table 8 compares multi-head attention, taking-heads attention with static projections, and talking-heads attention with dynamic projections. The dynamic projections reduce perplexity on the pre-training task. However, in our experiments, we did not see an improvement on the downstream tasks.

ln(PPL) ln(PPL) SQUAD step parameters multiplies
65536 524288 v1.1 MNLI-m time per per att. layer
steps steps dev-f1 dev (s) att. layer (n=m=512)
talking-heads 12 12 12 64 64 1.932 (0.004) 1.641 91.38 86.19 0.18 2359584
dyn. proj. 12 12 12 64 64 1.932 (0.011) 1.634 91.34 86.32 0.19 2470176
dyn. proj. 12 12 12 64 64 1.914 (0.005) 1.619 90.70 86.43 0.19 2470176
dyn. proj. 12 12 12 64 64 1.930 (0.010) 1.624 91.14 86.63 0.24 2470176
dyn. proj. 12 12 12 64 64 1.917 (0.003) 1.624 90.54 86.45 0.25 2470176
dyn. proj. 12 12 12 64 64 1.897 (0.007) 1.595 90.17 86.18 0.36 2801952
Table 9: In each of the middle four experiments, only one of the dynamic projections is employed.

a.1.1 Comparing the Four Dynamic Projections

Table 9 examines the effects of the four dynamic projections employed individually. The middle four rows represent experiments where only one of the four dynamic projections were employed. These are compared to static projections (top row) and all four dynamic projections together (bottom row).

ln(PPL) ln(PPL) SQUAD step parameters multiplies
65536 524288 v1.1 MNLI-m time per per att. layer
steps steps dev-f1 dev (s) att. layer (n=m=512)
multi-head 24 32 32 1.989 (0.009) 1.669 91.04 86.41 0.17 2359296
TH-enc-self 24* 24 24* 32 32 1.969 (0.002) 1.655 91.63 87.00 0.21 various various
DP-enc-self 24* 24 24* 32 32 1.953 (0.006) 1.639 91.99 86.97 0.42 various various
Table 10: Effects of applying talking-heads attention (with or without dynamic projections) in the encoder only.

a.1.2 Talking-Heads in Encoder Only

In section 7.1.4 we saw that talking heads were particularly useful in the encoder part of the model. Table 10 presents a set of experiments where the decoder uses only multi-head attention, while the encoder uses either multi-head attention (top row), talking-heads attention with static projections (middle row), or talking-heads attention with dynamic projections (bottom row). We observe that in this case, the dynamic projections do not appear to degrade performance on the downstream tasks.

Appendix B T5 Fine-Tuning Full Results

Tables 11, 12 and 13 present the results of fine-tuning the models in section 7.1 and appendix A on the GLUE [Wang et al., 2018] and SuperGlue [Wang et al., 2019], and Stanford Question-Answering Dataset (SQuAD) [Rajpurkar et al., 2016] benchmarks.

Score CoLA SST-2 MRPC MRPC STSB STSB QQP QQP MNLIm MNLImm QNLI RTE
Average MCC Acc F1 Acc PCC SCC F1 Acc Acc Acc Acc Acc
multi-head 6 128 128
multi-head 12 64 64
multi-head 24 32 32 93.59 91.18
multi-head 48 16 16 93.59 91.18
talking-heads 6 6 6 128 128
talking-heads 12 12 12 64 64
talking-heads 24 24 24 32 32 89.19 87.42
talking-heads 48 48 48 16 16 94.61 91.16 91.00
multi-head 24 64 64 55.99
general bilinear 12 768 768 91.98
talking-heads 6 6 6 128 128
talking-heads 6 24 6 128 128
talking-heads 24 6 24 32 32
talking-heads 6 24 24 128 32
talking-heads 24 24 6 32 128 84.85 84.84
talking-heads 24 24 24 32 32 89.19 87.42
multi-head 24 32 32 93.59 91.18
project logits 24 24 32 32
project weights 24 24 32 32
talking-heads 24 24 24 32 32 89.19 87.42
multi-head 24 32 32 93.59 91.18
TH-enc-self 24* 24 24* 32 32
TH-dec-self 24* 24 24* 32 32
TH-encdec 24* 24 24* 32 32
talking-heads 24 24 24 32 32 89.19 87.42
multi-head 12 64 64
talking-heads 12 12 12 64 64
dyn. proj. 12 12 12 64 64
multi-head 24 32 32 93.59 91.18
talking-heads 24 24 24 32 32 89.19 87.42
dyn. proj. 24 24 24 32 32
talking-heads 12 12 12 64 64
dyn. proj. 12 12 12 64 64
dyn. proj. 12 12 12 64 64
dyn. proj. 12 12 12 64 64
dyn. proj. 12 12 12 64 64
dyn. proj. 12 12 12 64 64
multi-head 24 32 32 93.59 91.18
TH-enc-self 24* 24 24* 32 32
DP-enc-self 24* 24 24* 32 32 87.52 93.48
[Raffel et al., 2019] 12 64 64
ibid. stddev.


Table 11: T5 on GLUE Language-Understanding Benchmark [Wang et al., 2018] (dev). Experiments described in Section 7.1 and appendix A.
Score BoolQ CB CB CoPA MultiRC MultiRC ReCoRD ReCoRD RTE WiC WSC
Average Acc F1 Acc Acc F1 EM F1 EM Acc Acc Acc
multi-head 6 128 128
multi-head 12 64 64
multi-head 24 32 32
multi-head 48 16 16
talking-heads 6 6 6 128 128
talking-heads 12 12 12 64 64
talking-heads 24 24 24 32 32 90.60 94.64
talking-heads 48 48 48 16 16 76.39 82.94 87.50
multi-head 24 64 64
general bilinear 12 768 768
talking-heads 6 6 6 128 128
talking-heads 6 24 6 128 128
talking-heads 24 6 24 32 32 71.32
talking-heads 6 24 24 128 32
talking-heads 24 24 6 32 128 76.00 86.28 87.50
talking-heads 24 24 24 32 32 90.60 94.64
multi-head 24 32 32
project logits 24 24 32 32
project weights 24 24 32 32
talking-heads 24 24 24 32 32 90.60 94.64
multi-head 24 32 32
TH-enc-self 24* 24 24* 32 32 78.18 43.55
TH-dec-self 24* 24 24* 32 32
TH-encdec 24* 24 24* 32 32
talking-heads 24 24 24 32 32 90.60 94.64
multi-head 12 64 64
talking-heads 12 12 12 64 64
dyn. proj. 12 12 12 64 64
multi-head 24 32 32
talking-heads 24 24 24 32 32 90.60 94.64
dyn. proj. 24 24 24 32 32
talking-heads 12 12 12 64 64
dyn. proj. 12 12 12 64 64
dyn. proj. 12 12 12 64 64
dyn. proj. 12 12 12 64 64
dyn. proj. 12 12 12 64 64
dyn. proj. 12 12 12 64 64
multi-head 24 32 32
TH-enc-self 24* 24 24* 32 32 78.18 43.55
DP-enc-self 24* 24 24* 32 32 77.88 76.99
[Raffel et al., 2019] 12 64 64
ibid. stddev.