Patient Knowledge Distillation for BERT Model Compression

08/25/2019 ∙ by Siqi Sun, et al. ∙ Microsoft 0

Pre-trained language models such as BERT have proven to be highly effective for natural language processing (NLP) tasks. However, the high demand for computing resources in training such models hinders their application in practice. In order to alleviate this resource hunger in large-scale model training, we propose a Patient Knowledge Distillation approach to compress an original large model (teacher) into an equally-effective lightweight shallow network (student). Different from previous knowledge distillation methods, which only use the output from the last layer of the teacher network for distillation, our student model patiently learns from multiple intermediate layers of the teacher model for incremental knowledge extraction, following two strategies: (i) PKD-Last: learning from the last k layers; and (ii) PKD-Skip: learning from every k layers. These two patient distillation schemes enable the exploitation of rich information in the teacher's hidden layers, and encourage the student model to patiently learn from and imitate the teacher through a multi-layer distillation process. Empirically, this translates into improved results on multiple NLP tasks with significant gain in training efficiency, without sacrificing model accuracy.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Language model pre-training has proven to be highly effective in learning universal language representations from large-scale unlabeled data. ELMo Peters et al. (2018), GPT Radford et al. (2018) and BERT Devlin et al. (2018) have achieved great success in many NLP tasks, such as sentiment classification Socher et al. (2013), natural language inference Williams et al. (2017), and question answering Lai et al. (2017).

Despite its empirical success, BERT’s computational efficiency is a widely recognized issue because of its large number of parameters. For example, the original BERT-Base model has 12 layers and 110 million parameters. Training from scratch typically takes four days on 4 to 16 Cloud TPUs. Even fine-tuning the pre-trained model with task-specific dataset may take several hours to finish one epoch. Thus, reducing computational costs for such models is crucial for their application in practice, where computational resources are limited.

Motivated by this, we investigate the redundancy issue of learned parameters in large-scale pre-trained models, and propose a new model compression approach, Patient Knowledge Distillation

(Patient-KD), to compress original teacher (e.g., BERT) into a lightweight student model without performance sacrifice. In our approach, the teacher model outputs probability logits and predicts labels for the training samples (extendable to additional unannotated samples), and the student model learns from the teacher network to mimic the teacher’s prediction.

Different from previous knowledge distillation methods Hinton et al. (2015); Sau and Balasubramanian (2016); Lu et al. (2017), we adopt a patient learning mechanism: instead of learning parameters from only the last layer of the teacher, we encourage the student model to extract knowledge also from previous layers of the teacher network. We call this ‘Patient Knowledge Distillation’. This patient learner has the advantage of distilling rich information through the deep structure of the teacher network for multi-layer knowledge distillation.

We also propose two different strategies for the distillation process: () PKD-Last: the student learns from the last layers of the teacher, under the assumption that the top layers of the original network contain the most informative knowledge to teach the student; and () PKD-Skip: the student learns from every layers of the teacher, suggesting that the lower layers of the teacher network also contain important information and should be passed along for incremental distillation.

We evaluate the proposed approach on several NLP tasks, including Sentiment Classification, Paraphrase Similarity Matching, Natural Language Inference, and Machine Reading Comprehension. Experiments on seven datasets across these four tasks demonstrate that the proposed Patient-KD approach achieves superior performance and better generalization than standard knowledge distillation methods Hinton et al. (2015), with significant gain in training efficiency and storage reduction while maintaining comparable model accuracy to original large models. To the authors’ best knowledge, this is the first known effort for BERT model compression.

2 Related Work

Language Model Pre-training

Pre-training has been widely applied to universal language representation learning. Previous work can be divided into two main categories: () feature-based approach; () fine-tuning approach.

Feature-based methods mainly focus on learning: () context-independent word representation (e.g., word2vec Mikolov et al. (2013), GloVe Pennington et al. (2014), FastText Bojanowski et al. (2017)); () sentence-level representation (e.g., Kiros et al. (2015); Conneau et al. (2017); Logeswaran and Lee (2018)); and () contextualized word representation (e.g., Cove McCann et al. (2017), ELMo Peters et al. (2018)). Specifically, ELMo Peters et al. (2018) learns high-quality, deep contextualized word representation using bidirectional language model, which can be directly plugged into standard NLU models for performance boosting.

On the other hand, fine-tuning approaches mainly pre-train a language model (e.g., GPT Radford et al. (2018), BERT Devlin et al. (2018)) on a large corpus with an unsupervised objective, and then fine-tune the model with in-domain labeled data for downstream applications Dai and Le (2015); Howard and Ruder (2018). Specifically, BERT is a large-scale language model consisting of multiple layers of Transformer blocks Vaswani et al. (2017). BERT-Base has 12 layers of Transformer and 110 million parameters, while BERT-Large has 24 layers of Transformer and 330 million parameters. By pre-training via masked language modeling and next sentence prediction, BERT has achieved state-of-the-art performance on a wide-range of NLU tasks, such as the GLUE benchmark Wang et al. (2018) and SQuAD Rajpurkar et al. (2016).

However, these modern pre-trained language models contain millions of parameters, which hinders their application in practice where computational resource is limited. In this paper, we aim at addressing this critical and challenging problem, taking BERT as an example, i.e., how to compress a large BERT model into a shallower one without sacrificing performance. Besides, the proposed approach can also be applied to other large-scale pre-trained language models, such as recently proposed XLNet Yang et al. (2019) and RoBERTa Liu et al. (2019b).

Model Compression & Knowledge Distillation

Our focus is model compression, i.e.,

making deep neural networks more compact

Han et al. (2016); Cheng et al. (2015). A similar line of work has focused on accelerating deep network inference at test time Vetrov et al. (2017) and reducing model training time Huang et al. (2016).

A conventional understanding is that a large number of connections (weights) is necessary for training deep networks Denil et al. (2013); Zhai et al. (2016). However, once the network has been trained, there will be a high degree of parameter redundancy. Network pruning Han et al. (2015); He et al. (2017), in which network connections are reduced or sparsified, is one common strategy for model compression. Another direction is weight quantization Gong et al. (2014); Polino et al. (2018), in which connection weights are constrained to a set of discrete values, allowing weights to be represented by fewer bits. However, most of these pruning and quantization approaches perform on convolutional networks. Only a few work are designed for rich structural information such as deep language models Changpinyo et al. (2017).

Knowledge distillation Hinton et al. (2015) aims to compress a network with a large set of parameters into a compact and fast-to-execute model. This can be achieved by training a compact model to imitate the soft output of a larger model. Romero et al. (2015) further demonstrated that intermediate representations learned by the large model can serve as hints to improve the training process and the final performance of the compact model. Chen et al. (2015) introduced techniques for efficiently transferring knowledge from an existing network to a deeper or wider network. More recently, Liu et al. (2019a) used knowledge from ensemble models to improve single model performance on NLU tasks. Tan et al. (2019) tried knowledge distillation for multilingual translation. Different from the above efforts, we investigate the problem of compressing large-scale language models, and propose a novel patient knowledge distillation approach to effectively transferring knowledge from a teacher to a student model.

3 Patient Knowledge Distillation

In this section, we first introduce a vanilla knowledge distillation method for BERT compression (Section 3.1), then present the proposed Patient Knowledge Distillation (Section 3.2) in details.

Problem Definition

The original large teacher network is represented by a function , where is the input to the network, and denotes the model parameters. The goal of knowledge distillation is to learn a new set of parameters for a shallower student network , such that the student network achieves similar performance to the teacher, with much lower computational cost. Our strategy is to force the student model to imitate outputs from the teacher model on the training dataset with a defined objective .

3.1 Distillation Objective

In our setting, the teacher is defined as a deep bidirectional encoder, e.g., BERT, and the student is a lightweight model with fewer layers. For simplicity, we use BERT to denote a model with layers of Transformers. Following the original BERT paper Devlin et al. (2018), we also use BERT-Base and BERT-Large to denote BERT and BERT, respectively.

Assume are training samples, where is the -th input instance for BERT, and is the corresponding ground-truth label. BERT first computes a contextualized embedding

. Then, a softmax layer

for classification is applied to the embedding of BERT output, where is a weight matrix to be learned.

To apply knowledge distillation, first we need to train a teacher network. For example, to train a 12-layer BERT-Base as the teacher model, the learned parameters are denoted as:


where the superscript denotes parameters in the teacher model, denotes set , denotes the cross-entropy loss for the teacher training, and denotes parameters of BERT.

The output probability for any given input can be formulated as:


where denotes the probability output from the teacher. is fixed as soft labels, and

is the temperature used in KD, which controls how much to rely on the teacher’s soft predictions. A higher temperature produces a more diverse probability distribution over classes 

Hinton et al. (2015). Similarly, let denote parameters to be learned for the student model, and denote the corresponding probability output from the student model. Thus, the distance between the teacher’s prediction and the student’s prediction can be defined as:


where is a class label and denotes the set of class labels.

Besides encouraging the student model to imitate the teacher’s behavior, we can also fine-tune the student model on target tasks, where task-specific cross-entropy loss is included for model training:


Thus, the final objective function for knowledge distillation can be formulated as:


where is the hyper-parameter that balances the importance of the cross-entropy loss and the distillation loss.

Figure 1: Model architecture of the proposed Patient Knowledge Distillation approach to BERT model compression. (Left) PKD-Skip: the student network learns the teacher’s outputs in every 2 layers. (Right) PKD-Last: the student learns the teacher’s outputs from the last 6 layers. Trm: Transformer.

3.2 Patient Teacher for Model Compression

Using a weighted combination of ground-truth labels and soft predictions from the last layer of the teacher network, the student network can achieve comparable performance to the teacher model on the training set. However, with the number of epochs increasing, the student model learned with this vanilla KD framework quickly reaches saturation on the test set (see Figure 2 in Section 4).

One hypothesis is that overfitting during knowledge distillation may lead to poor generalization. To mitigate this issue, instead of forcing the student to learn only from the logits of the last layer, we propose a “patient” teacher-student mechanism to distill knowledge from the teacher’s intermediate layers as well. Specifically, we investigate two patient distillation strategies: () PKD-Skip: the student learns from every layers of the teacher (Figure 1: Left); and () PKD-Last: the student learns from the last layers of the teacher (Figure 1: Right).

Learning from the hidden states of all the tokens is computationally expensive, and may introduce noise. In the original BERT implementation Devlin et al. (2018), prediction is performed by only using the output from the last layer’s [CLS] token. In some variants of BERT, like SDNet  Zhu et al. (2018), a weighted average of all layers’ [CLS] embeddings is applied. In general, the final logit can be computed based on , where could be either learned parameters or a pre-defined hyper-parameter, is the embedding of [CLS] from the hidden layer , and is the number of hidden layers. Derived from this, if the compressed model can learn from the representation of [CLS] in the teacher’s intermediate layers for any given input, it has the potential of gaining a generalization ability similar to the teacher model.

Motivated by this, in our Patient-KD framework, the student is cultivated to imitate the representations only for the [CLS] token in the intermediate layers, following the intuition aforementioned that the [CLS] token is important in predicting the final labels. For an input , the outputs of the [CLS] tokens for all the layers are denoted as:


We denote the set of intermediate layers to distill knowledge from as . Take distilling from BERT to BERT as an example. For the PKD-Skip strategy, ; and for the PKD-Last strategy, . Note that for both cases, because the output from the last layer (e.g., Layer 12 for BERT-Base) is omitted since its hidden states are connected to the softmax layer, which is already included in the KD loss defined in Eqn. (5). In general, for BERT student with layers, always equals to .

The additional training loss introduced by the patient teacher is defined as the mean-square loss between the normalized hidden states:


where denotes the number of layers in the student network, is the number of training samples, and the superscripts and in indicate the student and the teacher model, respectively. Combined with the KD loss introduced in Section 3.1, the final objective function can be formulated as:


where is another hyper-parameter that weights the importance of the features for distillation in the intermediate layers.

4 Experiments

In this section, we describe our experiments on applying the proposed Patient-KD approach to four different NLP tasks. Details on the datasets and experimental results are provided in the following sub-sections.

(67k) (3.7k) (364k) (393k) (393k) (105k) (2.5k)
BERT (Google) 93.5 88.9/84.8 71.2/89.2 84.6 83.4 90.5 66.4
BERT (Teacher) 94.3 89.2/85.2 70.9/89.0 83.7 82.8 90.4 69.1
BERT-FT 90.7 85.9/80.2 69.2/88.2 80.4 79.7 86.7 63.6
BERT-KD 91.5 86.2/80.6 70.1/88.8 80.2 79.8 88.3 64.7
BERT-PKD 92.0 85.0/79.9 70.7/88.9 81.5 81.0 89.0 65.5
BERT-FT 86.4 80.5/72.6 65.8/86.9 74.8 74.3 84.3 55.2
BERT-KD 86.9 79.5/71.1 67.3/87.6 75.4 74.8 84.0 56.2
BERT-PKD 87.5 80.7/72.5 68.1/87.8 76.7 76.3 84.7 58.2
Table 1: Results from the GLUE test server. The best results for 3-layer and 6-layer models are in-bold. Google’s submission results are obtained from official GLUE leaderboard. BERT (Teacher) is our own implementation of the BERT teacher model. FT represents direct fine-tuning on each dataset without using knowledge distillation. KD represents using a vanilla knowledge distillation method. And PKD represents our proposed Patient-KD-Skip approach. Results show that PKD-Skip outperforms the baselines on almost all the datasets except for MRPC. The numbers under each dataset indicate the corresponding number of training samples.
Figure 2: Accuracy on the training and dev sets of QNLI and MNLI datasets, by directly applying vanilla knowledge distillation (KD) and the proposed Patient-KD-Skip. The teacher and the student networks are BERT and BERT, respectively. The student network learned with vanilla KD quickly saturates on the dev set, while the proposed Patient-KD starts to plateau only in a later stage.

4.1 Datasets

We evaluate our proposed approach on Sentiment Classification, Paraphrase Similarity Matching, Natural Language Inference, and Machine Reading Comprehension tasks. For Sentiment Classification, we test on Stanford Sentiment Treebank (SST-2)  Socher et al. (2013). For Paraphrase Similarity Matching, we use Microsoft Research Paraphrase Corpus (MRPC)  Dolan and Brockett (2005) and Quora Question Pairs (QQP)222 datasets. For Natural Language Inference, we evaluate on Multi-Genre Natural Language Inference (MNLI)  Williams et al. (2017), QNLI333The dataset is derived from Stanford Question Answer Dataset (SQuAD).  Rajpurkar et al. (2016), and Recognizing Textual Entailment (RTE).

More specifically, SST-2 is a movie review dataset with binary annotations, where the binary label indicates positive and negative reviews. MRPC contains pairs of sentences and corresponding labels, which indicate the semantic equivalence relationship between each pair. QQP is designed to predict whether a pair of questions is duplicate or not, provided by a popular online question-answering website Quora. MNLI is a multi-domain NLI task for predicting whether a given premise-hypothesis pair is entailment, contradiction or neural. Its test and development datasets are further divided into in-domain (MNLI-m) and cross-domain (MNLI-mm) splits to evaluate the generality of tested models. QNLI is a task for predicting whether a question-answer pair is entailment or not. Finally, RTE is based on a series of textual entailment challenges, created by General Language Understanding Evaluation (GLUE) benchmark Wang et al. (2018).

For the Machine Reading Comprehension task, we evaluate on RACE Lai et al. (2017), a large-scale dataset collected from English exams, containing 25,137 passages and 87,866 questions. For each question, four candidate answers are provided, only one of which is correct. The dataset is further divided into RACE-M and RACE-H, containing exam questions for middle school and high school students.

4.2 Baselines and Training Details

For experiments on the GLUE benchmark, since all the tasks can be considered as sentence (or sentence-pair) classification, we use the same architecture in the original BERT Devlin et al. (2018), and fine-tune each task independently.

For experiments on RACE, we denote the input passage as , the question as , and the four answers as . We first concatenate the tokens in and each , and arrange the input of BERT as [CLS] [SEP] [SEP] for each input pair , where [CLS] and [SEP] are the special tokens used in the original BERT. In this way, we can obtain a single logit value for each . At last, a softmax layer is placed on top of these four logits to obtain the normalized probability of each answer being correct, which is then used to compute the cross-entropy loss for modeling training.

We fine-tune BERT-Base (denoted as BERT) as the teacher model to compute soft labels for each task independently, where the pretrained model weights are obtained from Google’s official BERT’s repo444, and use 3 and 6 layers of Transformers as the student models (BERT and BERT), respectively. We initialize BERT with the first layers of parameters from pre-trained BERT-Base, where . To validate the effectiveness of our proposed approach, we first conduct direct fine-tuning on each task without using any soft labels. In order to reduce the hyper-parameter search space, we fix the number of hidden units in the final softmax layer as 768, the batch size as 32, and the number of epochs as 4 for all the experiments, with a learning rate from {5e-5, 2e-5, 1e-5}. The model with the best validation accuracy is selected for each setting.

Besides direct fine-tuning, we further implement a vanilla KD method on all the tasks by optimizing the objective function in Eqn. (5). We set the temperature as {5, 10, 20}, , and perform grid search over , and learning rate, to select the model with the best validation accuracy. For our proposed Patient-KD approach, we conduct additional search over from on all the tasks. Since there are so many hyper-parameters to learn for Patient KD, we fix and to the values used in the model with the best performance from the vanilla KD experiments, and only search over and learning rate.

BERT (PKD-Last) 91.9 85.1/79.5 70.5/88.9 80.9 81.0 88.2 65.0
BERT (PKD-Skip) 92.0 85.0/79.9 70.7/88.9 81.5 81.0 89.0 65.5
Table 2: Performance comparison between PKD-Last and PKD-Skip on GLUE benchmark.

4.3 Experimental Results

We submitted our model predictions to the official GLUE evaluation server to obtain results on the test data. Results are summarized in Table 1. Compared to direct fine-tuning and vanilla KD, our Patient-KD models with BERT and BERT students perform the best on almost all the tasks except MRPC. For MNLI-m and MNLI-mm, our 6-layer model improves 1.1% and 1.3% over fine-tune (FT) baselines; for QNLI and QQP, even though the gap between BERT-KD and BERT teacher is relatively small, our approach still succeeded in improving over both FT and KD baselines and further closing the gap between the student and the teacher models.

Furthermore, in 5 tasks out of 7 (SST-2 (-2.3% compared to BERT-Base teacher), QQP (-0.1%), MNLI-m (-2.2%), MNLI-mm (-1.8%), and QNLI (-1.4%)), the proposed 6-layer student coached by the patient teacher achieved similar performance to the original BERT-Base, demonstrating the effectiveness of our approach. Interestingly, all those 5 tasks have more than 60k training samples, which indicates that our method tends to perform better when there is a large amount of training data.

For the QQP task, we can further reduce the model size to 3 layers, where BERT-PKD can still have a similar performance to the teacher model. The learning curves on the QNLI and MNLI datasets are provided in Figure 2. The student model learned with vanilla KD quickly saturated on the dev set, while the proposed Patient-KD keeps learning from the teacher and improving accuracy, only starting to plateau in a later stage.

For the MRPC dataset, one hypothesis for the reason on vanilla KD outperforming our model is that the lack of enough training samples may lead to overfitting on the dev set. To further investigate, we repeat the experiments three times and compute the average accuracy on the dev set. We observe that fine-tuning and vanilla KD have a mean dev accuracy of 82.23% and 82.84%, respectively. Our proposed method has a higher mean dev accuracy of 83.46%, hence indicating that our Patient-KD method slightly overfitted to the dev set of MRPC due to the small amount of training data. This can also be observed on the performance gap between teacher and student on RTE in Table 5, which also has a small training set.

We further investigate the performance gain from two different patient teacher designs: PKD-Last vs. PKD-Skip. Results of both PKD variants on the GLUE benchmark (with BERT as the student) are summarized in Table 2. Although both strategies achieved improvement over the vanilla KD baseline (see Table 1), PKD-Skip performs slightly better than PKD-Last. Presumably, this might be due to the fact that distilling information across every layers captures more diverse representations of richer semantics from low-level to high-level, while focusing on the last layers tends to capture relatively homogeneous semantic information.

BERT (Leaderboard) 65.00 71.70 62.30
BERT (Teacher) 65.30 71.17 62.89
BERT-FT 54.32 61.07 51.54
BERT-KD 58.74 64.69 56.29
BERT-PKD-Skip 60.34 66.57 57.78
Table 3: Results on RACE test set. BERT (Leaderboard) denotes results extracted from the official leaderboard ( BERT (Teacher) is our own implementation. Results of BERT are not included due to the large gap between the teacher and the BERT student.
# Layers # Param (Emb) # Params (Trm) Total Params Inference Time (s)
3 23.8M 21.3M 45.7M (2.40) 27.35 (3.73)
6 23.8M 42.5M 67.0M (1.64) 52.51 (1.94)
12 23.8M 85.1M 109M (1) 101.89 (1)
Table 4: The number of parameters and inference time for BERT, BERT and BERT. Parameters in Transformers (Trm) grow linearly with the increase of layers. Note that the summation of # Param (Emb) and # Param (Trm) does not exactly equal to Total Params, because there is another softmax layer with 0.6M parameters.

Results on RACE are reported in Table 3, which shows that the Vanilla KD method outperforms direct fine-tuning by 4.42%, and our proposed patient teacher achieves further 1.6% performance lift, which again demonstrates the effectiveness of Patient-KD.

4.4 Analysis of Model Efficiency

We have demonstrated that the proposed Patient-KD method can effectively compress BERT into BERT models without performance sacrifice. In this section, we further investigate the efficiency of Patient-KD on storage saving and inference-time speedup. Parameter statistics and inference time are summarized in Table 4

. All the models share the same embedding layer with 24 millon parameters that map a 30k-word vocabulary to a 768-dimensional vector, which leads to 1.64 and 2.4 times of machine memory saving from BERT

and BERT, respectively.

To test the inference speed, we ran experiments on 105k samples from QNLI training set Rajpurkar et al. (2016). Inference is performed on a single Titan RTX GPU with batch size set to 128, maximum sequence length set to 128, and FP16 activated. The inference time for the embedding layer is negligible compared to the Transformer layers. Results in Table 4 show that the proposed Patient-KD approach achieves an almost linear speedup, 1.94 and 3.73 times for BERT and BERT, respectively.

Setting Teacher Student SST-2 MRPC QQP MNLI-m MNLI-mm QNLI RTE
N/A N/A BERT (Teacher) 94.3 89.2/85.2 70.9/89.0 83.7 82.8 90.4 69.1
N/A N/A BERT (Teacher) 94.3 88.2/84.3 71.9/89.4 85.7 84.8 92.2 72.8
#1 BERT BERT[Base]-KD 91.5 86.2/80.6 70.1/88.8 79.7 79.1 88.3 64.7
#2 BERT BERT[Base]-KD 91.2 86.1/80.7 69.4/88.6 80.2 79.7 87.5 65.7
#3 BERT BERT[Large]-KD 89.6 79.0/70.0 65.0/86.7 75.3 74.6 83.4 53.7
#4 BERT BERT[Large]-PKD 89.8 77.8/68.3 67.1/87.9 77.2 76.7 83.8 53.2
Table 5: Performance comparison with different teacher and student models. BERT[Base]/[Large] denotes a BERT model with a BERT-Base/Large Transformer in each layer. For PKD, we use the PKD-Skip architecture.

4.5 Does a Better Teacher Help?

To evaluate the effectiveness of the teacher model in our Patient-KD framework, we conduct additional experiments to measure the difference between BERT-Base teacher and BERT-Large teacher for model compression.

Each Transformer layer in BERT-Large has 12.6 million parameters, which is much larger than the Transformer layer used in BERT-Base. For a compressed BERT model with 6 layers, BERT with BERT-Base Transformer (denoted as BERT[Base]) has only 67.0 million parameters, while BERT with BERT-Large Transformer (denoted as BERT[Large]) has 108.4 million parameters. Since the size of the [CLS] token embedding is different between BERT-Large and BERT-Base, we cannot directly compute the patient teacher loss (7) for BERT[Base] when BERT-Large is used as teacher. Hence, in the case where the teacher is BERT-Large and the student is BERT[Base], we only conduct experiments in the vanilla KD setting.

Results are summarized in Table 5. When the teacher changes from BERT to BERT (i.e., Setting #1 vs. #2), there is not much difference between the students’ performance. Specifically, BERT teacher performs better on SST-2, QQP and QNLI, while BERT performs better on MNLI-m, MNLI-mm and RTE. Presumably, distilling knowledge from a larger teacher requires a larger training dataset, thus better results are observed on MNLI-m and MNLI-mm.

We also report results on using BERT-Large as the teacher and BERT[Large] as the student. Interestingly, when comparing Setting #1 with #3, BERT[Large] performs much worse than BERT[Base] even though a better teacher is used in the former case. The BERT[Large] student also has 1.6 times more parameters than BERT[Base]. One intuition behind this is that the compression ratio for the BERT[Large] model is 4:1 (24:6), which is larger than the ratio used for the BERT[Base] model (2:1 (12:6)). The higher compression ratio renders it more challenging for the student model to absorb important weights.

When comparing Setting # 2 and #3, we observe that even when the same large teacher is used, BERT[Large] still performs worse than BERT[Base]. Presumably, this may be due to initialization mismatch. Ideally, we should pre-train BERT[Large] and BERT[Base] from scratch, and use the weights learned from the pre-training step for weight initialization in KD training. However, due to computational limits of training BERT from scratch, we only initialize the student model with the first six layers of BERT or BERT. Therefore, the first six layers of BERT may not be able to capture high-level features, leading to worse KD performance.

Finally, when comparing Setting #3 vs. #4, where for setting #4 we use Patient-KD-Skip instead of vanilla KD, we observe a performance gain on almost all the tasks, which indicates Patient-KD is a generic approach independent of the selection of the teacher model (BERT or BERT).

5 Conclusion

In this paper, we propose a novel approach to compressing a large BERT model into a shallow one via Patient Knowledge Distillation. To fully utilize the rich information in deep structure of the teacher network, our Patient-KD approach encourages the student model to patiently learn from the teacher through a multi-layer distillation process. Extensive experiments over four NLP tasks demonstrate the effectiveness of our proposed model.

For future work, we plan to pre-train BERT from scratch to address the initialization mismatch issue, and potentially modify the proposed method such that it could also help during pre-training. Designing more sophisticated distance metrics for loss functions is another exploration direction. We will also investigate Patient-KD in more complex settings such as multi-task learning and meta learning.


  • P. Bojanowski, E. Grave, A. Joulin, and T. Mikolov (2017) Enriching word vectors with subword information. TACL. Cited by: §2.
  • S. Changpinyo, M. Sandler, and A. Zhmoginov (2017)

    The power of sparsity in convolutional neural networks

    arXiv preprint arXiv:1702.06257. Cited by: §2.
  • T. Chen, I. J. Goodfellow, and J. Shlens (2015) Net2Net: accelerating learning via knowledge transfer. arXiv preprint arXiv:1511.05641. Cited by: §2.
  • Y. Cheng, F. X. Yu, R. S. Feris, S. Kumar, A. Choudhary, and S. Chang (2015) An exploration of parameter redundancy in deep networks with circulant projections. In ICCV, Cited by: §2.
  • A. Conneau, D. Kiela, H. Schwenk, L. Barrault, and A. Bordes (2017) Supervised learning of universal sentence representations from natural language inference data. In EMNLP, Cited by: §2.
  • A. M. Dai and Q. V. Le (2015) Semi-supervised sequence learning. In NIPS, Cited by: §2.
  • M. Denil, B. Shakibi, L. Dinh, M. Ranzato, and N. de Freitas (2013)

    Predicting parameters in deep learning

    In NIPS, Cited by: §2.
  • J. Devlin, M. Chang, K. Lee, and K. Toutanova (2018) Bert: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805. Cited by: §1, §2, §3.1, §3.2, §4.2.
  • W. B. Dolan and C. Brockett (2005) Automatically constructing a corpus of sentential paraphrases. In Proceedings of the Third International Workshop on Paraphrasing, Cited by: §4.1.
  • Y. Gong, L. Liu, M. Yang, and L. D. Bourdev (2014) Compressing deep convolutional networks using vector quantization. arXiv preprint arXiv:1412.6115. Cited by: §2.
  • S. Han, H. Mao, and W. J. Dally (2016) Deep compression: compressing deep neural networks with pruning, trained quantization and huffman coding. In ICLR, Cited by: §2.
  • S. Han, J. Pool, J. Tran, and W. J. Dally (2015) Learning both weights and connections for efficient neural networks. In NIPS, Cited by: §2.
  • Y. He, X. Zhang, and J. Sun (2017) Channel pruning for accelerating very deep neural networks. In ICCV, Cited by: §2.
  • G. Hinton, O. Vinyals, and J. Dean (2015) Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531. Cited by: §1, §1, §2, §3.1.
  • J. Howard and S. Ruder (2018) Universal language model fine-tuning for text classification. In ACL, Cited by: §2.
  • G. Huang, Y. Sun, Z. Liu, D. Sedra, and K. Q. Weinberger (2016) Deep networks with stochastic depth. In ECCV, Cited by: §2.
  • R. Kiros, Y. Zhu, R. R. Salakhutdinov, R. Zemel, R. Urtasun, A. Torralba, and S. Fidler (2015) Skip-thought vectors. In NIPS, Cited by: §2.
  • G. Lai, Q. Xie, H. Liu, Y. Yang, and E. Hovy (2017) Race: large-scale reading comprehension dataset from examinations. arXiv preprint arXiv:1704.04683. Cited by: §1, §4.1.
  • X. Liu, P. He, W. Chen, and J. Gao (2019a) Improving multi-task deep neural networks via knowledge distillation for natural language understanding. arXiv preprint arXiv:1904.09482. Cited by: §2.
  • Y. Liu, M. Ott, N. Goyal, J. Du, M. Joshi, D. Chen, O. Levy, M. Lewis, L. Zettlemoyer, and V. Stoyanov (2019b) RoBERTa: A robustly optimized BERT pretraining approach. arXiv preprint arXiv:1907.11692. Cited by: §2.
  • L. Logeswaran and H. Lee (2018) An efficient framework for learning sentence representations. In ICLR, Cited by: §2.
  • L. Lu, M. Guo, and S. Renals (2017) Knowledge distillation for small-footprint highway networks. In ICASSP, Cited by: §1.
  • B. McCann, J. Bradbury, C. Xiong, and R. Socher (2017) Learned in translation: contextualized word vectors. In NIPS, Cited by: §2.
  • T. Mikolov, I. Sutskever, K. Chen, G. S. Corrado, and J. Dean (2013) Distributed representations of words and phrases and their compositionality. In NIPS, Cited by: §2.
  • J. Pennington, R. Socher, and C. Manning (2014) Glove: global vectors for word representation. In EMNLP, Cited by: §2.
  • M. E. Peters, M. Neumann, M. Iyyer, M. Gardner, C. Clark, K. Lee, and L. Zettlemoyer (2018) Deep contextualized word representations. In NAACL, Cited by: §1, §2.
  • A. Polino, R. Pascanu, and D. Alistarh (2018) Model compression via distillation and quantization. arXiv preprint arXiv:1802.05668. Cited by: §2.
  • A. Radford, K. Narasimhan, T. Salimans, and I. Sutskever (2018) Improving language understanding by generative pre-training. arXiv. Cited by: §1, §2.
  • P. Rajpurkar, J. Zhang, K. Lopyrev, and P. Liang (2016) Squad: 100,000+ questions for machine comprehension of text. In EMNLP, Cited by: §2, §4.1, §4.4.
  • A. Romero, N. Ballas, S. E. Kahou, A. Chassang, C. Gatta, and Y. Bengio (2015) FitNets: hints for thin deep nets. In ICLR, Cited by: §2.
  • B. B. Sau and V. N. Balasubramanian (2016) Deep model compression: distilling knowledge from noisy teachers. arXiv preprint arXiv:1610.09650. Cited by: §1.
  • R. Socher, A. Perelygin, J. Wu, J. Chuang, C. D. Manning, A. Ng, and C. Potts (2013) Recursive deep models for semantic compositionality over a sentiment treebank. In EMNLP, Cited by: §1, §4.1.
  • X. Tan, Y. Ren, D. He, T. Qin, Z. Zhao, and T. Liu (2019)

    Multilingual neural machine translation with knowledge distillation

    In ICLR, Cited by: §2.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin (2017) Attention is all you need. In NIPS, Cited by: §2.
  • D. P. Vetrov, J. Huang, L. Zhang, M. Collins, M. Figurnov, R. Salakhutdinov, and Y. Zhu (2017) Spatially adaptive computation time for residual networks. In CVPR, Cited by: §2.
  • A. Wang, A. Singh, J. Michael, F. Hill, O. Levy, and S. R. Bowman (2018) Glue: a multi-task benchmark and analysis platform for natural language understanding. arXiv preprint arXiv:1804.07461. Cited by: §2, §4.1.
  • A. Williams, N. Nangia, and S. R. Bowman (2017) A broad-coverage challenge corpus for sentence understanding through inference. arXiv preprint arXiv:1704.05426. Cited by: §1, §4.1.
  • Z. Yang, Z. Dai, Y. Yang, J. Carbonell, R. Salakhutdinov, and Q. V. Le (2019) XLNet: generalized autoregressive pretraining for language understanding. arXiv preprint arXiv:1906.08237. Cited by: §2.
  • S. Zhai, Y. Cheng, W. Lu, and Z. M. Zhang (2016) Doubly convolutional neural networks. In NIPS, Cited by: §2.
  • C. Zhu, M. Zeng, and X. Huang (2018) SDNet: contextualized attention-based deep network for conversational question answering. arXiv preprint arXiv:1812.03593. Cited by: §3.2.