Learned Token Pruning for Transformers

07/02/2021 ∙ by Sehoon Kim, et al. ∙ 8

A major challenge in deploying transformer models is their prohibitive inference cost, which quadratically scales with the input sequence length. This makes it especially difficult to use transformers for processing long sequences. To address this, we present a novel Learned Token Pruning (LTP) method that reduces redundant tokens as the data passes through the different layers of the transformer. In particular, LTP prunes tokens with an attention score below a threshold value, which is learned during training. Importantly, our threshold based method avoids algorithmically expensive operations such as top-k token selection which are used in prior token pruning methods, and also leads to structured pruning. We extensively test the performance of our approach on multiple GLUE tasks and show that our learned threshold based method consistently outperforms the prior state-of-the-art top-k token based method by up to  2 our preliminary results show up to 1.4x and 1.9x throughput improvement on Tesla T4 GPU and Intel Haswell CPU, respectively, with less than 1 drop (and up to 2.1x FLOPs reduction). Our code has been developed in PyTorch and has been open-sourced.

READ FULL TEXT
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 9

page 14

Code Repositories

LTP

Learned Token Pruning for Transformers


view repo

embeddings

zero-vocab or low-vocab embeddings


view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.