Recently, there has been an increasing number of studies suggesting the use of general language models, for improving natural language processing tasksDai and Le (2015); Peters et al. (2018); Radford et al. ; Howard and Ruder (2018). Among the most promising techniques, the unsupervised pretraining approach Dai and Le (2015); Radford et al.
has emerged as a very successful method, that achieves state-of-the-art results on many language tasks, including sentiment analysisSocher et al. (2013), natural language inference Williams et al. (2017) and similarity and paraphrase tasks Dolan and Brockett (2005); Cer et al. (2017). This approach incorporates a two-phase training procedure. The first phase utilizes an unsupervised training of a general language model on a large corpus. The second phase applies supervision to fine-tune the model for a given task.
More lately, unsupervised pretraining models such as BERT Devlin et al. (2018), XLNET Yang et al. (2019) and RoBERTa Liu et al. (2019), have achieved unprecedented performance, even exceeding human level of performance on some language tasks. For example, in the GLUE benchmark Wang et al. (2018), BERT Devlin et al. (2018) reported to achieve performance that exceeds human level on a few different datasets, such as QNLI Rajpurkar et al. (2016), QQP Chen et al. (2018) and MRPC Dolan and Brockett (2005). However, although the great progress achieved by these task-specific and dataset-specific models, it is not yet clear how well they can generalize to different tasks, and how robust they are when evaluating the same task on different datasets.
The most direct way to estimate the specificity of a learned model is by employing cross-benchmark experiments. These evaluations can be done by using datasets of the same task the model was specialized on (to measure robustness), or by utilizing datasets from different tasks (to measure generalization).
In our work, we build upon the multiverse method of Littwin and Wolf (2016)
, which was shown to lead to cross dataset robustness in the computer vision task of face recognition as well as on the CIFAR-100 small image recognition dataset. The multiverse loss generalizes the cross entropy loss, by simultaneously training multiple linear classifiers heads to perform the same task. In order to prevent multiple copies of the same classifier, in the multiverse scheme, each classifier is mutually orthogonal to the rest of classifiers. The number of multiverse heads used was limited, never more than seven and typically set to five.
We propose a novel fine-tuning procedure for enhancing the robustness of recent unsupervised pretraining language models, by employing a large number of multiverse heads. The essence of our technique is as follows: given a pretrained language model and a downstream task with labeled data, we fine-tune the model using a maximal number of multiverse classifiers. The fine-tuning goal is to both minimize the task loss and an orthogonality loss applied to the classifier heads. When enforcing orthogonality hinders the classifiers’ performance, we detect and eliminat the less effective classifier heads.
The technique therefore preserves a maximal set of classifiers, which comprises of the best performing ones. By maintaining this maximal subset during training, our method leverages multiverse loss to its fullest. Hence, we name our method Maximal Multiverse Learning (MML).
Our contributions are as follows: (1) we present MML, a general training procedure to improve the robustness of neural models. (2) We apply MML on BERT and report its performance on various datasets. (3) we propose a set of cross dataset evaluations using common NLP benchmarks, demonstrating the effectiveness of MML in comparison to regular BERT fine-tuning.
2 Related Work
Recent breakthroughs in the field of NLP are centered around unsupervised pretraining of language models. The different variants can be categorized by two main approaches: (1) feature-based models, such as Peters et al. (2018) and (2) fine-tuning models, such as Devlin et al. (2018); Liu et al. (2019); Yang et al. (2019). The former technique utilize a language neural based model as a feature extractor. The extracted features may be used for the training of another separate models, receiving the extracted features as input. The second approach, utilizes a similar pre-trained model, but fine-tune it in an end-to-end manner to specialize on a given task. During the fine-tuning phase, all of the parameters of the model are updated, as a relatively small number of parameters are trained from scratch.
The usage of multiple classifiers can be found in few places in the literature. In GoogLeNet Szegedy et al. (2015), the authors use multiple classifier heads in different places in the model architecture. The additional classifiers led for better propagation of gradients during training. However, with the advent of better conditioning and normalization methods, as well as with the modern introduction of skip connections in architectures such as the ResNet He et al. (2016), the practice of adding intermediate branches for the sake of introducing loss at lower levels was mostly abandoned.
The multiverse loss was shown to promote better transfer learning and to lead to a low-dimensional representation in the penultimate layerLittwin and Wolf (2016). However, the current literature does not present any methodological way to select the number of multiverse heads and the idea was only applied for a handful of parallel classifiers.
In MML, hundreds of multivese heads are used. An emphasis is put on the resulting multi-term loss settings, in which the classifier accuracy is contrasted with the orthogonality constraint. MML balances the two terms by pruning the multiverse classifiers that underperform during training.
The section presents the problem setup, the MML architecture and loss terms, and training algorithm.
3.1 Problem Setup
Let be the vocabulary of all supported tokens in a given language. Let be the set of all possible sentences that can be generated by . may also contain the empty sentence. We will define to be a language model, receiving pairs of elements from
and returning coding vectors ofdimensions. Given a dataset with training samples, , each associated with a label , we will denote the coding vector of each sample by where . As a concrete example, for the BERT model, is the latent embedding of the CLS token.
Common language models use classifier which projects the coding vectors by a matrix, , (), and then adds a bias term :
The output of
Different from other language models that uses a single classifier as defined above, our model utilizes a multiverse classifier defined as:
where is a multiplicity parameter, are parallel classifiers, each with different weights, applying the same function as Eq. 1. :
Additionally, we will define as a set of binary scalars. Each classifier head will be associated with a different binary scalar . The binary scalars will be set with concrete values during training (see Sec.3.3). In our experiments we set m to be equal to the coding vector size , which entails a full rank of active multiverse classifiers at the beginning of the training. All in all, our aggregated model composed of (, ).
3.2 The Loss Function
Our loss function is composed of two components, the task loss and the multiverse loss. The task loss is set according to the task in hand, and its essence is optimizing the performance of all active multiverse classifiers, each independently, using the supervision obtained by the given labels. The active multiverse classifiers are the ones that survive the dynamic elimination used during training (see Sec.3.3 for more details), and are associated with a value . The multiverse loss soft-enforces orthogonality among the active classifiers, and its purpose is to regularize the model by encouraging to produce coding vectors that are robust enough to be effective for a large number of orthogonal classifiers.
As mentioned earlier, each classifier is associated with a binary value , which controls the applicability the classifier and is configured during training. Under the context of the loss function, setting to would eliminate the impact of the classifier head for both the task loss and multiverse loss.
For a multi-class classification task we apply the following task loss:
where n is the number of training samples, and is the cross entropy loss
For a binary classification task we set and use the same loss from Eq. 5. For a regression task, we replace with :
The second loss term enforces orthogonality between the set of classifiers, for each class separately. In our work, orthogonality is being forced through the weights of each classifier, using the multiverse loss:
where is the th column of the weights matrix corresponds to classifier . As motioned, we use in order to allow the training algorithm to dynamically eliminate the less effective multiverse classifiers during training (see Sec. 3.3).
The total loss is defined as:
3.3 Maximal Multiverse Training
The training algorithm begins with an initialization of the aggregated model (). may be initialized by any pre-trained general language model. The multiverse classifiers are randomly initialized from scratch, and all classifiers are initially activated by setting for .
During training, we track the performance of each multiverse classifier separately. Every steps, we search for a subset of the top-performing classifiers. When we find such a subset, we eliminate the less performing classifiers by setting their corresponding s to 0.
In order to detect the top-performing subset of classifiers, we calculate a moving average variable for each multiverse classifier. Specifically, holds the moving average of the task loss value associated with classification head . is being updated for every training step, using the moving average momentum constant of 0.99.
During training and every steps, we run MeanShift algorithm Comaniciu and Meer (2002) on the set . MeanShift is a clustering algorithm that analyzes the underlying density function of the samples. The algorithm reveals the number of clusters in a given data, and retrieves the corresponding centroid for each detected cluster. By utilizing MeanShift, we define the subset of top-performing multiverse heads as the cluster associated with the minimal centroid value. Next, we eliminate the rest of the multiverse heads by setting their corresponding to 0. This adaptive elimination is stopped when we reach a minimal number of active heads, see Alg. 1.
At inference, we use the active multiverse heads to retrieve predictions. Specifically, given a sample , we calculate the logits as:
for classification tasks, we apply the softmax function on , and return its output. For regression tasks, we simply return .
In this study, we evaluate MML, applied on a pre-trained BERT Devlin et al. (2018) model, using nine NLP datasets while employing two different settings: (1) a straight forward fine-tuning on different downstream tasks from the GLUE benchmark Wang et al. (2018), and (2) cross dataset evaluations for different datasets of the same or similar task.
For the first, we fine-tune MML on each dataset separately, and evaluate its performance on the development set and the test set of the same dataset. For the second, we evaluate our fine-tuned MML models on the train and development set of other datasets within the same task category. This allows us to study the robustness level of all models, across different datasets.
In addition, we perform an ablation study and report empirical results that showcase the efficiency of MML and its variants, compared to a baseline BERT.
4.1 The Datasets
We adopt 8 datasets from the GLUE benchmark Wang et al. (2018), and one extra dataset supporting the task of Natural Language Inference (NLI). The datasets can be arranged by categories as follows.
4.1.1 Inference Tasks
RTE The Recognizing Textual Entailment dataset Bentivogli et al. (2009) is composed of sentence pairs gathered from various online news sources. The task is to predict whether the second sentence is an entailment of the first sentence (binary classification).
MNLI Multi-Genre Natural Language Inference Corpus Williams et al. (2017) is a dataset comprised of sentence pairs with textual entailment annotations. For each pair of sentences, the task is to determine whether the second sentence is a contradiction, neutral or entailment with respect to the first one (multiclass classification).
SNLI The Stanford Natural Language Inference dataset Bowman et al. (2015) also contains sentence pairs. The task here is identical to MNLI, with the same three labels (multiclass classification). However, the two datasets were gathered from different sources.
QNLI The Question-answering Natural Language Inference dataset Rajpurkar et al. (2016) contains question-sentence pairs. The task is to determine whether a sentence contains the answer to its corresponding question (binary classification).
4.1.2 Similarity and Paraphrase Tasks
This category contains three datasets.
MRPC Microsoft Research Paraphrase Corpus Dolan and Brockett (2005) is a dataset of sentence pairs taken from online news websites. The task is to determine whether a pair of sentences are semantically equivalent (binary classification).
QQP Quora Question Pairs Chen et al. (2018) is a dataset of questions pairs taken from Quora website. The goal is to determine whether a pair of questions are semantically equivalent (binary classification).
STS-B Semantic Textual Similarity Benchmark Cer et al. (2017) is a dataset composed of sentence pairs extracted from news headlines, video and image captions, and natural language inference data. Each pair is annotated with a score between 1 and 5, indicating the semantic similarity level of both sentences. The task is to predict these scores (regression).
4.1.3 Misc. datasets
There are two datasets in this category. The two datasets are not used for the cross dataset evaluation, due to the lack of commonality between their tasks.
CoLA The Corpus of Linguistic Acceptability dataset Warstadt et al. (2018) consists of examples of expert English sentence acceptability judgments, which were drawn from multiple books. Each sample in this dataset is a string containing English words annotated by whether it is a grammatically sentence of English (binary classification).
SST-2 The Stanford Sentiment Treebank Socher et al. (2013) is a dataset composed of sentences extracted from movie reviews. The sentences are assigned with a human annotations of their sentiment, and the task is to determine whether the sentiment of each sentence is positive or negative (binary classification).
4.2 evaluation on GLUE benchmark
We evaluated MML on the eight different datasets from the GLUE benchmark, and compared to BERT Devlin et al. (2018). In addition, we conduct an ablation analysis for our method, presenting the importance of our Maximal Multiverse Training, which allows the training to adapt the number of multiverse classifiers to each dataset. The ablation disables the classifier elimination step during training, and utilizes the same MML architecture with a fixed number of heads. Each models was trained and evaluated on a single dataset. Development and test set performance are being reported for each model.
4.2.1 The Models
The BERT model we are using is the BERT-Large model from Devlin et al. (2018). It contains 24 attention layers, each with 16 attention heads with a hidden layer size of dimensions. The model was pre-trained using sentence pairs, to both reconstruct masked words and to predict whether sentence pairs are consecutive. BERT’s fine-tunning for downstream tasks employs supervision obtained by the given labels of each dataset.
MML utilizes a pre-trained BERT-Large model, and fine-tunes it via Maximal Multiverse Training, to minimize the loss presented above. During training, the MML model is initialized with active multiverse classifiers, which is equal to the hidden layer size . During training, the model converges to a smaller number of multiverse classifiers. The number of active multiverse classifiers of each model are presented in last row of Tab. 1.
Tab, 1 presents the results for the following models: (1) BERT (used as a baseline), (2) MML, (3-4) ablation models of two multiverse models utilizing a fixed number of multiverse classifiers, with 5 and 1024 classifiers, respectively.
As can be seen in the table, compared to BERT, MML yields significantly better results on the test set of four out of eight datasets. The largest gains were reported in the relatively smaller sized dataests, such as RTE, MRPC and STS-B, for whom MML yileds an absolute improvement of 4.1, 1.3, 1.6 points, respectively. This can be attributed to the ability of MML to encourage a more robust learning. On the rest of the datasets, MML yields similar performance on the test, besides CoLA for which a degradation of 1.9 points is reported. On the development set, MML outperforms BERT on all datasets, sometimes by a large margin. Specifically, for RTE and CoLA, MML yields an improvement of almost ten and four points, respectively.
The ablation models MV-5 and MV-1024, utilize a fixed number of multiverse heads during the entire training. We have found that this hyper parameter can be crucial for model convergence, and when not initialized properly, may significantly reduce performance for the given task in hand. Specifically, for the CoLA dataset, MV-1024 and MV-5 yield a relative performance gap of more than 11%, in favor of MV-5, while in RTE, there is a gap of 6.2% in favor of MV-1024. When comparing both MV-5 and MV-1024 to MML, MML produce better or similar performance on the development set of all datasets. More specifically, on RTE and MRPC, MML yields similar performance as in MV-1024, and outperforms it on all the other six datasets. Compared to MV-5, MML yields significantly better performance on four datasets out of eight, and produce similar performance on the rest.
Fig. 2 presents the amount of active multiverse heads when applying MML on QNLI dataset. During the training of MML-QNLI model, the MeanShift algorithm detected multiple clusters at three times111 The elimination is being invoked every time the MeanShift algorithm detects multiple clusters. Specifically, for MML-QNLI experiment, multiple clusters appeared three times during the training process. , through the entire training. Each time, the model eliminated the less performing subsets, and kept the top-performing multiverse classifiers as the active set of classifiers. The model achieved best performance on the development set at training step 85K. At this step, MML-QNLI model utilized 23 active multiverse heads. The plots in the figure present the cumulative loss of each multiverse head, sorted through the X axis according to the indices of the active heads. The red stars associated with the classifiers heads that were eliminated, and the green stars are the heads that were selected as the top-performing subset.
4.3 Cross Dataset Evaluations
To study the robustness level of all models, we perform cross dataset evaluations. In these evaluations, we use the fine-tuned MML models from Tab. 1. For each model trained on a dataset from the two first categories above (Sec. 4.1), we evaluate the model on all datasets from the same category.
Train and development set performances are reported to give a clear view on the robustness and stability of the models, and also to exhibit the level of overfitting when evaluating on the same dataset each model was trained on.
In order to conduct a clean comparison, we finetune BERT with the same hyperparmeters used for MML. Specifically, for the MML we employ 10 epochs for the relative larger datasets, 30 epochs for the medium sized datasets, and 100 epochs for the smaller-sized datasets. All models were trained with a batch size of 32, and a learning rate of 2e-5. Our code can be found athttps://github.com/ItzikMalkiel/MML.
4.3.1 Cross Inference Datasets Evaluation
First, we present performance on different inference datasets. We fine-tune both BERT and MML on each dataest separately, and evaluate on four NLI datasets: RTE, MNLI, SNLI, QNLI. Since MNLI and SLNI are multicalss classification tasks with 3 classes, we collapse the labels ”neutral” and ”contradication” into one label (”non entailment”). This modification, applied only in inference, allows us to evaluate MNLI and SNLI models on RTE and QNLI models, and vice versa.
The results are reported in Tab. 2. As can be seen, MML exhibits a significantly improved robustness compared to BERT. Each row in the table represent MML or BERT model trained on a single dataset associated by its name. All models are evaluated on all four datasets. In the last column, we report the relative average improvement obtained by MML, calculated by the performance ratio between MML and BERT across all three holdout datasets. For example, for RTE, our MML-RTE model yields 9.9% relative average performance on the train set of MNLI, QNLI and SNLI, and a 9.5% average improvement on the development set of these datasets.
4.3.2 Cross Similarity and Paraphrase Datasets Evaluation
Next, we conduct cross dataset evaluations on the three datasets for the similarity and paraphrase task. We fine-tune MML and BERT for the datasets MRPC, QQP and STS-B. More specifically, to allow cross evaluations between these models, and since STS-B is a regression task benchmark, while MRPC and QQP address a binary classification task, we adapt STS-B to form a binary classification task. The adaptation is being done by collapsing the labels in the range 1-2 (4-5) to the value of 0 (1). In addition, we omit all the ambiguous samples associated with label values between 2 and 4. This modification to STS-B allows us to identify a distinct set of similar and non-similar sentence pairs. The modified STS-B forms a binary classification dataset with 3.5K samples.
As can be seen in Tab. 3, MML yields better performance on the cross evaluations for the similarity and paraphrase datasets. Similar to Tab. 2, each row represent a single model trained on a single dataset. We evaluate all models on all three datasets, and report the average relative improvement obtained by MML calculated on the two holdout datasets. We have found MML to produce improved performance for all models, for example, MML-MRPC yields a +3.5% average improvement calculated on both train and development sets across STS-B* and QQP.
4.4 Discussion of results
As can be seen in both Tab. 2 and 3, conducting the cross dataset evaluations reveals a significant gap in performance for all models when evaluated on holdout datasets, although the holdout datasets share the same or similar task each model was trained for. For example, both MML-MRPC and BERT-MRPC models yield a 20% degradation in absolute accuracy on RTE dataset. Yet, compared to BERT, our MML method produces significantly better performance on the cross evaluations. Specifically, when evaluated on QQP, MML-MRPC outperforms BERT-MRPC by a relative improvement of 4.6%, for both development and train set.
Perhaps unintuitively, there is no direct link between the improvement obtained on the same dataset evaluation to that obtained in the cross dataset one. For example, our MML-QNLI model was able to outperform BERT-QNLI in the cross dataset evaluation, although it exhibits a somewhat degraded performance on QNLI’s development set and test set. We attribute this to the ability of MML to encourage the model to produce more robust coding vectors.
In this work, we introduce MML: a method for fine-tuning general language models, that is based on Multiverse loss. MML utilizes a large set of parallel multiverse heads, and eliminates the relatively weaker heads during training. The heads eliminations, employed through the entire course of training, assures the use of a maximal set of top-performing multiverse heads.
We demonstrate the effectiveness of MML on nine common NLP datasets, by applying inter- and intra- datasets evaluation, where it is shown to outperform the originally introduced BERT model. our results shade light on the robustness level of both models, and showcase the ability of MML to yield improved robustness.
- The fifth pascal recognizing textual entailment challenge.. In TAC, Cited by: §4.1.1.
- A large annotated corpus for learning natural language inference. arXiv preprint arXiv:1508.05326. Cited by: §4.1.1, §4.1.1.
- Semeval-2017 task 1: semantic textual similarity-multilingual and cross-lingual focused evaluation. arXiv preprint arXiv:1708.00055. Cited by: §1, §4.1.2.
- Quora question pairs. Cited by: §1, §4.1.2.
- Mean shift: a robust approach toward feature space analysis. IEEE Transactions on Pattern Analysis & Machine Intelligence (5), pp. 603–619. Cited by: §3.3.
- Semi-supervised sequence learning. In Advances in neural information processing systems, pp. 3079–3087. Cited by: §1.
- Bert: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805. Cited by: §1, §2, §4.2.1, §4.2, Table 1, Table 2, §4.
- Automatically constructing a corpus of sentential paraphrases. In Proceedings of the Third International Workshop on Paraphrasing (IWP2005), Cited by: §1, §1, §4.1.2.
Deep residual learning for image recognition.
Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778. Cited by: §2.
- Universal language model fine-tuning for text classification. arXiv preprint arXiv:1801.06146. Cited by: §1.
- The multiverse loss for robust transfer learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3957–3966. Cited by: §1, §2, §3.2.
- Roberta: a robustly optimized bert pretraining approach. arXiv preprint arXiv:1907.11692. Cited by: §1, §2.
- Deep contextualized word representations. arXiv preprint arXiv:1802.05365. Cited by: §1, §2.
-  Improving language understanding by generative pre-training. Cited by: §1.
- Squad: 100,000+ questions for machine comprehension of text. arXiv preprint arXiv:1606.05250. Cited by: §1, §4.1.1.
- Recursive deep models for semantic compositionality over a sentiment treebank. In Proceedings of the 2013 conference on empirical methods in natural language processing, pp. 1631–1642. Cited by: §1, §4.1.3.
- Going deeper with convolutions. In Computer Vision and Pattern Recognition (CVPR), External Links: Cited by: §2.
- Glue: a multi-task benchmark and analysis platform for natural language understanding. arXiv preprint arXiv:1804.07461. Cited by: §1, §4.1.1, §4.1, Table 1, §4.
- Neural network acceptability judgments. arXiv preprint arXiv:1805.12471. Cited by: §4.1.3.
- A broad-coverage challenge corpus for sentence understanding through inference. arXiv preprint arXiv:1704.05426. Cited by: §1, §4.1.1.
- XLNet: generalized autoregressive pretraining for language understanding. arXiv preprint arXiv:1906.08237. Cited by: §1, §2.