Generation-Distillation for Efficient Natural Language Understanding in Low-Data Settings

by   Luke Melas-Kyriazi, et al.
Harvard University

Over the past year, the emergence of transfer learning with large-scale language models (LM) has led to dramatic performance improvements across a broad range of natural language understanding tasks. However, the size and memory footprint of these large LMs makes them difficult to deploy in many scenarios (e.g. on mobile phones). Recent research points to knowledge distillation as a potential solution, showing that when training data for a given task is abundant, it is possible to distill a large (teacher) LM into a small task-specific (student) network with minimal loss of performance. However, when such data is scarce, there remains a significant performance gap between large pretrained LMs and smaller task-specific models, even when training via distillation. In this paper, we bridge this gap with a novel training approach, called generation-distillation, that leverages large finetuned LMs in two ways: (1) to generate new (unlabeled) training examples, and (2) to distill their knowledge into a small network using these examples. Across three low-resource text classification datsets, we achieve comparable performance to BERT while using 300x fewer parameters, and we outperform prior approaches to distillation for text classification while using 3x fewer parameters.



There are no comments yet.



Transformer to CNN: Label-scarce distillation for efficient text classification

Significant advances have been made in Natural Language Processing (NLP)...

MixKD: Towards Efficient Distillation of Large-scale Language Models

Large-scale language models have recently demonstrated impressive empiri...

Causal Distillation for Language Models

Distillation efforts have led to language models that are more compact a...

Data Distillation for Text Classification

Deep learning techniques have achieved great success in many fields, whi...

Which Student is Best? A Comprehensive Knowledge Distillation Exam for Task-Specific BERT Models

We perform knowledge distillation (KD) benchmark from task-specific BERT...

Self-training with Few-shot Rationalization: Teacher Explanations Aid Student in Few-shot NLU

While pre-trained language models have obtained state-of-the-art perform...

Knowledge Distillation with Noisy Labels for Natural Language Understanding

Knowledge Distillation (KD) is extensively used to compress and deploy l...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Over the past year, rapid progress in unsupervised language representation learning has led to the development of increasingly powerful and generalizable language models Radford et al. (2019); Devlin et al. (2018)

. Widely considered to be NLP’s “ImageNet moment

(Ruder, 2018)

, this progress has led to dramatic improvements in a wide range of natural language understanding (NLU) tasks, including text classification, sentiment analysis, and question answering

Wang et al. (2018); Rajpurkar et al. (2016). The now-common approach for employing these systems using transfer learning is to (1) pretrain a large language model (LM), (2) replace the top layer of the LM with a task-specific layer, and (3) finetune the entire model on a (usually relatively small) labeled dataset. Following this pattern, Peters et al. (2018), Howard and Ruder (2018), Radford et al. (2019), and Devlin et al. (2018) broadly outperform standard task-specific NLU models (i.e. CNNs/LSTMs), which are initialized from scratch (or only from word embeddings) and trained on the available labeled data.

Notably, transfer learning with LMs vastly outperforms training task-specific from scratch in low data regimes. For example, GPT-2 is capable of generating coherent text in a particular style (i.e. poetry, Java code, questions and answers) when conditioned on only a handful of sentences of that style

Radford et al. (2019). Similarly, on discriminative tasks such as question answering, BERT reaches accuracies comparable to previous task-specific models with orders of magnitude less labeled data Devlin et al. (2018).

At the same time however, these large language models are extremely unwieldy. The largest versions of GPT-2 and BERT have over 1.5B and 340M parameters, respectively; it is challenging to use either of these models on a modern GPU (with 12GB of VRAM) and nearly impossible to deploy them on mobile or embedded devices. Thus, there is a strong need for efficient task-specific models that can leverage the knowledge from large pretrained models, while remaining highly compressed.

Figure 1:

Our proposed generation-distillation training procedure. First, we use a large language model to augment our set of training examples, and second we train our student via distillation with a large language model-based classifier. In the diagram above, green blocks indicate models and purple blocks indicate text data.

In this project, we attempt to bridge this gap for the task of low-resource text classification. We propose a new approach, called generation-distillation, to improve the training of small, task-specific text classification models by utilizing multiple large pretrained language models. First, we use a large LM (GPT-2) to generate text in the style of our training examples, augmenting our data with unlabeled synthetic examples. Second, we use the synthetic examples to distill a second large LM (BERT), which has already been finetuned for classification, into a small task-specific model (CNN).

In our experiments, we show that this procedure delivers significant gains over a standard distillation approach in low-data regimes. Specifically, on low-data versions of three widely-adopted text classification datasets (AG News, DBPedia, Yahoo Answers), we obtain 98% of BERT’s performance with 300 fewer parameters. Moreover, compared to prior work on distilling BERT Chia et al. (2018) on these datasets, we outperform past approaches while using fewer parameters.

2 Related Work

Model Params (1000s) AG News DBPedia Yahoo Answers
Baseline - TFIDF + SVM Ramos and others (2003) 18.1 81.9 94.1 54.5
Baseline - FastText Joulin et al. (2016) N/A 75.2 91.0 44.9
BERT-Large 340,000 89.9 97.1 67.0
Chia et al. (2018) - BlendCNN* 3617 87.6 94.6 58.3
Chia et al. (2018) - BlendCNN + Dist* 3617 89.9 96.0 63.4
Ours (Kim-style) 1124 85.7 94.3 62.4
Ours (Res-style) 1091 86.2 94.7 60.9
Ours + Dist (Kim-style) 1124 86.9 95.0 62.9
Ours + Dist (Res-style) 1091 87.3 95.4 62.2
Ours + Gen-Dist (Kim-style) 1124 89.9 96.3 64.2
Ours + Gen-Dist (Res-style) 1091 89.8 96.0 65.0
Table 1: (Results) A comparison of model size and accuracy on text classification datasets. Bold font indicates best accuracy and italics+underline indicates second-best accuracy. Generation-distillation broadly improves small model performance over distillation, which in turn broadly improves performance over training from scratch. * results from other papers.

Designed to produce contextual word embeddings, large language models (LMs) build upon the now-classic idea of using pretrained word embeddings to initialize the first layer of deep natural language processing models

Collobert et al. (2011)

. Early proponents of contextual word vectors, including CoVe, ULMFit, and ELMo

McCann et al. (2017); Howard and Ruder (2018); Peters et al. (2018), extracted word representations from the activations of LSTMs, which were pretrained for either machine translation (CoVe) or for language modeling (ULMFit, ELMo).

Recent work has adopted the transformer architecture for large-scale language representation. BERT Devlin et al. (2018) trains a transformer using masked language modeling and next sentence prediction objectives, giving state-of-the-art performance across NLU tasks. GPT/GPT-2 Radford et al. (2019) trains a unidirectional objective, showing the ability to generate impressively coherent text.

Due to the unwieldy size of these models, a line of recent research has investigated how to best compress these models Tang et al. (2019). In the most popular of these approaches, knowledge distillation Hinton et al. (2015), the outputs of a larger “teacher” model are used to train a smaller “student” model. These outputs may contain more information than is available in the true label, helping bring the performance of the student closer to that of the teacher. On the task of text classification, Tang et al. (2019) and Chia et al. (2018) both recently showed that it is possible to compress transformer-based LMs into CNNs/LSTMs with fewer parameters, at the cost of a small (but nontrivial) drop in accuracy.

Our project builds on prior work in multiple ways. When performing generation-distillation, we employ a finetuned GPT-2 Radford et al. (2019) as our generator and a finetuned BERT Devlin et al. (2018) as our teacher classifier. Additionally, the distillation component of our generation-distillation approach is similar to the method used in Chia et al. (2018)

, but with a different loss function (KL divergence in place of mean absolute error).

3 Methodology

As shown in Figure 1, our generation-distillation approach is divided into three steps: finetuning, generation and distillation.

3.1 Finetuning

The first step in our approach involves finetuning two different large LMs on our small task-specific dataset. First, we finetune a generative model (in our case, GPT-2) using only the text of the dataset. This model is used to generate new synthetic examples in the generation step. Second, we finetune a large LM-based classifier (in our case, BERT with an added classification head) using both the text and the labels of the dataset. This model is used as the teacher in the distillation step.

3.2 Generation

In the generation step, we used a large generative LM, finetuned in the first step, to augment our training dataset with synthetic examples. Specifically, we use GPT-2 to generate new sentences in the style of our training dataset and add these to our training dataset. We do not have labels for these generated sentences, but labels are not necessary because we train with distillation; our goal in generating synthetic examples is not to improve the large LM-based classifier, but rather to improve our ability to distill a large LM-based classifier into a small task-specific classifier.

3.3 Distillation

We combine both the real training examples and our synthetic examples into one large training set for distillation. We distill a large LM-based teacher classifier, finetuned in the first step, into our smaller student model via standard distillation as in Hinton et al. (2015). For our loss function, like Hinton et al. (2015)

, we use the KL divergence between the teacher logits and the student logits; this differs from

Chia et al. (2018), who use the mean absolute error between the logits.

4 Experiments

4.1 Data

We perform text classification on three widely-used datasets: AG News, DBPedia, and Yahoo Answers (Gulli, ; Auer et al., 2007; Labrou and Finin, 1999). For purposes of comparison, we select our training set using the same procedure as Chia et al. (2018), such that the training set contains 100 examples from each class. For the generation-distillation experiments, we use GPT-2 to generate synthetic training examples on AG News and synthetic training examples on DBPedia and Yahoo Answers. Combining these with the , and original (labeled) examples yields a total of and examples on AG News, DBPedia, and Yahoo Answers, respectively.

4.2 Finetuning Details and Examples

We finetune GPT-2 345M using Neil Shepperd’s fork of GPT-2:

Finetuning is performed for a single epoch with a learning rate of

with the Adam optimizer. We use batch size 1 and gradient checkpointing in order to train on a single GPU with 12GB of VRAM. We choose to train for only 1 epoch after examining samples produced by models with different amounts of finetuning; due to the small size of the dataset relative to the number of parameters in GPT-2, finetuning for more than 1 epoch results in significant dataset memorization.

For sampling, we perform standard sampling (i.e. sampling from the full output distribution, not top-p or top-k sampling) with temperature parameter . Although we do not use top-k or top-p sampling, we believe it would be interesting to compare the downstream effect of different types of sampling in the future.

In Supplementary Table 3, we show examples of synthetic training texts generated by sampling from the finetuned GPT-2 model, for both DBPedia and Yahoo Answers.

In Supplementary Table 4, we show two synthetic training texts along with their nearest neighbors in the training set. Nearest neighbors were calculated by ranking all examples from the training dataset (1400 examples) according to cosine similarity of TF-IDF vectors. As can be seen in the example in the right column, the GPT-2 language model has memorized some of the entities in the training dataset (i.e. the exact words “Ain Dara Syria”), but provides a novel description of the entity. This novel description is factually incorrect, but it may still be helpful in training a text classification model in a low-resource setting, because the words the model generates (i.e. “Syria”, “Turkey”, “Karzahayel”) are broadly related to the original topic/label. For example, they may help the model learn the concept of the class “village”, which is the label of Nearest Neighbor 1.

Figure 2:

Above, we show how the accuracy of the final distilled model varies with the number of synthetic training examples generated by GPT-2. Error bars show the standard deviation of accuracies on five separate runs. The same GPT-2 model (trained on 100 examples per class, or a total of 1000 examples) was used to generate all synthetic texts.

Hard Labeling vs. Distillation on Generated Examples (Yahoo Answers)

Hard Labeling with BERT Distillation with BERT
Accuracy 62.9 0.22 64.2 0.13
Table 2:

Above, we show a comparison of hard labeling and distillation for labeling the synthetic examples produced by our generator network. We report the the mean and standard error of the student (Kim) model accuracy across 5 random restarts on the Yahoo Answers dataset. Generation and distillation significantly outperforms generation and hard labeling.

4.3 Student Models & Optimization

We experiment with two main CNN architectures. The first is a standard CNN architecture from Kim (2014). The second is a new CNN based on ResNet He et al. (2016)

. This “Res-style” model has 3 hidden layers, each with hidden size 100, and dropout probability

. We use multiple models to demonstrate that our performance improvements over previous approaches are not attributable to architectural changes, and to show that our approach generalizes across architectures.

We train the CNNs using Adam (Kingma and Ba, 2014; Loshchilov and Hutter, 2017) with learning rate . Additionally, the CNNs both use 100-dimensional pretrained subword embeddings (Heinzerling and Strube, 2018), which are finetuned during training.

4.4 Results

We report the performance of our trained models in Table 1.

When trained with standard distillation, our KimCNN and ResCNN models perform as would be expected given the strong results in Chia et al. (2018). Our models perform slightly worse than the 8-layer BlendCNN from Chia et al. (2018) on AG News and DBPedia, while performing slightly better on Yahoo Answers. Standard distillation improves their performance, but there remains a significant gap between the CNNs and the BERT-Large based classifier. Training with the proposed generation-distillation approach significantly reduces the gap between the CNNs and BERT-Large; across all datasets, the model trained with generation-distillation matches or exceeds both the model the model trained with standard distillation and the BlendCNN.

4.5 Ablation

In Figure 2, we show how the accuracy of the final distilled model varies with the number of synthetic training examples generated by GPT-2. The distilled model is trained entirely on synthetic examples, without ever seeing the original data. The model shows strong performance (60% accuracy) with as few as 500 generated training examples, or per class. Moreover, model performance continues to increase with more generated training examples, up to .

In Table 2, we compare two different methods of labeling the synthetic examples produced by our generator network (GPT-2): hard labeling and distillation. Hard labeling refers to taking the maximum-probability class according to our finetuned BERT model as the label for each generated example and using a standard cross entropy loss function. Distillation refers to using the probability distribution outputted by BERT as the label for each generated examtple and using a KL divergence loss function. Put differently, in the former we use BERT to generate labels, whereas in the latter we use BERT to generate perform distillation. We find that generation and distillation outperforms generation and hard labeling by a significant margin, consistent with previous work on knowledge distillation

Hinton et al. (2015).

5 Conclusion

In this work, we present a new approach to compressing natural language understanding models in low-data regimes. Our approach leverages large finetuned language models in two ways: (1) to generate new (unlabeled) training examples, and (2) to distill their knowledge into a small network using these examples. Across three low-resource text classification datsets, we achieve comparable performance to BERT while using fewer parameters, and we outperform prior approaches to distillation for text classification while using

fewer parameters. Although we focus on text classification in this paper, our proposed method may be extended to a host of other natural language understanding tasks in low-data settings, such as question answering or extractive summarization.


  • S. Auer, C. Bizer, G. Kobilarov, J. Lehmann, R. Cyganiak, and Z. Ives (2007) Dbpedia: a nucleus for a web of open data. In The semantic web, pp. 722–735. Cited by: §4.1.
  • Y. K. Chia, S. Witteveen, and M. Andrews (2018) Transformer to cnn: label-scarce distillation for efficient text classification. Cited by: §1, Table 1, §2, §2, §3.3, §4.1, §4.4.
  • R. Collobert, J. Weston, L. Bottou, M. Karlen, K. Kavukcuoglu, and P. P. Kuksa (2011) Natural language processing (almost) from scratch. CoRR abs/1103.0398. External Links: Link, 1103.0398 Cited by: §2.
  • J. Devlin, M. Chang, K. Lee, and K. Toutanova (2018) Bert: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805. Cited by: §1, §1, §2, §2.
  • [5] A. Gulli External Links: Link Cited by: §4.1.
  • K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep residual learning for image recognition. In

    Proceedings of the IEEE conference on computer vision and pattern recognition

    pp. 770–778. Cited by: §4.3.
  • B. Heinzerling and M. Strube (2018) BPEmb: Tokenization-free Pre-trained Subword Embeddings in 275 Languages. In Proceedings of the Eleventh International Conference on Language Resources and Evaluation (LREC 2018), N. C. (. chair), K. Choukri, C. Cieri, T. Declerck, S. Goggi, K. Hasida, H. Isahara, B. Maegaard, J. Mariani, H. Mazo, A. Moreno, J. Odijk, S. Piperidis, and T. Tokunaga (Eds.), Miyazaki, Japan (english). External Links: ISBN 979-10-95546-00-9 Cited by: §4.3.
  • G. Hinton, O. Vinyals, and J. Dean (2015)

    Distilling the knowledge in a neural network

    arXiv preprint arXiv:1503.02531. Cited by: §2, §3.3, §4.5.
  • J. Howard and S. Ruder (2018) Universal language model fine-tuning for text classification. ACL 2018. Cited by: §1, §2.
  • A. Joulin, E. Grave, P. Bojanowski, M. Douze, H. Jégou, and T. Mikolov (2016) compressing text classification models. arXiv preprint arXiv:1612.03651. Cited by: Table 1.
  • Y. Kim (2014) Convolutional neural networks for sentence classification. arXiv preprint arXiv:1408.5882. Cited by: §4.3.
  • D. P. Kingma and J. Ba (2014) Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980. Cited by: §4.3.
  • Y. Labrou and T. Finin (1999) Yahoo! as an ontology: using yahoo! categories to describe documents. In Proceedings of the eighth international conference on Information and knowledge management, pp. 180–187. Cited by: §4.1.
  • I. Loshchilov and F. Hutter (2017) Fixing weight decay regularization in adam. arXiv preprint arXiv:1711.05101. Cited by: §4.3.
  • B. McCann, J. Bradbury, C. Xiong, and R. Socher (2017) Learned in translation: contextualized word vectors. In Advances in Neural Information Processing Systems, pp. 6294–6305. Cited by: §2.
  • M. E. Peters, M. Neumann, M. Iyyer, M. Gardner, C. Clark, K. Lee, and L. Zettlemoyer (2018) Deep contextualized word representations. NAACL 2018. Cited by: §1, §2.
  • A. Radford, J. Wu, R. Child, D. Luan, D. Amodei, and I. Sutskever (2019) Language models are unsupervised multitask learners. OpenAI Blog 1, pp. 8. Cited by: §1, §1, §2, §2.
  • P. Rajpurkar, J. Zhang, K. Lopyrev, and P. Liang (2016) SQuAD: 100, 000+ questions for machine comprehension of text. CoRR abs/1606.05250. External Links: Link, 1606.05250 Cited by: §1.
  • J. Ramos et al. (2003) Using tf-idf to determine word relevance in document queries. Cited by: Table 1.
  • S. Ruder (2018) NLP’s imagenet moment has arrived. External Links: Link Cited by: §1.
  • R. Tang, Y. Lu, L. Liu, L. Mou, O. Vechtomova, and J. Lin (2019) Distilling task-specific knowledge from bert into simple neural networks. arXiv preprint arXiv:1903.12136. Cited by: §2.
  • A. Wang, A. Singh, J. Michael, F. Hill, O. Levy, and S. R. Bowman (2018) Glue: a multi-task benchmark and analysis platform for natural language understanding. arXiv preprint arXiv:1804.07461. Cited by: §1.