mrqa
Code for EMNLP-IJCNLP 2019 MRQA Workshop Paper: "Domain-agnostic Question-Answering with Adversarial Training"(https://arxiv.org/abs/1910.09342)
view repo
Adapting models to new domain without finetuning is a challenging problem in deep learning. In this paper, we utilize an adversarial training framework for domain generalization in Question Answering (QA) task. Our model consists of a conventional QA model and a discriminator. The training is performed in the adversarial manner, where the two models constantly compete, so that QA model can learn domain-invariant features. We apply this approach in MRQA Shared Task 2019 and show better performance compared to the baseline model.
READ FULL TEXT VIEW PDFCode for EMNLP-IJCNLP 2019 MRQA Workshop Paper: "Domain-agnostic Question-Answering with Adversarial Training"(https://arxiv.org/abs/1910.09342)
Followed by the success of deep learning in various tasks, it becomes important to build a single model covering various domains without further fine-tuning to out-of-domain distribution. Because for real world application, a model is required to generalize to unseen sources of data.
In case of Question Answering (QA) task which is one of the promising areas in NLP, however, models outperforming human on SQuAD rajpurkar2016squad cannot generalize well to other datasets. Models rather overfit to a specific dataset and require additional training on other dataset to adapt to new domain yogatama2019learning.
Thus, in order to build a domain-agnostic QA model which is capable of handling out-of-domain data, it is necessary for model to learn domain-invariant features rather than specific ones. In this paper, we apply adversarial training framework to train a QA model with domain-agnostic representation. As shown in Figure 1, the model is divided into two components, which are the QA model and the domain discriminator. The discriminator predicts domain label of hidden representation from QA model. During the training, the QA model tries to fool the discriminator so that the hidden representation becomes indistinguishable to the discriminator. Meanwhile the discriminator is trained to identify the domain label correctly. As a result, QA model can learn domain-invariant features. Our framework can be applied to any existing QA model because the architecture of QA model stays unchanged.
We train and validate our method on 12 datasets (6 datasets for training and 6 datasets for validation) which are provided by MRQA Shared Task. Each training dataset is considered different domain for adversarial learning in which QA model learns domain-invariant feature representation by competing with discriminator. Our experimental result shows that the proposed method improves performance compared to baseline.
Pre-trained Language Model Recently, there have been several applications for using pre-trained language models, such as ELMo peters2018deep, GPT radford2018improving, or BERT devlin2018bert to transfer the knowledge from pre-training to various downstream NLP tasks.
BERT is pretrained with bidirectional encoder vaswani2017attention on large corpora. Unlike other auto-regressive language models (unidirectional or concatenation of forward and backward language model), BERT randomly masks some input tokens and predicts the masked tokens based on its context. The masked language model enables bidirectional representation, which leads to significant improvements on a number of NLP tasks, such as sentence classification, POS tagging or question answering.
Domain Generalization Even though many deep learning models surpass human-level performance on various task, they perform poorly on out-of-domain dataset. To address this problem, domain adaptation and domain generalization are proposed, making models more robust to out-of-domain data. The difference between domain adaptation and domain generalization is that for domain generalization, data from the target domain is not available during training.
Several methods for domain generalization exist. One of them is to train a model for each in-domain dataset. When testing on out-of-domain, select the most correlated in-domain dataset and use that model for inference xu2014exploiting. Other works such as ghifary2015domain; muandet2013domain , model is trained to learn a domain-invariant feature by using multi-view autoencoders and mean map embedding-based techniques.
Other approaches khosla2012undoing; li2017deeper break down parameters of a model into domain-specific and domain-agnostic components during training with in-domain dataset, and use the domain invariant parameters for predicting data from unseen target domain.
Recently, meta-learning has been proposed for domain generalization. Some methods li2018learning; balaji2018metareg; li2019feature leverage meta-learning framework for domain generalization.
, known as Generative Adversarial Network (GAN). GAN is also adopted in text generation
Adversarial Training
The idea of adversarial training is originally proposed in the field of image generation goodfellow2014generative
Adversarial training has been used for domain adaptation or domain generalization as well. In Domain-Adversarial Neural Network (DANN) , it has two classifiers: one classifies task-specific class labels, and the other classifies whether the data belong to source or target domain. Recently, One approach
Datasets | Samples | Avg.Q.len | Avg.P.len | Source |
BioASQ (BA) | 1,504 | 16.4 | 353.9 | Bio-medical literature |
DROP (DP) | 1,503 | 12.0 | 268.4 | Wiki + National Football League (NFL) game summaries and history articles |
DuoRC (DR) | 1,501 | 9.8 | 798.9 | Wiki + IMDb |
RACE (RA) | 674 | 12.4 | 381.0 | English exams for Chinese middle and high school |
RelationExtraction (RE) | 2,948 | 11.6 | 38.0 | Wiki (WikiReading dataset) |
TextbookQA (TQ) | 1,503 | 12.1 | 751.0 | 1k lessons and 26k multi-modal questions, from middle school science curriculum |
We assume that there exists domain invariant feature representation such that QA model generalize well to predict answer on unseen out-of-domain. In order to adapt to out-of-domain, adversarial learning procedure is leveraged for learning domain-invariant representation. We present our proposed method in detail in the following sections.
We formulate the task as follows: given the in-domain datasets , consisting of triplets of passage , question , and answer , where . The model learned from predicts answer from for each out-of-domain datasets .
Model | BA | DP | DR | RA | RE | TQ | Avg | |||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
EM | F1 | EM | F1 | EM | F1 | EM | F1 | EM | F1 | EM | F1 | EM | F1 | |
Bert-base | 46.44 | 60.81 | 28.31 | 37.88 | 42.78 | 53.32 | 28.23 | 39.51 | 73.33 | 83.89 | 44.30 | 52.03 | 43.90 | 54.57 |
Bert-base-adv | 43.35 | 60.04 | 30.51 | 40.01 | 45.97 | 57.89 | 26.50 | 39.73 | 72.67 | 83.53 | 45.62 | 55.67 | 44.10 | 56.15 |
Model | BP | CQ | MC | MR | ST | TR | Avg | |||||||
EM | F1 | EM | F1 | EM | F1 | EM | F1 | EM | F1 | EM | F1 | EM | F1 | |
Bert-base | 38.36 | 57.38 | 47.40 | 55.29 | 54.16 | 66.12 | 47.83 | 64.81 | 58.64 | 77.02 | 36.73 | 53.96 | 47.19 | 62.43 |
Bert-base-adv | 42.92 | 61.09 | 48.13 | 56.50 | 55.83 | 69.30 | 52.82 | 68.78 | 52.73 | 75.63 | 39.08 | 56.79 | 48.59 | 64.68 |
Our method can be applied to any QA models which learn representation in the joint embedding space of passage and question. In this paper, we use BERT for QA because it is pre-trained on a large corpus and known to be generalized on several different tasks. As for standard QA task, the model is trained to minimize negative log-likelihood of answer for all the given in-domain datasets, where , and are respectively the total number of in-domain data, the start position and the end position of answer in the passage.
(1) |
Minimizing the cross-entropy as in equation (1) does not ensure that the model will generalize on unseen domain. Rather it tends to overfit to certain datasets. Inspired by GAN goodfellow2014generative, we propose a simple yet effective method to regularize the model such that it learns domain-invariant features.
In the adversarial training procedure, QA model learns to make the discriminator to be uncertain about its prediction. On the other hand, the discriminator is trained to classify the joint embedding of question and passage from QA model into the given domains. If the QA model can project question and passage into an embedding space where the discriminator cannot tell the difference between embeddings from different domains, we assume the QA model learns domain-invariant feature representation.
We formulate the adversarial training as follows. A discriminator is trained to minimize the cross-entropy loss as of equation (2), where is domain category and is the hidden representation of both question and passage. In our experiment, we use [CLS] token representation from BERT for .
(2) |
For the QA model, it tries to maximize the entropy of . In other words, it minimizes Kullback-Leibler (KL) divergence between uniform distribution over
(3) |
We validate our adversarial model for MRQA Shared Task with 6 different out-of-domain datasets, which are BioASQ (BA) tsatsaronis2012bioasq, DROP (DP) dua2019drop, DuoRC (DR) saha2018duorc, RACE (RA) lai2017race, RelationExtraction (RE) levy2017zero, and TextbookQA (TQ) kembhavi2017you. Table 1 shows the statistics and description of these datasets. Each dataset has about 1k samples. However, the number of samples from each dataset varies. Thus, we use stratified sampling in order to make class-balanced stochastic mini-batch having certain amount of samples from all domains. We use maximum sequence length of 64 and 384 for question and passage respectively. But some examples are longer than 384. Therefore each passage is split into several chunks with a window size of 128. We discard samples without answers because all questions are considered to be answerable from given context in MRQA shared task.
Note that the final evaluation shown in the Table 2 is conducted by MRQA organizers with additional 6 out-of-domain undisclosed private test datasets, which are BioProcess (BP) scaria2013learning, ComplexWebQuestion (CQ) talmor2018web, MCTest (MC) richardson2013mctest, QAMR (MR) michael2017crowdsourcing, QAST (ST) jitkrittum-etal-2009-qast and TREC (TR) voorhees2001trec.
We implement our model based on the HuggingFace’s open-source BERT implementation in Pytorch . We follow the hyperparameters as BERT for our model. In detail, we use "bert-base-uncased" with a learning rate 3e-5 and a batch size of 64. Additionally, our model requires one more hyperparameter provided by MRQA. We select the best performing model on validation set, where models are trained for 1 or 2 epochs. The codes for our model are available at
Table 2 shows the performance evaluation results of models on out-of-domain datasets. In the table, the model trained with our adversarial learning is named with ’-adv’. The top of the table is the result of validation datasets while the bottom is the result of test datasets. As shown in the table, overall, the model with adversarial learning has better performance compared to the baseline in terms of both EM and F1 measures.
For validation datasets, the average F1 score of our model is about point higher than the baseline. In detail, our model outperforms the baseline in DP, DR, RC, and RA dataset by large margin. But the adversarial learning degrades performance in BA and RE. We can see the same aspect in terms of EM score. Similar to the result of validation datasets, our model shows better performance in terms of EM (Exact Match) and F1 on the most of test datasets except for ST. Overall, our model has superior performance with considerable margin of over point in F1.
In this section, we discuss some trials that have failed to improve the performance but might be helpful for future works.
QA sample consists of a question, a passage, and an answer span. There could exist multiple answer spans because more than one phrase in the passage can be matched with the answer text. For simplicity, only the first occurrence of answer text is used for training in most of the baseline codes. However, considering context and semantic of the given question and answer, a certain phrase in the passage is more likely to be plausible answer span relevant to the question. In order to find the most plausible answer span, a question and sentences in the passage are encoded into fixed-size vectors with universal sentence encoder . We choose the span in a sentence, which is the most similar to the question in terms of cosine similarity, as golden span. In our experiment, this approach boosts up the performance of some datasets but degrades the performance a lot in the other datasets.
We apply meta learning to domain generalization li2018learning; li2019feature; balaji2018metareg to simulate train/test domain shift. For every epoch, one dataset is randomly selected as virtual test domain. As described in maml, QA model is trained to maximize meta objective, which leads to improve the performance in train domain, but also in test domain. But this requires to compute Hessian-vector products, which slows down the training. This is even worse for BERT because there are 110M parameters to fine-tune. Moreover, contrary to the previous works, the meta learning for domain generalization does not help improve the performance.
We leverage adversarial learning to learn domain-invariant features. In our experiments, the proposed method consistently improves the performance of baseline and it is applicable to any QA model. In future work, we will try adversarial learning for pre-training model with diverse set of domains.
We would like to thank Seonghan Ryu, Donghyeon Lee and Nicolas Remond for their valuable feedback, as well as the anonymous reviewers for their insightful comments.