DeepAI
Log In Sign Up

Forget Me Not: Reducing Catastrophic Forgetting for Domain Adaptation in Reading Comprehension

11/01/2019
by   Y. Xu, et al.
0

The creation of large-scale open domain reading comprehension data sets in recent years has enabled the development of end-to-end neural comprehension models with promising results. To use these models for domains with limited training data, one of the most effective approach is to first pretrain them on large out-of-domain source data and then fine-tune them with the limited target data. The caveat of this is that after fine-tuning the comprehension models tend to perform poorly in the source domain, a phenomenon known as catastrophic forgetting. In this paper, we explore methods that overcome catastrophic forgetting during fine-tuning without assuming access to data from the source domain. We introduce new auxiliary penalty terms and observe the best performance when a combination of auxiliary penalty terms is used to regularise the fine-tuning process for adapting comprehension models. To test our methods, we develop and release 6 narrow domain data sets that could potentially be used as reading comprehension benchmarks.

READ FULL TEXT VIEW PDF

page 1

page 2

page 3

page 4

11/25/2019

Unsupervised Domain Adaptation of Language Models for Reading Comprehension

This study tackles unsupervised domain adaptation of reading comprehensi...
10/07/2022

SpaceQA: Answering Questions about the Design of Space Missions and Space Craft Concepts

We present SpaceQA, to the best of our knowledge the first open-domain Q...
06/14/2022

Task Transfer and Domain Adaptation for Zero-Shot Question Answering

Pretrained language models have shown success in various areas of natura...
08/24/2019

Adversarial Domain Adaptation for Machine Reading Comprehension

In this paper, we focus on unsupervised domain adaptation for Machine Re...
11/13/2019

Unsupervised Domain Adaptation on Reading Comprehension

Reading comprehension (RC) has been studied in a variety of datasets wit...
11/01/2018

Progressive Memory Banks for Incremental Domain Adaptation

This paper addresses the problem of incremental domain adaptation (IDA)....
08/10/2022

Continual Machine Reading Comprehension via Uncertainty-aware Fixed Memory and Adversarial Domain Adaptation

Continual Machine Reading Comprehension aims to incrementally learn from...

Introduction

Reading comprehension (RC) is the task of answering a question given a context passage. Related to Question-Answering (QA), RC is seen as a module in the full QA pipeline, where it assumes a related context passage has been extracted and the goal is to produce an answer based on the context. In recent years, the creation of large-scale open domain comprehension data sets [27, 15, 18, 5, 17, 8] has spurred the development of a host of end-to-end neural comprehension systems with promising results.

In spite of these successes, it is difficult to train these modern comprehension systems on narrow domain data (e.g. biomedical), as these models often have a large number of parameters. A better approach is to transfer knowledge via fine-tuning, i.e. by first pre-training the model using data from a large source domain and continue training it with examples from the small target domain. It is an effective strategy, although a fine-tuned model often performs poorly when it is re-applied to the source domain, a phenomenon known as catastrophic forgetting [4, 26, 7, 19]. This is generally not an issue if the goal is to optimise purely for the target domain, but in real-word applications where model robustness is an important quality, over-optimising for a development set often leads to unexpected poor performance when applied to test cases in the wild.

In this paper, we explore strategies to reduce forgetting for comprehension systems during domain adaption. Our goal is to preserve the source domain’s performance as much as possible, while keeping target domain’s performance optimal and assuming no access to the source data. We experiment with a number of auxiliary penalty terms to regularise the fine-tuning process for three modern RC models: QANet [28], decaNLP [13] and BERT [3]. We observe that combining different auxiliary penalty terms results in the best performance, outperforming benchmark methods that require source data.

Technically speaking, the methods we propose are not limited to domain transfer for reading comprehension. We also show that the methodology can be used for transferring to entirely different tasks. With that said, we focus on comprehension here because it is a practical problem in real world applications, where the target domain often has a small number of QA pairs and over-fitting occurs easily when we fine-tune based on a small development set. In this scenario, it is as important to develop a robust model as achieving optimal development performance.

To demonstrate the applicability of our approach, we apply topic modelling to msmarco [15] — a comprehension data set based on internet search queries — and collect examples that belong to a number of salient topics, producing 6 small to medium sized RC data sets for the following domains: biomedical, computing, film, finance, law and music. We focus on extractive RC, where the answer is a continuous sub-span in the context passage.111Although RC with free-form answers is arguably a more challenging and interesting task, evaluation is generally more difficult [9]. Scripts to generate the data sets are available at: https://github.com/ibm-aur-nlp/domain-specific-QA.

Related Work

Most large comprehension data sets are open-domain because non-experts can be readily recruited via crowdsourcing platforms to collect annotations. Development of domain-specific RC data sets, on the other hand, is costly due to the need of subject matter experts and as such the size of these data sets is typically limited. Examples include bioasq [23] in the biomedical domain, which has less than 3k QA pairs — orders of magnitude smaller compared to most large-scale open-domain data sets [15, 18, 5, 8].

[26] explore supervised domain adaptation for reading comprehension, by pre-training their model first on large open-domain comprehension data and fine-tuning it further on biomedical data. This approach improves the biomedical domain’s performance substantially compared to training the model from scratch. At the same time, its performance on source domain decreases dramatically due to catastrophic forgetting [4, 14, 20].

This issue of catastrophic forgetting is less of a problem when data from multiple domains or tasks are present during training. For example in [13], their model decaNLP is trained on 10 tasks simultaneously — all casted as a QA problem — and forgetting is minimal. For multi-domain adaptation, [2] and [6] propose using a K+1 model to capture domain-general pattern that is shared by K domains, resulting in a more robust model. Using multi-task learning to tackle catastrophic forgetting is effective and generates robust models. The drawback, however, is that when training for each new domain/task, data from the previous domains/tasks has to be available.

Several studies present methods to reduce forgetting with limited or no access to previous data [21, 10, 7, 22, 19]. Inspired by synaptic consolidation, [7] propose to selectively penalise parameter change during fine-tuning. Significant updates to parameters which are deemed important to the source task incur a large penalty. [10] introduce a gradient episodic memory (gem

) to allow beneficial transfer of knowledge from previous tasks. More specifically, a subset of data from previous tasks are stored in an episodic memory, against which reference gradient vectors are calculated and the angles with the gradient vectors for the current task is constrained to be between

and . [19] suggest combining gem with optimisation based meta-learning to overcome forgetting. Among these three methods, only that of [7] assumes zero access to previous data. In comparison, the latter two rely on access to a memory storing data from previous tasks, which is not always feasible in real-world applications (e.g. due to data privacy concerns).

max width= Partition Domain #Examples #Unique Q Mean C Mean Q Mean A Length Length Length

Train
ms-bm 22,134 21,902 70.9 6.4 13.7
ms-cp 3,021 3,011 67.2 5.5 18.9 ms-fm 3,522 3,481 65.8 6.4 6.5 ms-fn 6,790 6,720 71.9 6.4 14.0 ms-lw 3,105 3,078 64.7 6.2 18.5 ms-ms 2,517 2,480 68.6 6.4 6.6 bioasq 3,083 387 35.4 11.0 2.4 Dev ms-bm 4,743 4,730 71.2 6.4 13.7 ms-cp 647 646 65.4 5.3 19.6 ms-fm 755 751 65.9 6.6 5.9 ms-fn 1,455 1,453 71.6 6.5 14.4 ms-lw 665 664 65.8 6.2 20.0 ms-ms 539 536 69.2 6.4 6.1 bioasq 674 83 39.7 11.1 2.4 Test ms-bm 4,743 4,728 70.5 6.4 13.5 ms-cp 648 645 66.6 5.6 18.3 ms-fm 755 755 66.7 6.3 6.2 ms-fn 1,455 1,452 70.8 6.5 13.6 ms-lw 666 663 65.1 6.2 18.9 ms-ms 540 540 67.4 6.6 7.0 bioasq 631 84 34.9 13.2 2.9

Table 1: Statistics of our seven target domain data sets (Q: Question; C: Context; and A: Answer).

Data Set

We use squad v1.1 [18] as the source domain data for pre-training the comprehension model. It contains over 100K extractive (context, question, answer) triples with only answerable questions.

To create the target domain data, we leverage msmarco [15], a large RC data set where questions are sampled from Bing search queries and answers are manually generated by users based on passages in web documents. We apply LDA topic model [1] to passages in msmarco and learn 100 topics.222When collecting the passages, we include only those being selected as useful for answering a query (i.e. is_selected ). We tokenise the passages with Stanford CoreNLP [11] and use MALLET [12] for topic modelling. Given the topics, we label them and select 6 salient domains: biomedical (ms-bm), computing (ms-cp), film (ms-fm), finance (ms-fn), law (ms-lw) and music (ms-ms). A QA pair is categorised into one of these domains if its passage’s top-topic belongs to them. We create multiple (context, question, answer) training examples if a QA pair has multiple contexts,333We only consider context passages that are marked as being useful by annotators in the original data (i.e. is_selected ). and filter them to keep only extractive examples.444A (context, question, answer) triple is defined to be extractive if the answer has a case-insensitive match to the context.

In addition to the msmarco data sets, we also experiment with a real biomedical comprehension data set: bioasq [24]. Each question in bioasq is associated with a set of snippets as context, and the snippets are single sentences extracted from a scientific publication’s abstract/title in PubMed Central. There are four types of questions: factoid, list, yes/no, and summary. As our focus is on extractive RC, we use only the extractive factoid questions from bioasq. As before, we create multiple training examples for QA pairs with multiple contexts.

For each target domain, we split the examples into 70%/15%/15% training/development/test partitions.555Partitioning is done at the question level to ensure the same question does not appear in more than one partition. We present some statistics for the data sets in Table 1.

Methodology

We first pre-train a general domain RC model on squad, our source domain. Given the pre-trained model, we then perform fine-tuning (finetune) on the msmarco and bioasq data sets: 7 target domains in total. By fine-tuning we mean taking the pre-trained model parameters as initial parameters and update them accordingly based on data from the new domain. To reduce forgetting on the source domain (squad), we experiment with incorporating auxiliary penalty terms (e.g. L2 between new and old parameters) to the standard cross entropy loss to regularise the fine-tuning process.

We explore 3 modern RC models in our experiments: QANet [28]; decaNLP [13]; and BERT [3]. QANet is a Transformer-based [25] comprehension model, where the encoder consists of stacked convolution and self-attention layers. The objective of the model is to predict the position of the starting and ending indices of the answer words in the context. decaNLP is a recurrent network-based comprehension model trained on ten NLP tasks simultaneously, all casted as a question-answer problem. Much of decaNLP’s flexibility is due to its pointer-generator network, which allows it to generate words by extracting them from the question or context passages, or by drawing them from a vocabulary. BERT is a deep bi-directional encoder model based on Transformers. It is pre-trained on a large corpus in an unsupervised fashion using a masked language model and next-sentence prediction objective. To apply BERT to a specific task, the standard practice is to add additional output layers on top of the pre-trained BERT and fine-tune the whole model for the task. In our case for RC, 2 output layers are added: one for predicting the start index and another the end index. [3]

demonstrates that this transfer learning strategy produces state-of-the-art performance on a range of NLP tasks. For RC specifically,

BERT (BERT-Large) achieved an F1 score of 93.2 on squad, outperforming human performance by 2 points.

Note that BERT and QANet RC models are extractive models (goal is to predict 2 indices), while decaNLP is a generative model (goal is to generate the correct word sequence). Also, unlike QANet and decaNLP, BERT is not designed specifically for RC. It represents a growing trend in the literature where large models are pre-trained on big corpora and further adapted to downstream tasks.

To reduce the forgetting of source domain knowledge, we introduce auxiliary penalty terms to regularise the fine-tuning process. We favour this approach as it does not require storing data samples from the source domain. In general, there are two types of penalty: selective and non-selective. The former penalises the model when certain parameters diverge significantly from the source model, while the latter uses a pre-defined distance function to measure the change of all parameters.

For selective penalty, we use elastic weight consolidation (EWC: [7]), which weighs the importance of a parameter based on its gradient when training the source model. For non-selective penalty, we explore L2 [26] and cosine distance. We detail the methods below.

Given a source and target domain, we pre-train the model first on the source domain and fine-tune it further on the target domain. We denote the optimised parameters of the source model as and that of the target model as . For vanilla fine-tuning (finetune

), the loss function is:

where is the cross-entropy loss.

For non-selective penalty, we measure the change of parameters based on a distance function (treating all parameters as equally important), and add it as a loss term in addition to the cross-entropy loss. One distance function we test is the L2 distance:

where is a scaling hyper-parameter to weigh the contribution of the penalty. Henceforth all scaling hyper-parameters are denoted using .

We also experiment with cosine distance, based on the idea that we want to encourage the parameters to be in the same direction after fine-tuning. In this case, we group parameters by the variables they are defined in, and measure the cosine distance between variables:

where denotes the vector of parameters belonging to variable .

For selective penalty, EWC uses the Fisher matrix to measure the importance of parameter in the source domain. Unlike non-selective penalty where all parameters are considered equally important, EWC provides a mechanism to weigh the update of individual parameters:

where is the gradient of parameter update in the source domain, with representing the model and / the data/label from the source domain.

(a) EWC
(b) Normalised EWC
Figure 1: Mean Fisher Matrix values for variables in QANet on squad.

In preliminary experiments, we notice that EWC tends to assign most of the weights to a small subset of parameters. We present Figure 0(a), a plot of mean Fisher values for all variables in QANet after it was trained on squad, the source domain. We see that only the last two variables have some significant weights (and a tiny amount for the rest of the variables). We therefore propose a new variation of EWC, normalised EWC, by normalising the weights within each variable via min-max normalisation, which brings up the weights for parameters in other variables (Figure 0(b)):

where denotes the set of parameters for variable where parameter belongs.

Among the four auxiliary penalty terms, L2 and EWC are proposed in previous work while cosine distance and normalised EWC are novel penalty terms. Observing that EWC and normalised EWC are essentially weighted distances and L2 is based on distance while cosine distance focuses on the angle between variables (and ignores the magnitude), we propose combining them altogether as these different distance metrics may complement each other in regularising the fine-tuning process:

max width=0.7 Model Partition Domain scratch finetune +ewc +ewcn +cd +l2 +all gem
QANet
squad ms-bm 62.92 63.35 63.93 63.49 64.93 65.54 63.22
ms-cp 39.13 41.62 43.43 41.19 41.61 51.84 43.49 ms-fm 56.32 58.23 58.46 57.01 58.48 60.79 57.53 ms-fn 65.08 65.45 67.03 65.36 66.27 68.14 66.53 ms-lw 68.29 68.64 68.63 68.75 68.38 69.39 69.04 ms-ms 69.60 69.96 70.11 69.72 69.74 71.13 70.63 bioasq 59.85 62.87 63.57 62.83 62.50 66.11 62.52 3-11 Avg. 60.17 61.45 62.17 61.19 61.70 64.71 61.85
Test ms-bm 62.75 68.45 67.96 67.85 67.80 68.05 67.33 68.31
ms-cp 60.67 68.86 69.26 69.86 70.27 69.42 70.42 69.17 ms-fm 59.57 73.84 72.70 74.13 73.94 73.50 73.47 72.00 ms-fn 63.62 70.96 70.70 70.60 70.49 70.15 70.27 69.18 ms-lw 61.66 71.29 71.27 71.39 71.25 71.28 71.41 71.49 ms-ms 58.36 69.58 69.94 69.89 69.62 69.92 70.67 71.47 bioasq 29.83 65.81 67.17 65.93 67.26 65.57 66.82 66.42 3-11 Avg. 56.64 69.83 69.86 69.95 70.09 69.70 70.06 69.72

decaNLP
squad ms-bm 62.99 63.00 63.26 63.27 62.43 63.82 64.95
ms-cp 56.48 58.19 59.44 61.96 60.73 62.61 63.37 ms-fm 58.69 59.21 59.18 62.66 58.32 64.04 63.36 ms-fn 58.21 61.63 63.43 59.25 58.80 66.55 62.47 ms-lw 57.86 58.14 59.73 58.17 56.89 60.75 61.76 ms-ms 59.75 64.92 62.01 62.00 60.06 63.62 63.89 bioasq 67.42 67.19 67.21 67.44 67.46 67.49 68.94 3-11 Avg. 60.20 61.75 62.04 62.11 60.67 64.13 64.11
Test ms-bm 62.01 66.90 67.39 67.52 67.61 67.19 67.41 67.02
ms-cp 63.7 66.67 67.11 68.15 66.37 67.82 67.55 67.90 ms-fm 63.28 70.45 70.47 70.83 69.08 70.36 68.04 69.73 ms-fn 64.41 64.59 64.57 64.35 64.32 64.87 64.32 64.88 ms-lw 66.36 73.43 73.28 73.34 73.42 74.13 73.04 72.89 ms-ms 64.65 68.67 67.12 67.93 67.34 69.40 66.51 68.28 bioasq 43.25 63.80 63.89 63.89 63.96 63.96 64.70 66.36 3-11 Avg. 61.09 67.79 67.69 68.00 67.44 68.25 67.37 68.15

BERT
squad ms-bm 72.55 74.24 76.51 72.36 74.14 77.32 74.14
ms-cp 68.41 69.63 75.65 76.92 75.98 77.86 73.37 ms-fm 73.82 75.175 79.75 75.28 74.71 81.42 76.89 ms-fn 72.59 74.27 75.52 73.22 74.84 78.18 76.16 ms-lw 71.93 81.11 81.05 78.77 77.97 83.11 75.90 ms-ms 72.59 78.06 83.56 75.67 74.29 83.54 76.99 bioasq 75.04 85.28 85.62 85.76 84.23 86.88 75.89 3-11 Avg. 72.42 76.82 79.67 76.85 76.59 81.19 75.62
Test ms-bm 66.83 68.30 68.20 68.00 68.04 68.24 67.87 68.02
ms-cp 65.99 70.57 71.21 71.41 69.33 69.57 69.49 70.40 ms-fm 72.59 74.73 74.75 74.36 73.73 74.85 75.78 74.63 ms-fn 66.70 69.13 70.42 70.60 69.07 70.05 69.15 69.54 ms-lw 67.38 69.99 70.73 71.59 70.57 70.91 68.59 68.87 ms-ms 70.45 73.56 73.19 73.07 72.97 73.43 72.50 72.73 bioasq 54.09 71.62 75.84 78.50 79.47 78.86 76.93 68.87 3-11 Avg. 66.29 71.13 72.05 72.50 71.88 72.27 71.47 70.44

Table 2: RC results over all domains. Pre-trained QANet/decaNLP/BERT performance on squad 80.47/75.50/87.62. Boldface indicates optimal performance for squad and Underline indicates best performance for target domains.
(a) squad
(b) ms-bm
(c) ms-cp
Figure 2: decaNLP’s F1 performance during continuous learning.

Experiments

We test 3 comprehension models: QANet, decaNLP and BERT. To pre-process the data, we use the the models’ original tokenisation methods.666That is, we use spaCy (https://spacy.io/), revtok (https://github.com/jekbradbury/revtok), and WordPiece for QANet, decaNLP and BERT, respectively. For BERT, we use the smaller pre-trained model with 110M parameters (BERT-Base).

Fine-Tuning with Auxiliary Penalty

We first pre-train QANet and decaNLP on squad, tuning their hyper-parameters based on its development partition.777We tune for dropout, batch size, learning rate and number of training iterations, and keep other hyper-parameters in their default configuration. For BERT, we fine-tune the released pre-trained model on squad by adding 2 additional output layers to predict the start/end indices (we made no changes to the hyper-parameters). We initialise word vectors of QANet and decaNLP with pre-trained GloVe embeddings [16] and keep them fixed during training. We also freeze the input embeddings for BERT.888The input embeddings of BERT is a sum of token, segment and position embeddings; we freeze only the token embeddings.

To measure performance, we use the standard macro-averaged F1 as the evaluation metric, which measures the average overlap of word tokens between prediction and ground truth answer.

999If there are multiple ground truths, the maximum F1 is taken. Our pre-trained QANet, decaNLP and BERT achieve an F1 score of 80.47, 75.50 and 87.62 respectively on the development partition of squad. Note that the test partition of squad is not released publicly, and so all reported squad performance in the paper is on the development set.

Given the pre-trained squad models, we fine-tune them on the msmarco and bioasq domains. We test vanilla fine-tuning (finetune) and 5 variants of fine-tuning with auxiliary penalty terms: (1) EWC (+ewc); normalised EWC (+ewcn); cosine distance (+cd); L2 (+l2); and combined normalised EWC, cosine distance and L2 (+all). As a benchmark, we also perform fine-tuning with gradient episodic memory (gem), noting that this approach uses the first examples from squad ( in our experiments).

To find the best hyper-parameter configuration, we tune it based on the development partition for each target domain. For a given domain, finetune and its variants (+ewc, +ewcn, +cd, +l2 and +all) all share the same hyper-parameter configuration.101010The only exception are the scaling hyper-parameters (, , and ), where we tune them separately for each model. Detailed hyper-parameter settings are given in the supplementary material.

As a baseline, we train QANet, decaNLP and BERT from scratch (scratch) using the target domain data. As before, we tune their hyper-parameters based on development performance. We present the full results in Table 2.

For each target domain, we display two F1 scores: the source squad development performance (“squad”); and the target domain’s test performance (“Test”). We first compare the performance between scratch and finetune. Across all domains for QANet, decaNLP and BERT, finetune substantially improves the target domain’s performance compared to scratch. The largest improvement is seen in bioasq for QANet, where its F1 improves two-fold (from 29.83 to 65.81). Among the three RC models, BERT has the best performance for both scratch and finetune in most target domains (with a few exceptions such as ms-fn and ms-lw). Between QANet and decaNLP, we see that decaNLP tends to have better scratch performance but the pattern is reversed in finetune, where QANet produces higher F1 than decaNLP in all domains except for ms-lw.

In terms of squad performance, we see that finetune degrades it considerably compared to its pre-trained performance. The average drop across all domains compared to their pre-trained performance is 20.30, 15.30 and 15.07 points for QANet, decaNLP and BERT, respectively. For most domains, F1 scores drop by 10-20 points, while for ms-cp the performance is much worse for QANet, with a drop of 41.34. Interestingly, we see BERT suffers from catastrophic forgetting just as much as the other models, even though it is a larger model with orders of magnitude more parameters.

We now turn to the fine-tuning results with auxiliary penalties (+ewc, +ewcn, +cd and +l2). Between +ewc and +ewcn, the normalised versions consistently produces better recovery for the source domain (one exception is ms-ms for decaNLP), demonstrating that normalisation helps. Between +ewcn, +cd and +l2, performance among the three models vary depending on the domain and there’s no clear winner. Combining all of these losses (+all) however, produces the best squad performance for all models across most domains. The average recovery (+all- finetune) of squad performance is 4.54, 3.93 and 8.77 F1 points for QANet, decaNLP and BERT respectively, implying that BERT benefits from these auxiliary penalties more than decaNLP and QANet.

When compared to gem, +all preserves squad performance substantially better, on average 2.86 points more for QANet and 5.57 points more BERT. For decaNLP, the improvement is minute (0.02); generally gem has the upper hand for most domains but the advantage is cancelled out by its poor performance in one domain (ms-fn). As gem requires storing training data from the source domain (squad training examples in this case), the auxiliary penalty techniques are more favourable for real world applications.

Does adding these penalty terms harm target performance? Looking at the “Test” performance between finetune and +all, we see that they are generally comparable. We found that the average performance difference (+all-finetune) is 0.23, 0.42 and 0.34 for QANet, decaNLP and BERT respectively, implying that it does not (in fact, it has a small positive net impact for QANet and BERT). In some cases it improves target performance substantially, e.g. in bioasq for BERT, the target performance is improved from 71.62 to 76.93, when +all is applied.

Based on these observations, we see benefits for incorporating these penalties when adapting comprehension models, as it produces a more robust model that preserves its source performance (to a certain extent) without trading off its target performance. In some cases, it can even improve the target performance.

Continuous Learning

In previous experiments, we fine-tune a pre-trained model to each domain independently. With continuous learning, we seek to investigate the performance of finetune and its four variants (+l2, +cd, +ewcn and +all) when they are applied to a series of fine-tuning on multiple domains. For the remainder of experiments in the paper, we test only with decaNLP.

When computing the penalties, we consider the last trained model as the source model.111111The implication is that we have to re-compute the Fisher matrix for the last domain before we fine-tune the model on a new domain. Figure 2 demonstrates the performance of the models on the development set of squad and test sets of ms-bm and ms-cp when they are adapted to ms-bm, ms-cp, ms-fn, ms-ms, ms-fm and ms-lw in sequence.121212In terms of hyper-parameters, we choose a configuration that is generally good for most domains. We exclude plots for the latter domains as they are similar to that of ms-cp.

Including the pre-training on squad, all models are trained for a total of 170K iterations: squad from 0–44K, ms-bm from 45K–65K, ms-cp from 66K–86K, ms-fn from 87K–107K, ms-ms from 108K–128K, ms-fm from 129K–149K and ms-lw from 150K–170K.

We first look at the recovery for squad in Figure 1(a). +all (black line; legend in Figure 1(c)) trails well above all other models after a series of fine-tuning, followed by +ewcn and +cd, while finetune produces the most forgetting. At the end of the continuous learning, +all recovers more than 5 F1 points compared to finetune. We see a similar trend for ms-bm (Figure 1(b)), although the difference is less pronounced. The largest gap between finetune and +all occurs when we fine-tune for ms-fm (iteration 129K–149K). Note that we are not trading off target performance when we first tune for ms-bm (iteration 45K–65K), where finetune and +all produces comparable F1.

For ms-cp (Figure 1(c)), we first notice that there is considerably less forgetting overall (ms-cp performance ranges from 65–75 F1, while squad performance in Figure 1(a) ranges from 45–75 F1). This is perhaps unsurprising, as the model is already generally well-tuned (e.g. it takes less iterations to reach optimal performance for ms-cp compared to ms-bm and squad). Most models perform similarly here. +all produces stronger recovery when fine-tuning on ms-fm (129K–149K) and ms-lw (150K–170K). At the end of the continuous learning, the gap between all models is around 2 F1 points.

Task Transfer

max width= Partition Task finetune +ewc +ewcn +cd +l2 +all squad SUM 8.60 11.65 12.48 11.28 9.34 14.00 SRL 50.51 51.30 56.99 55.40 51.56 57.64 SP 6.95 9.69 10.20 10.61 19.39 28.36 MT 3.55 4.03 4.29 3.48 3.15 3.59 SA 1.74 2.69 2.38 3.63 2.51 6.43

Test
SUM 20.06 19.79 19.99 20.01 20.38 20.12
SRL 71.69 71.80 71.74 72.12 71.90 72.56 SP 92.52 92.77 92.70 92.62 92.59 91.11 MT 24.99 25.10 25.04 25.00 24.90 24.90 SA 84.79 86.38 84.84 85.06 86.27 85.89

Table 3: decaNLP’s squad and target performance for several tasks.
Figure 3:

Averaged gradient cosine similarities on

ms-fn.

In decaNLP, curriculum learning was used to train models for different NLP tasks. More specifically, decaNLP was first pre-trained on squad and then fine-tuned on 10 tasks (including squad) jointly. During the training process, each minibatch consists of examples from a particular task, and they are sampled in an alternating fashion among different tasks.

In situations where we do not have access to training data from previous tasks, catastrophic forgetting occurs when we adapt the model for a new task. In this section, we test our methods for task transfer (as opposed to domain transfer in previous sections). To this end, we experiment with decaNLP and monitor its squad

performance when we fine-tune it for other tasks, including semantic role labelling (SRL), summarisation (SUM), semantic parsing (SP), machine translation (MT), and sentiment analysis (SA). Note that we are not doing joint or continuous learning here: we are taking the pre-trained model (on

squad) and adapting it to the new tasks independently. Description of these tasks are detailed in [13].

A core novelty of decaNLP is that its design allows it to generate words by extracting them from the question, context or its vocabulary, and this decision is made by the pointer-generator network. Based on the pointer-generator analysis in [13], we know that the pointer-generator network favours generating words using: (1) context for SRL, SUM, and SP; (2) question for SA; and (3) vocabulary for MT.

As before, finetune serves as our baseline, and we have 5 variants with auxiliary penalty terms. Table 3 displays the F1 performance on squad and the target task; the table shares the same format as Table 2.

In terms of target task performance (“Test”), we see similar performances for all models. This is a similar observation we saw in previously, and it shows that the incorporation of the auxiliary penalty terms does not harm target task or domain performance.

For the source task squad, +all produces substantial recovery for SUM, SRL, SP and SA, but not for MT. We hypothesise that this is due to the difference in nature between the target task and the source task: i.e. for SUM, SRL and SP, the output is generated by selecting words from context, which is similar to squad; MT, on the other hand, generate using words from the vocabulary and question, and so it is likely to be difficult to find an optimal model that performs well for both tasks.

Discussion

Observing that the model tends to focus on optimising for the target domain/task in early iterations (as the penalty term has a very small value), we explore using a dynamic scale that starts at a larger value that decays over time. With just simple linear decay, we found substantial improvement in +ewc for recovering squad’s performance, although the results are mixed for other penalties (particularly for +ewcn). We therefore only report results that are based on static values in this paper. With that said, we contend that this might be an interesting avenue for further research, e.g. by exploring more complex decay functions.

To validate the assumption made by gem [10], we conduct gradient analysis for the auxiliary penalty terms. During fine-tuning, at each step , we calculate the gradient cosine similarity , where , , is a memory containing squad examples, and / is training data/label from the current domain. We smooth the scores by averaging over every 1K steps, resulting in 20 cosine similarity values for 20K steps. Figure 3 plots the gradient cosine similarity for our models in ms-fn.

Curiously, our best performing model +all produces the lowest cosine similarity at most steps (the only exception is between 0-1K steps). finetune, on the other hand, maintains relatively high similarity throughout. Similar trends are found for other domains. These observations imply that the inspiration gem draw on — i.e. catastrophic forgetting can be reduced by constraining a positive dot product between and — is perhaps not as empirically effective as intuition might tell us, and that our auxiliary penalty methods represent an alternative (and very different) direction to preserving source performance.

Conclusion

To reduce catastrophic forgetting when adapting comprehension models, we explore several auxiliary penalty terms to regularise the fine-tuning process. We experiment with selective and non-selective penalties, and found that a combination of them consistently produces the best recovery for the source domain without harming its performance in the target domain. We also found similar observations when we apply our approach for adaptation to other tasks, demonstrating its general applicability. To test our approach, we develop and release six narrow domain reading comprehension data sets for the research community.

References

  • [1] D. Blei, A. Ng, and M. Jordan (2003) Latent Dirichlet allocation.

    Journal of Machine Learning Research

    3, pp. 993–1022.
    Cited by: Data Set.
  • [2] H. Daume III (2007) Frustratingly easy domain adaptation. In Proceedings of the 45th Annual Meeting of the Association of Computational Linguistics, pp. 256–263. Cited by: Related Work.
  • [3] J. Devlin, M. Chang, K. Lee, and K. Toutanova (2018) Bert: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805. Cited by: Introduction, Methodology.
  • [4] R. M. French (1999) Catastrophic forgetting in connectionist networks. Trends in cognitive sciences 3 (4), pp. 128–135. Cited by: Introduction, Related Work.
  • [5] M. Joshi, E. Choi, D. Weld, and L. Zettlemoyer (2017) TriviaQA: a large scale distantly supervised challenge dataset for reading comprehension. In Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pp. 1601–1611. Cited by: Introduction, Related Work.
  • [6] Y. Kim, K. Stratos, and R. Sarikaya (2016) Frustratingly easy neural domain adaptation. In Proceedings of COLING 2016, the 26th International Conference on Computational Linguistics: Technical Papers, pp. 387–396. Cited by: Related Work.
  • [7] J. Kirkpatrick, R. Pascanu, N. Rabinowitz, J. Veness, G. Desjardins, A. A. Rusu, K. Milan, J. Quan, T. Ramalho, A. Grabska-Barwinska, D. Hassabis, C. Clopath, D. Kumaran, and R. Hadsell (2017)

    Overcoming catastrophic forgetting in neural networks

    .
    Proceedings of the National Academy of Sciences 114, pp. 3521–3526. Cited by: Introduction, Related Work, Methodology.
  • [8] T. Kocisky, J. Schwarz, P. Blunsom, C. Dyer, K. M. Hermann, G. Melis, and E. Grefenstette (2018) The narrativeqa reading comprehension challenge. Transactions of the Association for Computational Linguistics 6, pp. 317–328. Cited by: Introduction, Related Work.
  • [9] T. Kwiatkowski, J. Palomaki, O. Redfield, M. Collins, A. Parikh, C. Alberti, D. Epstein, I. Polosukhin, M. Kelcey, J. Devlin, K. Lee, K. N. Toutanova, L. Jones, M. Chang, A. Dai, J. Uszkoreit, Q. Le, and S. Petrov (2019) Natural questions: a benchmark for question answering research. Transactions of the Association of Computational Linguistics. Cited by: footnote 1.
  • [10] D. Lopez-Paz et al. (2017) Gradient episodic memory for continual learning. In Advances in Neural Information Processing Systems, pp. 6467–6476. Cited by: Related Work, Discussion.
  • [11] C. D. Manning, M. Surdeanu, J. Bauer, J. Finkel, S. J. Bethard, and D. McClosky (2014)

    The Stanford CoreNLP natural language processing toolkit

    .
    In Association for Computational Linguistics (ACL) System Demonstrations, pp. 55–60. Cited by: footnote 2.
  • [12] A. K. McCallum (2002) MALLET: a machine learning for language toolkit. Note: http://mallet.cs.umass.edu Cited by: footnote 2.
  • [13] B. McCann, N. S. Keskar, C. Xiong, and R. Socher (2018) The natural language decathlon: multitask learning as question answering. CoRR abs/1806.08730. Cited by: Introduction, Related Work, Methodology, Task Transfer, Task Transfer.
  • [14] M. McCloskey and N. J. Cohen (1989) Catastrophic interference in connectionist networks: the sequential learning problem. In Psychology of learning and motivation, Vol. 24, pp. 109–165. Cited by: Related Work.
  • [15] T. Nguyen, M. Rosenberg, X. Song, J. Gao, S. Tiwary, R. Majumder, and L. Deng (2016) MS MARCO: A human generated machine reading comprehension dataset. CoRR abs/1611.09268. Cited by: Introduction, Introduction, Related Work, Data Set.
  • [16] J. Pennington, R. Socher, and C. D. Manning (2014) GloVe: global vectors for word representation. In Empirical Methods in Natural Language Processing (EMNLP), pp. 1532–1543. Cited by: Fine-Tuning with Auxiliary Penalty.
  • [17] P. Rajpurkar, R. Jia, and P. Liang (2018) Know what you don’t know: unanswerable questions for squad. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers), pp. 784–789. Cited by: Introduction.
  • [18] P. Rajpurkar, J. Zhang, K. Lopyrev, and P. Liang (2016) SQuAD: 100,000+ questions for machine comprehension of text. In Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing, Austin, Texas, pp. 2383–2392. Cited by: Introduction, Related Work, Data Set.
  • [19] M. Riemer, I. Cases, R. Ajemian, M. Liu, I. Rish, Y. Tu, and G. Tesauro (2018) Learning to learn without forgetting by maximizing transfer and minimizing interference. arXiv preprint arXiv:1810.11910. Cited by: Introduction, Related Work.
  • [20] M. Riemer, E. Khabiri, and R. Goodwin (2017) Representation stability as a regularizer for improved text analytics transfer learning. arXiv preprint arXiv:1704.03617. Cited by: Related Work.
  • [21] M. Riemer, T. Klinger, M. Franceschini, and D. Bouneffouf (2017) Scalable recollections for continual lifelong learning. arXiv preprint arXiv:1711.06761. Cited by: Related Work.
  • [22] J. Serra, D. Surís, M. Miron, and A. Karatzoglou (2018) Overcoming catastrophic forgetting with hard attention to the task. arXiv preprint arXiv:1801.01423. Cited by: Related Work.
  • [23] G. Tsatsaronis, G. Balikas, P. Malakasiotis, I. Partalas, M. Zschunke, M. R. Alvers, D. Weissenborn, A. Krithara, S. Petridis, D. Polychronopoulos, et al. (2015) An overview of the bioasq large-scale biomedical semantic indexing and question answering competition. BMC bioinformatics 16 (1), pp. 138. Cited by: Related Work.
  • [24] G. Tsatsaronis, M. Schroeder, G. Paliouras, Y. Almirantis, I. Androutsopoulos, E. Gaussier, P. Gallinari, T. Artieres, M. R. Alvers, M. Zschunke, et al. (2012) Bioasq: a challenge on large-scale biomedical semantic indexing and question answering. In 2012 AAAI Fall Symposium Series, Cited by: Data Set.
  • [25] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin (2017) Attention is all you need. In Advances in Neural Information Processing Systems 30, pp. 5998–6008. Cited by: Methodology.
  • [26] G. Wiese, D. Weissenborn, and M. Neves (2017) Neural domain adaptation for biomedical question answering. In Proceedings of the 21st Conference on Computational Natural Language Learning (CoNLL 2017), Vancouver, Canada, pp. 281–289. Cited by: Introduction, Related Work, Methodology.
  • [27] Y. Yang, W. Yih, and C. Meek (2015) WikiQA: a challenge dataset for open-domain question answering. In Proceedings of the 2015 Conference on Empirical Methods in Natural Language Processing, pp. 2013–2018. Cited by: Introduction.
  • [28] A. W. Yu, D. Dohan, M. Luong, R. Zhao, K. Chen, M. Norouzi, and Q. V. Le (2018) QANet: combining local convolution with global self-attention for reading comprehension. CoRR abs/1804.09541. Cited by: Introduction, Methodology.