Text classification [sebastiani2002machine], such as classification of emails into spam and not-spam [shams2014semi]kumar2019submodular]tang2015document], topic classification [tong2001support], and relation classification [giridharastudy].
Depending upon the problem at hand, getting a good fit for a classifier model may require abundant labeled data [shams2014semi]. However, in many cases, and especially when developing AI systems for specific applications, labeled data is scarce and costly to obtain.
One example of text classification is intent classification in the growing market of automated chatbot platforms [collinaszy2017implementation]. A developer of an intent classifier for a new chatbot may start with a dataset containing two, three, or five samples per class, and in some cases no data at all.
Data augmentation [wong2016understanding]
is a common strategy for handling scarce data situations. It works by synthesizing new data from existing training data, with the objective of improving the performance of the downstream model. This strategy has been a key factor in the performance improvement of various neural network models, mainly in the domains of computer vision and speech recognition. Specifically, for these domains there exist well-established methods for synthesizing labeled-data to improve classification tasks. The simpler methods also apply transformations on existing training examples, such as cropping, padding, flipping, and shifting along time and space dimensions, as these transformations are usually class preserving[krizhevsky2012imagenet, cui2015data, ko2015audio, szegedy2015going].
However, in the case of textual data, such transformations usually invalidate and distort the text, making it grammatically and semantically incorrect. This makes data augmentation more challenging. In fact, textual augmentation could even do more harm than good, since it is not an easy task to synthesize good artificial textual data. Thus, data augmentation methods for text usually involve replacing a single word with a synonym, deleting a word, or changing the word order, as suggested by [wei2019eda].
Recent advances in text generation models [radford2018improving, kingma2014auto] facilitate an innovative approach for handling scarce data situations. Although improving text classification in these situations by using deep learning methods seems like an oxymoron, pre-trained models [radford2018improving, peters2018elmo, devlin2019bert] are opening new ways to address this task.
In this paper, we present a novel method, referred to as language-model-based data augmentation (LAMBADA), for synthesizing labeled data to improve text classification tasks.
LAMBADA is especially useful when only a small amount of labeled data is available, where its results go beyond state-of-the-art performance.
Models trained with LAMBADA exhibit increased performance compared to:
1) The baseline model, trained only on the existing data
2) Models trained on augmented corpora generated by the state-of-the-art techniques in textual data augmentation.
LAMBADA’s data augmentation pipeline builds upon a powerful language model: the generative pre-training (GPT) model [radford2018improving]. This neural-network model is pre-trained on huge bodies of text. As such, it captures the structure of natural language to a great extent, producing deeply coherent sentences and paragraphs. We adapt GPT to our needs by fine-tuning it on the existing, small data. We then use the fine-tuned model to synthesize new labeled sentences. Independently, we train a classifier on the same original small dataset and use it to filter the synthesized data corpus, retaining only data that appears to be qualitative enough. We then re-train the task classifier on both the existing and the synthesized data.
We compare LAMBADA to other data augmentation methods and find it statistically better along several datasets and classification algorithms. We mainly focus on small datasets, e.g., containing five examples per class, and show that LAMBADA significantly improves the baseline in such scenarios.
In summary, LAMBADA contributes along three main fronts:
Statistically improves classifiers’ accuracy.
Outperforms state-of-the-art data augmentation methods in scarce-data situations.
Suggests a compelling alternative to semi-supervised techniques when unlabeled data does not exist.
The rest of this paper is structured as follows: In Section 2, we present related work for the state-of-the-art in textual data augmentation techniques and the recent advances in generative pre-trained modeling. In Section 3, we define the problem of data augmentation for text classification, and in Section 4, we detail our LAMBADA method solution. In Section 5, we describe the experiments and results we conducted to analyze LAMBADA performance and to support the paper’s main claims. We conclude with a discussion in Section 6.
2 Related Work
Previous textual data augmentation approaches focus on sample alteration [kobayashi2018contextual, wu2019conditional, wei2019eda, mueller2016siamese, jungiewicz2019towards], in which a single sentence is altered in one way or another, to generate a new sentence while preserving the original class. One set of these approaches make local changes only within a given sentence, primarily by synonym replacement of a word or multiple words. One of the recent methods in this category is easy data augmentation (EDA) [wei2019eda], which uses simple operations such as synonym replacement and random swap [miller1995wordnet]. Another method, conditional BERT contextual augmentation recently introduced in [wu2019conditional], proposes fine-tuned BERT [devlin2019bert] for data augmentation by carrying out a masked prediction of words, while conditioning on the class label. Presumably, methods that make only local changes will produce sentences with a structure similar to the original ones, thus yielding low corpus-level variability.
Other recent possible approaches to textual data augmentation generate whole sentences rather than making a few local changes.
The approaches include using variational autoencoding
variational autoencoding(VAE) [kingma2014auto], round-trip translation [yu2018qanet], paraphrasing [kumar2019submodular], and methods based on generative adversarial networks [tanaka2019data]. They also include data noising techniques, such as altering words in the input of self-encoder networks in order to generate a different sentence [xie2017data, zolna2017fraternal, li2018undeepvo], or introducing noise on the word-embedding level. These methods were analyzed in [marivate2019improving]. Although a viable option when no access to a formal synonym model exists, they require abundant training data.
Last year, several exciting deep learning methods [vaswani2017attention]
pushed the boundaries of natural language technology. They introduced new neural architectures and highly effective transfer learning techniques that dramatically improve natural language processing. These methods enable the development of new high-performance deep learning models such asELMO [peters2018elmo], GPT [radford2018improving], BERT [devlin2019bert], and GPT-2 [radford2019language]. Common to these models is a pre-train phase, in which the models are trained on enormous bodies of publicly available text, and a fine-tuned phase
, in which they are further trained on task-specific data and loss functions.
When introduced, these models processed natural language better than ever, breaking records in a variety of benchmark tasks related to natural language processing and understanding, as well as tasks involving text generation. For example, when GPT was first introduced [radford2018improving], it improved the state-of-the-art in 12 benchmark tasks, including textual entailment, semantic similarity, sentiment analysis, and commonsense reasoning. These models can produce high-quality sentences even when fine-tuned on small training data. Table 1, shows an example of a few generated sentences based on a small dataset consisting of five sentences per class.
|Flight time||what time is the last flight from san francisco to washington dc on continental|
|what is the flight schedule from oakland to san francisco on Tuesday morning|
|show me the flight schedule between san francisco and washington dc on friday nights|
|Aircraft||show me all the types of aircraft used flying from atl to dallas|
|what is the smallest aircraft available to fly on from pittsburgh to san francisco|
|show me all the types of aircraft Boeing 737|
|City||show me the cities served by canadian airlines international|
|what are the cities that american airlines serves|
|what are the cities served by delta airlines|
These results suggest a counter-intuitive text classification approach: is it possible to fine-tune a pre-trained model and use it to generate new high-quality sentences that will improve the performance of a text classifier?
3 Problem Definition
Text classification is an instance of the supervised learning problem [russell2016artificial] over textual data. In this context, we are given a training dataset containing labeled sentences. Each is a string of text or, specifically, a sequence of tokens (roughly, words) . The label indicates the class of among a set of classes. Each is drawn independently from the entire set of strings (that is, ), according to an unknown distribution on X, denoted by . Moreover, we assume there is an unknown function , and that in , for all .
The objective of a supervised learning problem is to approximate on the entire , given only the dataset . In short, we are generalizing from the domain of to the entire . Formally, a classification algorithm receives the training dataset , and after a training period, it outputs a classifier function , where is also known as a hypothesis
. To estimate the extent to whichapproximates on , it is customary to initially leave out both a training dataset and a test dataset . The test dataset is chosen randomly and has the same structure as . Both parts are usually drawn from a single, extensive, dataset.
There are different ways of measuring the quality of classifier as an approximation to using . The most straightforward way measures the accuracy:
where is the Kronecker delta (equals when both arguments are equal, or otherwise), and ,
are drawn from the test set. When the test set is large, accuracy approximates the probability of havingequal , namely, . We use accuracy as an estimate of classifier performance.
Regardless of how we measure performance, if the train set is small, it will dramatically affect the performance of the algorithm . Data augmentation tries to solve this problem by synthesizing additional training pairs that, together with the existing dataset, better reflect the underlying distribution of the data while refraining from introducing too much noise.
Our work does not focus on the classification algorithm per se. Rather, given a training dataset and an algorithm , we are interested in a general method for synthesizing an artificial dataset, . We aim to apply algorithm on , denoted by , to yield a relatively good classifier that outperforms the baseline classifier .
In the following section, we describe our method, LAMBADA, and exactly how we obtain from and . LAMBADA is specifically tailored to the case of small training sets, even miniscule ones, with only a few examples per class.
4 LAMBADA Method
We introduce a novel method for improving the performance of textual classification. Named LAMBADA for its use of Language Model Based Data Augmentation, this method adds more synthesized, weakly-labeled data samples to a given dataset. We define the method in Algorithm 1 and elaborate on its steps in the following section. LAMBADA has two key ingredients: 1) model fine-tuning (step 2), which synthesizes labeled data and 2) data filtering (step 4), which retains only high-quality sentences.
The main input to LAMBADA is a training dataset , which we would like to augment with synthesized data. contains a set of sentences, each labeled with a class. To train a classifier, we use a training algorithm . As far as the LAMBADA method is concerned, is arbitrary. However, LAMBADA synthesizes data for the algorithm , and this is given as a second input to LAMBADA. This is a distinctive feature of our method. We describe both and in Section 3.
LAMBADA uses a pre-trained language model to synthesize new data. A language model [bengio2003neural] provides an estimate for the probability that a token (word) will appear, in accordance with a given distribution of text , conditioned on the preceding and/or succeeding tokens. More formally, given a token , and the preceding tokens (or less) , one would like to approximate the conditional probability of the appearance of in accordance with . is usually calculated using a concrete corpus of text , sampled from distribution .
In contrast to , is far from being arbitrary. We use GPT-2, a recent pre-trained neural-network model (see Section 2), and show that our method outperforms state-of-the-art classifiers in our main use case, where is scarce.
GPT-2 is pre-trained on an enormous body of text available on the web. The corpus is organized as a long sequence of tokens, denoted by . GPT-2, like GPT, is a right-to-left model based on the transformer architecture [vaswani2017attention]. It is pre-trained on with loss defined by
where is the set of learnable parameters in the neural network of GTP2, and
is the trained language model: an estimate of the conditional probability distribution on the set of tokens, as calculated by the network. Specifically, we take, where indicates the state of the learnable parameters after pre-training. Nonetheless, from a technical standpoint, the language model and its underlying technology can differ to a great extent, and it is thus presented as a third input to LAMBADA. As final input—in addition to the training set, classification set, and language model—LAMBADA is given the number of labeled sentences to synthesize per class .
Step 1: Train baseline classifier
We train a baseline classifier using the existing data . This classifier will be used for filtering in Step 4.
Step 2: Fine-tune language model
Independently of Step 1, we fine-tune the language model to the task of synthesizing labeled sentences, to obtain the fine-tuned language model . Here, is specifically fine-tuned to the linguistic domain of (that is, the sentences, vocabulary, style, etc.), as well as the particular classes in . Generally speaking, we would like to use to generate a sentence set of any size, and each sentence labeled with a class.
In our case, is the neural model of GPT-2. We fine-tune GPT-2 by training it with the data in .
We concatenate the sentences in in order to form , in the following way:
Here, the auxiliary token separates between a class label and a corresponding sentence, while token terminates a sentence, and separates it from the label that follows. We further train the learnable parameters of GPT-2 to predict the next token in the exact same way GPT-2 was pre-trained – using the loss function in Equation 1
(with the same training procedure and hyperparameters). However, we useinstead of , and the learnable parameters are already initialized. The resulting language model is referred to as .
Step 3: Synthesize labeled data
Given , new labeled sentences can be synthesized. For any class label , we can use the adapted language model to predict the continuation of the sequence ”” until , which terminates the generated sentence. This way, for each class, any number of sentences may be synthesized. For example, this allows us to balance between the classes or otherwise control the ratio of generated sentences per class. Creating a more balanced training set can improve classification performance, especially in the case of classifiers that are sensitive to unbalanced classes.
In this step, we synthesize a set of labeled sentences, which is denoted by
. We use a simple and rather crude heuristic, where we generate for each class, 10 times more sentences than we wish to add to the class (i.e., ). Accordingly, the total number of generated sentences is . Of course, more sophisticated heuristics can also be examined.
GPT-2 generates labeled sentences that are typically both high quality and diverse, facilitating the relative success of our method. This is also where the power of GPT-2 comes into play.
Step 4: Filter synthesized data
One obstacle in using synthesized text is the noise and error it may introduce. In the last step, we filter the data in , which was synthesized by in Step 3, leaving only the instances of the highest quality. We do this using the classifier that was trained in Step 1.
For each class , we take the top sentences from that are labeled by , as follows: Given a synthesized sentence , we first verify that , and then use confidence score (see below) as a rank for . That is, we take the top ranked sentences for class . This results in a synthesized dataset , consisting of labeled sentences and with the same structure as . This is the outcome of LAMBADA.
The confidence score given to a data instance by can be regarded as the extent the instance is conservative with respect to . In turn, takes into account both and the algorithm
that is to be used with the augmented dataset. This approach is borrowed from semi-supervised learning[shams2014semi], where it is used to classify and filter unlabeled data in a conservative manner. Note, however, that generates sentences conditioned on a class label. In our case, this means we have a type of double voting mechanism.
While not addressed in this paper, the process described could generally be repeated by applying LAMBADA further on to obtain , , and so on.
5 Experimental Results
We tested our method with three different classifiers (BERT, SVM and LSTM) on three distinct datasets (ATIS, TREC, and WVA) by running multiple experiments in which we varied the amount of training samples per class. Next, we compared LAMBADA to other data augmentation methods (CVAE, EDA, and CBERT) by using the above-mentioned classifiers and datasets. We statistically validated our results with the McNemar test [mcnemar1947note].
Table 2 presents a description of the datasets we used in our experiments.
|WVA||Telco Customer support||87||17k|
Airline Travel Information Systems (ATIS)222https://www.kaggle.com/siddhadev/atis-dataset-from-ms-cntk/data A dataset providing queries on flight-related information widely used in language understanding research. ATIS is characterized as an imbalanced dataset, as most of the data belongs to the flight class.
Text Retrieval Conference (TREC)333https://cogcomp.seas.upenn.edu/Data/QA/QC/ A well-known dataset in the information retrieval community for question classification consisting of open-domain, fact-based questions, divided into broad semantic categories.
IBM Watson Virtual Assistant (WVA) A commercial dataset used for intent classification, comprising data from telco customer support chatbot systems.
We mainly focus on topic classification datasets with the task of classifying a sentence, not an entire document. Notably, classification of shorter text is considered a more difficult task.
We randomly split each dataset into train, validation, and test sets . We then randomly chose from the training set a subset including 5, 10, 20, 50, or 100 samples per class, which we used in each experiment for training. Once determined, we used the same subset throughout all experiments.
We demonstrated that our augmentation approach is independent of the classification algorithm by inspecting three different classifiers, representing three text classification ”generations”.
Support Vector Machine classifiers were already commonly used before the deep neural network era. We employ a commercial SVM classifier 444IBm Watson Natural Language Classifier
dedicated to natural language processing, which handles both the feature extraction process and the training of the classifier. While recent models are based on neural networks, in the context of our problem, SVM may have an advantage, since unlike neural-network-based models, it performs well even for relatively small datasets.
represents the type of classifiers that emerged after the advances in training recurrent neural networks, and the introduction of word embeddings[DBLP:journals/corr/abs-1301-3781]
, LSTMs were commonly used for sequential and textual data. We implemented a sequence-to-vector model based on an LSTM component followed by two fully connected layers and a softmax layer. For word embedding, we employed GLoVe[pennington2014glove] of 100 dimensions. An LSTM classifier usually requires a large amount of data for training.
Bidirectional Encoder Representations from Transformers is a relatively new family of classifiers. Based on the transformer architecture [vaswani2017attention], BERT is pre-trained using two unsupervised tasks: masked language model and next-sentence prediction, on the ”BooksCorpus” (800 million words) [zhu2015aligning] and has proven state-of-the-art performance on several text classification tasks. Therefore, BERT, like GPT-2, leverages large amounts of data that were used as part of its pre-training phase, in order to perform well, even on relatively small datasets.
We compared LAMBADA’s synthetic corpus quality to synthetic corpora generated by various other generative models. Similar to our selection of classifiers, we selected generators of various types representing different generation approaches. For a fair comparison, we mainly considered conditional generative models that allow generating samples conditioned on the class label. This enabled the creation of balanced synthetic corpora, an important feature for some classification models. In the following we provide a brief description of these generators.
Conditional Variational Autoencoder [kingma2014auto]. This generative model assumes a prior distribution over a latent space and uses deep neural networks to predict its parameters. It is an extension of the Variational Autoencoder model, enabling the conditional generation of an output sentence given a latent vector and the target class. We used a standard CVAE model with RNN-based encoder and decoder for generating sentences.
Easy Data Augmentation [wei2019eda]. This is a recent but simple rule-based data augmentation framework for text. It includes synonym replacement, random insertion, random swap, and random deletion. These methods were found beneficial, especially for small training set sizes.
Conditional Bidirectional Encoder Representations from Transformers [wu2019conditional]. As a recent augmentation method for labeled sentences based on BERT, this model operates by randomly replacing words with more varied substitutions predicted by the language model. Similar to GPT-2 and BERT, CBERT is pre-trained on a large corpus in an unsupervised setting, allowing it to adapt to specific domains even when fine-tuned through relatively small datasets.
Table 3 describes the attributes of the three generative models mentioned above, including the GPT-2 model.
|CBERT||Language Model||Wikipedia and Book corpus|
|GPT-2||Language Model||Web Pages|
We conducted comprehensive experiments, testing LAMBADA’s quality from various aspects. We statistically validated all our results with McNemar’s test [mcnemar1947note].
Number of Samples and Classifiers
We compared the LAMBADA approach with the baseline using three different classifiers over varied numbers of trained samples: 5, 10, 20, 50, and 100 for each class. We used the ATIS dataset to discover for which sample size our approach is beneficial.
Figure 1 clearly demonstrates the superiority of our LAMBADA approach over the baseline throughout all classifiers and all sample sizes that are smaller than or equal to 50. Larger amounts of data do not benefit as much from data augmentation and therefore, in the case of 100 samples for each class, the accuracy of LSTM and SVM does not improve using our approach. Figure 1 also nicely demonstrates the differences between the three classifiers on various sizes in order to generate any size set of sentences. BERT, which is a pre-trained model, is significantly better than SVM and LSTM throughout all sample sizes. However, the gap between the accuracy of BERT and the other classifiers is more predominant in smaller sample sizes. SVM handles smaller data sizes better than LSTM, as expected. Notably, our approach was even able to improve BERT, which is state-of-the-art for text classification and already uses pre-trained data.
We substantiate previous results by comparing the baseline to our LAMBADA approach over three datasets and three classifiers using five samples for each class. Table 4 shows that our approach significantly improves all classifiers over all datasets.
Similarly to ATIS dataset, TREC and WVA datasets also demonstrate the dominance of BERT over SVM and LSTM. LSTM achieves poor results when using a small number of samples, as expected. Interestingly, on the ATIS dataset, with BERT and SVM classifiers, the percentage of improvement is far greater than on the other datasets. We believe that this improvement is due to ATIS’ imbalanced characteristics and our ability to generate additional data for the under-represented classes.
Comparison of Generative Models
Next, we compared our approach to other leading text generator approaches. Table 5 shows that our approach is statistically superior to all other generation algorithms in the ATIS and WVA datasets over all classifiers. In the TREC dataset, the results for BERT are significantly better than all other methods. On the TREC dataset with SVM classifier, our method is on par with EDA. Moreover, on the TREC dataset with LSTM classifier, our method is on par with CVAE.
LAMBADA vs. Unlabeled Data
Our augmentation framework does not require additional unlabeled data. As such, it can be applied when unlabeled data is unavailable or costly. To test the expected LAMBADA performance in such a scenario, we compared it to a semi-supervised approach [ruder2018strong] that uses unlabeled data. Table 6 compares between different experimental settings on ATIS using three classifiers and five samples per class.
To create an unlabeled dataset, we randomly selected samples from the original dataset while ignoring their labels. Next, following a simple weak labeling approach, we classified the samples with one of the classifiers after training it on the labeled dataset. We compared LAMBADA’ s classification results with the results we obtained from this classifier. These results appear in the LAMBADA and Unlabeled data columns of Table 6. Surprisingly, for most classifiers, LAMBADA achieves better accuracy compared to a simple weak labeling approach. Clearly, the generated dataset contributes more to improving the accuracy of the classifier than the unlabeled samples taken from the original dataset.
We may attribute this increased performance to two main factors:
LAMBADA uses its ”generated” labels, which significantly improve performance.
LAMBADA allows us to control the number of samples per class by investing more effort in generating samples for classes that are under-represented in the original dataset.
We further assessed the importance of the ”generated” labels by removing them from LAMBADA’s synthesized dataset. We provide the results for this experiment under the GPT Unlabeled column in Table 6. In future work, we plan to use various data balancing approaches on the unlabeled dataset to assess the importance of the second factor above.
6 Discussion and Future Work
We introduce LAMBADA for improving classifiers’ performance. It involves fine-tuning a language model, generating new labeled-condition sentences and a filtering phase. We showed that our method statically improves classifiers’ performance on small data sets. In addition, we showed that LAMBADA beats the state-of-the-art techniques in data augmentation.
Generative vs. discriminative
Generally speaking, training a generative model requires more data than training a discriminative model [ng2002discriminative]. This is attributed mainly to the fact that discriminative models aim at estimating the class boundaries, while generative models approximate the probability distribution of the samples in each class. Therefore, prima facie, it is counter-intuitive to employ a generative model to improve discriminative classifier accuracy. All the more so, when both models are trained on the same dataset. However, unlike discriminative models, generative models may exploit unsupervised data to compensate for the inherent higher sample complexity. Consequently, and given the available abundant amount of unlabeled textual data, language models, pre-trained on huge corpora, have recently become state-of-the-art. Fine-tuning these generative models requires an extremely small amount of labeled data, as we show in this work, and sampling from them is straightforward.
LAMBADA synthesizes data in two steps. It first generates a large number of sentences per class and then filters them by multiple conditions. In this work, we employ a simple filtering heuristic, inspired by the semi-supervised paradigm that takes into account: 1) the class label of the generated sentence 2) the class label as given by the filtering classifier, together with its confidence score and 3) the number of sentences per class. We plan to further investigate other filtering heuristics and approaches in future work.
Weak labeling and self-supervision
LAMBADA synthesizes corpora of weakly labeled data by conditionally generating sentences on a given class’ label. Thus, one may incorporate a LAMBADA synthesized corpus within any weak labeling or semi-supervised framework such as one of these suggested by [ruder2018strong]. Moreover, one may use a synthesized corpus in situations where unlabeled data is not available and still expect comparable results.
Most textual datasets contain class names with semantic meaning. LAMBADA, an approach based on a language model, utilizes this class label meaning in its generation process. Consequently, it enables synthesizing samples for any meaningful, domain-related, class name. It thus potentially allows the generation of samples for unseen classes, a method also known as zero-shot learning [socher2013zero]. We plan to investigate this idea in future research.
Iterative training process
While a single step of the augmentation process may sufficiently improve the classifier, as shown in this paper, there is no real impediment to repeating the process by running several iterations of Algorithm 1. One of the possible hazards that the repetition of this process may cause is data drifting, in which biased synthesized samples gain domination over the training dataset.