Transformer (Vaswani et al., 2017) based language representation has replaced many previous pre-training or initialization approaches (Devlin et al., 2018; Radford et al., 2019; Yang et al., 2019; Liu et al., 2019). Fine tuning using these architectures often yields state-of-the-art results on the order of a few hours. The caveat to these robust models is that the initial training can be on the scale of several weeks and on many distributed GPUs which is a costly endeavour.
Pre-trained language models are further adapted to perform strongly in other domains as well. For example, while the original BERT models (Devlin et al., 2018) were trained on English Wikipedia articles and BooksCorpus (Zhu et al., 2015), the same masked language modeling was continued on biomedical data. BioBERT (Lee et al., 2019) was trained using Pubmed abstracts and full articles and Clinical BERT (Alsentzer et al., 2019) parameters were further refined using MIMIC-III clinical notes (Johnson et al., 2016). Evidence suggest that understanding the syntactic structure of scientific literature and clinical data from pre-training boosts performance in their respective downstream tasks Peng et al. (2019). Training is done with the expectation of building robust, high capacity generalized language models can continue to absorb new domain knowledge.
Catastrophic forgetting (McCloskey and Cohen, 1989; Ratcliff, 1990) is the unfortunate side-effect of incorporating new domain data one after another. Parameters shift towards capturing the current task and if previous data is no longer available, the model will lose representation of it. In general perplexity increases for older domains, and models lose confidence in continual learning settings Yogatama et al. (2019). For many tasks the straightforward solution is to combine datasets during training and approach this as a multi-task learning (MTL) (Ruder, 2017) problem. Mixing data has the desired effect of constraining parameters to find a space where both tasks reach close to optimal performance.
We further argue that these expensive pre-trained models are an example where MTL is not feasible in practice for several reasons. Time and hardware accessibility are the largest constraints for developing these models. Access to processed training data is generally not possible (Radford et al., 2019; Devlin et al., 2018), and exact training configurations are equally difficult to gather with results being arduous to reproduce. Resource usage has recently been criticized from another perspective as well. Strubell et al. (2019) show that as deep neural architectures in the natural language community grow we increasingly trade results for carbon emissions.
Current work in catastrophic forgetting mitigation has been limited to a few small scale tested methods.
Howard and Ruder (2018) introduced a multi stage training scheme for fine tuning LSTM based universal language models (ULMFiT).
The authors proposed that current methods, rather than data, are ineffective and focused on learning rate control across layers, as well as modifying learning rate scheduling.
A larger category of work deals with constraining model parameters to a latent space where they continue to capture previous tasks.
Initial work focused on model regularization and varying activations Goodfellow et al. (2013). Kirkpatrick et al. (2017) provided a more sophisticated solution constraining weights individually termed elastic weight consolidation (EWC).
We make use of both EWC and ULMFiT and provide further technical detail in this paper.
The final approach is focused on experience replay.
Using small samples of data from previous tasks coupled with local adaptation d’Autume et al. (2019) demonstrate improvement in a Lifelong Learning (LL) training scheme.
Chaudhry et al. (2019) also explore LL by experimenting with updating the memory bank for experience replay.
Our work focuses on both of these techniques with the major difference being problem scale.
Many existing works apply these solutions on small networks whereas we experiment on architectures having several orders of magnitude more parameters.
Our contributions are as follows:
We motivate the task by providing concrete evidence of catastrophic forgetting for language representation pre-training evaluated on the GLUE benchmark Wang et al. (2018).
We provide empirical evidence of catastrophic forgetting mitigation with experience replay, learning rate control, and elastic weight consolidation, applied towards large scale language model pre-training.
We further demonstrate the robustness of elastic weight consolidation when pre-training under two stages of domain shift.
2 Continual Learning
Our work focuses on three forms of mitigation from catastrophic forgetting. We explore using constraint based training in the form of EWC, learning rate control from Howard and Ruder (2018), and episodic memory in the form of experience replay.
2.1 Elastic Weight Consolidation
EWC makes use of a simple Bayesian factorization of model representation Kirkpatrick et al. (2017). This isolates the posterior of a learned task (A) while maintaining the objective of a current task (B). Due to the intractability of the true posterior, EWC makes use of a Fisher information (Frieden, 2004)
matrix diagonal to approximate the effect of task A on the parameters of a model. Intuitively speaking, if a parameter had a large effect on task A the Fisher value would be small yielding low variance to adapt to task B. This holds true inversely for when the Fisher value is large.
In practice, we initialize the Fisher matrix using gradients calculated with data sampled from Task A, which has already converged. This is demonstrated in Eq. 1 where and index parameters and data samples respectively.
The full objective for task B is given in Eq. 2 where is the objective of Task B, and EWC is represented as the second term regularizing model parameters by weighting the shift of model parameters as it trains on task B ( and being the currently updated and frozen task A parameters at index
respectively). The EWC objective component is further adjusted by the hyperparameter.
2.2 Learning rate control
Our approach models the second stage of ULMFiT Howard and Ruder (2018), namely target task fine-tuning. We begin with a layer wise modifications by applying a decaying learning rate as a function of layer depth moving from the last layer towards model input, where ( and denoting learning rate and layer index respectively). Depth plays a factor in our model since the network consists of 14 layers (i.e. 12 transformer layers, one layer for input, and one for LM heads). Additionally, we switch from the polynomial decay learning rate scheduler to slanted triangular learning rate (STLR).
2.3 Experience Replay
We explore experience replay in a very simple fashion. At a chosen interval we replay a buffer of batches retained from the domain(s) of the previous task. We explore the frequency of replay as well as the size of the replay data.
We processed publicly available biomedical and non-biomedical corpora for pre-training our models. For non-biomedical data, we use BookCorpus and English Wikipedia data, CommonCrawl Stories Trinh and Le (2018), and OpenWebText Gokaslan and Cohen (2019). This combined corpus contains roughly 18B tokens. For biomedical data, we use full Pubmed111https://www.ncbi.nlm.nih.gov/pmc/ articles which we processed to remove all tables, references, equations, and figures. This yields a dataset of over 4B tokens. For all datasets we retain training, validation, and test splits sampled at the document level with a respective ratio of 8:1:1.
4 Experimental Details
For modeling we use the RoBERTa architecture Liu et al. (2019), and implement EWC, learning rate control, and experience replay changes directly into the model. This extension of the original BERT removed next sentence prediction and trained using only masked language modeling using very large batch sizes. We utilize all training hyperparameters as provided by Liu et al. (2019) unless otherwise noted, and use RoBERTa BASE as parameter initialization for all experiments.
As a form of deterioration understanding, we continue to train a model using Pubmed articles (denoted as PMC) with no mitigation techniques. For a baseline and potential upper bound of performance we train a multi-domain learning (MDL) model which utilizes the full combined training sets as input data. The learning rate control model (+LRC) uses the hyperparameters provided by Howard and Ruder (2018) and learning rate layerwise decay as outlined in section 2.2.
For EWC (+EWC) we tune both [, , , ], and the size of the data used for fisher initialization [, , ]; best values are underlined. For experience replay (+ER) we experiment with sampling update batches from the non-biomedical dataset (the subset used for EWC init.) at various intervals. Ten original domain updates at every 1k, 2k, and 5k training steps where each batch size is 2048; a single update of
of the original domain at the end of an epoch of training. Best performance was obtained using the latter.
4.1 Evaluation Data
To evaluate modeling we track the perplexity of held-out test data for both domains. We report the average accuracy across GLUE tasks to track the performance of the model on general natural language understanding. Additionally we evaluate on CoNLL-03 Sang and De Meulder (2003)named entity recognition (NER), and MATRES Ning et al. (2018), a temporal relation dataset. To demonstrate domain shift we evaluate using BC5CDR Li et al. (2016) and Chemprot Krallinger et al. (2017) which are NER and relation extraction (RE) tasks respectively. The former dataset is from the 2015 CDR challenge for identifying chemicals and diseases expertly annotated from Pubmed abstracts. Chemprot contains annotations of chemical-protein reactions, also taken from Pubmed articles.
Our experimental results are highlighted in Table 2. The first two rows contain the off-the-shelf RoBERTa model as well as that which received no mitigation when further trained on biomedical data. The bottom section lists all other experimental settings described in Section 4. For all models pre-trained using Pubmed data we finetune on tasks after a single epoch of pre-training.
We divide columns by task domain. The first three tasks cover general language understanding. For measuring performance on GLUE, we further limit the selection of tasks to be the five most deteriorated (i.e. CoLA Warstadt et al. (2018), SST-2 Socher et al. (2013), MNLI Williams et al. (2018), QNLI Rajpurkar et al. (2016) and RTE Giampiccolo et al. (2007)). Tasks such as QQP222https://www.quora.com/q/quoradata/First-Quora-Dataset-Release-Question-Pairs and MRPC Dolan and Brockett (2005) are generally robust against domain change and perform well regardless of initialization. Biomedical tasks are displayed next followed by model perplexity in both domains. Task scores, save for GLUE, are reported using micro-.
5.1 Catastrophic Forgetting
We complement our own findings with those from existing pre-trained models. To this end we fine tuned a BERTBASE architecture on all nine GLUE tasks. These were compared directly against BioBERT, which has been further trained on full Pubmed articles. Taking a look at Table 1 an overall trend of performance deterioration is apparent with a relative increased error of . BioBERT performed negligibly better than original BERT on only a single task (MRPC). Furthermore, we observed that on tasks which BERT struggles with, such as CoLA and WNLI, the performance decrease is amplified when switching pre-training domains.
Our own results are similarly divided. Unsurprisingly among these RoBERTa BASE performs best on GLUE, CoNLL and MATRES. Conversely it under-performs on the biomedical tasks, validating the need to further pre-train on domain specific data. Similarly we see that the PMC model performs the best in its domain however there is significant drop in performance across GLUE, CoNLL amd MATRES. The perplexity analysis further illustrates the degree of separation between tasks, with the biomedical model exhibiting a sharp change when leaving the generic domain.
5.2 Mitigation based models
EWC and LRC both respond well during domain shifts and are our best candidates for combating catastrophic forgetting. LRC has negligible degradation on GLUE tasks, and yields best overall numbers for MATRES, and Chemprot. Furthermore this model has the highest combined confidence when observing perplexity across domains. Our trials with EWC left us with several findings. While the amount of data used for Fisher initialization did not have a profound effect, the model was quite sensitive to values. With higher coefficients () EWC was able to halt deterioration nearly completely but performed quite poorly on biomedical tasks. To better understand the importance of fisher, we trained EWC with no Fisher (i.e removing from Eq. 2). We found that this resulted in lower biomedical results, which shows that giving equal weight to all the parameters results in poor generalization on source and target domains. MDL performed surprisingly average compared to the resource trade-off of the model. While it does produce competitive results in the biomedical domain, the model struggles to retain generic knowledge. Experience replay grapples most with domain retention and produced the highest mitigated biomedical results coupled with the lower generic results.
5.2.1 Two stage domain shift
To further evaluate the robustness of the best performing methods we add a third domain to the continual learning setup. We processed 659M tokens of de-identified clinical notes and continued training the EWC, and LRC from Table 2 (denoted with a subscript 2). Evaluating the clinical domain we use NER from the 2010 i2b2 challenge. Due to the relatively small amount of clinical data we pre-train for five epochs. We compare against the deterioration on an unmitigated model trained first on Pubmed, and then clinical data (PMC, clin.).
As expected the unmitigated model suffers from performance deterioration in both previous domains, with average GLUE dropping drastically. LRC worsens below baseline on BC5CDR and shows only a small boost in clinical results over RoBERTa BASE, although it continues to perform well on generic tasks. EWC gives the best performance across the board. The model exhibits slight decay on GLUE and CoNLL, robust performance on biomedical NER, and the best overall results on i2b2. This further indicates that the EWC objective has the capability to produce a model which generalizes better across multiple domains, outperforming unregulated methods.
In this work, we have demonstrated the existence of catastrophic forgetting in large language model pre-training. We further explored constraint and replay based mitigation techniques to close the performance gap between general and domain specific natural language tasks.
- Publicly available clinical bert embeddings. arXiv preprint arXiv:1904.03323. Cited by: §1.
- Continual learning with tiny episodic memories. arXiv preprint arXiv:1902.10486. Cited by: §1.
- Episodic memory in lifelong language learning. arXiv preprint arXiv:1906.01076. Cited by: §1.
- Bert: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805. Cited by: §1, §1, §1.
- Automatically constructing a corpus of sentential paraphrases. In Proceedings of the Third International Workshop on Paraphrasing (IWP2005), External Links: Cited by: §5.
- Science from fisher information: a unification. Cambridge University Press. Cited by: §2.1.
- The third pascal recognizing textual entailment challenge. In Proceedings of the ACL-PASCAL workshop on textual entailment and paraphrasing, pp. 1–9. Cited by: §5.
- OpenWebText corpus. Note: http://Skylion007.github.io/OpenWebTextCorpus Cited by: §3.
An empirical investigation of catastrophic forgetting in gradient-based neural networks. arXiv preprint arXiv:1312.6211. Cited by: §1.
- Universal language model fine-tuning for text classification. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pp. 328–339. Cited by: §1, §2.2, §2, §4.
- MIMIC-iii, a freely accessible critical care database. Scientific Data 3, pp. 160035. Cited by: §1.
- Overcoming catastrophic forgetting in neural networks. Proceedings of the national academy of sciences 114 (13), pp. 3521–3526. Cited by: §1, §2.1.
- Overview of the biocreative vi chemical-protein interaction track. In Proceedings of the sixth BioCreative challenge evaluation workshop, Vol. 1, pp. 141–146. Cited by: §4.1.
- Biobert: pre-trained biomedical language representation model for biomedical text mining. arXiv preprint arXiv:1901.08746. Cited by: §1.
- BioCreative v cdr task corpus: a resource for chemical disease relation extraction. Database 2016. Cited by: §4.1.
- RoBERTa: a robustly optimized bert pretraining approach. arXiv preprint arXiv:1907.11692. Cited by: §1, §4.
- Catastrophic interference in connectionist networks: the sequential learning problem. In Psychology of learning and motivation, Vol. 24, pp. 109–165. Cited by: §1.
- A multi-axis annotation scheme for event temporal relations. arXiv preprint arXiv:1804.07828. Cited by: §4.1.
- Transfer learning in biomedical natural language processing: an evaluation of bert and elmo on ten benchmarking datasets. In Proceedings of the 2019 Workshop on Biomedical Natural Language Processing (BioNLP 2019), Cited by: §1.
- Language models are unsupervised multitask learners. OpenAI Blog 1 (8). Cited by: §1, §1.
- Squad: 100,000+ questions for machine comprehension of text. arXiv preprint arXiv:1606.05250. Cited by: §5.
- Connectionist models of recognition memory: constraints imposed by learning and forgetting functions.. Psychological review 97 (2), pp. 285. Cited by: §1.
- An overview of multi-task learning in deep neural networks. arXiv preprint arXiv:1706.05098. Cited by: §1.
- Introduction to the conll-2003 shared task: language-independent named entity recognition. arXiv preprint cs/0306050. Cited by: §4.1.
- Recursive deep models for semantic compositionality over a sentiment treebank. In Proceedings of the 2013 conference on empirical methods in natural language processing, pp. 1631–1642. Cited by: §5.
Energy and policy considerations for deep learning in nlp. arXiv preprint arXiv:1906.02243. Cited by: §1.
- A simple method for commonsense reasoning. arXiv preprint arXiv:1806.02847. Cited by: §3.
- Attention is all you need. In Advances in Neural Information Processing Systems 30, I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett (Eds.), pp. 5998–6008. External Links: Cited by: §1.
- Glue: a multi-task benchmark and analysis platform for natural language understanding. arXiv preprint arXiv:1804.07461. Cited by: 1st item.
- Neural network acceptability judgments. arXiv preprint arXiv:1805.12471. Cited by: §5.
- A broad-coverage challenge corpus for sentence understanding through inference. In Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers), pp. 1112–1122. External Links: Cited by: §5.
- XLNet: generalized autoregressive pretraining for language understanding. arXiv preprint arXiv:1906.08237. Cited by: §1.
- Learning and evaluating general linguistic intelligence. arXiv preprint arXiv:1901.11373. Cited by: §1.
Aligning books and movies: towards story-like visual explanations by watching movies and reading books.
Proceedings of the IEEE international conference on computer vision, pp. 19–27. Cited by: §1.