Towards Efficient Post-training Quantization of Pre-trained Language Models

09/30/2021
by   Haoli Bai, et al.
HUAWEI Technologies Co., Ltd.
0

Network quantization has gained increasing attention with the rapid growth of large pre-trained language models (PLMs). However, most existing quantization methods for PLMs follow quantization-aware training (QAT) that requires end-to-end training with full access to the entire dataset. Therefore, they suffer from slow training, large memory overhead, and data security issues. In this paper, we study post-training quantization (PTQ) of PLMs, and propose module-wise quantization error minimization (MREM), an efficient solution to mitigate these issues. By partitioning the PLM into multiple modules, we minimize the reconstruction error incurred by quantization for each module. In addition, we design a new model parallel training strategy such that each module can be trained locally on separate computing devices without waiting for preceding modules, which brings nearly the theoretical training speed-up (e.g., 4× on 4 GPUs). Experiments on GLUE and SQuAD benchmarks show that our proposed PTQ solution not only performs close to QAT, but also enjoys significant reductions in training time, memory overhead, and data consumption.

READ FULL TEXT VIEW PDF
06/04/2022

ZeroQuant: Efficient and Affordable Post-Training Quantization for Large-Scale Transformers

How to efficiently serve ever-larger trained natural language models in ...
10/14/2020

An Investigation on Different Underlying Quantization Schemes for Pre-trained Language Models

Recently, pre-trained language models like BERT have shown promising per...
06/24/2021

Quantization Aware Training, ERNIE and Kurtosis Regularizer: a short empirical study

Pre-trained language models like Ernie or Bert are currently used in man...
11/17/2019

Loss Aware Post-training Quantization

Neural network quantization enables the deployment of large models on re...
11/19/2020

Learning in School: Multi-teacher Knowledge Inversion for Data-Free Quantization

User data confidentiality protection is becoming a rising challenge in t...
01/29/2022

ScaLA: Accelerating Adaptation of Pre-Trained Transformer-Based Language Models via Efficient Large-Batch Adversarial Noise

In recent years, large pre-trained Transformer-based language models hav...
11/05/2019

Post-Training 4-bit Quantization on Embedding Tables

Continuous representations have been widely adopted in recommender syste...

1 Introduction

Large pre-trained language models (PLMs) have achieved remarkable success in various natural language processing tasks 

[vaswani2017attention, devlin2019bert, brown2020language]. However, the increasing size and computation overhead also make it prohibitive to deploy these PLMs on resource-constrained devices. To obtain compact PLMs, various model compression methods have been proposed, such as pruning [michel2019sixteen, fan2019reducing], knowledge distillation [sanh2019distilbert, sun2019patient, jiao2020tinybert], weight-sharing [dehghani2019universal, lan2020albert, wang2020revisiting, huang2021ghostbert], dynamic computation with adaptive depth or width [hou2020dynabert, xin2020deebert, zhou2020bert], and quantization [zafrir2019q8bert, shen2020qbert, zadeh2020gobo, zhang2020ternarybert, bai2021binarybert].

Among these methods, network quantization enjoys the reduction of both model size and computation overhead without modifying the network architecture. However, despite its remarkable performance, prior methods mostly follow quantization-aware training (QAT) and thus suffer from multiple challenges: 1) QAT generally requires the same order of training iterations with the original full-precision training [shen2020qbert, zhang2020ternarybert, bai2021binarybert], which can be slow to obtain the quantized model; 2) QAT usually adopts end-to-end training by back-propagation, while it can be challenging to load the entire PLM into memory on resource-limited devices. Moreover, recent QAT efforts combine knowledge distillation to enhance the performance [zhang2020ternarybert, bai2021binarybert], and thus further increase the memory overhead due to the teacher model; 3) QAT needs full access to the entire training set, which may lead to data security issues when exposing them to third-party organizations for the quantization service.

Figure 1: An illustrative comparison between our parallel post-training quantization method (MREM) and QAT on four dimensions of the quantization pipeline: accuracy, training time, memory overhead, and data consumption. The results are based on a quantized BERT-large model with 4-bit weights and 8-bit activations over the MNLI dataset. Best viewed in color.

Given the above challenges, post-training quantization (PTQ) serves as an appealing alternative. In contrast to QAT, PTQ is efficient in both training time, memory overhead and data consumption. Instead of the full training set, it is common in PTQ to adopt only a small portion of training data to minimize the reconstruction error incurred by quantization [nagel2019data, nahshan2019loss, nagel2020up, hubara2020improving]

. This can be done by calibrating the batch normalization statistics 

[nagel2019data] or step sizes [nahshan2019loss] in quantization functions in a layer-wise manner. The layer-wise objective is also more sample-efficient [zhou2020go]

and memory-saving compared with end-to-end training in QAT. Despite the success of prior PTQ solutions on convolutional neural networks (CNNs), we show that it is non-trivial to directly apply them to PLMs such as BERT 

[devlin2019bert]. Different from CNNs, there are multiple linear layers coupled together within the multi-head self-attention and feed-forward network of the transformer architecture. Therefore, layer-wise training ignores the underlying correlation among layers and thus leads to poor performance.

In this paper, we aim at improving the performance of post-training quantization for PLM, while simultaneously maintaining its efficiency w.r.t training time, memory overhead and data accessibility. Firstly, we propose module-wise reconstruction error minimization (MREM) to incorporate more layer-wise correlation. By partitioning the PLM into multiple modules, each module consists of multiple Transformer layers for joint optimization. Meanwhile, the module size can be flexibly adjusted depending on the memory constraints, achieving an appropriate trade-off between layer-wise correlation and memory overhead. While similar block-wise objectives are previously considered in [li2021brecq], they require to compute second-order Hessian matrices for optimization, which can be computationally prohibitive for large PLMs. Secondly, we design a new model parallel strategy to further accelerate the process of MREM. By allocating each module to an individual computing device, all modules can perform local training in parallel, achieving nearly the theoretical speed-up (e.g., on GPUs). Thirdly, we develop annealed teacher forcing for the parallel training. We find that the naive parallel training suffers from the propagation of reconstruction error, since each quantized module passes the error to its successors before it is fully converged. Inspired by [williams1989learning], we use the full-precision module to provide clean signals to the next quantized module. This breaks the reconstruction error propagation and improves the performance of quantized PLMs.

Empirical results on the GLUE and SQuAD benchmarks show that our proposed MREM not only significantly improves the performance for post-training quantization, but also enjoys advantages of faster training, less memory overhead, and improved data security over QAT. For instance, as is shown in Figure 1, the BERT-large model trained by parallel MREM can achieve accuracy based on only K training samples. Moreover, it consumes merely one-third of memory per GPU and is more than faster than previous QAT training.

We summarize the contributions of this paper as follows:

  • We study the post-training quantization of PLMs, and propose module-wise reconstruction error minimization (MREM), a fast, memory-saving, and data-efficient approach to improve the quantized PLMs.

  • We design a new model parallel strategy based on MREM to accelerate post-training quantization with theoretical speed-up for distributed training.

  • The parallel MREM can be combined with annealed teacher forcing to alleviate the propagation of reconstruction error and boost the performance.

  • We conduct extensive experiments on both GLUE and SQuAD benchmarks to verify the advantages of our approach w.r.t. multiple aspects of the quantization pipeline. We also provide detailed discussions on other important factors of our approach.

The rest of this paper is organized as follows. We summarize the related work in Section 2. In Section 3, we review the necessary backgrounds and explain the motivation for this research. In Section 4, we propose our PTQ solution together with the parallel training technique. The experiments are present in Section 5. Finally, we conclude this work in Section 6.

2 Related Work

In this section, we review the literature related to our research, including both network compression and parallel training for pre-trained language models.

2.1 Network Compression for Pre-trained Language Models

As the research focus in this paper, network quantization replaces the original full-precision parameters and activations with low-bit representations [courbariaux2015binaryconnect, li2016ternary, hou2017loss, hou2018loss, esser2019learned, li2020rtn, zhuang2021effective, young2021transform]. To apply network quantization on PLMs, Q8BERT [zafrir2019q8bert] convert both parameters and activations with 8-bit representations, and finds that there is negligible accuracy drop on natural language understanding tasks. Q-BERT [shen2020qbert] exploit the hessian matrix of loss curvature to determine the best layer-wise quantization bit-width, which achieves a higher compression rate. Additionally, TernaryBERT [zhang2020ternarybert] propose to ternarize the parameters with 2-bit representations together with two-stage knowledge distillation [jiao2020tinybert]. Recently, BinaryBERT [bai2021binarybert]binarize the model parameters based by splitting from ternarized models, which further reduces the quantization bit-width. Despite the promising performance of these quantization approaches, they mostly follow quantization-aware training that requires heavy fine-tuning, which can be prohibitive given constraints on the training time, memory size, and data accessibility. In this work, we shall follow the post-training quantization [nagel2019data, nahshan2019loss, zhao2019improving, nagel2020up, li2021brecq, hubara2020improving], the other way to improve the quantized PLMs given limited training resources. More details on quantization will be discussed in Section 3.

Aside from quantization, there are also several other popular techniques to compress PLMs. Pruning removes unimportant parameters or connections in a well-trained model [he2017channelprunning, luo2018thinet, wen2019structured, wang2020bayesian, liu2021discrimination], and is widely explored in PLMs [gordon2020compressing, wang2020rethinking]. A direct way is to remove connections with small magnitudes during the pre-training and adds them back when necessary in downstream tasks [gordon2020compressing]. Structured pruning can also be applied to directly reduce the width of BERT [McCarley2019q], which is more friendly for practical inference acceleration. In [michel2019sixteen], it is shown that attention heads can also be pruned without hurting the representation of PLMs. Furthermore, a comprehensive study is provided in [wang2020rethinking] to investigate the pruning sensitivity of different parts of the Transformer model. There are also efforts on dropping the entire layers of transformer models [fan2019reducing, xin2020deebert].

Knowledge distillation [hinton2015distilling, romero2014fitnets, wang2021distilling, zhang2021self, liu2020structured]

is another successful tool to design efficient PLMs, by distilling knowledge from a large teacher model to a smaller student model. Knowledge distillation is first applied to PLMs by minimizing the soft cross-entropy between output logits 

[sanh2019distilbert]

. Aside from this, recent efforts show that it is also helpful to minimize the mean square error of hidden representations 

[jiao2020tinybert, sun2019patient]. While knowledge distillation is promising in performance, memory consumption is the main concern as the teacher model itself or the size of its pre-computed representation for distillation is generally large. Post-training quantization is also closely related to knowledge distillation [nahshan2019loss, nagel2019data, nagel2020up, li2021brecq], when the full-precision model acts as the teacher to provide layer-wise supervision signals to the quantized student model.

An orthogonal direction is to apply neural architecture search (NAS) for efficient PLMs structures. AdaBERT [chen2020adabert] adopt the differentiable search [liu2018darts] to automatically compress BERT into task-specific architectures. To obtain task-agnostic architectures, NAS can also be applied during the pre-training stage [so2019evolved, xu2021bert]. Recently, one-shot NAS is also developed to search tiny PLMs [yin2021autotinybert]. Despite the promising performance of these approaches, the algorithmic efficiency [bender2019understanding, wang2020revisiting] is a major concern for NAS-based PLMs.

2.2 Parallel Training for Pre-trained Language Models

Parallel training is also a popular topic in training large pre-trained language models, where pipeline model parallelism [huang2018gpipe, narayanan2019pipedream, tarnawski2020efficient, park2020hetpipe, fan2021dapple]

is closely related to our proposed parallel training strategy. Specifically, by partitioning the model into multiple modules, pipeline parallelism similarly puts each module on individual computing devices. However, pipeline parallelism has a high computational cost among different workers to transmit the intermediate tensors in the forward and backward pass. GPipe 

[huang2018gpipe] adopts mini-batches to reduce the bubble time, but still suffers from limited speed-up as will be discussed in Section 4.2.2. PipeDream [narayanan2019pipedream] optimizes the module partition to minimize the communication cost, but the resulting strategy still highly depends on the model architecture. Our parallel training strategy, on the other hand, does not require real-time communication among workers and thus brings the theoretical speed-up. While the solution can be suboptimal, we propose a novel teacher-forcing mechanism to alleviate the problem and the resultant performance is already reasonably good for post-training quantization.

Aside from pipeline model parallelism, there are several other dimensions for parallel training. Data-parallelism [dean2008mapreduce, li2014communication]

is the most widely used solution in modern deep learning frameworks, when the training data are partitioned over multiple workers. Op-level model parallelism 

[jia2018exploring, shoeybi2019megatron, song2020accpar] slices the parameters over multiple works to compute local results and concatenate them together afterwards, which has a high communication cost. Optimizer model parallelism [rajbhandari2020zero] is capable of partitioning parameters, gradients, and optimizer states into each computing device, achieving linear memory reduction with a number of workers. It would be promising to combine all these approaches with our parallel strategy, enabling the post-training quantization on gigantic pre-trained language models such as GPT-3 [ramesh2021zero] and PanGu- [zeng2021pangu].

3 Motivation

(a) Training Time.
(b) Memory Overhead.
(c) Data Accessibility.
(d) Performance.
Figure 2: Comparison between QAT and REM-based PTQ over four dimensions. We use a BERT-large model over MNLI dataset for illustration. The full-precision (FP) fine-tuning is also included as a baseline. We follow the procedure in  [zhang2020ternarybert] for QAT, and REM in Equation (2) for PTQ. The training time and memory in (a) and (b) are measured by 4-bit weights and 8-bit activations (i.e., W4A8) on an NVIDIA V100 GPU.

In this section, we show that it is important yet challenging to conduct post-training quantization of PLMs. Before diving into details, we first review the necessary backgrounds for network quantization.

3.1 Quantization Background

Network quantization replaces the original full-precision weight or activation with its lower-bit counterpart . Denoting as the step size, the -bit symmetric uniform quantization function can be written as

(1)

where is the set of -bit integers, and is the projection function that maps to its closest integer.

In the context of quantization of Transformer-based PLMs, we follow the default setting in previous works [zafrir2019q8bert, zhang2020ternarybert, bai2021binarybert]: we quantize both the network weights and activations in each matrix multiplication.  We use symmetric uniform quantization for weights, embeddings, and activations, except activations after the self-attention and GeLU function. For these two activations, we adopt asymmetric quantization since their elements are mostly positive. We skip the quantization for all layer-normalization layers, skip connections, biases and the last classification head due to limited computation overhead or large performance drop.

In the following, we introduce two common branches in the quantization literature: quantization-aware training and post-training quantization.

3.1.1 Quantization-aware Training (QAT)

Quantization-aware training resembles normal training, but performs the forward propagation with the quantized network. Thus it is also time-consuming to iterate over the entire training set . Typical training objective can be either the cross-entropy loss between the prediction and ground-truth labels for classification tasks [zafrir2019q8bert], or the distillation objective between the quantized model and a full-precision teacher model [zhang2020ternarybert]. As the quantization function

is non-differentiable, straight-through estimator 

[courbariaux2015binaryconnect] is usually adopted to allow the gradient back propagation through these discrete operations.

3.1.2 Post-training Quantization (PTQ)

Unlike QAT, post-training quantization seeks to recover the performance degradation without intensive training over the entire training set

. One line of PTQ research quantizes the network purely without using any training data, but removes outliers in the full-precision parameters. This can be achieved by splitting an outlier neuron with a large magnitude into two parts 

[zhao2019improving], where the magnitude can be halved. Alternatively, one can scale down outlier magnitude and multiply it back in subsequent layers, a.k.a. weight equalization in  [nagel2019data]. Another solution is to treat the outliers and normal values in the distribution separately, by keeping two sets of quantization parameters [fang2020post, zadeh2020gobo].

Another line of PTQ research [nahshan2019loss, wang2020towards, nagel2020up, hubara2020improving] aims at reconstruction error minimization (REM) using a very slight portion of unlabeled data (a.k.a. calibration set) from the original training set. Compared with training-free PTQ approaches, such an approach is able to significantly improve the performance of the quantized network. REM can be achieved by minimizing the distance between the output of multiplication between the quantized and the full-precision counterpart as follows:

(2)

where and are full-precision weights and activations, and are their quantized representations with and bit-widths, and denotes all step-sizes involved for quantization. REM is usually conducted in a greedy manner. It proceeds to the matrix multiplication only after the training of previous ones. Meanwhile, Equation (2) can be solved quickly with the calibration set . Recent work [zhou2020go] also theoretically shows that such greedy objective is more sample-efficient compared with conventional end-to-end training. In this paper, we shall extend REM-based post-training quantization given its past success.

3.2 Why Post-training Quantization?

In this section, we discuss the difference between REM-based PTQ and QAT along four dimensions of a quantization pipeline: 1) training time; 2) memory overhead; 3) data accessibility and 4) performance. According to Figure 2, we summarize the findings in the following paragraphs.

3.2.1 Training Time

As QAT iterates over the full training set

for multiple epochs, it is much more time-consuming than PTQ. Note that recent QAT methods 

[zhang2020ternarybert, bai2021binarybert] further combine two-stage knowledge distillation [jiao2020tinybert], which even prolongs the training compared with the full-precision (FP) fine-tuning. As shown in Figure 2(a), QAT can take nearly four times longer than FP.

3.2.2 Memory Overhead

The increasing size of recent large PLMs makes it prohibited to conduct QAT on memory-limited computing resources. From Figure 2(b), QAT [zhang2020ternarybert] even consumes GB more memory than FP when combined with knowledge distillation to store the full-precision teacher model. On the other hand, PTQ only caches intermediate results during the layer-wise REM in Equation (2), which can be fed into a single GTX 1080 Ti. Therefore, PTQ is also applicable on memory-limited computing devices.

3.2.3 Data Accessibility

The quantization service can be usually offered by some third-party organizations, where data security is always of high priority. As QAT requires access to the entire training set, it inevitably increases the risk of data exposure. PTQ, on the other hand, needs only a small amount of calibration data , and can be easily constructed by randomly sampling K K instances from , as shown in Figure 2(c). Therefore, most original training instances are kept untouched and data security can be largely guaranteed.

3.2.4 Performance

When fine-tuned over the entire training set, QAT usually maintains better quantized performance than PTQ. From Figure 2(d), the performances of QAT are close to FP results, and remain steady across different bit-widths, i.e., W4A8, W2A8 and W2A4. However, the performances of PTQ drop significantly, which has been the main concern to address.

In summary, REM-based PTQ is superior to QAT with regard to training efficiency, memory overhead, and data accessibility. Nevertheless, it is still often less preferred than QAT due to its severe performance drop especially for low quantization bit-width [zafrir2019q8bert, shen2020qbert, zhang2020ternarybert]. In this paper, we aim at improving the performance of post-training quantization for PLMs, while preserving its merits of fast training, light memory overhead, and data consumption.

4 Methodology

Figure 3: The overview of the proposed module-wise reconstruction error minimization (MREM). We partition both the full-precision model and quantized model into multiple modules and put these modules on different computing devices. By sampling tensors from the input queue, each module can be trained locally without waiting for its predecessors. Teacher forcing is applied to mitigate the issue of reconstruction error propagation on the quantized module.

In this section, we propose our solution to improve the post-training quantization of Transformer-based PLMs. We first extend the existing reconstruction error minimization from the layer-wise to the module-wise granularity to fit Transformer models. Secondly, based on the module partition, we further design a new parallel training strategy that further speeds up the PTQ pipeline. An overview of our solution can be found in Figure 3.

4.1 Module-wise Reconstruction Error Minimization

We propose a new PTQ solution called module-wise reconstruction error minimization (MREM) for PLMs. Existing REM [nagel2020up] solves Equation (2) for each matrix multiplication. However, a standard transformer layer in PLMs consists of a Multi-Head Attention (MHA) and a Feed-Forward Network (FFN), both of which contain a number of matrix multiplications that are coupled together. Greedily tackling each matrix multiplication in REM can thus lead to suboptimal quantized networks. Moreover, the insufficiently minimized reconstruction error shall propagate and enlarge along with transformer layers, and finally deteriorate the network output [chen2019deep, bai2020few].

Towards that end, the proposed module-wise reconstruction error minimization admits larger granularity by jointly optimizing all the coupled linear layers inside each module. Specifically, given a PLM with transformer layers, embedding layers and the classification head, we partition them into modules, where the -th module include transformer layers with being the first layer of this module111Note that the embedding layers and the classification head are incorporated in the first and last module respectively.. MREM aims at minimizing the joint reconstruction errors between all quantized FFN output in the module from their full-precision counterpart as follows:

(3)

where and are all learnable parameters and quantization step sizes in the -th module. Similar to REM, MREM can be optimized sequentially: given previously trained modules, only parameters and quantization step sizes in the current module are updated. Besides the grouped Transformer layers, we also minimize the MSE loss in the Transformer embedding and output logits respectively.

Note that the number of modules can be adjusted depending on the memory constraint of computing resources. When , this reduces to intermediate-layer knowledge distillation [jiao2020tinybert], which can be memory-demanding when quantizing large PLMs on a single GPU.

4.2 Accelerated Parallel Training

Based on the proposed MREM, we propose a new model parallel strategy to further accelerate the training. As shown in Figure 3, we put different modules on individual computing devices. A set of input queues is deployed between each pair of adjacent modules. For the -th module, the queue collects its output of the most recent steps, i.e., . Meanwhile, the -th module can always sample with replacement from the queue without waiting for the -th module. Similar rules hold for the quantized module and their input queues as well. The design of the input queue resembles stale synchronous parallel [ho2013more] which stores the stale output in a local cache so as to reduce the waiting time among workers, where is the stale threshold.

The training workflow is as follows. Initially, every module is computed one after another in the first step to fill in the input queue, after which parallel training takes place. Then the module samples input from the queue and calculates the loss correspondingly for . Meanwhile, the input queue is also updated with the rule of first-in-first-out throughout the training. In the backward pass, we constrain the gradients to propagate locally within each module, without affecting its predecessors. Such a design can avoid the load imbalance issue from straggler modules, bringing nearly the theoretical speed-up.

4.2.1 Annealed Teaching Forcing

Since all modules proceed with training simultaneously instead of the sequential manner, the next module takes the output from the queue before its predecessor is fully optimized. Therefore, the reconstruction error from the predecessor is propagated to the following modules before it is sufficiently minimized.

Inspired by teacher forcing [williams1989learning] in training recurrent networks, the output from the -th full-precision module naturally serves as the clean input to the -th quantized module to substitute . Thus stops the propagation of the reconstruction error accumulated on the quantized module. Nevertheless, such an approach breaks the connection to previous quantized modules and may suffer from forward inconsistency between training and inference [bai2020few] on the quantized model. To achieve a proper trade-off, we take the convex combination between the full-precision and quantized as follows:

(4)

where the hyperparameter

controls the strength of teacher forcing. gives the full correction of reconstruction error but with forward inconsistency, while reduces to the conventional setting that suffers from the propagated reconstruction error. We adopt a linear decay strategy for : , where is the preset maximum steps of the decay. Intuitively, a large is desired at the beginning when each module is rarely optimized. Later, a small is preferred to transit to normal training such that the forward inconsistency can be bridged. The remaining steps stick to normal training without teacher forcing, so as to make each quantized module adapt to its own predecessors.

4.2.2 Comparison with Pipeline Parallelism

Notably, our MREM with stale synchronous parallel is different from the recent pipeline parallel [huang2018gpipe, narayanan2019pipedream]. Pipeline parallel adopts end-to-end training with synchronous updates between adjacent modules, which gives rise to bubble time on computing devices. While GPipe [huang2018gpipe] divides the original data batch into pipelined micro-batches, it still has the bubble time of under partitions. On the one hand, a larger or smaller would increase the bubble time. On the other hand, a larger leads to smaller batches that still cannot fully exploit the computing power, which again affects the acceleration rate. Differently, our parallel strategy conducts local training with stale synchronous updates of the module. Hence there is negligible bubble time as long as the straggler is faster than the staleness threshold , which can be easily satisfied with balanced module partitions or larger .

1:procedure Main:
2:     Partition the PLM into modules
3:     Fill in the input queues ,
4:     for  in  do
5:            run in parallel
6:          while  do
7:               ,  
8:                MREM ()
9:               Update , with                
10:     return the Quantized PLM
Algorithm 1 Efficient Post-training Quantization.

Finally, an overview of the proposed parallel module-wise reconstruction error minimization is shown in Algorithm 1 and Algorithm 2. The in Algorithm 2 can be any gradient update function such as AdamW [loshchilov2018decoupled] with learning rate .

5 Experiments

In this section, we empirically verify the proposed MREM for post-training quantization of PLMs. We first introduce the experimental setup in Section 5.1. Then we present main results in Section 5.2, including comparisons with QAT and REM, as well as other existing quantization baselines. In Section 5.4, we provide more discussions on a variety of factors in our approach, such as the effect of teacher forcing, the number of model partitions, and calibration data size. Code will be released upon acceptance.

#Bits
(W-E-A)
Quant
Method
BERT-base BERT-large
Time
(min)
Mem
(GB)
# Data
(K)
Acc
m(%)
Acc
mm(%)
Time
(min)
Mem
(GB)
# Data
(K)
Acc
m(%)
Acc
mm(%)

MNLI

full-prec. N/A
4-4-8 QAT
REM
MREM-S
MREM-P
2-2-8 QAT
REM
MREM-S
MREM-P
2-2-4 QAT
REM
MREM-S
MREM-P
Table I: Results of our proposed MREM-S and MREM-P against QAT and REM on the development set of MNLI. “#Bits (W-E-A)” represents the bit-width for weights of Transformer layers, word embedding, and activations. Acc-m and Acc-mm denote accuracies on the matched and mismatched sections of MNLI respectively.
#Bits
(W-E-A)
Quant
Method
BERT-base BERT-large
Time
(min)
Mem
(GB)
# Data
(K)
EM (%)
F1 (%)
Time
(min)
Mem
(GB)
# Data
(K)
EM (%)
F1 (%)

SQuAD v1.1

full-prec. -
4-4-8 QAT
REM
MREM-S
MREM-P
2-2-8 QAT
REM
MREM-S
MREM-P
2-2-4 QAT
REM
MREM-S
MREM-P
Table II: Results of our proposed MREM-S and MREM-P against QAT and REM on the development set of SQuAD v1.1. “” denotes results with two gradient accumulation steps under the same total batch size due to memory constraint.
#Bits
(W-E-A)
Quant
Method
BERT-base BERT-large
Time
(min)
Mem
(GB)
# Data
(K)
EM (%)
F1 (%)
Time
(min)
Mem
(GB)
# Data
(K)
EM (%)
F1 (%)

SQuAD v2.0

full-prec. -
4-4-8 QAT
REM
MREM-S
MREM-P
2-2-8 QAT
REM
MREM-S
MREM-P
2-2-4 QAT
REM
MREM-S
MREM-P
Table III: Results of our proposed MREM-S and MREM-P against QAT and REM on the development set of SQuAD v2.0. “” denotes results with two gradient accumulation steps under the same total batch size due to memory constraint.
Quant
Method
#Bits
(W-E-A)
Size
(MB)
PTQ
MNLI-m
QQP QNLI SST-2 CoLA STS-B MRPC RTE Avg.
- full-prec. -
Q-BERT 2-8-8 - - - - - - -
Q-BERT 2/4-8-8 - - - - - - -
Quant-Noise PQ - - - - - - - -
TernaryBERT 2-2-8
GOBO 3-4-32 - - - - - - -
GOBO 2-2-32 - - - - - - -
MREM-S 4-4-8
2-2-8
MREM-P 4-4-8
2-2-8
Table IV: Results on the GLUE development set. “Size” refers to model storage in “MB”. “PTQ” indicates whether the method belongs to post-training quantization. “Avg.” denotes the average results of all tasks.
1:procedure MREM ():
2:     if  then
3:          
4:          Compute by Equation (4)      
5:     Compute the full-precision module output
6:     Compute the quantized module output
7:     Compute the loss by Equation (3)
8:     
9:     
10:     return
Algorithm 2 Module-wise Reconstruction Error Min.

5.1 Experimental Setup

5.1.1 Datasets and Metrics

We evaluate post-training quantization w.r.t. both text classification on the GLUE dataset [wang2018glue], and reading comprehension on SQuAD benchmarks [rajpurkar2016squad]. The size of calibration data is by default , with instances randomly sampled from the full training set. As both RTE and MRPC tasks in the GLUE benchmark contain fewer than 4,096 samples, we use their full training set on these two tasks. We leave the study of data size in Section 5.4

. Each experiment is repeated ten times with different calibration sets, and both the mean and standard deviations are reported.

We use the same evaluation metrics in

[devlin2019bert, zhang2020ternarybert] for the development set of GLUE and SQuAD benchmarks. For results in Section 5.2, we report accuracies on both the matched section and mis-matched sections of MNLI, and EM (exact match) and F1 score for SQuAD. Additionally, we also report the training time (min), memory overhead (GB) as well as the size of the training set (K). We also provide comparisons with other existing methods in Section 5.3, where we adopt Matthews correlation for CoLA, Spearman correlation for STS-B, and accuracy for the rest ones (i.e., RTE, MRPC, SST-2, QQP, MNLI). We also report the averaged performance on GLUE as an overview.

5.1.2 Implementation

We use the standardly fine-tuned BERT-base and BERT-large models222We follow the default fine-tuning hyperparameter settings in Huggingface: https://github.com/huggingface/transformers. on downstream tasks for post-training quantization. We implement MREM in both the sequential training (abbv. MREM-S) in Section 4.1 and parallel training with teaching forcing (abbv. MREM-P) in Section 4.2.1. For each module, we train for steps with an initial learning rate of on GLUE tasks, and steps with an initial learning rate of on SQuAD datasets. The learning rate decays linearly as done in [jiao2020tinybert, zhang2020ternarybert]. By default, we partition the model into 4 modules on 4 NVIDIA-V100 GPUs. The analysis of the training steps and partition numbers will be provided in Section 5.4.

For baselines, we mainly compare with QAT and REM, where the former measures how much PTQ can get close to QAT, and the latter studies the effect of objective granularity in PTQ training. We conduct QAT following the state-of-the-art training pipeline [zhang2020ternarybert], i.e., intermediate-layer distillation followed by prediction-layer distillation, which takes 6 training epochs in total. Detailed hyperparameter settings can be found in [zhang2020ternarybert]. In terms of REM, we follow the practice in [nagel2020up, hubara2020improving] to minimize the reconstruction error after each matrix multiplication, as introduced in Section 3.1. For a fair comparison of each method, we use the same quantization scheme, i.e., TWN [li2016ternary] or LAQ [hou2018loss] for 2-bit and 4-bit weight quantization, and LSQ [esser2019learned] for all activation quantization. Unlike QAT that picks the best model based on the development set results, MREM is only tested once after training, which ensures data security of the development set. We leave the comparison with more existing quantization approaches in Section 5.3.

5.2 Main Results: Comparison with QAT and REM

We first compare MREM-S and MREM-P with QAT and REM over MNLI and SQuAD benchmarks. We take BERT-base and BERT-large as backbone PLMs for quantization. The results on MNLI, SQuADv1.1 and SQuADv2.0 are summarized in Table I, Table II and Table III respectively. We summarize the results from the four dimensions mentioned in Section 3.2.

5.2.1 Performance

It can be found that our proposed MREM-S improves the performance of REM significantly given the same training time, and is much closer to QAT. For instance, according to in Table I, MREM-S with 4-bit weight quantization on BERT-base and BERT-large achieves accuracies of and on the matched section of MNLI, which is on average and better than REM, and only and inferior to QAT, respectively. With REM, BERT-base sometimes even outperforms BERT-large on MNLI. We speculate that this is due to the suboptimal solutions in REM that lead to propagated reconstruction error when more neurons or transformer layers are stacked in BERT.

Moreover, with all modules trained in parallel, MREM-P is still close to or only slightly inferior to MREM-S. From results of SQuAD v1.1 in Table II, MREM-P can even outperform MREM-S with the “W2-E2-A4” quantized BERT-large model (i.e., the EM score and F1 score are on average and respectively).

5.2.2 Training Time

Our proposed MREM also enjoys significantly less training time than QAT. For instance, MREM only takes minutes for -bit weight quantized training on the BERT-large over MNLI, which is about faster than QAT and faster than full-precision fine-tuning. When compared with REM, MREM does not need to cache the output after every matrix multiplication, which admits more iterations given the same amount of time. We shall discuss this further in Section 5.4.2. Moreover, when armed with the proposed parallel training, MREM-P is further faster than MREM-S, which achieves the theoretical linear speedup on GPUs. These together bring more than reduction of training time when compared with QAT.

#Bits (W-E-A) # Steps BERT-base BERT-large w/o TF w/ TF w/o TF w/ TF 2-2-8 2-2-4
Table V: Ablation studies of teacher forcing at different training steps over MNLI-m.
#Bits (W-E-A) Quant Method # Steps Time (min) Mem (G) Acc m(%) Acc mm(%) 4-4-8 REM REM MREM-S 2-2-8 REM REM MREM-S 2-2-4 REM REM MREM-S
Table VI: Comparison of REM with our MREM on BERT-base over MNLI.
(a) Module-1 (250 Steps).
(b) Module-2 (250 Steps).
(c) Module-3 (250 Steps).
(d) Module-4 (250 Steps).
(e) Module-1 (2,000 Steps).
(f) Module-2 (2,000 Steps).
(g) Module-3 (2,000 Steps).
(h) Module-4 (2,000 Steps).
Figure 4: The training loss curves with and without teacher forcing (TF) in MREM-P. The red area denotes teacher forcing in the first training steps. (a), (b), (c) and (d) in the first row are the four modules trained for 250 steps, and (e), (f), (g) and (h) in the second row are trained for 2,000 steps.

5.2.3 Memory Overhead

While the module-wise training inevitably consumes more memory than REM, it still takes only around a third of the memory by QAT, and a half of that by the full-precision fine-tuning. For instance, while QAT takes GB memory on BERT-large, MREM only consumes

GB memory, which can be even fed into a single NVIDIA GTX 1080 Ti. Moreover, for input with a longer sequence length (i.e., 384 tokens on the SQuAD dataset), QAT over BERT-large may suffer from memory overflow even on an NVIDIA V100 GPU with 32GB memory. QAT with gradient accumulation inevitably doubles the training time under the same total batch size (i.e., underlined figures (“

”) in Table II and Table III). On the other hand, such issues can be easily mitigated in both REM and our proposed MREM. Meanwhile, increasing the number of modules can further decrease the memory overhead of each module, but may harm the performance, as will be discussed in Section 5.4.3.

5.2.4 Data Accessibility

Both REM and our proposed MREM follow the common practice of PTQ, relying on only randomly sampled training instances on both MNLI and SQuAD, which is a tiny fraction of the original dataset used in QAT. We shall provide more discussion on the effect of calibration size in Section 5.4.

In summary, our MREM-S improves post-training quantization on PLMs significantly, while still enjoys fast training, light memory overhead, and data security. Moreover, with parallel training, the proposed MREM-P further strengthens the advantages of PTQ without an apparent performance drop.

5.3 Main Results: Comparison with Existing Methods

In the next, we compare our MREM with a number of existing state-of-the-art BERT quantization methods. They include various QAT approaches such as Q-BERT [shen2020qbert], Quant-Noise [fan2020training] and TernaryBERT [zhang2020ternarybert], as well as the PTQ baseline GOBO [zadeh2020gobo]. Their results are taken from the original papers, respectively.

From Table IV, both our proposed MREM-S and MREM-P outperform existing PTQ approaches in most cases, and even achieve results close to QAT approaches. For example, the “W4-E4-A8” quantized MREM-S and MREM-P have the averaged accuracies of and on MNLI respectively, both of which are on par with “W2/4-E8-A8” quantized Q-BERT. In terms of the “W2-E2-A8” quantized models, our MREM-S and MREM-P surpass GOBO by and on MNLI-m respectively.

5.4 Discussions

In this section, we provide further discussions to better understand the proposed approach. By default, all experiments in this section are based on the BERT-base model over the MNLI dataset.

(a) Number of Modules and Memory Overhead.
(b) Size of Calibration Data.
(c) Error Propagation (A8).
(d) Error Propagation (A4).
Figure 5: Discussions on the proposed MREM approach. In (a) and (b), the solid line and shaded area denote the averaged results and standard deviation of a “W2-E2-A4” quantized BERT-base model from 10 different seeds. (c) and (d) visualize the propagation of reconstruction error on “W2-E2-A8” and “W2-E2-A4” quantized BERT-base model, respectively.

5.4.1 Teacher Forcing

We now study how teacher forcing benefits MREM-P with different numbers of training steps, and results are listed in Table VI. It can be found that teacher forcing can bring consistent improvement for both BERT-base and BERT-large models. Moreover, the gain of teacher forcing is more significant with fewer training steps or lower quantization bit-width, i.e., and on the “W2-E2-A4” quantized BERT-base and BERT-large respectively under steps. This matches our intuition that fewer training steps or higher compression ratio give larger reconstruction error, when the clean input from the full-precision module can benefit more the quantized module. As further increasing the training steps brings only marginal improvement and diminishes the effect of teacher forcing, we by default set the training steps to .

Additionally, we also plot training loss curves of the four modules under 250 and 2,000 training steps in Figure 4. We find that: 1) the loss curves with teacher forcing are apparently lower, especially when trained with fewer steps, which matches the observations in Table VI; 2) the loss curves of late modules are usually lower than the earlier ones, indicating that modules closer to the output benefit more from teacher forcing. This matches the intuition that the late modules have more errors accumulated to correct. The red areas in Figure 4 show that of the total iterations are used for teacher forcing. We also try tuning it within , and does not observe large difference in the final performance. We thus choose by default.

5.4.2 Further Comparison with REM

Here we provide further discussions with REM on the training efficiency. Note that both REM and MREM-S follow the sequential training procedure, where the output from the previous objective is cached for the next objective. However, as there are many matrix multiplications in each Transformer layer, it can be time-consuming for REM to repeat this procedure recursively. According to results in Section 5.2, while REM and MREM take roughly the same amount of time, REM is only iterated for steps on MNLI and steps on SQuAD, while MREM takes steps and steps respectively.

We also provide results when REM takes the same amount of training steps with MREM-S in Table VI. It can be found that even with iterations, REM is still inferior to MREM-S across all quantization bit-widths. Meanwhile, REM nearly takes around more training time than MREM. Therefore, the module-wise granularity in MREM not only improves the quantization performance with more layer-wise dependencies considered, but also makes the training pipeline efficient with fewer stages to cache intermediate results.

5.4.3 Number of Modules and Memory Overhead

We verify the effect of model partition on the final quantized performance, as well as their corresponding memory consumption. According to Figure 5(a), by varying the number of modules within , it can be found that fewer model partitions give slightly better performance, as layer-wise dependencies can be better incorporated for reconstruction error minimization. However, this also comes with more running memory, i.e., GB for these partitions correspondingly. The decrease of memory also diminishes with fewer partitions. Therefore. as a trade-off, we partition the model into 4 modules by default.

5.4.4 Size of Calibration Data

The size of calibration data directly relates to the security and privacy issues in post-training quantization. To learn its effects, we vary the calibration data size within , and list the results of REM, MREM-S and MREM-P. From Figure 5(b), it can be found that while REM is ahead of MREM-S/P with fewer than training samples, the accuracy of REM rises slowly and saturates at around afterwards. We hypothesize that the simple training objective in REM can hardly hold more training instances for optimization. MREM-S/P, on the other hand, can better exploit larger calibration data size, since the module-wise granularity admits higher flexibility for the optimization. As we find the diminishing gain to increase the training size after samples, we by default take samples.

5.4.5 Reconstruction Error Propagation

We visualize the propagation of reconstruction error for both “W2-E2-A8” and “W2-E2-A4” quantized BERT-base models in Figure 5(c) and Figure 5(d) respectively. It can be observed that our MREM achieves both lower values and slower rising rates of the reconstruction error than REM across all layers, which verifies the advantage of module-wise granularity to minimize the reconstruction error. Interestingly, while the reconstruction error generally gets enlarged layer-wisely in the first ten layers, it begins to decrease afterwards. We speculate this is due to the effect of the classification head that encourages concentrated hidden representations for the task.

#Bits
(W-E-A)
Methods w/o PCQ w/ PCQ
Acc
m(%)
Acc
mm(%)
Acc
m(%)
Acc
mm(%)
4-4-8 REM
MREM
2-2-8 REM
MREM
2-2-4 REM
MREM
Table VII: Comparison of BERT-base results with and without per-channel quantization (PCQ) on MNLI.

5.4.6 Per-channel Quantization

Per-channel Quantization (PCQ) is prevalent in the post-training quantization of convolution neural networks 

[nahshan2019loss, nagel2020up, hubara2020improving]. To learn its effect in PLMs, PCQ assigns different quantization step-sizes at each output dimension of the linear layer, which is also known as row-wise quantization in [zhang2020ternarybert]. The PCQ results of REM and MREM are shown in Table VII. It can be found that while PCQ improves REM by around to , the gain is very incremental on MREM. We hypothesize that more training steps of MREM can better adjust the quantization distribution for PLMs. Our results are also similar to the findings in  [zhang2020ternarybert], where the row-wise quantization brings little improvement. As PCQ also requires to store more full-precision step sizes, we do not employ PCQ by default.

6 Conclusion

In this paper, we study post-training quantization for pre-trained language models. We show that existing quantization-aware training solutions suffer from slow training, huge memory overhead, and data privacy issues when accessing the full training set. To mitigate these issues, we propose module-wise reconstruction error minimization, an efficient solution to quantize PLMs. MREM can be conducted either sequentially or in parallel, where the parallel training can achieve the speedup close to the theoretical limit without apparent performance degradation. Experimental results show that the proposed solution greatly improves the performance. Meanwhile, it significantly reduces the training time and memory overhead with only thousands of training instances.

There are several promising directions to explore in the future: 1) We can scale the proposed approach to larger PLMs, and thus more models can benefit from post-training quantization; 2) The proposed parallel strategy can be applied to warm up the pre-training of PLMs, such that the overall pre-training cost can be reduced; 3) While the current parallel strategy conducts local training separately, it would be interesting to cache module-wise gradients in some queues for the backward pass so that there is less discrepancy with conventional end-to-end training.

References