DeepAI
Log In Sign Up

Transformer Acceleration with Dynamic Sparse Attention

Transformers are the mainstream of NLP applications and are becoming increasingly popular in other domains such as Computer Vision. Despite the improvements in model quality, the enormous computation costs make Transformers difficult at deployment, especially when the sequence length is large in emerging applications. Processing attention mechanism as the essential component of Transformer is the bottleneck of execution due to the quadratic complexity. Prior art explores sparse patterns in attention to support long sequence modeling, but those pieces of work are on static or fixed patterns. We demonstrate that the sparse patterns are dynamic, depending on input sequences. Thus, we propose the Dynamic Sparse Attention (DSA) that can efficiently exploit the dynamic sparsity in the attention of Transformers. Compared with other methods, our approach can achieve better trade-offs between accuracy and model complexity. Moving forward, we identify challenges and provide solutions to implement DSA on existing hardware (GPUs) and specialized hardware in order to achieve practical speedup and efficiency improvements for Transformer execution.

READ FULL TEXT VIEW PDF

page 3

page 6

08/12/2022

An Algorithm-Hardware Co-Optimized Framework for Accelerating N:M Sparse Transformers

The Transformer has been an indispensable staple in deep learning. Howev...
02/28/2022

Dynamic N:M Fine-grained Structured Sparse Attention Mechanism

Transformers are becoming the mainstream solutions for various tasks lik...
10/18/2021

Energon: Towards Efficient Acceleration of Transformers Using Dynamic Sparse Attention

In recent years, transformer models have revolutionized Natural Language...
10/18/2022

ViTCoD: Vision Transformer Acceleration via Dedicated Algorithm and Accelerator Co-Design

Vision Transformers (ViTs) have achieved state-of-the-art performance on...
09/24/2021

Predicting Attention Sparsity in Transformers

A bottleneck in transformer architectures is their quadratic complexity ...
05/27/2022

What Dense Graph Do You Need for Self-Attention?

Transformers have made progress in miscellaneous tasks, but suffer from ...
05/27/2022

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Transformers are slow and memory-hungry on long sequences, since the tim...

1 Introduction

Transformers (Vaswani et al., 2017)

have become the driving force for sequence modeling tasks such as neural machine translation

(Ott et al., 2018), language understanding (Devlin et al., 2019), and generative modeling (Parmar et al., 2018; Brown et al., 2020). Equipped with the self-attention mechanism, Transformers are capable of handling long-range dependencies.

Despite the impressive progress made by Transformers, the computational requirements make the deployment of Transformer-based models difficult at inference time, especially when processing long sequences. The quadratically scaled self-attention modules are the execution bottleneck under long sequences. Therefore, many studies propose Transformer variants to mitigate the quadratic time and space complexity issue. Some approaches are primary for memory footprint reduction during training while efficient inference is being understudied (Gong et al., 2019; Dai et al., 2019; Kitaev et al., 2020; Roy et al., 2021). Other methods use fixed or static sparse attention patterns to save computations (Child et al., 2019; Qiu et al., 2020; Beltagy et al., 2020; Zaheer et al., 2020). However, we find that intrinsic sparse patterns in attention are naturally dynamic, depending on input sequences. Thus, we propose to exploit the dynamic sparse patterns to save attention computations without sacrificing the representation power of attention. Intuitively, posing static sparsity constraints in attention could be too strong to capture dynamic attention connections.

We propose the Dynamic Sparse Attention (DSA) approach that exploits dynamic sparsity to improve efficiency. The challenge is to efficiently search for sparse patterns close to oracle sparse patterns that keep all the important attention weights. We formulate the searching as a prediction problem and augment the standard attention mechanism with a prediction path. As discussed in Section 3, we first obtain an approximation of attention scores with low computational costs. Then, we predict the sparse attention patterns using the approximate attention scores. With the predicted sparse attention patterns represented as binary masks, we can save computations involved in full attention scores, , and attention outputs.

Compared with static sparse attention methods, our method is dynamic and naturally captures sparse attention patterns of different input sequences. We observe important tokens that attract a large portion of attention weights from other tokens, similar to the global attention method Beltagy et al. (2020); Zaheer et al. (2020). However, the positions of global tokens are input-dependent, and our method can effectively identify such varieties, instead of relying on domain knowledge to predetermine certain global tokens in fixed positions. Compared with other low-rank approximation methods, the approximation in DSA is only for sparsity prediction without strict and static constraints on attention positions. Therefore, our method can maintain the representation power of full attention while reducing unnecessary attention weights.

Although DSA can save theoretical computations and maintain attention capability, achieving practical speedups and energy savings on real hardware systems is challenging. We discuss the implications of DSA on existing GPU architectures and specialized hardware accelerators. We extend the fine-grained dynamic sparsity as searched by DSA to structural dynamic patterns, such as block-wise and vector-wise. We give the study on structural sparse patterns vs. attention’s expressive power and explore the opportunities for dataflow optimization and data reuse from dynamic sparsity.

Our evaluation in Section 4 shows that DSA can achieve 95% sparsity in attention weights without compromising model accuracy. Under this setting, the overall computational saving is up to 4.35 compared with full attention, while the sparsity prediction only introduces around 1.17% to 1.33% computational overhead. Experimental results on NVIDIA V100 GPU show that, under 90% sparsity ratio, applying vector-wise sparsity on DSA delivers speedup on attention score computation, speedup on softmax computation, and speedup on attention output computation, with only 0.1% of accuracy loss. Finally, through hardware specialization, we can further explore token-level parallelism and computation reordering for DSA. Our characterization results show that this can reduce the total memory access of attention computation by up to .

2 Background and Motivation

Before we describe our method in detail, we first introduce the preliminaries of the standard attention mechanism used in vanilla Transformers. Then, we discuss the challenge of serving long sequences under the quadratic complexity of attention. Finally, we demonstrate that redundancy exists in attentions and dynamic sparse patterns are naturally expressed in attention.

2.1 Preliminaries of Attention

The attention mechanism is the essential component of Transformers (Vaswani et al., 2017). Self-attention operates on input representations of length , , with three linear projections namely, query, key, and value as

(1)

, where denotes the queries, denotes the keys, and denotes the values. After linear projections, the attention weights is defined as

(2)

where is the row-wise ) function. Finally, the output values are computed by multiplying the attention weights with the projected values as

(3)

Serving Transformer-based models is challenging when the input sequence length is large. When using long sequences, computing Eq. (2) and Eq. (3) consumes the majority of operations and becomes the bottleneck of model evaluation. The asymptotic complexity of attention is quadratic to sequence length .

2.2 Intrinsic Sparsity in Attention Weights

A number of efficient Transformer variants have been proposed to mitigate the quadratic complexity of self-attention Child et al. (2019); Beltagy et al. (2020); Zaheer et al. (2020); Shi et al. (2021). One straightforward way to exploit the intrinsic redundancy in attention is forming sparse patterns as in

(4)

where represents the sparse attention pattern, is a large constant () such that where , indicating unimportant attention, after normalization. Here, we omit for simplicity. The sparse patterns can be pre-determined into global, block, random, or a combination of different patterns. Another way to determine sparse patterns is through trainable masks. However, all these methods explore static or fixed sparse patterns, restricting viable attention connections.

2.3 Dynamic Sparse Patterns in Attention

A common motivation of sparse attention methods is that not all attention weights, i.e., probabilities, are equally important in Eq. (

3). A large portion of attention weights do not contribute to attention output and are redundant. In other words, only a small portion of attention weights are useful. However, we find that sparse patterns in attention are inherently dynamic and data-dependent.

Figure 1: Visualization of attention weights from different inputs and attention heads. Only a small amount of attention weights are important. Note values are clamped to show as .

Here, we further support our hypothesis by showing the original attention weights matrix (after normalization) in Figure 1. The model used here is a vanilla Transformer and the benchmark is Text Classification from Google Long-Range Arena(Tay et al., 2020b). Figure 1 indicates that only a small amount of attention weights are with large magnitude and a significant portion is near zero. We want to emphasize that this shows the raw attention weights without forcing any sparsity constraints or fine-tuning, which indicates that redundancy naturally exists in attention. In short, attention mechanism exhibits the focused positions on a set of important tokens.

More importantly, the attention weights have dynamic sparse patterns. As shown in Figure 1, the sparse patterns in attention weights are dynamically changing depending on the input sequence. Different heads in multi-head attention also have different sparse patterns. The characteristic of dynamic sparsity in attention weights motivates us to explore effective methods to eliminate the redundancy and save computations. Prior work on static or fixed sparse patterns cannot capture the dynamically changing attention weights.

3 Dynamic Sparse Attention

From Section 2, we show that attention weights have intrinsic sparse patterns, and the positions of important attention weights are dynamically changing as different input sequences. While attention exhibits dynamic sparse patterns, how to efficiently and effectively obtain the dynamic sparse patterns remains challenging. We formulate the process of identifying sparse attention patterns as a prediction problem. The key challenge is how to obtain an approximate attention predictor that can accurately find the sparse patterns while keeping the prediction overhead small.

Here, we present Dynamic Sparse Attention (DSA) that exploits sparsity in attention weights to reduce computations. The principle of our method is to effectively search for dynamic sparse patterns without enforcing strict and static constraints on attention while keeping the searching cost small. Our approach leverages trainable approximation to predict sparse attention patterns. As shown in Figure 2, we use a prediction path based on low-rank transformation and low-precision computation. The prediction path processes input sequences functionally similar to query and key transformations but at much lower computational costs. Given the prediction results that approximate well, we can search sparse patterns based on the magnitude of prediction results.

Figure 2: (a) Standard full attention; (b) Dynamic sparse attention with approximation-based prediction and sparse computation.

3.1 Design of Prediction Path

We denote attention scores as and omit the scaling factor for simplicity. As shown in Figure 2(a), two general matrix-matrix multiplication kernels (GEMM) and one kernel consume the majority of computations in self-attention. We construct a pair of approximate query and key transformations in the prediction path to compute for approximate score , as in

(5)

Here is a sparse random projection matrix shared by both paths, and are parameters in approximating query and key.

Then, we have approximate attention scores . From , we can predict sparse attention masks using thresholds, where the threshold values are either fixed by tuning from the validation set or determined by searching. When is well approximated with accurate attention scores , the large scores in are also large in with high probability. The resulting sparse attention weights is used to multiply the value matrix similar to Eq. 3.

Optimization of Approximation. The random projection matrix is constant after initialization and shared by two approximate transformations. We obtain the trainable parameters, and , through minimizing the mean squared error (MSE) as the criterion to optimize for approximation:

(6)

where B is the mini-batch size.

Given the motivation of finding dynamic sparse patterns, the hypothesis of our method is that there exist oracle sparse patterns that perform well. Such that the optimization target is to approximate full attention scores well enough to predict sparse patterns. We further give the results of applying oracle sparse patterns by directly dropping small-magnitude attention weights during inference without fine-tuning the model. As listed in Table 1, around 90% (up to 97%) of small attention weights can be dropped with negligible accuracy loss.

Case Sparsity EM F1
Base 0% 81.49 88.70
75% - 95% 81.50 88.70
94% - 97% 80.51 87.85
Table 1: Sparsity in attention weights, where values are set to zero. A significant portion of attention weights that have small magnitude are redundant. The accuracy metrics are Exact Match (EM) and F1 Score.

3.2 Model Adaptation

When sparse attention scores are masked out to generate sparsity in attention, the remaining attention weights, i.e., the important weights, are scaled up as the denominator becomes small. Leaving the disturbed attention weights intact will degrade model quality. As a countermeasure, we propose to fine-tune model parameters with dynamic sparse constraints, referred to as model adaptation. With adaptation, the model evaluation accuracy can recover to be on par with full attention baselines, while the computational costs are significantly reduced.

We do not change the computational graph and the loss function of the original model, except adding dynamic sparse constraints in attention as mask

. As a result, the new attention are sparse and only have important weights from prediction. Given a pre-trained model, our method jointly fine-tunes the model parameters and parameters of the prediction path as in

(7)

where is the regularization factor of MSE. Our method can also train from scratch with initialized model parameters.

Our method approximates the original attention score with a low-rank matrix . When training the model with loss function in Eq. 6, the gradient from will be passed to both the low-rank approximation and the original attention score . Intuitively, this loss function not only makes a better approximation of , but also makes easier to be approximated by a low-rank matrix, i.e., by reducing the rank of . On the other hand, the loss guarantees the rank of to be high enough to preserve the model accuracy. In other words, the joint optimization of and implicitly learns a low-rank with a learnable rank depending on the difficulty of the task. Our design brings two advantages. First, the rank of will be automatically adjusted to tasks with different difficulty levels. Hence, our method can potentially achieve higher accuracy on difficult tasks and higher speedup on simple tasks compared with low-rank approximation methods using fixed rank. Second, as the rank of only implicitly influences the rank of , the final result is less sensitive to the hyper-parameter .

3.3 Computation Saving Analysis

DSA introduces additional computations in the prediction step, but the overall computation saving from sparse attention kernels is fruitful and can have practical speedup. The original full attention takes

MACs (multiply-and-accumulate operations) asymptotically. However, the asymptotic analysis does not consider practical concerns such as sparsity, quantization, and data reuse. Here, we augment the traditional asymptotic analysis with a sparsity factor

and a quantization factor . In this way, DST prediction takes MACs; DST attention takes MACs. Both and are determined depending on tasks and underlying hardware platforms. In our settings, we choose between 90% and 98% and our GPU kernels can achieve practical speedups. We assume the baseline model uses FP32 as the compute precision and set prediction precision to be INT4. The execution time on is not revealed in asymptotic analysis but is one of the major time-consuming components. Our method can also save the time of kernel with the same sparse attention patterns.

3.4 Implications for Efficient Deployment

Compared with standard attention, DSA exhibits two new features that can potentially affect model deployment. Firstly, a light-weight prediction path is attached to the attention layer to search for dynamic sparse patterns. The prediction involves approximation of attention scores, which is essentially a low-precision matrix-matrix multiplication (GEMM). While NVIDIA GPUs with Tensor Cores support data precision as low as INT8 and INT4, DSA prediction can tolerate INT2 computation on certain benchmarks. Therefore, specialized hardware is preferable when seeking ultra-efficient attention estimation. In Section 

5, we introduce two types of architectures to support multi-precision computations.

Secondly, the predicted sparse patterns can be used to reduce unnecessary attention computations. In other words, instead of computing and as two dense GEMM operations, we can reformulate as a sampled dense dense matrix multiplication (SDDMM) and as a sparse matrix-matrix multiplication (SpMM). When processing SDDMM and SpMM kernels on GPU, data reuse is the key disadvantage that limits its performance compared with GEMM. Therefore, we extend DSA to support structural sparsity that can improve the data reuse of both SDDMM and SpMM kernels. We implement customized kernels that take advantage of the sparsity locality to improve kernel performance, achieving practical runtime speedup on NVIDIA V100 GPU. Also, we demonstrate our choice of structural sparsity pattern and that DSA is able to maintain the model expressive power with the extra constraints.

As for specialized hardware, the advantage of DSA can be fully exploited as the specialized architecture and dataflow is able to deal with fine-grained sparsity, therefore achieving optimal sparsity ratio and computation reduction. However, the challenge also arises as irregular sparsity causes load imbalance and under-utilization of processing elements. Moreover, instead of independently executing SDDMM and then SpMM, we point out that more optimization opportunities can be explored when considering the whole process as a two-step SDDMM-SpMM chain. Please refer to Section 5 for more architectural design details and experimental results.

4 Evaluation

In this section, we evaluate the performance of DSA over representative benchmarks from Long-Range Arena (Tay et al., 2020b). We first compare the model accuracy of DSA with dense vanilla transformers and other efficient transformer models. Then, we present a sensitivity study over different configurations of the prediction path. By choosing different number of prediction parameters, DSA is able to achieve flexible trade-offs between computational cost and model accuracy. Finally, we study the model efficiency of DSA by analyzing the computational cost (MACs) and relative energy consumption.

4.1 Experiment Settings

The datasets used are from Long-Range Arena (LRA), which is a benchmark suite for evaluating model quality under long-sequence scenarios. In LRA, different transformer models are implemented using Jax (Bradbury et al., 2018) API and optimized with just-in-time () compilation. We implement DSA on top of the vanilla transformer provided by LRA and compare it with other models included in LRA. Specifically, the self-attention layer in the vanilla transformer is augmented by the DSA method as described in Section 3. All the other model configurations are kept the same for a fair comparison.

We incorporate three tasks from the LRA benchmark in our experiment, including Text Classification, Document Retrieval, and Image Classification. The Long ListOps and Pathfinder tasks are excluded. We provide benchmark descriptions and experiment configurations in Appendix A.

4.2 Accuracy

Figure 3 presents the overall model accuracy of DSA on different LRA tasks. In this experiment, the DSA model is fine-tuned from a pretrained vanilla transformer by jointly updating the model parameters and prediction parameters using the combined loss of and . Different percentage numbers indicate the sparsity ratio that we applied to the DSA models. For instance, DSA-90% means that we only keep 10% of the attention weights in each row of the attention matrix, while masking out all the other 90% of the weights. The sparsity ratio constraint is uniform for all the heads and attention layers in the DSA model.

Figure 3: Overall model accuracy of DSA (fine-tuned from a pretrained checkpoint) compared with vanilla dense transformer.

As shown in Figure 3, for all the evaluated tasks, dense transformer possesses a considerable amount of redundancy in the attention matrix under the long-sequence condition, which supports our previous claim in Section 2. Specifically, we can safely mask out up to 95% of the attention weights without suffering from any accuracy degradation. In fact, by jointly optimizing the model parameters to adapt dynamic sparse attention, DSA delivers slightly higher performance with 90% and 95% sparsity ratio. Even with up to 99% of sparsity, DSA still demonstrates promising performance with negligible accuracy drop compared with the dense baseline.

Model Text Retrieval Image Avg
Transformer 65.12 62.5 42.74 56.79
Local Attention 52.98 53.39 41.46 50.89
Sparse Trans. 63.58 59.59 44.24 55.80
Longformer 62.85 56.89 42.22 53.99
Linformer 53.94 52.27 38.56 48.26
Reformer 56.10 53.40 38.07 49.19
Sinkhorn Trans. 61.20 53.83 41.23 52.09
Synthesizer 61.68 54.67 41.61 52.65
BigBird 64.02 59.29 40.83 54.71
Linear Trans. 65.90 53.09 42.34 53.78
Performer 65.40 53.82 42.77 54.00
DSA-90% 65.62 63.07 43.75 57.48
Table 2: Accuracy of different Transformer models on the LRA benchmark suite Tay et al. (2020b). For a fair comparison, we follow the instructions in LRA and train our model from scratch. DSA-90% uses projection scale and INT4 quantization.

To fairly compare with other transformer variants provided by LRA, we follow several training constraints during our experiment. For example, instead of fine-tuning from a pretrained baseline, the DSA model used in the comparison is obtained from a randomly initialized model, i.e., training from scratch. We also fix other model parameters (e.g., number of layers, number of heads, hidden dimension) and training configurations (e.g., total training steps). The results are shown in Table 2. We use DSA-90% with quantization precision to be INT4, and let the random projection dimension scale ==. As we can see from the table, DSA achieves first-tier performance in all three tasks and delivers a leading average score on the LRA benchmarks.

Figure 4: Oracle attention mask generated by top-k selection.
Figure 5: Sparse attention mask generated by DSA prediction.

This encouraging performance mainly comes from two aspects. Firstly, joint optimization ensures that the DSA model can well adapt to the sparse attention patterns for computing the attention output. Secondly, the trainable prediction path is able to accurately capture the input-dependent patterns. Figure 4 shows the oracle sparse patterns of four different input sequences obtained from top-k selection over the original full attention matrix. The yellow dots indicate that the important positions in the attention matrix, while the purple region is masked out. Figure 5 shows the sparsity patterns generated by DSA prediction. As we can see from the two figures, horizontally, the sparse attention pattern changes with different input sequences. Vertically, the predicted patterns are very close to the oracle patterns. In our experiments, the prediction accuracy is around %.

To make sure the high performance of DSA comes from the proposed approach rather than the task itself, we further test two cases on the Text Classification dataset. Firstly, we apply a 99% sparsity constraint on the vanilla transformer, but with a static local attention pattern. Secondly, we use a short sequence with dense attention, and let the total number of tokens in the short sequence matches with the number of important tokens in the long-sequence scenario. The results show that these two cases perform very poorly on the task, delivering a model accuracy of only 53.24% and 54.16% compared with 64.04% accuracy achieved by DSA-99%. This further supports our previous discussion.

4.3 Design Space Exploration of Prediction Path

One of the most important design choices of DSA is the configuration of the Prediction Path. Overall, we want the predictor to accurately capture dynamic sparse patterns. However, we also want to minimize the cost of prediction while maintaining DSA model accuracy. Thus, while we involve trainable parameters for prediction, we also introduce random projection matrix to control the prediction parameters (), and use low-precision to reduce the computation overhead. Here, we present the sensitivity results regarding different choices of the reduced dimension size and quantization precision.

0.1 0.16 0.2 0.25 0.33 0.4 Baseline
DSA-90% 65.32 65.25 65.17 65.46 65.63 65.54 65.12
Quantization Random INT2 INT4 INT8 INT16 FP32 Baseline
DSA-90% 60.42 64.23 65.56 65.69 65.63 65.63 65.12
Table 3: Change of DSA-90% model accuracy when sweeping random projection scale and quantization precision.

We first sweep over different sizes of and evaluate the accuracy of DSA-90% on the LRA Text Classification task. Here, we use to represent the size of the predictor. A Larger value indicates more prediction parameters and better representation power, but also larger computation overhead. As we can see from Table 3, DSA demonstrates relatively stable performance with different values. Even with , DSA-90% still achieves a slightly higher accuracy compared with vanilla transformer. We believe this is because we use predictor to indicate the positions of the important attention weights, while passing the accurate attention weights to the output. Therefore, our predictor module can tolerate highly approximate computation as long as it can capture the relative importance in the attention matrix.

To further study the performance and the impact of the predictor, we conduct another experiment to sweep over different quantization precision, while fixing to be 0.25. As shown in Table 3, DSA-90% achieves good accuracy with quantized precision as low as 4-bit. Accuracy degradation occurs when the precision further scales down to 2-bit. As we go deeper into the predictor module, we collect and show the prediction accuracy in each attention block of this 4-layer DSA model. The prediction accuracy is defined by the percentage of the correct guesses among the total number of predictions. For example, for a DSA-90% model working on a sequence length of 2000, for each row of the attention matrix, the predictor will output 200 positions to be important. If 100 of these 200 locations actually matches with the - results, the prediction accuracy is 50%. As shown in Figure 6, the predictor is able to maintain its prediction accuracy even with 4-bit quantization. When the precision is 2-bit, the prediction accuracy suffers a significant degradation, dropping from % to %. Despite this, the overall model accuracy is acceptable, with only 0.89% degradation compared with the baseline transformer. We believe this is because, for the binary Text Classification task, it is more crucial to capture the very few most important attentions. Although the prediction accuracy becomes lower, the most important positions are preserved and therefore maintaining the overall model accuracy. Finally, in Figure 6 and Table 3 we include a special case of randomly selecting 10% important positions. With this random mask applied to the model, the prediction accuracy is less than 10%, and overall model accuracy directly drops to 60.42%. This result supports our previous analysis.

Figure 6: The prediction accuracy of DSA in a 4-layer DSA-90% model with different quantization precision.

4.4 Model Efficiency

As we mentioned earlier, DSA has the potential to significantly reduce computation and memory consumption of the self-attention layer, which is especially beneficial for deploying a long sequence transformer model at inference time. While we acknowledge that the actual runtime performance and memory footprint are largely depending on the underlying hardware implementation, in this subsection we shed light on this problem by quantitatively analyzing the cost of DSA.

We start with presenting the number of required MAC operations for each attention layer. We use MAC number as the computational cost metric because the majority of the operations in the self-attention layer are matrix multiplications. We break down the total MAC operations into three parts: (1) Linear: General Matrix-matrix Multiplication(GEMM) for computing Query, Key, and Value. (2) Attention: GEMM for computing attention weight matrix and output Value. (3) Other: Other GEMMs inside the attention block like Feed-Forward layers. As we introduced earlier, the two GEMM operations in the part (2) scale quadratically with the sequence length, and we transform them to be SDDMM and SpMM in our DSA model to reduce both computation and memory consumption. Based on this setting, the computational cost breakdown of different models used in our LRA experiment is shown in Figure 7. Comparing different tasks, the tasks with longer sequence length (Text and Retrieval) are more bounded by the Attention part. The benefit of using DSA is also more significant on the 4K tasks. Comparing within each task, it is obvious that DSA model with higher sparsity ratio delivers higher computation savings. Overall, DSA achieves computation reduction without any accuracy degradation.

Figure 7: Computational cost measured in the number of MACs.

Note that we do not include the computation overhead of the prediction path for generating the sparsity mask. This is because the computations conducted in prediction are in reduced precision rather than full-precision. Besides, it is inappropriate to directly project the number of low-precision MACs to the number of FP32 MACs. Therefore, we use the relative energy consumption to illustrate the overall cost of DSA-augmented attention. As shown in Figure 8, we show the relative energy consumption of DSA-95% with and INT4 quantization. Each INT4 MAC’s energy cost is projected to the relative factor of FP32 MAC, where the factor number is referenced from industry-level simulator Tang et al. (2021) with 45nm technology. From the figure we can see that, even with the predictor overhead considered, the overall benefit is still compelling by virtue of the high dynamic sparsity ratio and low-cost prediction methodology.

Figure 8: Relative energy consumption projected to vanilla transformer.

5 Algorithm-Hardware Co-design

In Section 4.4, we analyze the potential of DSA in terms of reducing the total cost of Transformer model. While the estimated number of MAC operations and relative energy consumption present very promising results, it remains challenging to achieve practical speedup and energy reduction on real hardware systems. In this section, we dive deeper into this problem as we discuss the implementation of DSA on GPUs and Accelerators. Specifically, we evaluate the challenge of mapping DSA onto different platforms, and we demonstrate the flexibility of DSA to enable efficient algorithm-hardware co-designs.

5.1 GPU Acceleration with DSA

Given the predicted sparse patterns, we can reformulate as the sampled dense dense matrix multiplication (SDDMM) and as the sparse matrix-matrix multiplication (SpMM). Under fine-grained sparsity, a recent work (Gale et al., 2020) proposes SpMM and SDDMM kernel that outperforms dense GEMM kernel under and sparsity, respectively. Besides, cusparse (Naumov et al., 2010) also achieves practical speedup at sparsity for single precision data. As we presented in Section 4, DSA can easily deliver a sparsity ratio of more than with zero accuracy degradation, therefore enabling faster kernel implementations on GPUs.

While fine-grained sparse GPU kernels are able to outperform the dense counterparts on relatively high sparsity ratios, the speedup is significantly limited due to low data reuse. Moreover, when half precision (FP16) is used for computation, above fine-grained kernels can hardly compete with GEMM kernel, as NVIDIA Tensor Core provides much higher throughput for half precision matrix multiplication. Thus, we find that the performance gain on sparse matrix multiplication can hardly mitigate the overhead of computing the prediction path in DSA, especially for half precision scenarios that commonly appeared at inference. To tackle this problem, structural dynamic sparsity can be introduced to the attention selection. Specifically, instead of selecting independent attention weights, we can enforce block-wise and vector-wise constraints. Also, trade-off can be made by adjusting the block size, as larger blocks deliver higher speedup but can potentially cause accuracy loss.

Figure 9: Column-vector sparse encoding Chen et al. (2021).

In our work, we experiment on vector sparsity using the Text Classification benchmark. As shown in Figure 9, we choose column-vector sparse encoding, where the attention elements are pruned in a column-vector granularity. Column-vector sparsity provides the same data reuse as block sparsity, but its smaller granularity makes it more friendly to model training Chen et al. (2021). Table 4 gives the corresponding kernel speedup and model accuracy under 90% sparsity ratio. The data type is FP16 for sparsity and FP32 for fine-grained sparsity. As we can see, DSA can be flexibly combined with different sparsity patterns, achieving practical runtime speedup on GPU while maintaining on-par model accuracy with full attention.

Sparsity Pattern vec 14 vec 18 Fine-grained
SpMM Speedup 1.57 1.94 1.85
SDDMM Speedup 0.94 1.15 1.09
Accuracy(%) -0.02 -0.1 +0.5
Table 4: Model accuracy and kernel speedup over cuBLASHgemm. We implement customized SDDMM/SpMM kernel for sparsity and reuse the kernel in Gale et al. (2020) for fine-grained sparsity. Experiments are done on NVIDIA V100 GPU.

To shed some light on the results, we can trace back to the visualizations of the attention matrix in Figure 1. As shown by the figure, despite the sparse and dynamic characteristics of the attention matrix, the distribution of important attention connections exhibits a certain degree of locality. For example, there exist some global tokens that attend to most of the tokens within a sequence. Therefore, some columns of the attention matrix will contain many important positions. Besides, local attention also indicates row-wise locality, as a token is likely to be influenced by its neighbors. Therefore, row-vector sparsity can be added to DSA for performance/accuracy exploration as well. While these fixed locality patterns have been well discussed in prior work Zaheer et al. (2020); Beltagy et al. (2020), DSA illustrates the dynamic distribution which motivates us to propose the prediction path to efficiently locate these important connections.

Figure 10: Speedup of softmax with different sparsity ratios.

Sparse Softmax Computation. Under the long-sequence scenario, the softmax function could be a bottleneck. Let , , and be the number of head, sequence length, and feature dimension of each head, respectively. Our profiling result shows that with , ,

, softmax contributes 47% of the total execution time of the multi-head self-attention layer. By sparsifying the attention matrix, DSA directly saves both memory access and computation consumption of the softmax function to reduce execution time. We evaluate the latency of the pytorch-implemented softmax function on NVIDIA V100 GPU. Following the configuration in Text Classification Benchmark, we set batch size=16,

, and enforce different sparsity ratios. Figure 10 shows that the reduced softmax achieves speedup compared with dense softmax function.

5.2 Hardware Specialization for DSA

While adding structural constraints can potentially benefit GPU kernel implementation, the expressive power of the model is still inevitably affected. For instance, as shown in Table 4, the vector encoding achieves comparable accuracy with full-attention, but is lower than the accuracy of using fine-grained sparsity under the same sparsity ratio. Thus, an alternative approach is to use hardware specialization to fully exploit the potential saving from DSA.

As we know, self-attention mechanism mainly involves matrix-matrix multiplication, which can be efficiently handled with a 2D spatial array of processing elements (PEs). Prior work also proposes efficient dataflow for computing attention layer using techniques such as operator fusion, loop tiling, and loop reordering (Park et al., 2020; Kao et al., 2021). With DSA, the underlying spatial array and data-flow should be adjusted accordingly. Specifically, DSA poses two architectural implications as follows.

Multi-precision Computation. DSA relies on dimension reduction and quantization to control the overhead of attention prediction. Therefore, the system needs to handle both high-precision (eg., FP32/FX16) and low-precision (e.g., INT2/INT4) computations. This can be implemented either with a decoupled design or a coupled design. In decoupled architecture, standalone hardware modules are implemented for different precision Liu et al. (2020), e.g., using two PE arrays for low-precision and high-precision. The two modules work in a pipelined manner, where the small PE array generates sparsity information for the large PE array to skip unnecessary computations. A drawback of this type of design is that, the computation throughput is fixed for the two modules, but the relative workload between prediction and execution is task-dependent. As a result, one module may become idle time to time due to workload imbalance. On the contrary, a coupled PE array tackles this problem by using multi-precision arithmetic units Sharma et al. (2018). Specifically, the computation precision of each PE is configurable, such that different sections of the PE array can be dynamically controlled to balance the relative throughput. Yet, this requires runtime configuration which makes system control to be more complicated.

Sparsity-aware Execution. As mentioned above, with DSA, we can reformulate into a SDDMM followed by a SpMM (softmax is omitted for simplicity). In other words, the sparsity information is used as output sparsity (OS) for the first matrix multiplication and used as input sparsity (IS) during the second matrix multiplication. However, different PEs may also encounter workload imbalance due to irregularly distributed sparsity, causing low utilization. Prior work tackles this problem from multiple approaches. For example, one can enable early finished PE to switch to the execution of other elements, or apply offline sorting and online shuffling to balance the computation Song et al. (2018); Aklaghi et al. (2018); Gondimalla et al. (2019). These approaches impose different software and hardware overheads such as larger scratchpad memory, redundant memory access, higher bandwidth requirement and so on. In DSA, we use a simple and effective solution by enforcing a row-wise constraint such that different attention rows contain the same amount of important attention weights.

Figure 11: Using sparsity locality and compute reordering to improve data reuse.

Finally, the locality of important attention weights provides opportunities for data reuse. As shown in Figure 11, if multiple rows of attention matrix are computed simultaneously, the corresponding columns in matrix can be loaded once and shared by different PEs. Similarly, this holds true for computing matrix , as the rows in matrix can be reused. In this example, suppose four PEs work in parallel to compute the attention matrix, and each PE is responsible for one row. The colored squares are selected attentions. The numbers in the square indicate the computation order of each PE. As we can see, in the left figure, each PE computes the selected attention weights from left to right. Thus, although the sparsity distribution delivers some locality, the data reuse is bad. In contrast, if we reorder the computation within each row as shown in the right figure, then we can utilize the column locality to improve data reuse. We evaluate the benefit of computation reordering on real benchmarks. As shown in Table 5, on the Text Classification task, the locality naturally brings memory access reduction compared with row-by-row processing, while reordering further improves this ratio to . On Image Classification, the reduction ratio is without reordering and with reordering.

An important benefit of this type of out-of-order execution is that, matrix does not need to be reshuffled and matrix is still generated in a regular order. This granted advantage comes from the attention mechanism itself, because the whole computation process is a two-step GEMM chain. Therefore, the reordered is completely consumed during the second matrix multiplication. In contrast, exploring the same reordering in CNN would require a crossbar-like design to correctly store the output result Liu et al. (2020), causing additional performance and resource overhead.

Dataflow Image Text
row-by-row
row-parallel w/o
compute reordering
row-parallel w/
compute reordering
Table 5: Memory access reduction of the second matrix operand used in the multiplication of and .

6 Related Work

Transformers with the use of self-attention mechanism are difficult to scale with sequence length because of the quadratic time and memory complexity. Our paper focuses on the exploration of sparse attention patterns in Transformers. Other orthogonal approaches such as parameters sharing (Gong et al., 2019) can mitigate the issue. We refer readers to a survey paper for a more comprehensive view of efficient Transformers (Tay et al., 2020c).

Static Sparse Patterns. A straightforward way to exploit attention sparsity is to set static or fixed sparse patterns, such as local windows, block-wise, dilated patterns, or a combination of static patterns (Zaheer et al., 2020; Child et al., 2019; Qiu et al., 2020). However, as the sparse attention patterns are inherently dynamic depending on input sequences, those work lack the capability of capturing dynamic sparse patterns. As shown in our evaluation, the sparsity-saving trade-offs of representative methods using static sparse patterns are worse than our dynamic sparse attention approach.

Clustering-based methods. Building upon static block-sparse patterns, another line of research is to group similar tokens into chunks and perform local attention within chunks (Kitaev et al., 2020; Roy et al., 2021; Tay et al., 2020a). The similarity function used to group tokens can be hashing, clustering, or learned sorting. However, those methods are designed for training memory reduction and impractical at inference time when operating on each sequence. The quality of grouping, e.g., convergence of clustering, is not guaranteed at long sequences, and the overhead of on-the-fly clustering is not acceptable.

Approximation methods. Recent work proposes to replace standard attention with forms of approximation of the attention weights (Wang et al., 2020; Katharopoulos et al., 2020; Choromanski et al., 2021; Peng et al., 2021). While we provide a comparison in our evaluation, we regard those work out the scope of our discussion for exploring sparsity in (standard) attention. Whether using a form of approximation to replace standard attention or as we suggest to predict sparse patterns explicitly is a design choice leaving up to practitioners.

Attention and Transformer Accelerators Recent work adopt algorithm and hardware co-design to reduce the cost of attention mechanism. MnnFast Jang et al. (2019) proposes to skip the computations of based on the magnitude of the calculated attention scores. This method can only benefit the second GEMM of attention layer.  Ham et al. (2020) introduces attention approximation to prune the unimportant attentions. However, involves expensive online sorting, which causes significant performance and energy overhead. ELSA Ham et al. (2021) uses sign random projection to estimate the attention weights, making the approximation much more hardware efficient, but the model quality is hurt due to inaccurate approximation. In DSA, we address these limitations by simultaneously considering approximation accuracy and efficiency. Finally, SpAtten Wang et al. (2021) proposes cascade token pruning and head pruning to reduce the cost of both self-attention block and subsequent layers. While removing several rows and columns of the attention matrix makes the operation regular and hardware-friendly, we find this constraint to be too aggressive as the locality of attention weights usually exists in small granularity.

7 Conclusion

In this paper, we present Dynamic Sparse Attention (DSA), a novel method that exploits dynamic sparse patterns in attention to reduce computational cost when serving Transformers. Specifically, we show that our method can achieve up to 95% attention sparsity without model inference quality loss. Other than prior art that uses static sparse patterns in attention, our method explores dynamic sparse patterns that are inherent in attention when processing different input sequences. Instead of replacing standard attention with other variants such as low-rank approximation methods, we augment standard attention with a prediction path as the means to locate dynamic sparsity. On one hand, attention approximation can be very efficient when only used for sparsity prediction. On the other hand, the expressive power of full attention is preserved as the important attention weights from full attention are effective in model inference. Experimental results on the LRA benchmark demonstrate superior performance and model efficiency of DSA. Furthermore, we demonstrate the potential of using DSA to improve hardware performance and efficiency. With customized kernel design and structural sparsity, DSA delivers practical speedup on GPU. The algorithm benefit can be further exploited with specialized architecture, as the hardware can fully benefit from low-precision prediction, fine-grained sparse computation, and data locality.

References

  • V. Aklaghi, A. Yazdanbakhsh, K. Samadi, H. Esmaeilzadeh, and R. Gupta (2018)

    Snapea: predictive early activation for reducing computation in deep convolutional neural networks

    .
    Cited by: §5.2.
  • I. Beltagy, M. E. Peters, and A. Cohan (2020) Longformer: the long-document transformer. arXiv preprint arXiv:2004.05150. Cited by: §1, §1, §2.2, §5.1.
  • J. Bradbury, R. Frostig, P. Hawkins, M. J. Johnson, C. Leary, D. Maclaurin, G. Necula, A. Paszke, J. VanderPlas, S. Wanderman-Milne, and Q. Zhang (2018) JAX: composable transformations of Python+NumPy programs. External Links: Link Cited by: §4.1.
  • T. B. Brown, B. Mann, N. Ryder, M. Subbiah, J. Kaplan, P. Dhariwal, A. Neelakantan, P. Shyam, G. Sastry, A. Askell, et al. (2020) Language models are few-shot learners. arXiv preprint arXiv:2005.14165. Cited by: §1.
  • Z. Chen, Z. Qu, L. Liu, Y. Ding, and Y. Xie (2021) Efficient tensor core-based gpu kernels for structured sparsity under reduced precision. Cited by: Figure 9, §5.1.
  • R. Child, S. Gray, A. Radford, and I. Sutskever (2019) Generating long sequences with sparse transformers. External Links: 1904.10509 Cited by: §1, §2.2, §6.
  • K. Choromanski, V. Likhosherstov, D. Dohan, X. Song, A. Gane, T. Sarlos, P. Hawkins, J. Davis, A. Mohiuddin, L. Kaiser, D. Belanger, L. Colwell, and A. Weller (2021) Rethinking attention with performers. External Links: 2009.14794 Cited by: §6.
  • Z. Dai, Z. Yang, Y. Yang, J. Carbonell, Q. V. Le, and R. Salakhutdinov (2019) Transformer-xl: attentive language models beyond a fixed-length context. External Links: 1901.02860 Cited by: §1.
  • J. Devlin, M. Chang, K. Lee, and K. Toutanova (2019) BERT: pre-training of deep bidirectional transformers for language understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pp. 4171–4186. Cited by: §1.
  • T. Gale, M. Zaharia, C. Young, and E. Elsen (2020)

    Sparse gpu kernels for deep learning

    .
    arXiv preprint arXiv:2006.10901. Cited by: §5.1, Table 4.
  • A. Gondimalla, N. Chesnut, M. Thottethodi, and T. Vijaykumar (2019)

    SparTen: a sparse tensor accelerator for convolutional neural networks

    .
    In Proceedings of the 52nd Annual IEEE/ACM International Symposium on Microarchitecture, pp. 151–165. Cited by: §5.2.
  • L. Gong, D. He, Z. Li, T. Qin, L. Wang, and T. Liu (2019) Efficient training of BERT by progressively stacking. In Proceedings of the 36th International Conference on Machine Learning, Proceedings of Machine Learning Research, Vol. 97, pp. 2337–2346. External Links: Link Cited by: §1, §6.
  • T. J. Ham, S. J. Jung, S. Kim, Y. H. Oh, Y. Park, Y. Song, J. Park, S. Lee, K. Park, J. W. Lee, and D. Jeong (2020) A3̂: accelerating attention mechanisms in neural networks with approximation. In 2020 IEEE International Symposium on High Performance Computer Architecture (HPCA), Vol. , pp. 328–341. External Links: Document Cited by: §6.
  • T. J. Ham, Y. Lee, S. H. Seo, S. Kim, H. Choi, S. J. Jung, and J. W. Lee (2021) ELSA: hardware-software co-design for efficient, lightweight self-attention mechanism in neural networks. In 2021 ACM/IEEE 48th Annual International Symposium on Computer Architecture (ISCA), pp. 692–705. Cited by: §6.
  • H. Jang, J. Kim, J. Jo, J. Lee, and J. Kim (2019) Mnnfast: a fast and scalable system architecture for memory-augmented neural networks. In Proceedings of the 46th International Symposium on Computer Architecture, pp. 250–263. Cited by: §6.
  • S. Kao, S. Subramanian, G. Agrawal, and T. Krishna (2021) ATTACC the quadratic bottleneck of attention layers. ArXiv abs/2107.06419. Cited by: §5.2.
  • A. Katharopoulos, A. Vyas, N. Pappas, and F. Fleuret (2020) Transformers are rnns: fast autoregressive transformers with linear attention. In International Conference on Machine Learning, pp. 5156–5165. Cited by: §6.
  • N. Kitaev, L. Kaiser, and A. Levskaya (2020) Reformer: the efficient transformer. In International Conference on Learning Representations, External Links: Link Cited by: §1, §6.
  • A. Krizhevsky (2012) Learning multiple layers of features from tiny images. University of Toronto, pp. . Cited by: §A.3.
  • L. Liu, Z. Qu, L. Deng, F. Tu, S. Li, X. Hu, Z. Gu, Y. Ding, and Y. Xie (2020) DUET: boosting deep neural network efficiency on dual-module architecture. In 2020 53rd Annual IEEE/ACM International Symposium on Microarchitecture (MICRO), Vol. , pp. 738–750. External Links: Document Cited by: §5.2, §5.2.
  • A. L. Maas, R. E. Daly, P. T. Pham, D. Huang, A. Y. Ng, and C. Potts (2011)

    Learning word vectors for sentiment analysis

    .
    In Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics: Human Language Technologies, Portland, Oregon, USA, pp. 142–150. External Links: Link Cited by: §A.1.
  • M. Naumov, L. Chien, P. Vandermersch, and U. Kapasi (2010) Cusparse library. In GPU Technology Conference, Cited by: §5.1.
  • M. Ott, S. Edunov, D. Grangier, and M. Auli (2018) Scaling neural machine translation. arXiv preprint arXiv:1806.00187. Cited by: §1.
  • J. Park, H. Yoon, D. Ahn, J. Choi, and J. Kim (2020)

    OPTIMUS: optimized matrix multiplication structure for transformer neural network accelerator

    .
    Proceedings of Machine Learning and Systems 2, pp. 363–378. Cited by: §5.2.
  • N. Parmar, A. Vaswani, J. Uszkoreit, L. Kaiser, N. Shazeer, A. Ku, and D. Tran (2018) Image transformer. In International Conference on Machine Learning, pp. 4055–4064. Cited by: §1.
  • H. Peng, N. Pappas, D. Yogatama, R. Schwartz, N. Smith, and L. Kong (2021) Random feature attention. In International Conference on Learning Representations, External Links: Link Cited by: §6.
  • J. Qiu, H. Ma, O. Levy, S. W. Yih, S. Wang, and J. Tang (2020) Blockwise self-attention for long document understanding. External Links: 1911.02972 Cited by: §1, §6.
  • D. Radev, P. Muthukrishnan, and V. Qazvinian (2009) The acl anthology network corpus. Vol. 47, pp. . External Links: Document Cited by: §A.2.
  • A. Roy, M. Saffar, A. Vaswani, and D. Grangier (2021) Efficient content-based sparse attention with routing transformers. Transactions of the Association for Computational Linguistics 9, pp. 53–68. Cited by: §1, §6.
  • H. Sharma, J. Park, N. Suda, L. Lai, B. Chau, V. Chandra, and H. Esmaeilzadeh (2018) Bit fusion: bit-level dynamically composable architecture for accelerating deep neural network. In 2018 ACM/IEEE 45th Annual International Symposium on Computer Architecture (ISCA), pp. 764–775. Cited by: §5.2.
  • H. Shi, J. Gao, X. Ren, H. Xu, X. Liang, Z. Li, and J. T. Kwok (2021) SparseBERT: rethinking the importance analysis in self-attention. arXiv preprint arXiv:2102.12871. Cited by: §2.2.
  • M. Song, J. Zhao, Y. Hu, J. Zhang, and T. Li (2018) Prediction based execution on deep neural networks. In 2018 ACM/IEEE 45th Annual International Symposium on Computer Architecture (ISCA), pp. 752–763. Cited by: §5.2.
  • T. Tang, S. Li, L. Nai, N. Jouppi, and Y. Xie (2021) NeuroMeter: an integrated power, area, and timing modeling framework for machine learning accelerators industry track paper. In 2021 IEEE International Symposium on High-Performance Computer Architecture (HPCA), Vol. , pp. 841–853. External Links: Document Cited by: §4.4.
  • Y. Tay, D. Bahri, L. Yang, D. Metzler, and D. Juan (2020a) Sparse sinkhorn attention. In International Conference on Machine Learning, pp. 9438–9447. Cited by: §6.
  • Y. Tay, M. Dehghani, S. Abnar, Y. Shen, D. Bahri, P. Pham, J. Rao, L. Yang, S. Ruder, and D. Metzler (2020b) Long range arena: A benchmark for efficient transformers. CoRR abs/2011.04006. External Links: Link, 2011.04006 Cited by: §A.1, §2.3, Table 2, §4.
  • Y. Tay, M. Dehghani, D. Bahri, and D. Metzler (2020c) Efficient transformers: A survey. arXiv (August 2020), pp. 1–28. External Links: 2009.06732, ISSN 23318422 Cited by: §6.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin (2017) Attention is all you need. In NIPS, pp. 6000–6010. External Links: Link Cited by: §1, §2.1.
  • H. Wang, Z. Zhang, and S. Han (2021) SpAtten: efficient sparse attention architecture with cascade token and head pruning. In 2021 IEEE International Symposium on High-Performance Computer Architecture (HPCA), Vol. , pp. 97–110. External Links: Document Cited by: §6.
  • S. Wang, B. Li, M. Khabsa, H. Fang, and H. Ma (2020) Linformer: self-attention with linear complexity. arXiv preprint arXiv:2006.04768. Cited by: §6.
  • M. Zaheer, G. Guruganesh, A. Dubey, J. Ainslie, C. Alberti, S. Ontanon, P. Pham, A. Ravula, Q. Wang, L. Yang, et al. (2020) Big bird: transformers for longer sequences. arXiv preprint arXiv:2007.14062. Cited by: §1, §1, §2.2, §5.1, §6.

Appendix A Benchmark Descriptions and Experiment Configurations

In our experiments, we choose Text Classification, Image Classification and Document Retrieval from Long-Range Arena, while excluding Long ListOps and Pathfinder. This is because the ListOps results in LRA exhibit significant divergence without much explanation. And for Pathfinder, we are unable to reproduce the baseline results with the given training configurations from LRA.

a.1 Text Classification

The Text Classification task, as introduced in LRA (Tay et al., 2020b)

, is a binary classification that uses real-world data to benchmark the ability of the models to deal with compositionality. IMDb review 

(Maas et al., 2011) is selected as the dataset, which is a common choice for document classification. Moreover, to make the problem more challenging, this task takes a byte-level setup instead of the normal character-level setup for language modeling. Therefore, the model needs to learn from the unsegmented data and make compositional decisions.

For model configuration, we use the original hyperparameters given in the LRA repository

111https://github.com/google-research/long-range-arena. Specifically, the baseline transformer consists of 4 attention layers, each with 4 heads. The hidden dimension size is 256 and the positional FFN layer has a dimension size of 1024. The learning rate is 0.05 with a weight decay of 0.1. Finally, the baseline model is trained for steps where the first are warmup steps and the batch size is 32.

When compared with the dense baseline in Figure 2 of the full paper, the DSA-x% models are obtained from fine-tuning the dense model for steps with different levels of sparsity constraints. During fine-tuning, parameters from both original model and the predictor are updated simultaneously using the combination of cross-entropy loss and MSE loss. The weight factor of the MSE loss is 0.01 and the learning rate is uniformly set as 0.0002.

When compared with other efficient transformers as shown in Table 1 of the full paper, we directly train the DSA prediction path from scratch. The overall training step is still , but we use the first to train the original model and freeze the predictor module. Therefore, the first steps are the same as training a dense baseline. After this, we jointly optimize the model and the predictor module during the last steps with the same MSE loss factor and learning rate as above.

Finally, to limit the training cost, we set the sequence length to be 2000 for the baseline comparison and sensitivity study, while only set the length to be 4000 when comparing with other models.

a.2 Document Retrieval

Document Retrieval is a binary classification task that serves as a test to evaluate how well a model compresses long sequences into representations for similarity-based matching. This task uses ACL Anthology Network (Radev et al., 2009) and aims to identify if two papers have a citation link. Similar to Text Classification, byte-level setup is used to increase the difficulty of the problem.

We use a uniform sequence length of 4000 in this task. The baseline transformer consists of 4 attention layers. Each attention layer has 4 heads, 128 hidden dimensions, and 512 FFN dimensions. The learning rate is 0.05 with a weight decay of 0.1. The model is trained for steps with Adam optimizer and a batch size of 32. Similar to the strategy in the Text Classification task, we use fine-tuning for baseline comparison and training-from-scratch for cross model comparison. The steps are equally divided into for dense training and for joint training in the training-from-scratch experiment. When jointly optimizing all the parameters, the weight factor of the MSE loss is 0.01 and the learning rate is 0.0002.

a.3 Image Classification

The final task we include in our evaluation is image classification using CIFAR-10 

(Krizhevsky, 2012). Each input image is flattened as a sequence of pixels. Therefore, the sequence length of this task is 1024. The input images are mapped to a single gray-scale channel where each pixel is represented with an 8-bit integer value. Following the given settings, the baseline transformer model contains one attention layer with 8 heads, 64 query/key/value hidden dimensions, and 128 FFN dimensions.

There are in total 45,000 training samples and 15,000 validation samples. We train the model for 200 epochs with a learning rate of 0.0005 and a batch size of 128. Same as above, we use finetuning for baseline comparison and training-from-scratch for cross model comparison. The

steps are divided into for dense training and for joint training in the training-from-scratch experiment. When jointly optimizing all the parameters, the weight factor of the MSE loss is 0.01 and the learning rate is 0.0002.