Pruning a BERT-based Question Answering Model

10/14/2019 ∙ by J. S. McCarley, et al. ∙ 0

We investigate compressing a BERT-based question answering system by pruning parameters from the underlying BERT model. We start from models trained for SQuAD 2.0 and introduce gates that allow selected parts of transformers to be individually eliminated. Specifically, we investigate (1) reducing the number of attention heads in each transformer, (2) reducing the intermediate width of the feed-forward sublayer of each transformer, and (3) reducing the embedding dimension. We compare several approaches for determining the values of these gates. We find that a combination of pruning attention heads and the feed-forward layer almost doubles the decoding speed, with only a 1.5 f-point loss in accuracy.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The recent surge in NLP model complexity has outstripped Moore’s law. 111 ELMO: , BERT: , Megatron: parameters Peters et al. (2018); Devlin et al. (2018); Narasimhan (2019) Deeply stacked layers of transformers (including BERT, RoBERTa Liu et al. (2019), XLNet Yang et al. (2019b), and ALBERT Lan et al. (2019)

) have greatly improved state-of-the-art accuracies across a variety of NLP tasks, but the computational intensity raises concerns in the cloud-computing economy. Numerous techniques developed to shrink neural networks including distillation, quantization , and pruning are now being applied to transformers.

Question answering, in particular, has immediate applications in real-time systems. Question answering has seen striking gains in accuracy due to transformers, as measured on the SQuAD Rajpurkar et al. (2016) and SQuAD 2.0 Rajpurkar et al. (2018) leaderboards. SQuAD is seen as a worst-case performance loss, for speed up techniques based on quantization, Shen et al. (2019) while the difficulty of distilling a SQuAD model (compared to sentence-level GLUE tasks) is acknowledged in Jiao et al. (2019). We speculate that these difficulties are because answer selection requires token level rather than passage level annotation, and the need for long range attention between query and passage.

In this paper we investigate pruning three aspects of BERT:

(1) the number of attention heads ,

(2) the intermediate size

(3) the embedding or hidden dimension .

The contributions of this paper are (1) application of structured pruning techniques to the feed-forward layer and the hidden dimension of the transformers, not just the attention heads, (2) thereby significantly pruning BERT with minimal loss of accuracy on a question answering task, considerable speedup, all without the expense of revisiting pretraining, and (3) surveying multiple pruning techniques (both heuristic and trainable) and providing recommendations specific to transformer-based question answering models.

Widely distributed pre-trained models consist of typically 12-24 layers of identically sized transformers. We will see that an optimal pruning yields non-identical transformers, namely lightweight transformers near the top and bottom while retaining more complexity in the intermediate layers.

2 Related work

While distillation (student-teacher) of BERT has produced notably smaller models, Tang et al. (2019); Turc et al. (2019); Tsai et al. (2019); Yang et al. (2019a), the focus has been on sentence level annotation tasks that do not require long-range attention. Revisiting the pretraining phase during distillation is often a significant requirement. DistilBERT Sanh et al. (2019) reports modest speedup and small performance loss on SQuAD 1.1. TinyBERT Jiao et al. (2019) restricts SQuAD evaluation to using BERT-base as a teacher, and defers deeper investigation to future work.

Our work is perhaps most similar to Fan et al. (2019), an exploration of pruning as a form of dropout. They prune entire layers of BERT, but suggest that smaller structures could also be pruned. They evaluate on MT, language modeling, and generation-like tasks, but not SQuAD. regularization was combined with matrix factorization to prune transformers in Wang et al. (2019). Gale et al. Gale et al. (2019) induced unstructured sparsity on a transformer-based MT model, but did not report speedups. Voita et al. Voita et al. (2019) focused on linguistic interpretability of attention heads and introduced regularization to BERT, but did not report speedups. Kovaleva et al. Kovaleva et al. (2019) also focused on interpreting attention, and achieved small accuracy gains on GLUE tasks by disabling (but not pruning) certain attention heads. Michel et al. Michel et al. (2019) achieved speedups on MT and MNLI by gating only the attention with simple heuristics.

3 Pruning transformers

3.1 Notation

notation dimension base large
layers 12 24
embeddings 768 1024
attention heads 12 16
intermediate size 3072 4096
Figure 1: Notation: important dimensions of a BERT model

The size of a BERT model is characterized by the values in table 1.

3.2 Gate placement

Our approach to pruning each aspect of a transformer is similar. We insert three masks

into each transformer. Each mask is a vector of gate variables

, where indicates a slice of transformer parameters to be pruned, and indicates a slice to remain active. We describe the placement of each mask following the terminology of Vaswani et al. (2017), indicating the relevant sections of that paper.

In each self-attention sublayer, we place a mask, of size which selects attention heads to remain active. (section 3.2.2)

In each feed-forward sublayer, we place a mask, of size

which selects ReLU/GeLU activations to remain active. (section 3.3)

The final mask, , of size , selects which embedding dimensions, (section 3.4) remain active. This gate is applied identically

to both input and residual connections in each transformer.

3.3 Determining Gate Values

We investigate four approches to determining the gate values.

(1) “random:” each

is sampled from a Bernoulli distribution of parameter

, where is manually adjusted to control the sparsity

(2) ”gain:” We follow the method of Michel et al. (2019)

and estimate the influence of each gate

on the training set likelihood by computing the mean value of


(“head importance score”) during one pass over the training data. We threshold to determine which transformer slices to retain.

(3) ”leave-one-out:” We again follow the method of Michel et al. (2019) and evaluate the impact on devset score of a system with exactly one gate set to zero: Note that this procedure requires passes through the data. We control the sparsity during decoding by retaining those gates for which is large.

(4) “ regularization:” Following the method described in Louizos et al. (2017), during training time the gate variables are sampled from a hard-concrete distribution Maddison et al. (2017) parameterized by a corresponding variable . The task-specific objective function is penalized in proportion to the expected number instances of . Proportionality constants , , and in the penalty terms are manually adjusted to control the sparsity. We resample the with each minibatch. We note that the full objective function is differentiable with respect to the because of the reparameterization trick. Kingma and Welling (2014); Rezende et al. (2014) The

are updated by backpropgation for one training epoch with the SQuAD training data, with all other paramaters held fixed. The final values for the gates

are obtained by thresholding the .

3.4 Pruning

After the values of the have been determined by one of the above methods, the model is pruned. Attention heads corresponding to

are removed. Slices of the feed forward linear transformations corresponding to

are removed. The pruned model no longer needs masks, and now consists of transformers of varying, non-identical sizes.

We note that task-specific training of all BERT parameters may be continued further with the pruned model.

Figure 2: vs percentage of attention heads pruned
Figure 3: vs percentage of feed-forward activations pruned
Figure 4: vs percentage of embedding dimensions removed
model time (sec) f1 attn-prune ff-prune size (MiB)
no pruning 2605 84.6 0 0 1278
2253 84.2 44.3 0 1110
2078 83.2 0 47.7 909
1631 82.6 44.3 47.7 741
1359 80.9 52.6 65.2 575
+ retrain 1349 83.2 575
Table 1: Decoding times, accuracies, and space savings achieved by two sample operating points on large-qa
Figure 5: Percentage of attention heads and feed forward activations remaining after pruning, by layer

4 Experiments

For development experiments (learning rate penalty weight exploration), and in order to minimize overuse of the official dev-set, we use of the official SQuAD 2.0 training data for training gates, and report results on the remaining . Our development experiments (base-qa) are all initialized from a SQuAD 2.0 system initialized from bert-base-uncased and trained on the that provides a baseline performance of on the dataset. 222Our baseline SQuAD model depends upon code distributed by, and incorporates either bert-base-uncased or bert-large-uncased with a standard task-specific head.

Our validation experiments use the standard training/dev configuration of SQuAD 2.0. All are initialized from system that has an accuracy of on the official dev set. Glass et al. (2019) (This model was initialized from bert-large-uncased.)

The gate parameters of the ” regularization” experiments are trained for one epoch starting from the models above, with all transformer and embedding parameters fixed. The cost of training the gate parameters is comparable to extending fine tuning for an additional epoch. We investigated learning rates of , , and on base-qa, and chose the latter for presentation and results on large-qa

. This is notably larger than typical learning rates to tune BERT parameters. We used a minibatch size of 24, otherwise default hyperparameters of the BERT-Adam optimizer. We used identical parameters for out

large-qa experiments, except with gradaccsteps=3. Tables report median values across 5 random seeds; graphs overplot results for 5 seeds.

4.1 Accuracy as function of pruning

In figure 2 we plot the accuracy of base-qa accuracy as a function of the percentage of heads removed. As expected, the performance of ”random” decays most abruptly. ”Leave-one-out” and “Gain” are better, but substantially similar. “ regularization” is best, allowing pruning at a cost of f-points.

Also in figure 3 we plot the accuracy accuracy of removing activations. We see broadly similar trends as above, except that the performance is robust to even larger pruning. “Leave-one-out” require a prohibitive number of passes () through the data.

In figure 4 we plot the accuracy for removing embedding dimensions. We see that performance falls much more steeply with the removal of embedding dimensions. Attempts to train “ regularization” were unsuccessfully - we speculate that the strong cross-layer coupling may necessitate a different learning rate schedule.

4.2 Validating these results

On the basis of the development experiments, we select operating points (values of and ) and train the gates of large-qa with these penalties. The decoding times, accuracies, and model sizes are summarized in table 1. Models in which both attention and feed forward components are pruned were produced by combining the independently trained gate configurations of attention and feed forward. For the same parameters values, the large model is pruned somewhat less than the small model. We also note that the loss due to pruning is somewhat smaller, for the same parameter values. We note that much of the performance loss can be recovered by continuing the training for an additional epoch after the pruning.

The speedup in decoding due to pruning the model is not simply proportional to the amount pruned. There are computations in both the attention and feed-forward part of each transformer layer that necessarily remain unpruned, for example layer normalization.

4.3 Impact of pruning each layer

In Fig. 5 we show the percentage of attention heads and feed forward activations remaining after pruning, by layer. We see that intermedate layers retained more, while layers close to the embedding and close to the answer were pruned more heavily.

5 Conclusion

We investigate various methods to prune transformer-based models, and evaluate the accuracy-speed tradeoff for this pruning. We find that both the attention heads and especially the feed forward layer can be pruned considerably with minimal lost of accuracy, while pruning the embedding/hidden dimension is ineffective because of a loss in accuracy. We find that regularization pruning, when successful, is considerably more effective than heuristic methods. We also find that pruning the feed-forward layer and the attention heads can be easily combined, and, especially after retraining, yield a considerably faster question answering model with minimal loss in accuracy.


  • [1] (2017) 5th international conference on learning representations, ICLR 2017, toulon, france, april 24-26, 2017, conference track proceedings. External Links: Link Cited by: C. J. Maddison, A. Mnih, and Y. W. Teh (2017).
  • J. Devlin, M. Chang, K. Lee, and K. Toutanova (2018) BERT: pre-training of deep bidirectional transformers for language understanding. CoRR abs/1810.04805. External Links: Link, 1810.04805 Cited by: §1.
  • A. Fan, E. Grave, and A. Joulin (2019) Reducing transformer depth on demand with structured dropout. External Links: 1909.11556 Cited by: §2.
  • T. Gale, E. Elsen, and S. Hooker (2019) The state of sparsity in deep neural networks. CoRR abs/1902.09574. External Links: Link, 1902.09574 Cited by: §2.
  • M. Glass, A. Gliozzo, R. Chakravarti, A. Ferritto, L. Pan, G. P. S. Bhargav, D. Garg, and A. Sil (2019) Span selection pre-training for question answering. External Links: 1909.04120 Cited by: §4.
  • X. Jiao, Y. Yin, L. Shang, X. Jiang, X. Chen, L. Li, F. Wang, and Q. Liu (2019) TinyBERT: distilling bert for natural language understanding. External Links: 1909.10351 Cited by: §1, §2.
  • D. P. Kingma and M. Welling (2014) Auto-encoding variational Bayes. International Conference on Learning Representations. Cited by: §3.3.
  • O. Kovaleva, A. Romanov, A. Rogers, and A. Rumshisky (2019) Revealing the dark secrets of bert. CoRR. External Links: Link Cited by: §2.
  • Z. Lan, M. Chen, S. Goodman, K. Gimpel, P. Sharma, and R. Soricut (2019)

    ALBERT: a lite bert for self-supervised learning of language representations

    External Links: 1909.11942 Cited by: §1.
  • Y. Liu, M. Ott, N. Goyal, J. Du, M. Joshi, D. Chen, O. Levy, M. Lewis, L. Zettlemoyer, and V. Stoyanov (2019) RoBERTa: A robustly optimized BERT pretraining approach. CoRR abs/1907.11692. External Links: Link, 1907.11692 Cited by: §1.
  • C. Louizos, M. Welling, and D. P. Kingma (2017) Learning sparse neural networks through regularization. External Links: 1712.01312 Cited by: §3.3.
  • C. J. Maddison, A. Mnih, and Y. W. Teh (2017)

    The concrete distribution: A continuous relaxation of discrete random variables

    See 1, External Links: Link Cited by: §3.3.
  • P. Michel, O. Levy, and G. Neubig (2019) Are sixteen heads really better than one?. CoRR abs/1905.10650. External Links: Link, 1905.10650 Cited by: §2, §3.3, §3.3.
  • S. Narasimhan (2019) NVIDIA clocks world’s fastest bert training time and largest transformer based model, paving path for advanced conversational ai. External Links: Link Cited by: §1.
  • M. E. Peters, M. Neumann, M. Iyyer, M. Gardner, C. Clark, K. Lee, and L. Zettlemoyer (2018) Deep contextualized word representations. External Links: 1802.05365 Cited by: §1.
  • P. Rajpurkar, R. Jia, and P. Liang (2018) Know what you don’t know: unanswerable questions for squad. CoRR abs/1806.03822. External Links: Link, 1806.03822 Cited by: §1.
  • P. Rajpurkar, J. Zhang, K. Lopyrev, and P. Liang (2016) SQuAD: 100, 000+ questions for machine comprehension of text. CoRR abs/1606.05250. External Links: Link, 1606.05250 Cited by: §1.
  • D. J. Rezende, S. Mohamed, and D. Wierstra (2014)

    Stochastic backpropagation and approximate inference in deep generative models


    Proceedings of the 31st International Conference on Machine Learning

    , E. P. Xing and T. Jebara (Eds.),
    Proceedings of Machine Learning Research, Vol. 32, Bejing, China, pp. 1278–1286. External Links: Link Cited by: §3.3.
  • V. Sanh, L. Debut, J. Chaumond, and T. Wolf (2019) DistilBERT, a distilled version of bert: smaller, faster, cheaper and lighter. External Links: 1910.01108 Cited by: §2.
  • S. Shen, Z. Dong, J. Ye, L. Ma, Z. Yao, A. Gholami, M. W. Mahoney, and K. Keutzer (2019) Q-bert: hessian based ultra low precision quantization of bert. External Links: 1909.05840 Cited by: §1.
  • R. Tang, Y. Lu, L. Liu, L. Mou, O. Vechtomova, and J. Lin (2019) Distilling task-specific knowledge from bert into simple neural networks. External Links: 1903.12136 Cited by: §2.
  • H. Tsai, J. Riesa, M. Johnson, N. Arivazhagan, X. Li, and A. Archer (2019) Small and practical bert models for sequence labeling. External Links: 1909.00100 Cited by: §2.
  • I. Turc, M. Chang, K. Lee, and K. Toutanova (2019) Well-read students learn better: on the importance of pre-training compact models. External Links: 1908.08962 Cited by: §2.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin (2017) Attention is all you need. External Links: 1706.03762 Cited by: §3.2.
  • E. Voita, D. Talbot, F. Moiseev, R. Sennrich, and I. Titov (2019) Analyzing multi-head self-attention: specialized heads do the heavy lifting, the rest can be pruned. CoRR abs/1905.09418. External Links: Link, 1905.09418 Cited by: §2.
  • Z. Wang, J. Wohlwend, and T. Lei (2019) Structured pruning of large language models. External Links: 1910.04732 Cited by: §2.
  • Z. Yang, L. Shou, M. Gong, W. Lin, and D. Jiang (2019a) Model compression with multi-task knowledge distillation for web-scale question answering system. External Links: 1904.09636 Cited by: §2.
  • Z. Yang, Z. Dai, Y. Yang, J. G. Carbonell, R. Salakhutdinov, and Q. V. Le (2019b) XLNet: generalized autoregressive pretraining for language understanding. CoRR abs/1906.08237. External Links: Link, 1906.08237 Cited by: §1.

Appendix A Supplemental Material