Knowledge Distillation in Document Retrieval

11/11/2019 ∙ by Siamak Shakeri, et al. ∙ Amazon University of Illinois at Urbana-Champaign 0

Complex deep learning models now achieve state of the art performance for many document retrieval tasks. The best models process the query or claim jointly with the document. However for fast scalable search it is desirable to have document embeddings which are independent of the claim. In this paper we show that knowledge distillation can be used to encourage a model that generates claim independent document encodings to mimic the behavior of a more complex model which generates claim dependent encodings. We explore this approach in document retrieval for a fact extraction and verification task. We show that by using the soft labels from a complex cross attention teacher model, the performance of claim independent student LSTM or CNN models is improved across all the ranking metrics. The student models we use are 12x faster in runtime and 20x smaller in number of parameters than the teacher

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Deep learning models have shown promising results in the field of document retrieval. Specifically attention based models such as (Vaswani et al. (2017a); Devlin et al. (2018); Xiong et al. (2016)) demonstrate clear improvements in the performance of neural models in question answering tasks. In such models, rich encodings of claim (question) and document (answers) are generated using various attention mechanisms. A challenge when using such models in large scale document retrieval systems is the lack of separation between document and claim encodings, making it infeasible to pre-index and retrieve the document encodings efficiently during runtime. In this paper we explore the use of knowledge distillation as a means to transfer the embedded attention information to a simpler attention-free neural model.

Knowledge distillation using posterior probabilities of one model to improve the performance of another model has been widely studied (

Bucila et al. (2006)). (Hinton et al. (2015)) discusses using aggregate posteriors of an ensemble of acoustic deep models to improve the performance of a single model. (Kim and Rush (2016)

) suggests using word-level knowledge distillation in Neural Machine Translation. (

Zagoruyko and Komodakis (2016)

) defines an attention mechanism in Convolutional Neural Networks and uses knowledge distillation to improve the performance of a student model by forcing it to mimic such mechanism. (

Romero et al. (2014)) explores training a student model which is deeper and thinner than the teacher while utilizing both softmax posteriors and intermediate layer representations of the teacher. (Mou et al. (2015)) has experimented with distilling knowledge from a large embedding to a smaller one. (Hu et al. (2018)) have used an ensemble of models as the teacher model, similar to (Hinton et al. (2015)), to guide the alignments of the student model in machine reading comprehension.

We conduct knowledge distillation experiments on document retrieval for the fact extraction and verification task introduced in (Thorne et al. (2018)). In order to make our approach generic, no restrictions are imposed on the type of attention that the teacher model can employ. Furthermore, the student model does not need to be the same type as the teacher model, e.g. the teacher model can be a CNN based model while the student is an LSTM. The student models that are experimented with in this paper are both faster (up to 12x) and smaller (up to 20x) than the teacher model. We start with the problem definition and task setup in section 2. Next we look at model training with knowledge distillation in section 3 and present experimental results in section 4.1.

2 Problem Description

In the knowledge distillation or teacher-student training framework, the student model is the target model to be trained using annotation labels and information such as posteriors or hidden unit activations from a complex teacher model. In this paper we consider single layer CNN and LSTM models with a linear layer on top as our student models. Specifically, we avoid models that require interactions between claim and document to create encodings of both the document and claim. Teacher model is a more complex model than the student model both in terms of the number of parameters and the structure of the network. As shown in 1, the teacher model uses claim dependent document encodings.

Figure 1: Teacher vs Student Model

The document retrieval task can be considered as a classification task: Given a pair, shortened as , assign a score indicating the relevancy of the document to the claim. For each claim, the documents are sorted based on the assigned score, and the top ones are picked. We further discuss the metrics used in later sections.

2.1 Data

The publicly available FEVER dataset is used in this paper (Thorne et al. (2018)

). In FEVER, a corpus of Wikipedia documents is given, and the task is to classify a given claim as

supported, refuted or not enough info using the given corpus. Three sub tasks are defined: document retrieval, sentence retrieval and textual entailment. In this paper, we focus on the document retrieval task. The corpus consists of 5.4 million pages, and more than 175,000 claims. Each sample in FEVER consists of a claim, all the relevant documents, all the relevant sentences in those documents, and the annotated label.

For training and evaluating our model, we construct tuples. For each claim, all the annotated relevant documents are labeled as positive samples. DrQA (Chen et al. (2017)) is employed to find the

nearest documents of a claim from the entire corpus based on cosine similarity of TF-IDF vectors. The top results returned by DrQA that are not annotated as relevant are labeled as negative samples. The rationale behind this is to have most similar irrelevant documents to the claim as negative samples. This makes the resulting dataset to be non-trivial. Each claim has a fixed number of documents

C. The claims are split into train, dev and test sets, each having 145000, 20000 and 10000 claims respectively.

Table 1 shows given a certain C, what percentage of claims will have all the annotated relevant documents. We use C=10, as it will cover vast majority of the claims.

C 1 2 5 10 15
claims 87.37 96.90 99.27 99.90 99.99
Table 1: Percentage of claims with all the relevant documents versus C.

2.2 Metrics

Being a ranking task, Discounted Cumulative Gain (DCG) and Recall at top k values are the performance metrics. In order to aggregate per-claim recall values, we define the followings:

  • recallmacro(k)=

  • recallmicro(k)=

Where, indicates the number of relevant documents for claim , is if document at th position in the sorted documents list is relevant to claim and otherwise. indicates the total number of claims.

3 Training

In this section we discuss the training setup for our knowledge distillation experiments.

3.1 Teacher Model

As our teacher models, we experimented with architectures that have been shown to give state of the art performance on the SQuAD task (Rajpurkar et al. (2016)). Two models that performed the best were DCN (Xiong et al. (2016)) without highway maxout network layers and Transformer (Vaswani et al. (2017b)). Note that the purpose of our study is not to find the best teacher model, but a teacher model that significantly outperforms the baseline student model. Table 2 shows the performance of these two models. We picked the DCN model as the teacher for further experiments. Please note that the models were modified to be used in our classification task.

Model R(1) Rmicro(3) Rmacro(3) Rmicro(5) Rmacro(5) DCG
DCN 63.79 82.21 84.83 92.12 93.61 91.56
Transformer 43.29 72.25 75.49 89.47 91.25 89.72
Table 2: Teacher Models Metrics

3.2 Student Model

The following candidate models were employed as student:

SimpleCNN: CNN and Maxpooling layers are separately applied to claim and document to create the encodings. A linear layer is used to join the encodings.

SimpleLSTM: Recurrent layers separately applied to claim and document to create the encodings. Similar to SimpleCNN, A linear layer is used to join the encodings.

The final claim and document encodings are independent of each other, as mentioned in 2.

3.3 Objective Function

In order to train the student model, the trained teacher model is run over the entire training, dev and test sets, and similarity score of each pair is recorded. When training the student network, these similarity scores alongside the annotated labels are used. We define the following losses:

(1)
(2)
(3)

Where,

denote teacher logits, student logits and true label of document

and claim , respectively. is a hyper parameter that dictates the importance of . Setting it to 0 indicates no teacher training. (Temperature) is another hyper parameter that indicates how much smoothing of the classification scores is done. Setting it to 0 is equal to picking the largest value only.

4 Experimental Results

4.1 Full Training Data

SoftLoss//T R(1) Rmicro(3) Rmacro(3) Rmicro(5) Rmacro(5) DCG
No Teacher 42.27 72.49 69.43 87.46 85.46 79.72
MSE/1.0/4 44.88 74.79 71.66 89.31 87.71 81.38
MSE/0.2/3 46.16 76.77 73.69 90.23 88.42 82.13
MSE/0.65/4 44.15 74.28 71.26 89.35 87.74 80.86
CE/0.2/3 43.57 75.89 72.89 89.42 87.56 80.93
CE/0.5/6 46.91 74.36 71.22 88.57 86.66 81.87
CE/0.4/5 40.71 72.75 69.49 88.66 86.53 78.87
Table 3: Teacher Student Training with SimpleLSTM. Softloss//T indicates the softloss type used, SoftLoss importance and temperature, respectively.
SoftLoss//T R(1) Rmicro(3) Rmacro(3) Rmicro(5) Rmacro(5) DCG
No Teacher 44.13 70.17 65.98 84.28 81.2 83.95
MSE/0.8/6 42.95 70.3 66.4 84.84 82.39 83.52
MSE/1.0/1 41.07 68.77 64.4 84.63 81.74 82.26
MSE/0.8/2 40.31 67.89 63.5 84.62 81.55 81.81
CE/0.5/3 43.59 71.16 67.13 85.11 82.28 83.94
CE/0.65/1 41.96 67.18 63.07 83.29 80.2 82.4
CE/0.3/2 41.0 66.67 62.51 82.14 79.01 81.65
Table 4: Teacher Student Training with SimpleCNN
Model R(1) Rmicro(3) Rmacro(3) Rmicro(5) Rmacro(5) DCG
Best Diff A 3.89 4.28 4.26 2.77 2.96 2.41
MSE Avg A 45.06 75.28 72.20 89.63 87.96 81.46
CE Avg A 43.73 74.33 71.2 88.88 86.92 81.22
Best Diff B -0.54 0.99 1.15 0.83 1.08 -0.01
MSE Avg B 41.44 68.99 64.77 84.7 81.89 82.53
CE Avg B 42.18 68.34 64.24 83.51 80.50 82.66
Table 5: Teacher Student Training Averages. Best Diff indicates the difference between the best student model when teacher training is used vs when training with only hard labels. MSE and CE Avgs indicate average across the top three performing configurations when using MSE and CE softlosses, respectively. A refers to SimpleLSTM and B refers to SimpleCNN

.

SoftLoss//T R(1) Rmicro(3) Rmacro(3) Rmicro(5) Rmacro(5) DCG
No Teacher 39.69 69.95 66.56 85.57 83.21 77.6
CE/1.0/3 40.83 69.35 65.74 85.42 83.02 77.65
MSE/0.65/5 37.82 68.45 64.85 85.21 82.54 76.39
Table 6: Teacher Student Training with SimpleLSTM, Data
SoftLoss//T R(1) Rmicro(3) Rmacro(3) Rmicro(5) Rmacro(5) DCG
No Teacher 33.11 62.89 59.28 81.74 78.97 73.03
CE/1.0/2 32.66 64.57 61.16 83.7 80.96 73.23
MSE/0.3/6 33.29 64.64 61.28 82.08 79.36 73.25
Table 7: Teacher Student Training with SimpleLSTM, Data

We first experiment with training the teacher model (3) with the entire training data, and then using the posteriors in training the student models (3.2). Tables 4 and 5 show the top performing results. For each loss type and model, the top three performing models are picked. Some observations are as follows:

  • Using teacher student training improves the performance of student models.

  • Improvements resulting from knowledge distillation are larger with SimpleLSTM. This indicates that the LSTM module is more capable of benefiting from the information embedded in the soft labels provided by the teacher model, as well as its superiority in encoding sequential inputs (Tan et al. (2015)), (Bahdanau et al. (2014))

  • Best performance is achieved with temperatures 1. This shows that using smoothing of the logits is crucial. (Hinton et al. (2015)) also shows improvements using smoothing. In fact, none of top performing runs have been with .

  • The best performance is achieved when using a mix of teacher and hard labels. It can be seen that values generate the largest improvements. Using only soft labels from teacher model lacks the more credible annotated labels. Using only hard labels lacks the extra information provided by soft labels. This indicates soft and hard labels provide complementing information.

  • MSE vs CE: Results do not show any consistent pattern of distillation favoring one versus the other. For SimpleLSTM, MSE performs better, and for SimpleCNN, CE is a better choice.

4.2 Partial Training Data

It has been claimed (Hinton et al. (2015)) that teacher student training could act as a regularizer. We test this claim by designing an experiment where using a small portion of the entire dataset to train, we expect less overfitting when employing knowledge distillation versus when no teacher training is involved.

In this section, the results of experiments with only partial training data to train the student model are discussed. Please note that the teacher model is trained with the entire training dataset. We experimented with SimpleLSTM model training it with only and of the training set.

Tables 6 and 7 show the results. improvements with training and none with training are observed. These results do not support the hypothesis regarding the teacher training being a regularizer. The results show by having more training data, the benefits of the soft labels would become more evident.

4.3 Running time

The experiments were done on AWS EC2 instances running on Tesla V100 GPUs. PyTorch (

Paszke et al. (2017)) was employed to implement the neural models. The student model is up to 12x faster and 20x smaller in the number of parameters than the teacher. This is besides the reduction in computational complexity by reusing the indexed document encodings as discussed in 1. Particularly, if there are D documents and N claims where each document should be evaluated for each of the claims, computation cost of student model is O(N+D) while teacher’s is O(ND). Please note that the cost is in the unit of computing the encoding of claim or document. Table 8 shows detailed running time metrics.

Model # Parameters Loading Time Evaluation Time
Teacher 2.85M 5.35s 760.7s
SimpleLSTM 141k 5.1s 66.5s
Table 8: Run time Performance with Batch Size of 8

5 Conclusion

In this paper, we proposed using knowledge distillation to improve the performance of student models that generate claim independent document encodings in document retrieval task for factual verification. We experimented with various configurations when adding the teacher model posteriors to the student training, and results show that significant improvements can be achieved across the ranking metrics, without sacrificing runnig time advantages of simpler models. In future, we propose applying this work to a larger set of input documents (C) to replace the DrQA retriever with the student model.

References