Magic Pyramid: Accelerating Inference with Early Exiting and Token Pruning

by   Xuanli He, et al.
Monash University

Pre-training and then fine-tuning large language models is commonly used to achieve state-of-the-art performance in natural language processing (NLP) tasks. However, most pre-trained models suffer from low inference speed. Deploying such large models to applications with latency constraints is challenging. In this work, we focus on accelerating the inference via conditional computations. To achieve this, we propose a novel idea, Magic Pyramid (MP), to reduce both width-wise and depth-wise computation via token pruning and early exiting for Transformer-based models, particularly BERT. The former manages to save the computation via removing non-salient tokens, while the latter can fulfill the computation reduction by terminating the inference early before reaching the final layer, if the exiting condition is met. Our empirical studies demonstrate that compared to previous state of arts, MP is not only able to achieve a speed-adjustable inference but also to surpass token pruning and early exiting by reducing up to 70 (GFLOPs) with less than 0.5 express distinctive preferences to sequences with different lengths. However, MP is capable of achieving an average of 8.06x speedup on two popular text classification tasks, regardless of the sizes of the inputs.



There are no comments yet.


page 1

page 2

page 3

page 4


Pyramid-BERT: Reducing Complexity via Successive Core-set based Token Selection

Transformer-based language models such as BERT have achieved the state-o...

Accelerating BERT Inference for Sequence Labeling via Early-Exit

Both performance and efficiency are crucial factors for sequence labelin...

Elbert: Fast Albert with Confidence-Window Based Early Exit

Despite the great success in Natural Language Processing (NLP) area, lar...

TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference

Existing pre-trained language models (PLMs) are often computationally ex...

Sequential Attention Module for Natural Language Processing

Recently, large pre-trained neural language models have attained remarka...

Accelerating Attention through Gradient-Based Learned Runtime Pruning

Self-attention is a key enabler of state-of-art accuracy for various tra...

The Right Tool for the Job: Matching Model and Instance Complexities

As NLP models become larger, executing a trained model requires signific...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In the past few years, owing to the success of Transformer-based [20] pre-trained models, such as BERT [1], RoBERTa[11], GPT2[13], etc., we have experienced a performance breakthrough in natural language processing (NLP) tasks. With a small amount of fine-tuning, the pre-trained models can achieve state-of-the-art performance across different tasks [1, 11, 13]. Nevertheless, the outperforming models are evaluated in offline settings, and the inference latency is not assessed or considered as a quality factor.

However, adapting and deploying such large pre-trained models to production systems (e.g., online shopping services) is not straightforward due to the latency constraint and the large volume of incoming requests (e.g., millions of requests per second). Prior to this work, researchers have proposed to compress a large model via either model pruning [12, 2, 6] or token pruning [22, 3, 9]. In addition, compressing a large teacher model into a compact model via knowledge distillation has been studied extensively in the past  [14, 17, 18, 7, 4]

. Finally, another line of work targets on plugging multiple sub-classifiers into deep neural networks to enable a flexible computation on demand,

a.k.a., early exiting [19, 8, 15, 10]

(a) Yelp.
(b) AG news
Figure 1: Speedup of LTP (token pruning), FastBERT (early exiting) and MP (ours) under different sequence lengths on Yelp and AG news. X axis is sequence length, while Y axis is speedup. short: 1-35 tokens; middle: 35-70 tokens; long: >70 tokens

The token pruning concentrates on a width-wise computational reduction, whereas the early exiting succeeds in a depth-wise inference acceleration. Our study shows that for certain tasks where the input data is diverse (in terms of sequence length), these two latency reduction methods perform in the opposite direction. As illustrated in Figure 1 (a), speedup (Y axis) achieved via early exiting (FastBERT) decreases for long input sizes (X axis). However, token pruning (LTP) speedup rises as the input size increases. We believe these two approaches are orthogonal and can be combined into a single model to maintain the latency reduction gain across the variable input length. In this work, we present a novel approach, Magic Pyramid (MP), to encourage a speed-adjustable inference. The contribution of this paper includes:

  • Our empirical study shows that token pruning and early exiting are potentially orthogonal. This motivates further research on employing the two orthogonal inference optimization methods within a single model for model inference acceleration.

  • We propose a method (referred to as Magic Pyramid) to exploit the synergy between token pruning and early exiting and attain higher computational reduction from width and depth perspectives.

  • Compared to two strong baselines, our approach can significantly accelerate the inference time with an additional 0.5-2x speedup but less than 0.5% degradation on accuracy across five classification tasks.

2 Related Work

Large pre-trained models have demonstrated that increasing the model capacity can pave the way for the development of superior AI. However, as we have limited resources allocated for production systems, there has been a surge of interest in efficient inference. Previous works [14, 17, 18, 7, 4] have opened a window into an effective model compression via knowledge distillation (KD) [5]. The core of KD is to use a compact student model to mimic the behavior or structure of a large teacher model. As such, the performance of the student model is as accurate as its teacher, but consuming less computation.

Other researchers approach efficient inference by manipulating the original model. One elegant solution is pruning, which can reduce the computation by removing non-essential components. These components can be either model parameters (model pruning) [12, 2, 6] or tokens (token pruning) [22, 3, 9]. In addition, one can boost the speed of the numerical operations of a model through quantization [24, 23, 16].

The aforementioned works lack flexibility in terms of the speedup, albeit some success. To satisfy varying demands, we have to train multiple models. Since deep neural networks can be considered as a stack of basic building blocks, a list of works introduces early exiting [19, 8, 15, 10], which attaches a set of sub-classifiers to these sub-networks to encourage an adjustable inference within a single model, when needed. As opposed to the prior works, which focus on the one-dimensional speedup, this work takes the first step to superimpose token pruning on early exiting. Our empirical studies confirm that these two approaches can accelerate the inference collaboratively and significantly.

3 Methodology - Proposed Method

Prior to this work, token pruning and early exiting have been proven to be effective in accelerating the inference  [22, 3, 15, 10, 9]. However, as shown in Figure 1, these approaches fall short of reducing the latency at two ends, i.e., short sequences and long sequences. For example, Figure 1 (a) shows that LTP (token-pruning) provides the highest speed-up for long input sequences. While FastBERT (early exiting) speedup drops as the input size increases from short to long. Therefore we propose a novel approach: Magic Pyramid (MP), which benefits from a combination of token pruning and early exiting. Figure 2 provides a schematic illustration of MP. First of all, MP enables to terminate an inference at any layer when needed. Second, with the increase of the depth of Transformer, redundant tokens can be expelled. The detailed designs are provided in the rest of this section.

Figure 2:

Schematic illustration of magic pyramid for a sentiment analysis task

Transformer architecture

Owing to it outstanding performance, Transformer [20] has become a de facto model for NLP tasks, especially after the triumph of pre-trained language models [1, 11, 13]. A standard Transformer is comprised of stacked Transformer blocks: , where each block is formulated as:


where and are the hidden states. is sequence length, while is the feature dimension. and are multi-head attention module and position-wise feedforward module respectively. We omit module and module in between for simplicity.

Early exiting

As shown in Figure 2, in addition to the Transformer backbone and a main classifier, one has to attach an individual sub-classifier module () to the of each Transformer block . As such, one can choose to terminate the computation at any layer, when a halt value is reached.

Following [10], the consists of a Transformer block , a pooling layer and a projection layer with a function. extracts the hidden states of as the representation of the input, while

projects the dense vector into

-class logits.

Similar to  [10], we leverage a two-stage fine-tuning to enhance the performance of sub-classifiers via knowledge distillation. Specifically, we first train the Transformer backbone and the primary classifier through a standard cross entropy between the ground truth and the predictions . Afterward, we freeze the backbone and the primary classifier, but train each

via a Kullback–Leibler divergence:


where and

are the predicted probability distribution from the

and the main classifier respectively. Since there are sub-classifiers, the loss of the second stage can be formulated as:


Once all modules are well-trained, we can stitch them together to achieve a speed-adjustable inference. At each layer , we first obtain the hidden states from the Transformer block

. Then a probability

can be computed from . One can use to calculate the uncertainty via:


where is bound to . If , we can terminate the computation. A larger suggests a faster exit.

Figure 3: An example of attention probability in a single head. Darker color suggests a higher attention score. The bottom heatmap is the column mean of the attention matrix.

Token pruning

The core of the Transformer block is the module, which is responsible for a context-aware encoding of each token. Notably, we compute pairwise importance among all tokens within the input via self-attention. The attention score of each head between and is obtained from:


where and are the hidden states of and respectively. and are learnable parameters. is set to , and is the number of heads used in the Transformer block . Since we have to conduct such operation to acquire an attention score matrix , the complexity of quadratically scales with the sequence length. Therefore, we encounter a computational bottleneck when working on long sequences. However, if we take the average of along the th column, we notice that the different tokens have distinctive scores as shown in Figure 3. Tokens with large scores tend to be more salient than others, as they receive more attention. As such, we can prune the non-salient tokens to save the computation. We formally define an importance score of each token at layer as:


Before this work, researchers have proposed two approaches to remove the unimportant tokens based on : i) top-k based pruning and ii) threshold-based pruning [3, 22, 9]. In this work, we follow [9], which leverages a layer-wise learnable threshold to achieve a fast inference and is superior to other works[22, 3]. We first fine-tune Transformer parameters on a downstream task. Then we introduce a two-stage pruning scheme to seek a suitable threshold and , which can accelerate the inference and maintain a decent accuracy. In the first pruning stage, we apply a gating function to weight the outputs from the current layer , before we pass them to the next layer like this:



is a sigmoid function,

is an element-wise multiplication, is a temperature parameter and is a learnable threshold at layer . If approaches zero, will become zero as well. As such, has no impact on the subsequent layers. At this stage, since allows the flow of the back-propagation, both and can be optimized. In addition, [9] also impose a loss on as a regularizer to encourage the pruning operation. Please refer to their paper for the details.

In the second pruning stage, we binarize the mask values at the inference time via:


If is below the threshold , is subject to the removal from layer and will not contribute towards the final predictions. We freeze but update , such that the model can learn to accurately predict the labels merely conditioning on the retained tokens.

4 Experiments

To examine the effectiveness of the proposed approach, we use five language understanding tasks as the testbed. We describe our datasets and our experimental setup in the following.

Data Train Test Task
AG news 120K 7.6K topic
Yelp 560K 38K sentiment
QQP 364K 40K paraphrase
MRPC 3.7K 408 paraphrase
RTE 2.5K 277 language inference
Table 1: The statistics of datasets


The first two tasks are: i) AG news topic identification [25], and ii) Yelp polarity sentiment classification [25]. The last three are: i)Quora Question Pairs (QQP) similarity detection dataset, ii) Microsoft Research Paraphrase Corpus (MRPC) dataset and iii) Recognizing Textual Entailment (RTE) dataset. All datasets are from GLUE benchmark [21] and focus on predicting between a sentence pair. The datasets are summarized in Table 1.

Datasets BERT DistilBERT LTP FastBERT MP (ours)
AG news 3,-,-,- 3,-,-,- 3,1,2,- 3,-,-,2 3,1,2,2
Yelp 3,-,-,- 3,-,-,- 3,1,2,- 3,-,-,2 3,1,2,2
QQP 5,-,-,-, 5,-,-,-, 5,2,5,- 5,-,-,5 5,2,5,5
MRPC 10,-,-,-, 10,-,-,-, 10,10,5,- 10,-,-,5 10,10,5,5
RTE 10,-,-,-, 10,-,-,-, 10,10,5,- 10,-,-,5 10,10,5,5
Table 2:

The number of epochs used for regular training, soft pruning, hard pruning, subclassifiers training on different datasets. “-" indicates the corresponding stage is inactive.

Experimental setup

We compare our approach with four baselines: i) standard BERT [1], ii) distilBERT [14], iii) learned token pruning (LTP) [9] and iv) FastBERT [10]. Except distilBERT, all approaches are fine-tuned on uncased BERT-base model (12 layers).

For training, we use a batch size of 32 for QQP, MRPC, and RTE. We set this to 64 for AG news and Yelp. Since different approaches adopt different training strategies, we unify them as four steps:

  1. Regular training: training a model without additional components;

  2. Soft pruning: training a model and threshold ;

  3. Hard pruning: training a model with the binarized mask values;

  4. Sub-classifiers training: training sub-classifiers on Equ. (4). For MP, we also activate the pruning operations.

We report the number of the training epochs of different steps for all approaches in Table 2. Similar to [9], we vary the threshold of the final layer from 0.01 to 0.08, and the threshold for is set to . We search the temperature in a search space of {1e5,2e5,5e5} and vary from 0.001 to 0.2. We use a learning rate of 2e5 for all experiments. We consider accuracy for the classification performance and giga floating point operations (GFLOPs) for the speedup.

BERT 94.3 9.0 (1.00x) 95.8 17.2 (1.00x) 91.3 5.1 (1.00x) 85.3 9.2 (1.00x) 68.6 11.2 (1.00x)
distilBERT 94.4 4.5 (2.00x) 95.7 8.6 (2.00x) 90.4 2.6 (2.00x) 84.6 4.6 (2.00x) 58.8 5.6 (2.00x)
LTP 94.3 5.3 (1.72x) 94.7 7.4 (2.32x) 90.6 3.2 (1.60x) 84.8 6.2 (1.48x) 67.8 7.5 (1.50x)
FastBERT 94.3 2.3 (3.97x) 94.8 2.8 (6.18x) 90.7 1.6 (3.20x) 84.3 4.3 (2.13x) 67.6 8.4 (1.33x)
MP (ours) 94.3 1.8 (4.95x) 94.5 2.1 (8.25x) 90.4 1.3 (4.03x) 83.8 3.3 (2.77x) 67.5 6.5 (1.72x)
Table 3: The accuracy and GFLOPs of BERT [1], distilBERT [14], LTP (learned token pruning) [9], FastBERT [10] and MP (ours) on different datasets. The numbers in parentheses are speedup.
AG news Yelp
0.1 0.5 0.8 0.1 0.5 0.8
FastBERT 3.97x 10.30x 11.95x 3.15x 6.18x    8.84x
MP (ours) 4.95x 10.53x 11.95x 5.35x 8.25x 10.10x
Table 4: Speedup of FastBERT and MP with different .

For token pruning approach, previous works [22, 9] have shown that there exits a trade-off between accuracy and speedup. Thus, we report the performance of models achieving smallest GFLOPs with at most 1% accuracy drop compared to the BERT baseline. Similarly, the speedup of FastBERT is also controlled by the halt value . We select obtaining a on-par accuracy with the token pruning competitors for the sake of a fair comparison. This selection criterion is applied to MP as well.

Table 3 demonstrates that all approaches experience loss in accuracy, when a fast inference is activated. Overall, FastBERT is superior to disilBERT and LTP in terms of both accuracy and GFLOPs. Under the similar accuracy, our approach manages to have a significantly faster inference than FastBERT, which leads to up to 2.13x extra speedup. We notice that the speedup and accuracy also correlate to the complexity of tasks and the number of training data. Specifically, for the sentence-pair classification tasks, since QQP has much more data (c.f., Table 1), it achieves 4.03x speedup with a loss of 1% accuracy. On the contrary, RTE and MRPC obtain at most 2.77x speedup with the same amount of accuracy degradation. Under the same magnitude of the training data, as AG news and Yelp are simpler than QQP, they can gain up to 8.25x speedup after sacrificing 1% accuracy.

Gains over FastBERT

In section 3, we have claimed that MP can benefit from both token pruning and early exiting. Although this claim is evidenced in Table 3, we are interested in investigating whether such gains consistently hold, when tuning to control the speed of the inference. According to Table 4, MP can drastically boost the speedup of FastBERT, except for an aggressive , which will cause the computation to terminate at the first two layers.

Speedup on sequences with different lengths

Intuitively, longer sentences tend to have more redundant tokens, which can confuse the lower sub-classifiers. Consequently, longer sentences require more computation before reaching a lower uncertainty . We bucket the Yelp and AG news dataset into three categories: i) short sequences (1-35 tokens), ii) middle sequences (35-70 tokens) and iii) long sequences (>70 tokens). Figure 1 indicates that LTP prefers long sequences, while FastBERT favors short sequences. Since MP combines the early exiting with the token pruning, it can significantly accelerate both short and long sequences, compared to the two baselines.

5 Conclusion

In this work, we introduce Magic Pyramid, which can maintain a trade-off between speedup and accuracy for BERT-based models. Since MP is powered by two outstanding efficiency-encouraging approaches, it can yield substantially faster inference over the baselines up to additional 2x speedup. We also found that token pruning and early exiting falls to efficiently handle sequences under certain length groups. In contrary, such limitations can be combated by MP, thereby our approach can indiscriminately accelerate inference for every input data (i.e, inference request) regardless of its length.


  • [1] J. Devlin, M. Chang, K. Lee, and K. Toutanova (2019) BERT: pre-training of deep bidirectional transformers for language understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pp. 4171–4186. Cited by: §1, §3, §4, Table 3.
  • [2] A. Fan, E. Grave, and A. Joulin (2019) Reducing transformer depth on demand with structured dropout. In International Conference on Learning Representations, Cited by: §1, §2.
  • [3] S. Goyal, A. R. Choudhury, S. Raje, V. Chakaravarthy, Y. Sabharwal, and A. Verma (2020) Power-bert: accelerating bert inference via progressive word-vector elimination. In

    International Conference on Machine Learning

    pp. 3690–3699. Cited by: §1, §2, §3, §3.
  • [4] X. He, I. Nassar, J. Kiros, G. Haffari, and M. Norouzi (2021) Generate, annotate, and learn: generative models advance self-training and knowledge distillation. arXiv preprint arXiv:2106.06168. Cited by: §1, §2.
  • [5] G. Hinton, O. Vinyals, and J. Dean (2015) Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531. Cited by: §2.
  • [6] L. Hou, Z. Huang, L. Shang, X. Jiang, X. Chen, and Q. Liu (2020) DynaBERT: dynamic bert with adaptive width and depth. Advances in Neural Information Processing Systems 33. Cited by: §1, §2.
  • [7] X. Jiao, Y. Yin, L. Shang, X. Jiang, X. Chen, L. Li, F. Wang, and Q. Liu (2020) TinyBERT: distilling bert for natural language understanding. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Findings, pp. 4163–4174. Cited by: §1, §2.
  • [8] Y. Kaya, S. Hong, and T. Dumitras (2019-09–15 Jun) Shallow-deep networks: understanding and mitigating network overthinking. In Proceedings of the 36th International Conference on Machine Learning, K. Chaudhuri and R. Salakhutdinov (Eds.), Proceedings of Machine Learning Research, Vol. 97, pp. 3301–3310. External Links: Link Cited by: §1, §2.
  • [9] S. Kim, S. Shen, D. Thorsley, A. Gholami, J. Hassoun, and K. Keutzer (2021) Learned token pruning for transformers. arXiv preprint arXiv:2107.00910. Cited by: §1, §2, §3, §3, §4, §4, §4, Table 3.
  • [10] W. Liu, P. Zhou, Z. Wang, Z. Zhao, H. Deng, and Q. Ju (2020) FastBERT: a self-distilling bert with adaptive inference time. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pp. 6035–6044. Cited by: §1, §2, §3, §3, §3, §4, Table 3.
  • [11] Y. Liu, M. Ott, N. Goyal, J. Du, M. Joshi, D. Chen, O. Levy, M. Lewis, L. Zettlemoyer, and V. Stoyanov (2019) Roberta: a robustly optimized bert pretraining approach. arXiv preprint arXiv:1907.11692. Cited by: §1, §3.
  • [12] P. Michel, O. Levy, and G. Neubig (2019) Are sixteen heads really better than one?. Advances in Neural Information Processing Systems 32, pp. 14014–14024. Cited by: §1, §2.
  • [13] A. Radford, J. Wu, R. Child, D. Luan, D. Amodei, I. Sutskever, et al. (2019) Language models are unsupervised multitask learners. OpenAI blog 1 (8), pp. 9. Cited by: §1, §3.
  • [14] V. Sanh, L. Debut, J. Chaumond, and T. Wolf (2019) DistilBERT, a distilled version of bert: smaller, faster, cheaper and lighter. arXiv preprint arXiv:1910.01108. Cited by: §1, §2, §4, Table 3.
  • [15] R. Schwartz, G. Stanovsky, S. Swayamdipta, J. Dodge, and N. A. Smith (2020) The right tool for the job: matching model and instance complexities. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pp. 6640–6651. Cited by: §1, §2, §3.
  • [16] S. Shen, Z. Dong, J. Ye, L. Ma, Z. Yao, A. Gholami, M. W. Mahoney, and K. Keutzer (2020) Q-bert: hessian based ultra low precision quantization of bert. In

    Proceedings of the AAAI Conference on Artificial Intelligence

    Vol. 34, pp. 8815–8821. Cited by: §2.
  • [17] S. Sun, Y. Cheng, Z. Gan, and J. Liu (2019) Patient knowledge distillation for bert model compression. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), pp. 4323–4332. Cited by: §1, §2.
  • [18] Z. Sun, H. Yu, X. Song, R. Liu, Y. Yang, and D. Zhou (2020) MobileBERT: a compact task-agnostic bert for resource-limited devices. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pp. 2158–2170. Cited by: §1, §2.
  • [19] S. Teerapittayanon, B. McDanel, and H.T. Kung (2016) BranchyNet: fast inference via early exiting from deep neural networks. In

    2016 23rd International Conference on Pattern Recognition (ICPR)

    Vol. , pp. 2464–2469. External Links: Document Cited by: §1, §2.
  • [20] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin (2017) Attention is all you need. In Advances in neural information processing systems, pp. 5998–6008. Cited by: §1, §3.
  • [21] A. Wang, A. Singh, J. Michael, F. Hill, O. Levy, and S. Bowman (2018) GLUE: a multi-task benchmark and analysis platform for natural language understanding. In Proceedings of the 2018 EMNLP Workshop BlackboxNLP: Analyzing and Interpreting Neural Networks for NLP, pp. 353–355. Cited by: §4.
  • [22] H. Wang, Z. Zhang, and S. Han (2021) SpAtten: efficient sparse attention architecture with cascade token and head pruning. In 2021 IEEE International Symposium on High-Performance Computer Architecture (HPCA), pp. 97–110. Cited by: §1, §2, §3, §3, §4.
  • [23] K. Wróbel, M. Karwatowski, M. Wielgosz, M. Pietroń, and K. Wiatr (2020)

    Compression of convolutional neural network for natural language processing

    Computer Science 21 (1). Cited by: §2.
  • [24] O. Zafrir, G. Boudoukh, P. Izsak, and M. Wasserblat Q8BERT: quantized 8bit bert. Cited by: §2.
  • [25] X. Zhang, J. Zhao, and Y. LeCun (2015) Character-level convolutional networks for text classification. Advances in neural information processing systems 28, pp. 649–657. Cited by: §4.