The recent surge in NLP model complexity has outstripped Moore’s law. 111 ELMO: , BERT: , Megatron: parameters Peters et al. (2018); Devlin et al. (2018); Narasimhan (2019) Deeply stacked layers of transformers (including BERT, RoBERTa Liu et al. (2019), XLNet Yang et al. (2019b), and ALBERT Lan et al. (2019)
) have greatly improved state-of-the-art accuracies across a variety of NLP tasks, but the computational intensity raises concerns in the cloud-computing economy. Numerous techniques developed to shrink neural networks including distillation, quantization , and pruning are now being applied to transformers.
Question answering, in particular, has immediate applications in real-time systems. Question answering has seen striking gains in accuracy due to transformers, as measured on the SQuAD Rajpurkar et al. (2016) and SQuAD 2.0 Rajpurkar et al. (2018) leaderboards. SQuAD is seen as a worst-case performance loss, for speed up techniques based on quantization, Shen et al. (2019) while the difficulty of distilling a SQuAD model (compared to sentence-level GLUE tasks) is acknowledged in Jiao et al. (2019). We speculate that these difficulties are because answer selection requires token level rather than passage level annotation, and the need for long range attention between query and passage.
In this paper we investigate pruning three aspects of BERT:
(1) the number of attention heads ,
(2) the intermediate size
(3) the embedding or hidden dimension .
The contributions of this paper are (1) application of structured pruning techniques to the feed-forward layer and the hidden dimension of the transformers, not just the attention heads, (2) thereby significantly pruning BERT with minimal loss of accuracy on a question answering task, considerable speedup, all without the expense of revisiting pretraining, and (3) surveying multiple pruning techniques (both heuristic and trainable) and providing recommendations specific to transformer-based question answering models.
Widely distributed pre-trained models consist of typically 12-24 layers of identically sized transformers. We will see that an optimal pruning yields non-identical transformers, namely lightweight transformers near the top and bottom while retaining more complexity in the intermediate layers.
2 Related work
While distillation (student-teacher) of BERT has produced notably smaller models, Tang et al. (2019); Turc et al. (2019); Tsai et al. (2019); Yang et al. (2019a), the focus has been on sentence level annotation tasks that do not require long-range attention. Revisiting the pretraining phase during distillation is often a significant requirement. DistilBERT Sanh et al. (2019) reports modest speedup and small performance loss on SQuAD 1.1. TinyBERT Jiao et al. (2019) restricts SQuAD evaluation to using BERT-base as a teacher, and defers deeper investigation to future work.
Our work is perhaps most similar to Fan et al. (2019), an exploration of pruning as a form of dropout. They prune entire layers of BERT, but suggest that smaller structures could also be pruned. They evaluate on MT, language modeling, and generation-like tasks, but not SQuAD. regularization was combined with matrix factorization to prune transformers in Wang et al. (2019). Gale et al. Gale et al. (2019) induced unstructured sparsity on a transformer-based MT model, but did not report speedups. Voita et al. Voita et al. (2019) focused on linguistic interpretability of attention heads and introduced regularization to BERT, but did not report speedups. Kovaleva et al. Kovaleva et al. (2019) also focused on interpreting attention, and achieved small accuracy gains on GLUE tasks by disabling (but not pruning) certain attention heads. Michel et al. Michel et al. (2019) achieved speedups on MT and MNLI by gating only the attention with simple heuristics.
3 Pruning transformers
The size of a BERT model is characterized by the values in table 1.
3.2 Gate placement
Our approach to pruning each aspect of a transformer is similar. We insert three masks
into each transformer. Each mask is a vector of gate variables, where indicates a slice of transformer parameters to be pruned, and indicates a slice to remain active. We describe the placement of each mask following the terminology of Vaswani et al. (2017), indicating the relevant sections of that paper.
In each self-attention sublayer, we place a mask, of size which selects attention heads to remain active. (section 3.2.2)
In each feed-forward sublayer, we place a mask, of size
which selects ReLU/GeLU activations to remain active. (section 3.3)
The final mask, , of size , selects which embedding dimensions, (section 3.4) remain active. This gate is applied identically
to both input and residual connections in each transformer.
3.3 Determining Gate Values
We investigate four approches to determining the gate values.
(1) “random:” each
is sampled from a Bernoulli distribution of parameter, where is manually adjusted to control the sparsity
(2) ”gain:” We follow the method of Michel et al. (2019)
and estimate the influence of each gateon the training set likelihood by computing the mean value of
(“head importance score”) during one pass over the training data. We threshold to determine which transformer slices to retain.
(3) ”leave-one-out:” We again follow the method of Michel et al. (2019) and evaluate the impact on devset score of a system with exactly one gate set to zero: Note that this procedure requires passes through the data. We control the sparsity during decoding by retaining those gates for which is large.
(4) “ regularization:” Following the method described in Louizos et al. (2017), during training time the gate variables are sampled from a hard-concrete distribution Maddison et al. (2017) parameterized by a corresponding variable . The task-specific objective function is penalized in proportion to the expected number instances of . Proportionality constants , , and in the penalty terms are manually adjusted to control the sparsity. We resample the with each minibatch. We note that the full objective function is differentiable with respect to the because of the reparameterization trick. Kingma and Welling (2014); Rezende et al. (2014) The
are updated by backpropgation for one training epoch with the SQuAD training data, with all other paramaters held fixed. The final values for the gatesare obtained by thresholding the .
After the values of the have been determined by one of the above methods, the model is pruned. Attention heads corresponding to
are removed. Slices of the feed forward linear transformations corresponding toare removed. The pruned model no longer needs masks, and now consists of transformers of varying, non-identical sizes.
We note that task-specific training of all BERT parameters may be continued further with the pruned model.
|model||time (sec)||f1||attn-prune||ff-prune||size (MiB)|
For development experiments (learning rate penalty weight exploration), and in order to minimize overuse of the official dev-set, we use of the official SQuAD 2.0 training data for training gates, and report results on the remaining . Our development experiments (base-qa) are all initialized from a SQuAD 2.0 system initialized from bert-base-uncased and trained on the that provides a baseline performance of on the dataset. 222Our baseline SQuAD model depends upon code distributed by https://github.com/huggingface/transformers, and incorporates either bert-base-uncased or bert-large-uncased with a standard task-specific head.
Our validation experiments use the standard training/dev configuration of SQuAD 2.0. All are initialized from system that has an accuracy of on the official dev set. Glass et al. (2019) (This model was initialized from bert-large-uncased.)
The gate parameters of the ” regularization” experiments are trained for one epoch starting from the models above, with all transformer and embedding parameters fixed. The cost of training the gate parameters is comparable to extending fine tuning for an additional epoch. We investigated learning rates of , , and on base-qa, and chose the latter for presentation and results on large-qa
. This is notably larger than typical learning rates to tune BERT parameters. We used a minibatch size of 24, otherwise default hyperparameters of the BERT-Adam optimizer. We used identical parameters for outlarge-qa experiments, except with gradaccsteps=3. Tables report median values across 5 random seeds; graphs overplot results for 5 seeds.
4.1 Accuracy as function of pruning
In figure 2 we plot the accuracy of base-qa accuracy as a function of the percentage of heads removed. As expected, the performance of ”random” decays most abruptly. ”Leave-one-out” and “Gain” are better, but substantially similar. “ regularization” is best, allowing pruning at a cost of f-points.
Also in figure 3 we plot the accuracy accuracy of removing activations. We see broadly similar trends as above, except that the performance is robust to even larger pruning. “Leave-one-out” require a prohibitive number of passes () through the data.
In figure 4 we plot the accuracy for removing embedding dimensions. We see that performance falls much more steeply with the removal of embedding dimensions. Attempts to train “ regularization” were unsuccessfully - we speculate that the strong cross-layer coupling may necessitate a different learning rate schedule.
4.2 Validating these results
On the basis of the development experiments, we select operating points (values of and ) and train the gates of large-qa with these penalties. The decoding times, accuracies, and model sizes are summarized in table 1. Models in which both attention and feed forward components are pruned were produced by combining the independently trained gate configurations of attention and feed forward. For the same parameters values, the large model is pruned somewhat less than the small model. We also note that the loss due to pruning is somewhat smaller, for the same parameter values. We note that much of the performance loss can be recovered by continuing the training for an additional epoch after the pruning.
The speedup in decoding due to pruning the model is not simply proportional to the amount pruned. There are computations in both the attention and feed-forward part of each transformer layer that necessarily remain unpruned, for example layer normalization.
4.3 Impact of pruning each layer
In Fig. 5 we show the percentage of attention heads and feed forward activations remaining after pruning, by layer. We see that intermedate layers retained more, while layers close to the embedding and close to the answer were pruned more heavily.
We investigate various methods to prune transformer-based models, and evaluate the accuracy-speed tradeoff for this pruning. We find that both the attention heads and especially the feed forward layer can be pruned considerably with minimal lost of accuracy, while pruning the embedding/hidden dimension is ineffective because of a loss in accuracy. We find that regularization pruning, when successful, is considerably more effective than heuristic methods. We also find that pruning the feed-forward layer and the attention heads can be easily combined, and, especially after retraining, yield a considerably faster question answering model with minimal loss in accuracy.
-  (2017) 5th international conference on learning representations, ICLR 2017, toulon, france, april 24-26, 2017, conference track proceedings. OpenReview.net. External Links: Cited by: C. J. Maddison, A. Mnih, and Y. W. Teh (2017).
- BERT: pre-training of deep bidirectional transformers for language understanding. CoRR abs/1810.04805. External Links: Cited by: §1.
- Reducing transformer depth on demand with structured dropout. External Links: Cited by: §2.
- The state of sparsity in deep neural networks. CoRR abs/1902.09574. External Links: Cited by: §2.
- Span selection pre-training for question answering. External Links: Cited by: §4.
- TinyBERT: distilling bert for natural language understanding. External Links: Cited by: §1, §2.
- Auto-encoding variational Bayes. International Conference on Learning Representations. Cited by: §3.3.
- Revealing the dark secrets of bert. CoRR. External Links: Cited by: §2.
ALBERT: a lite bert for self-supervised learning of language representations. External Links: Cited by: §1.
- RoBERTa: A robustly optimized BERT pretraining approach. CoRR abs/1907.11692. External Links: Cited by: §1.
- Learning sparse neural networks through regularization. External Links: Cited by: §3.3.
The concrete distribution: A continuous relaxation of discrete random variables. See 1, External Links: Cited by: §3.3.
- Are sixteen heads really better than one?. CoRR abs/1905.10650. External Links: Cited by: §2, §3.3, §3.3.
- NVIDIA clocks world’s fastest bert training time and largest transformer based model, paving path for advanced conversational ai. External Links: Cited by: §1.
- Deep contextualized word representations. External Links: Cited by: §1.
- Know what you don’t know: unanswerable questions for squad. CoRR abs/1806.03822. External Links: Cited by: §1.
- SQuAD: 100, 000+ questions for machine comprehension of text. CoRR abs/1606.05250. External Links: Cited by: §1.
Stochastic backpropagation and approximate inference in deep generative models. In
Proceedings of the 31st International Conference on Machine Learning, E. P. Xing and T. Jebara (Eds.), Proceedings of Machine Learning Research, Vol. 32, Bejing, China, pp. 1278–1286. External Links: Cited by: §3.3.
- DistilBERT, a distilled version of bert: smaller, faster, cheaper and lighter. External Links: Cited by: §2.
- Q-bert: hessian based ultra low precision quantization of bert. External Links: Cited by: §1.
- Distilling task-specific knowledge from bert into simple neural networks. External Links: Cited by: §2.
- Small and practical bert models for sequence labeling. External Links: Cited by: §2.
- Well-read students learn better: on the importance of pre-training compact models. External Links: Cited by: §2.
- Attention is all you need. External Links: Cited by: §3.2.
- Analyzing multi-head self-attention: specialized heads do the heavy lifting, the rest can be pruned. CoRR abs/1905.09418. External Links: Cited by: §2.
- Structured pruning of large language models. External Links: Cited by: §2.
- Model compression with multi-task knowledge distillation for web-scale question answering system. External Links: Cited by: §2.
- XLNet: generalized autoregressive pretraining for language understanding. CoRR abs/1906.08237. External Links: Cited by: §1.