Transformer on a Diet

02/14/2020 ∙ by Chenguang Wang, et al. ∙ Amazon 0

Transformer has been widely used thanks to its ability to capture sequence information in an efficient way. However, recent developments, such as BERT and GPT-2, deliver only heavy architectures with a focus on effectiveness. In this paper, we explore three carefully-designed light Transformer architectures to figure out whether the Transformer with less computations could produce competitive results. Experimental results on language model benchmark datasets hint that such trade-off is promising, and the light Transformer reduces 70 parameters at best, while obtains competitive perplexity compared to standard Transformer. The source code is publicly available.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Transformer has shown its effectiveness in modeling sequence information due to the combination of self-attention mechanism and positional encoding. The variants of Transformer architecture, such as BERT Devlin et al. (2018) and GPT-2 Radford et al. (2018, 2019), have obtained the state-of-the-art results across a wide range of NLP tasks, including GLUE benchmark dataset Wang et al. (2018), and question answering dataset, e.g., SQuAD Rajpurkar et al. (2016)

. However, Transformer in nature is a fully-connected feed-forward neural network and exhibits heavy computation characteristics. The recent BERT and GPT-2 are constructed as a stack of Transformer blocks, e.g., the largest GPT-2 is a stack of 48 Transformer blocks and contains 1542M parameters, BERT-large contains 24 layers of Transformer block and results in 340M parameters, therefore the computational burden of the fully connected Transformer becomes heavier. A side effect in industrial applications is that this potentially makes it harder to deploy due to the huge size of the model. Therefore a light version of the standard Transformer architecture is expected to relieve the heavy computation issue and compress the model to ease the deployment in real world applications.

In this paper, we carefully design several light Transformer architectures. The intuition behind the light Transformers is: preserving the Transformer connections that are useful to capture the essential sequence information, while omitting the ones with less impact. In particular, we explore along two directions: 1) better preserve the connections that are useful for capturing long-range dependency. We adapt the idea of dilated convolutions Yu and Koltun (2015) to preserve the Transformer connections that are useful to extend the effective history of the context, and 2) better preserve the connections that are essential in capturing local context. We leverage cascade connections that are capable to intensively incorporate the local context information in a flexible manner.

The contributions of the paper are two-folds:

  • We explore three light Transformer architectures that are able to preserve the necessary connections in standard Transformer. We show that the light Transformer architectures reduce the computation from quadratic to linear compared to the standard Transformer.

  • We conduct experiments on two language model benchmark datasets, one of the most traditional sequence modeling tasks, where the results indicate that the lightest architecture could reduce 70% parameters of standard Transformer, and performs competitively with the standard Transformer.

2 Light Transformers

We describe the three proposed light Transformer architectures in this section.

2.1 Background

Revisiting Transformer architecture. Transformer Vaswani et al. (2017) consists of an encoder and a decoder. We mainly focus on the sequence generation problem, thus we briefly describe the Transformer decoder structure, full Transformer, below for sake of clarity, in the following sections, we mean Transformer decoder when mention Transformer, unless other ways stated. As illustrated in Figure 1(a)

, the full Transformer block contains two sub-layers: 1) a masked multi-head attention layer; 2) a position-wise fully connected feed-forward network. Besides, there is a residual connection

He et al. (2016) around each of the two sub-layers, followed by layer normalization Ba et al. (2016).

Transformer computation complexity analysis. For each Transformer block, we assume that the length of the sequence to be , the size of hidden states to be , the computation of each Transformer block is .

We regard the following three architectures that have less computation compared to the full Transformer as light Transformers.

2.2 Dilated Transformer

Model assumption. The key strength of the Transformer is that the combination of self-attention and positional encoding is able to capture long-term dependency in the sequence. Thus we need to preserve the long-term dependency when lightening up the Transformer blocks. Inspired by the dilated convolutions Yu and Koltun (2015), we introduce the idea of dilated Transformer to enable an exponentially large receptive field in the Transformer scenario.

Model architecture. The model architecture is illustrated in Figure 1(b), where the sub-layers are the same to full Transformer, but with dilated connections across the sequence. To enable this, we introduce as the dilation factor, as the filter size. Similar to the common usage in dilated convolutions, we increase exponentially with the depth of the Transformer based network, i.e., at level of the Transformer, to increase the receptive field. By doing this, there is some filter that hits each input within the effective history, while also allowing for an extremely large effective history using the deep Transformer architecture.

Computation complexity analysis. In each dilated Transformer block, there would be nodes need to compute the output of the current node, so the computation complexity is . The computation cost is significantly lower compared to that of the full Transformer when is significantly smaller than .

2.3 Dilated Transformer with Memory

Model assumption. Similar to the idea of dilated Transformer, we use dilated connections to preserve the long-range dependency in the sequence. Additionally, in dilated Transformer with memory, we try to cache more local contexts by memorizing the nodes in the previous dilated connections.

Model architecture. Figure 1(c) illustrates the model architecture, where the sub-layers are still the same with full Transformer. Similar to dilated Transformer, we use the dilation factor and filter size to construct the dilated connections. However, the dilated connections in the previous layer are preserved. This will ensure a large effective history by with richer local history. This could potentially preserve the connections that are necessary to decode.

Computation complexity analysis. In each dilated Transformer with memory block, the computation of the connections in the previous layer would add to the current computation, which results in , where indicates the extra connections in the previous layer. If the sequence length is infinite, then .

(a) Full Transformer.
(b) Dilated Transformer.
(c) Dilated Transformer with memory.
(d) Cascade Transformer.
Figure 1: Transformer architectures. (a): standard Transformer; (b)-(d): proposed light Transformers.

2.4 Cascade Transformer

Model assumption. Instead of exploiting the dilated Transformer idea, we instead explore cascade connections idea to exponentially incorporate the local connections. By exploring this method, we would see how the local connections in different depths of the network contribute to the results.

Model architecture. Figure 1(d) illustrate the cascade Transformer architecture, where the sub-layers are still the same as full Transformer. We introduce base window size as , the cardinal number is , then the number of cascade connections at level of the Transformer is . By doing this, we can control the shape of the cascade across the levels of the Transformer, which gives Transformer the flexibility to learn from the cascade connections.

Computation complexity analysis. In each cascade Transformer, the computation cost is . Compare to full Transformer, the complexity is still smaller since .

Model Computation Complexity
Table 1: Computation complexities of different Transformer architectures. Full: full Transformer; Dilated: dilated Transformer; Dilated-Memory: dilated Transformer with memory; Cascade: cascade Transformer; is the length of the sequence. is the size of the hidden state. is the filter size. is the base window size. is the cardinal number.

3 Transformer Language Model

We select language model as the task to evaluate the proposed Transformer architecture, since it is one of the fundamental NLP tasks. In this section, we introduce how the different Transformer blocks adapted to the task of language model.

Given a corpus of tokens , the objective of language model is described in Eq. 1.


Then the Transformer (decoder) blocks are used to generate the output distribution over the vocabulary as indicated in Eq. 4.


where can be replaced with any of the three proposed Transformer architectures, is the hidden output of -th layer, is the word embedding matrix, and is positional embedding matrix.

4 Experiments

Model Parameter PTB WT-2
Val Test Val Test
Full 30.0M 109.19 103.72 148.76 140.74
Dilated 8.8M 115.67 110.92 157.67 147.58
Dilated-Memory 11.1M 115.35 110.98 167.35 157.08
Cascade 13.5M 109.16 105.27 145.96 136.02
Table 2: Results comparison (perplexity) of different Transformer language models on PTB and WT-2 data. Full: full Transformer; Dilated: dilated Transformer; Dilated-Memory: dilated Transformer with memory; Cascade: cascade Transformer.

We compare the light Transformers with standard Transformers from both results and computation perspectives.

4.1 Datasets and Metrics

We evaluate the proposed methods on three widely-used language model benchmark datasets. Penn TreeBank (PTB): we use the preprocessed version of Mikolov et al. (2010), which contains 100M tokens. WikiText-2 (WT-2) is a small preprocessed version of Wikipedia, containing 200M tokens Merity et al. (2016). We use perplexity to evaluate the language model results.

4.2 Training Details

For fair comparison, the full Transformer and the light Transformer architectures are with 3 layers, embedding size equals to 320, number of heads in the multi-head attention is 16. The dropout rate is set as 0.4 and 0.2 on PTB and WT-2 respectively. For dilated Transformer and dilated Transformer with memory, , the base of . For cascade Transformer, and . For the light Transformers and full Transformer, the hidden size equals to 2000. These settings are shared on the two datasets.

We use truncated back-propagation through time to compute the gradients across all the experiment settings. The batch size equals to 20 on both datasets, whereas the sequence length equals to 70 on both datasets. We use SGD for training with learning rate equals to 10.

4.3 Results Analysis

We compare the effectiveness of the proposed Transformer architectures with the full-Transformer architecture. From the results in Table 2, we find out that cascade Transformer performs closely to the full-Transformer structure. This indicates that local context is very important to language model tasks, and cascade Transformer is able to capture the meaning local dependency.

We also compare the parameter sizes between light Transformers and full Transformer. Among all the architectures, dilated Transformer is lightest one. Although it delivers moderate results, however, when compare to full Transformer, we save 70% model size and the computation could be more efficient. Table 2 also shows the trade-off between parameter size and perplexities on the two datasets. It would suggest the best Transformer architecture given the deployment constraints, such as model size limit, latency requirement or the quality.

5 Related Work

Transformer architectures have been proposed to compute the sequence input efficiently. The basic Transformer block consists of a multi-head attention layer and a position-wise fully connected feed-forward network. The original Transformer architectures contains an encoder and decoder. The encoder and decoder share similar structures with 6 layers of Transformer block. Instead, the decoder uses masked multi-head attention each block to prevent leftward information flow. Recently, stacked Transformer architectures, such as BERT Devlin et al. (2018), GPT(-2) Radford et al. (2018, 2019), and the most recent ones Peters et al. (2018); Wang et al. (2019); Raffel et al. (2019); Liu et al. (2019); Yang et al. (2019) are proposed and shown the state-of-the-art results on a wide range of NLP tasks, such as GLUE benchmark Wang et al. (2018) and question answering datasets Rajpurkar et al. (2016). However, these Transformer architectures are heavy and it is hard to deploy in practice where the environment has constraints. Lightened Transformer architectures Ye et al. (2019); Guo (2019) are proposed to speed up the computation. Our work is aligned with such Transformers but with even less computation. Compared to ALBERT Lan et al. (2019), the proposed method optimizes the base Transformer and could be further integrated into BERT.

Language models

have been studied extensively in NLP. Neural language models have supplanted traditional n-gram models in recent years

Bengio et al. (2003); Mnih and Hinton (2007); Mikolov et al. (2010)

. Particularly, recurrent neural networks

Inan et al. (2016); Merity et al. (2017); Melis et al. (2017); Krause et al. (2018), such as LSTMs have achieved state-of-the-art results on various benchmark datasets with different regularization techniques and post-training methods Grave et al. (2016); Krause et al. (2018). The mixture of softmax Yang et al. (2017) has helped address the low-rank embedding problem for word prediction. Recently, more advanced Transformer architectures, such as GPT Radford et al. (2018) and GPT-2 Radford et al. (2019) are applied to the task of language model. Due to the efficiency of the Transformer computation, these models have been trained on large scale text corpora and shown good results across language model datasets. We instead study how to lighten the Transformer, which could be generalized to the idea of large Transformer architecture (e.g., GPT) to train on large corpora to obtain better results.

6 Conclusion

We explore less computation-expensive Transformer architectures. The design principle is to still preserve the long and short range dependency in the sequence but with less connections. Experiments on language model datasets show that a light weighted Transformer is able to perform competitively but with much improved computation efficiency. We plan to extend the Transformer architectures to experiment on deeper Transformer architectures and more tasks Wang et al. (2015, 2016).

Acknowledgements We are grateful to Mu Li, Da Zheng, Haibin Lin, and Leyuan Wang for their helpful inputs on the paper.


  • J. L. Ba, J. R. Kiros, and G. E. Hinton (2016) Layer normalization. arXiv preprint arXiv:1607.06450. Cited by: §2.1.
  • Y. Bengio, R. Ducharme, P. Vincent, and C. Jauvin (2003) A neural probabilistic language model. JMLR 3 (Feb), pp. 1137–1155. Cited by: §5.
  • J. Devlin, M. Chang, K. Lee, and K. Toutanova (2018) Bert: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805. Cited by: §1, §5.
  • E. Grave, A. Joulin, and N. Usunier (2016) Improving neural language models with a continuous cache. CoRR. Cited by: §5.
  • Q. Guo (2019) Star-transformer. arXiv preprint arXiv:1902.09113. Cited by: §5.
  • K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep residual learning for image recognition. In

    Proceedings of the IEEE conference on computer vision and pattern recognition

    pp. 770–778. Cited by: §2.1.
  • H. Inan, K. Khosravi, and R. Socher (2016)

    Tying word vectors and word classifiers: a loss framework for language modeling

    arXiv preprint arXiv:1611.01462. Cited by: §5.
  • B. Krause, E. Kahembwe, I. Murray, and S. Renals (2018) Dynamic evaluation of neural sequence models. In ICML, pp. 2771–2780. Cited by: §5.
  • Z. Lan, M. Chen, S. Goodman, K. Gimpel, P. Sharma, and R. Soricut (2019)

    Albert: a lite bert for self-supervised learning of language representations

    arXiv preprint arXiv:1909.11942. Cited by: §5.
  • Y. Liu, M. Ott, N. Goyal, J. Du, M. Joshi, D. Chen, O. Levy, M. Lewis, L. Zettlemoyer, and V. Stoyanov (2019) RoBERTa: A robustly optimized BERT pretraining approach. CoRR abs/1907.11692. Cited by: §5.
  • G. Melis, C. Dyer, and P. Blunsom (2017) On the state of the art of evaluation in neural language models. arXiv preprint arXiv:1707.05589. Cited by: §5.
  • S. Merity, N. S. Keskar, and R. Socher (2017) Regularizing and optimizing LSTM language models. CoRR. Cited by: §5.
  • S. Merity, C. Xiong, J. Bradbury, and R. Socher (2016) Pointer sentinel mixture models. CoRR. Cited by: §4.1.
  • T. Mikolov, M. Karafiát, L. Burget, J. Černockỳ, and S. Khudanpur (2010) Recurrent neural network based language model. In Eleventh Annual Conference of the International Speech Communication Association, Cited by: §4.1, §5.
  • A. Mnih and G. Hinton (2007) Three new graphical models for statistical language modelling. In ICML, pp. 641–648. Cited by: §5.
  • M. E. Peters, M. Neumann, M. Iyyer, M. Gardner, C. Clark, K. Lee, and L. Zettlemoyer (2018) Deep contextualized word representations. In NAACL, pp. 2227–2237. Cited by: §5.
  • A. Radford, K. Narasimhan, T. Salimans, and I. Sutskever (2018) Improving language understanding by generative pre-training. URL https://s3-us-west-2. amazonaws. com/openai-assets/research-covers/language-unsupervised/language_ understanding_paper. pdf. Cited by: §1, §5, §5.
  • A. Radford, J. Wu, R. Child, D. Luan, D. Amodei, and I. Sutskever (2019) Language models are unsupervised multitask learners. Cited by: §1, §5, §5.
  • C. Raffel, N. Shazeer, A. Roberts, K. Lee, S. Narang, M. Matena, Y. Zhou, W. Li, and P. J. Liu (2019)

    Exploring the limits of transfer learning with a unified text-to-text transformer

    CoRR abs/1910.10683. Cited by: §5.
  • P. Rajpurkar, J. Zhang, K. Lopyrev, and P. Liang (2016) Squad: 100,000+ questions for machine comprehension of text. arXiv preprint arXiv:1606.05250. Cited by: §1, §5.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin (2017) Attention is all you need. In NIPS, pp. 5998–6008. Cited by: §2.1.
  • A. Wang, A. Singh, J. Michael, F. Hill, O. Levy, and S. R. Bowman (2018) GLUE: a multi-task benchmark and analysis platform for natural language understanding. arXiv preprint arXiv:1804.07461. Cited by: §1, §5.
  • C. Wang, M. Li, and A. J. Smola (2019) Language models with transformers. CoRR abs/1904.09408. Cited by: §5.
  • C. Wang, Y. Song, A. El-Kishky, D. Roth, M. Zhang, and J. Han (2015) Incorporating world knowledge to document clustering via heterogeneous information networks. In SIGKDD, pp. 1215–1224. Cited by: §6.
  • C. Wang, Y. Song, H. Li, M. Zhang, and J. Han (2016) Text classification with heterogeneous information network kernels. In AAAI, pp. 2130–2136. Cited by: §6.
  • Z. Yang, Z. Dai, R. Salakhutdinov, and W. W. Cohen (2017) Breaking the softmax bottleneck: A high-rank RNN language model. CoRR. Cited by: §5.
  • Z. Yang, Z. Dai, Y. Yang, J. G. Carbonell, R. Salakhutdinov, and Q. V. Le (2019) XLNet: generalized autoregressive pretraining for language understanding. CoRR abs/1906.08237. Cited by: §5.
  • Z. Ye, Q. Guo, Q. Gan, X. Qiu, and Z. Zhang (2019) BP-transformer: modelling long-range context via binary partitioning. arXiv preprint arXiv:1911.04070. Cited by: §5.
  • F. Yu and V. Koltun (2015) Multi-scale context aggregation by dilated convolutions. arXiv preprint arXiv:1511.07122. Cited by: §1, §2.2.