In recent years, deep learning approaches have achieved state of the arts on many NLP tasks, such as machine translationVaswani et al. (2017) et al. (2017)
, and sentiment analysisShen et al. (2017)
. However, they are usually trained from scratch. On one hand, training deep networks based on recurrent neural network (RNN) or convolutional neural network (CNN) requires significant amount of manually labeled text corpus and substantial time to converge, which usually cannot be satisfied by the customers. On the other hand, in comparison to NLP tasks, people can get a high performing model on supervised computer vision (CV) tasks with a very small dataset. So far, the most commonly used pre-training strategy for text is using pre-trained word embeddings,e.g., McCann et al. (2017); Pennington et al. (2014), which helps to boost the performance but cannot alleviate the requirement of having sufficiently large labeled data.
Transfer learning methods have been widely used in CV. Pre-trained models that trained on ImageNet dataset can be applied to fine-tuning image classification models, or many other CV applications, on datasets coming from similar but different distributions. With fine-tuning techniques, the target model can perform well even with a small dataset. Moreover, the model training can be completed fairly quickly as well if the dataset size is reasonable.
However, for NLP applications, fine-tuning has not seen much success apart from using pre-trained word embeddings until very recently. Some works Devlin et al. (2018); Howard and Ruder (2018); Peters et al. (2018); Radford et al. (2018)
pre-trained a neural language model on large scale datasets and then applied it to the target tasks. With suitable unsupervised pre-training and fine-tuning algorithms, one can get improvement of the performance on various downstream tasks. Language model (LM) is a broadly useful tool for building NLP systems. For a long time, it has been found useful for improving the quality of generated text from machine translationSchwenk et al. (2012); Vaswani et al. (2013), speech recognition Arisoy et al. (2012); Mikolov et al. (2010), and text summarization Filippova et al. (2015); Rush et al. (2015). More excitingly, it has been recognized recently that Neural Language Model can be used as a powerful feature extractor for texts when it is trained on a large amount of unannotated data, e.g., Peters et al. (2017, 2018); Radford et al. (2017).
In this paper, we aim to propose an end-to-end fine-tuning algorithm, given the pre-trained language model, on the text classification tasks. The contributions of our work lie in the following aspects:
We compare the performance of different language model pre-trained on different datasets, trying to demonstrate what factor of the pre-trained model is the most important to the downstream tasks performance.
We adopt a self-attention mechanism to extract significant contextualized features from the pre-trained LM and use those features as the representations to do classification tasks. The experimental results show that we can achieve new state-of-the-arts on all the six datasets we have tested.
We introduce the multi-task learning approach, which simultaneously fine-tunes the classifier and pre-trained LM on the target classification datasets, leading to an end-to-end algorithm given the pre-trained model. By doing multi-task learning, we can largely reduce the total training time.
2 Related Work
There has been a recent resurgence on using language models as the starting point to generate contextualized word or sentence embeddings which can later be used, via fine tuning, to train on the specific end applications, e.g., classification, entailment, and so on.
The ELMo model Peters et al. (2018) introduces a new type of deep contextualized word representations. The representations are learned functions of the internal states of a deep bidirectional language model, which is pre-trained on a large text corpus, i.e., One Billion Word Benchmark dataset Chelba et al. (2013). Then for a specific downstream task, ELMo can be easily added to existing models and improve the performance. This kind of method is known as hypercolumn in the domain of transfer learning Peters et al. (2017).
Furthermore, the ULMFiT (Universal Language Model Fine-tuning) algorithm Howard and Ruder (2018)
is an effective transfer learning method that can be applied to a variety of classification problems in NLP, which introduces techniques that are key for fine-tuning a language model. ULMFiT consists of three steps, namely general-domain LM pre-training, target task LM fine-tuning, and target task classifier fine-tuning. Different from the ELMo model which incorporates the embeddings to the task-specific neural architectures, ULMFiT employs a simple two-block classifier for all downstream tasks. The input of the classifier is a concatenation of the hidden state at the last time step of the document with both the max-pooled and the mean-pooled representation of the hidden states over time steps. Although this method is general enough and has achieved state-of-the-art results on many classification datasets, the concatenation representation is considered lack of representativeness.
Most recently, there are a couple of methods adopting a multi-layer Transformer Vaswani et al. (2017), which is a multi-headed self-attention based model, for the language model pre-training. OpenAI GPT (Generative Pre-trained Transformer) Radford et al. (2018) uses a left-to-right Transformer and introduces the supervised fine-tuning with only one extra output layer. In this work, the authors propose a novel task-specific input transformations, which converts structured inputs into an ordered sequence that can fit into the pre-trained language model. Moreover, in the supervised fine-tuning, they also include language modeling as an auxiliary objective in order to improve generalization of the supervised model. Besides, the BERT (Bidirectional Encoder Representations from Transformers) model Devlin et al. (2018)
employs a bidirectional Transformer and is pre-trained using two novel unsupervised prediction tasks, namely Masked LM (MLM) and Next Sentence Prediction. MLM trains the LM with masking 15% of tokens in each sequence at random, and Next Sentence Prediction pre-trains a binarized next sentence prediction task to make the model understand sentence relationships.
In this paper, we use ULMFiT as the main method for comparison for the following reasons. On one hand, although OpenAI GPT and BERT might perform better on different types of language understanding tasks, we only focus on the text classification tasks in this work as ULMFiT does. On the other hand, our goal is to demonstrate the effectiveness of the attention mechanism in fine-tuning. Although the transformer networks have strong ability to capture longer-range linguistic structure, we aim to use variants of long short-term memory (LSTM) networkHochreiter and Schmidhuber (1997) in our LM due to its simplicity for implementation and efficiency in training.
In this section, we will illustrate our training procedure, which is also shown in Figure 1. Our training consists of four steps. We first pre-train a language model on a large scale text corpus. Then the pre-trained model is fine-tuned by the downstream classification dataset on unsupervised language modeling tasks. Moreover, we use an attention-based decoder to build our classification model and train the classifier using target task datasets.
3.1 Unsupervised Pre-training
Given a sequence of tokens
, a language model computes the probability of the sequence by modeling the probability of tokengiven the history :
Assume we have a vectorencoding the history . Then the conditional probability of a token can be parametrized as
where is the weight matrix to be learned. In a recurrent neural network, the hidden state is usually computed from and , namely
In this paper, we aim to use AWD-LSTM (Merity et al., 2017) to model the conditional probability. The update of the hidden state from an LSTM can be defined as
where are weight matrices, is the current exposed hidden state, is the current memory cell state, and denotes the element-wise multiplication. In order to prevent overfitting within the recurrent connections of an RNN, the AWD-LSTM performs the DropConnect (Wan et al., 2013) on the hidden-to-hidden matrices within the LSTM.
For the training, we follow a standard language modeling objective to maximize the log-likelihood:
where is defined in (2) and is the parameter set from AWD-LSTM to be learned.
3.2 End-to-end Fine-tuning
Note that the data of the downstream classification task usually comes from a distribution that is different from the pre-training data. In order to apply the pre-trained language model, we follow (Howard and Ruder, 2018) by introducing the target task LM fine-tuning using the target classification dataset without labels (Step 2 in Figure 1).
For classification, we adopt an attention-based encoder-decoder structure (Figure 1, Step 3). Self-attention has been widely used in a variety of tasks including reading comprehension Cheng et al. (2016), textual entailment (Parikh et al., 2016), and abstractive summarization (Paulus et al., 2017). As the encoder, our pre-trained model learns the contextualized features from inputs of the dataset. Then the hidden states over time steps, denoted as , can be viewed as the representation of the data to be classified, which are also the input of attention layer. Since we do not have any additional information from the decoder, we use the self-attention to extract the relevant aspects from the input states. Specifically, the alignment is computed as
for , where and are the weight matrix and bias term to be learned. Then the alignment scores are given by the following soft-max function:
The final context vector, which is also the input of the decoder (classifier), is computed by
For the classifier, we follow (Howard and Ruder, 2018)
and the standard practice for CV classifiers, namely two additional linear blocks with batch normalizationIoffe and Szegedy (2015) and dropout Srivastava et al. (2014)
, and ReLU activations for the intermediate layer and a soft-max activation for the output layer that calculates a probability distribution over target classes. Assume the output of the last linear block is. Moreover, denote by the target classification data, where , is the input sequence of tokens and is the corresponding label. Then the classification loss we use to train the model (Figure 1, Step 4) can be computed by
However, for some large target datasets, such as Yelp and Sogou News, the LM fine-tuning (Figure 1, Step 2) can take up to few days. Therefore, we can fine-tune the pre-trained model directly on the classification task, but in a sense of multi-task learning. More specifically, we combine the LM fine-tuning and classification fine-tuning in one objective optimization, which leads to an end-to-end fine-tuning. In other words, instead of Eq. (9), we aim to maximize the following objective:
where is the pre-defined weight. More discussion on the multi-task learning can be found in Section 4.4.
4.1 Experimental Setup
|Dataset||# classes||# examples||Avg. sequence length||# test|
We evaluate our model on six widely studied datasets, with varying document length and number of classes. The statistics for each dataset are presented in Table 1. Note that the average sequence length is the average number of tokens after data pre-processing.
For topic classification, we evaluate on the large-scale Sogou news, AG news and DBpedia ontology datasets Zhang et al. (2015). For sentiment analysis, we evaluate our approach on the binary movie review IMDb dataset Maas et al. (2011) and the binary and five-class version of the Yelp review dataset Zhang et al. (2015). In addition, we use the same pre-processing as in earlier work Howard and Ruder (2018).
We first pre-train our LM with different model architectures and datasets in order to demonstrate what factor of pre-training is significant to the downstream tasks performance. Then we demonstrate the effectiveness of our self-attention mechanism by comparing the classification performance with ULMFiT using the same pre-trained LM provided by ULMFiT in Howard and Ruder (2018). We should remark here that we can obtain similar accuracy with the LM pre-trained by our own. At last, we present the classification results from multi-task learning and discuss the trade-offs between classification accuracy and training efficiency.
4.2 Unsupervised Pre-training
|LSTMP||One Bn Word||1.1B||44||91.39%|
Regarding the language model pre-training, we have tried three different ways: 1) LSTM with projection layer (LSTMP) Sak et al. (2014) on One Billion Word Benchmark dataset; 2) AWD-LSTM on WikiText-103 dataset Merity et al. (2016); and 3) AWD-LSTM on WikiText-2 dataset Merity et al. (2016).
For LSTMP, we use word embedding dimension of 512, a one-layer LSTM with hidden size 2048 and projected to 512-dimensional output, while for AWD-LSTM we use word embedding dimension of 400 and a 3-layer AWD-LSTM with hidden size 1150. In order to evaluate the performance of pre-trained language models, we add an attention-based classifier on the top of LMs, and fine-tuned the model over AG’s News dataset. The results of pre-training and classification fine-tuning are presented in Table 2.
From the results in Table 2, we can have the following observations. On one hand, although LSTMP on One Billion Word dataset performs best in the pre-training, it obtains the lowest accuracy in the classification task, which may indicate the problem of overfitting. In comparison to the size of target data (AG news), our language models are large enough (have enough parameters). In this case, AWD-LSTM is more suitable than LSTMP because it adopts DropConnect for weight matrices of hidden-to-hidden states, which mitigates the overfitting.
On the other hand, pre-training on larger source datasets does not always improve downstream task performance. WikiText-2 is a subset of WikiText-103 and it is much smaller than WikiText-103 or One Billion Word dataset. But pre-training on WikiText-2 leads to the best performance among the three. We can see that the size of source data is not significant once we have a large enough pre-trained dataset. This observation indicates the possibility that when the source dataset is large enough, the performance of language modeling is a significant factor on transfer learning.
4.3 Self-attention Mechanism
We then carry out experiments to demonstrate that our self-attention mechanism is universally effective in the fine-tuning. In order to compare with ULMFiT, we use the pre-trained language model provided by ULMFiT and follow Howard and Ruder (2018) to fine-tune the LM on target datasets before incorporating the classifier. Table 3 shows the classification performance on various datasets.
|Dataset||Error rate||Error rate||Error rate||Error rate||Improvement||Improvement|
|(from scratch)||(SOTA)||(ULMFiT)||(self-attention)||over ULMFiT||over SOTA|
|AG||6.53||5.29 (Howard and Ruder (2018))||5.54||5.17||6.68%||2.27%|
|IMDb||9.86||5.00 (Howard and Ruder (2018))||5.08||4.59||9.65%||8.20%|
|DBpedia||1.00||0.84 (Johnson and Zhang (2016))||0.87||0.80||8.05%||4.76%|
|Yelp-bi||2.91||2.64 (Johnson and Zhang (2017))||2.37||1.97||16.88%||25.38%|
|Yelp-full||31.11||30.58 (Johnson and Zhang (2017))||30.73||28.86||6.05%||5.59%|
|Sogou||2.50||1.84 (Johnson and Zhang (2017))||2.26||1.69||25.22%||8.15%|
In Table 3, the second column presents the results of training from scratch. The third column presents the previous state-of-the-art (SOTA) results that ever found in the literature. Note that ULMFiT Howard and Ruder (2018) achieves all the SOTA results except on Sogou news with a combination of both forward and backward path LMs. However, in our experiments, the comparison is based on the single model. The error rates in the fourth column are the results we obtain by running the implementation from Howard and Ruder111https://github.com/fastai/fastai. The last two columns computes the relative improvements of our methods over ULMFiT and previous SOTA results.
From Table 3, we can observe that by adding a self-attention layer, our model advances SOTA accuracies, as well as the performance of ULMFiT, on all the six datasets. The relative improvement can be as high as more than 25%. In order to see the effectiveness of the attention layer more clearly, we visualize the attention scores with respect to the input texts on AG news. The randomly chosen examples of visualization with respect to different classes are given in Figure 2, where darker color means higher attention scores. Note that some tokens, such as <xbos> and <xfld 1>, represent the sentence tags that we used for data pre-processing and are not specific words in the original documents.
In our experiments, we have also tried ensemble self-attentions and multi-head attentions, the improvement is similar. Therefore, we adopted the simplest self-attention layer in our model architecture.
4.4 Multi-task Learning
As mentioned in Section 3.2, a separate target task LM fine-tuning is time costly. We list the LM fine-tuning time and classification fine-tuning time in Table 4 according to the experiments we presented in last section. From the table, we can see that the LM fine-tuning can take as much as more than half of the total training time.
In order to make the training more efficient, we can fine-tune the two parts together, which leads to a multi-task learning with the objective Eq. (11). Figure 3 showcases the structure of our multi-task learning algorithm, which consists of two steps. We first pre-train a language model on a large scale text corpus. We then use two decoders to build the target model, one attention-base decoder for the classifier and one simple linear-block decoder for incorporating the language modeling loss. In the experiments, we choose the weight in Eq. (11) to be .
We test the multi-task learning on AG news. The total training time until convergence is reduced from 6.5 hours to 3.5 hours. However, we get an error rate of 5.49%, which is slightly better than ULMFiT but not as good as SOTA result. Therefore, customers should be aware that by using the multi-task learning framework, we have to trade off some classification accuracy for the training efficiency.
|Dataset||LM fine-tuning||classification fine-tuning||Ratio of LM|
|time (hours)||time (hours)||fine-tuning|
In this work, we have proposed an attention-based fine-tuning algorithm which provides a reliable and easy-to-use feature extractor from the pre-trained language model and uses those features for downstream text classification tasks. The performance of the proposed algorithm advances state-of-the-art methods on various benchmark datasets. With this algorithm, the customers can use the given language model and fine-tune the target model by their own data. In addition, the customers can also adopt another version of our algorithm, i.e., the approach of multi-task learning, for faster training if they allow a slight reduction of the model performance.
Deep neural network language models.
Proceedings of the NAACL-HLT 2012 Workshop: Will We Ever Really Replace the N-gram Model? On the Future of Language Modeling for HLT, pp. 20–28. Cited by: §1.
- One billion word benchmark for measuring progress in statistical language modeling. arXiv preprint arXiv:1312.3005. Cited by: §2.
- Long short-term memory-networks for machine reading. arXiv preprint arXiv:1601.06733. Cited by: §3.2.
- Bert: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805. Cited by: §1, §2.
- Sentence compression by deletion with lstms. In Proceedings of the 2015 Conference on Empirical Methods in Natural Language Processing, pp. 360–368. Cited by: §1.
- Long short-term memory. Neural computation 9 (8), pp. 1735–1780. Cited by: §2.
- Universal language model fine-tuning for text classification. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), Vol. 1, pp. 328–339. Cited by: §1, §2, §3.2, §3.2, §4.1, §4.1, §4.3, §4.3, Table 3.
- Batch normalization: accelerating deep network training by reducing internal covariate shift. arXiv preprint arXiv:1502.03167. Cited by: §3.2.
- Convolutional neural networks for text categorization: shallow word-level vs. deep character-level. arXiv preprint arXiv:1609.00718. Cited by: Table 3.
- Deep pyramid convolutional neural networks for text categorization. In Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), Vol. 1, pp. 562–570. Cited by: Table 3.
- Learning word vectors for sentiment analysis. In Proceedings of the 49th annual meeting of the association for computational linguistics: Human language technologies-volume 1, pp. 142–150. Cited by: §4.1.
- Learned in translation: contextualized word vectors. In Advances in Neural Information Processing Systems, pp. 6294–6305. Cited by: §1.
- Regularizing and optimizing lstm language models. arXiv preprint arXiv:1708.02182. Cited by: §3.1.
- Pointer sentinel mixture models. arXiv preprint arXiv:1609.07843. Cited by: §4.2.
- Recurrent neural network based language model. In Eleventh Annual Conference of the International Speech Communication Association, Cited by: §1.
A decomposable attention model for natural language inference. arXiv preprint arXiv:1606.01933. Cited by: §3.2.
- A deep reinforced model for abstractive summarization. arXiv preprint arXiv:1705.04304. Cited by: §3.2.
- Glove: global vectors for word representation. In Proceedings of the 2014 conference on empirical methods in natural language processing (EMNLP), pp. 1532–1543. Cited by: §1.
- Semi-supervised sequence tagging with bidirectional language models. arXiv preprint arXiv:1705.00108. Cited by: §1, §2.
- Deep contextualized word representations. arXiv preprint arXiv:1802.05365. Cited by: §1, §2.
- Learning to generate reviews and discovering sentiment. arXiv preprint arXiv:1704.01444. Cited by: §1.
- Improving language understanding by generative pre-training. Cited by: §1, §2.
- A neural attention model for abstractive sentence summarization. arXiv preprint arXiv:1509.00685. Cited by: §1.
- Long short-term memory recurrent neural network architectures for large scale acoustic modeling. In Fifteenth annual conference of the international speech communication association, Cited by: §4.2.
- Large, pruned or continuous space language models on a gpu for statistical machine translation. In Proceedings of the NAACL-HLT 2012 Workshop: Will We Ever Really Replace the N-gram Model? On the Future of Language Modeling for HLT, pp. 11–19. Cited by: §1.
- Disan: directional self-attention network for rnn/cnn-free language understanding. arXiv preprint arXiv:1709.04696. Cited by: §1.
Dropout: a simple way to prevent neural networks from overfitting.
The Journal of Machine Learning Research15 (1), pp. 1929–1958. Cited by: §3.2.
- Attention is all you need. In Advances in Neural Information Processing Systems, pp. 5998–6008. Cited by: §1, §2.
- Decoding with large-scale neural language models improves translation. In Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing, pp. 1387–1392. Cited by: §1.
- Regularization of neural networks using dropconnect. In International Conference on Machine Learning, pp. 1058–1066. Cited by: §3.1.
- Character-level convolutional networks for text classification. In Advances in neural information processing systems, pp. 649–657. Cited by: §4.1.
- Selective encoding for abstractive sentence summarization. arXiv preprint arXiv:1704.07073. Cited by: §1.