Controlling Computation versus Quality for Neural Sequence Models

by   Ankur Bapna, et al.

Most neural networks utilize the same amount of compute for every example independent of the inherent complexity of the input. Further, methods that adapt the amount of computation to the example focus on finding a fixed inference-time computational graph per example, ignoring any external computational budgets or varying inference time limitations. In this work, we utilize conditional computation to make neural sequence models (Transformer) more efficient and computation-aware during inference. We first modify the Transformer architecture, making each set of operations conditionally executable depending on the output of a learned control network. We then train this model in a multi-task setting, where each task corresponds to a particular computation budget. This allows us to train a single model that can be controlled to operate on different points of the computation-quality trade-off curve, depending on the available computation budget at inference time. We evaluate our approach on two tasks: (i) WMT English-French Translation and (ii) Unsupervised representation learning (BERT). Our experiments demonstrate that the proposed Conditional Computation Transformer (CCT) is competitive with vanilla Transformers when allowed to utilize its full computational budget, while improving significantly over computationally equivalent baselines when operating on smaller computational budgets.


page 1

page 2

page 3

page 4


Depth-Adaptive Transformer

State of the art sequence-to-sequence models perform a fixed number of c...

Dynamic-TinyBERT: Boost TinyBERT's Inference Efficiency by Dynamic Sequence Length

Limited computational budgets often prevent transformers from being used...

Length-Adaptive Transformer: Train Once with Length Drop, Use Anytime with Search

Although transformers have achieved impressive accuracies in various tas...

Multi-Exit Vision Transformer for Dynamic Inference

Deep neural networks can be converted to multi-exit architectures by ins...

Linearizing Transformer with Key-Value Memory Bank

Transformer has brought great success to a wide range of natural languag...

Budget-Aware Adapters for Multi-Domain Learning

Multi-Domain Learning (MDL) refers to the problem of learning a set of m...

Do Transformers use variable binding?

Increasing the explainability of deep neural networks (DNNs) requires ev...

1 Introduction

Over the last few years, scaling neural networks has tremendously improved the quality of models on several machine learning tasks. State-of-the-art Natural Language Processing models have billions of parameters, especially for tasks like Machine Translation (Shazeer et al., 2018; Huang et al., 2019), Language Modeling (Radford et al., ) and Natural Language Understanding (Devlin et al., 2018; Raffel et al., 2019). While training these models is feasible given the dramatic increase in the efficiency of training hardware (Jouppi et al., 2017) and research into efficient model-parallelism (Shazeer et al., 2018; Huang et al., 2019), the amount of computation that can be expended at inference is often limited. However, these huge networks are usually inflexible and offer little control over the amount of computation used on any example, independent of the complexity of the input or the available computation budget for inference.

Conditional Computation based approaches allow training networks where certain sub-networks can be conditionally executed, based on discrete decisions (optionally) trained with the model (Spall et al., 1992; Bengio et al., 2013). These methods also offer the potential for more control over the computation expended by the model during inference, conditioned on example difficulty or the available computation budget.

(a) Training with noisy continuous gating.
(b) Inference with conditional execution.
Figure 1:

Our approach for adapting models for conditional computation: During training, sub-network outputs are gated by noised continuous outputs from control networks trained end-to-end with the model. During inference, sub-networks are conditionally executed depending on discrete outputs from control networks. Outputs are optionally short-circuited with residual connections.

Training a model with discrete intermediate outputs requires back-propagating through discrete random variables, which hinders model trainability. While several approaches have been suggested to alleviate this problem, including the use of gradient estimators

(Bengio et al., 2013; Jang et al., 2016)

and reinforcement learning

(Bengio et al., 2015), training neural networks with conditionally executable sub-networks is still rare. As a consequence, most work involving conditional computation is restricted to very specific applications.

In this work we present a general framework to adapt models for conditional computation and control the amount of computation used at inference. We make three major contributions: (i) We provide a simple approach to adapt models for conditional computation by adding control networks trained end-to-end with the model. These control networks produce continuous outputs during training which allows for back-propagation. During inference these networks act like binary stochastic units that control the execution of their respective sub-networks. (ii) We propose a multi-task training approach to train a single model at different computation budgets. This allows controlling the amount of compute expended by the model on any example at inference. (iii) We adapt the Transformer architecture for conditional computation and demonstrate the efficacy of our approach on two large scale sequence modeling tasks: WMT’14 En-Fr Translation and Representation learning with BERT.

2 Method

2.1 Adapting models for conditional computation

We adapt neural sequence models for conditional computation by allowing the model to selectively execute certain sub-networks of the computation graph, conditioned on the outputs of small control networks learned jointly with the model.

Let the input to a layer be a sequence of length . Let the output of this operation be , given by . In the presence of residual connections this can be re-written as .

We now introduce a control network to control the execution of layer . While it is possible to train neural networks with back-propagation in the presence of discrete outputs we preclude this problem by training in expectation. At training time, instead of sampling a discrete decision from , we compute the expected, . We define the operation of by , where can be any function mapping and

is the logistic sigmoid function. As a result

for any during training.

The gated version of layer can be written as:


At inference, this operation simplifies to


This, however, introduces a discrepancy between training and inference, with the former using soft decisions and the latter selectively executing layers based on discrete decisions. To bridge between these two modes of operation we encourage

to become more discrete as training progresses. We follow the approach used in previous work for training binary stochastic neurons for monotonic attention mechanisms

(Raffel et al., 2017; Chiu & Raffel, 2017; Arivazhagan et al., 2019). We add zero-mean Gaussian noise to the output of during training, as shown below:


where increases linearly during the training process. This increasing schedule carries the pre-activation towards the saturation range of gating function and in return, forcing the output of closer to the boundaries of . While could have any possible parameterization, for the purpose of this work we restrict it to single hidden layer feed-forward networks for simplicity.


2.2 Modulating the Inference Budget

In the absence of any other training signal, we would expect the training loss to pull the model towards using all (or most) of its computation in order to maximize performance. To control the amount of computation utilized by the model we impose a computational budget loss in addition to the training objective.

For any layer , the expected cost of computation utilized by the model on any token, , can be given by , where is the cost of applying layer to one token. For the purpose of this work we define the cost of a layer to represent its computational cost in terms of Flops.

Given a batch of sequences with time-steps, let be the -th token of the -th sequence. We define the computational budget as a fraction, , of the maximum computation available for the batch. Then the computational budget on the given batch of tokens is:


The expected computation used by the proposed conditional computation model is determined by the activations of the control networks on individual tokens of the batch:


Then we define the computational budget loss on the given batch of tokens to be:


We impose a constraint on the total computation used for a batch, instead of the compute used for a single sequence or token. This looser constraint allows the model to allocate more computation for ‘difficult’ examples by using less compute for ‘easy’ examples. Empirically we find that using the batch-level constraint performs better, especially at lower computation budgets.

Training a conditional computation model with the above loss allows operating that model at a single computation budget, . Given a set of desired computational budgets, , that we want the model to operate at, we utilize a simple multi-task training approach. We define a set of control symbols, , which can be fed as additional inputs to the model. We associate each budget, , with a control input . Given a batch of training sequences, , we (uniform) randomly assign each sequence to a budget in the set . Let the batch of sequences assigned to budget be . The corresponding control symbol is then fed to the model as an additional input when training on sequences in . By associating specific control inputs with different computational budgets, we train a single model to operate at specific levels of computation controlled by these external inputs. This is similar to approaches used for training multilingual Machine Translation models (Johnson et al., 2017) and other multi-task models.

The total budget loss function in this multi-task training setup is then:


In certain cases it might be desirable to control the amount of computation spent on different sub-networks of the model independently. For eg., in auto-regressive seq2seq models there is an inherent difference in the mode of operation of encoder and decoder sub-networks. To control the budgets for sub-networks independently, our multi-budget formulation can be extended to allow . Each symbol, then maps to a tuple of budgets, , specifying the desired budget for each sub-network.

The generalized budget loss function can be described as:


where is the -th budget loss for sub-network .

Given a model adapted for conditional computation following the approaches described above, controlling the inference time computation just requires feeding the right control input, , corresponding to the desired budget .

3 Conditional Computation Transformer

We now apply our approach to the Transformer architecture (Vaswani et al., 2017). We follow the new transformer layout where layer normalization (LN) is applied to the input instead of the output.111Please refer to the ‘nda’ layout as implemented in the Tensor2Tensor library (Vaswani et al., 2018).

3.1 Conditional Attention Layer

Figure 2: Conditional Computation Attention Layer.

Given a vector

and a sequence of vectors, , the transformer attention layer can be described by the following sequence of operations. The set of vectors to be attended are first projected to keys () and corresponding values ()


The projected input queries are then used to attend the keys to summarize the set to be attended.


We introduce two control networks to control the execution of the operations defined by Equations 10 and 11 respectively. The first control network, , controls the execution of the key-value projections. The second network, , controls the execution of the query projection, multi-headed attention and the attention post-projection. We also introduce additional normalization layers to stabilize training in the presence of discrete operations. During training we implement these changes as:




The above modifications are applied to all self-attention and cross-attention layers.

3.2 Conditional Feed-forward Layer

Figure 3: Conditional Computation Feedforward Layer.

Given a vector, , the transformer feed-forward layer can be described by


where and . Then the output of the layer incorporating residual connections is given by .

Adding conditional execution for this layer, our output can be written as .

While it’s straightforward to add conditional execution for the entire feed-forward layer, we can optionally decompose the large feed-forward layer into independently controlled smaller layers to provide more granular control over feed-forward layer capacity222This decomposition makes our feed-forward layer similar to the Sparsely Gated Mixture-of-Experts layer (Shazeer et al., 2017). However, in our approach the number of experts applied per input are a function of the input.


where , and maps the input to .

3.3 Feeding Control Input

For training with multiple computation budgets, as described in Section 2.2, we need to feed an additional control input, , with every input sequence. Given an input sequence, , input to transformer layers is a sequence of embeddings corresponding to each input symbol summed with the corresponding position embedding. In addition to the position embeddings we learn an additional input embedding of control symbols, . The embedding of the symbol is then added to the embedding of each token before feeding into the model.

Figure 4:

Comparing the performance of CCT (red) at different encoder-decoder computation budgets against Transformer baselines (blue). x-axis corresponds to the average encoder-decoder per-token Flops (in millions). The transformer network size is denoted next to each corresponding data-point using the format (hidden layer size, number of layers). Note: We do not compare the computation required for embedding lookup and softmax operations.

Figure 5: Comparing the performance of CCT at different decoder computation budgets against Transformer baselines, when allowed to use full encoder computation. x-axis corresponds to the decoder per-token Flops (in millions). Blue dots denote the quality of individual transformer baselines. The decoder size is denoted next to each corresponding data-point using the format (hidden layer size, number of layers). Note: We do not compare the computation required for embedding lookup and softmax operations.

4 Experiments on Machine Translation

Most machine translation models fall within the sequence-to-sequence paradigm (Sutskever et al., 2014; Bahdanau et al., 2015), with an encoder that learns representations of the source sequence and a decoder to generate the target sequence, trained on the cross-entropy loss . Since there is a difference in the inference-time operation of the encoder and decoder (the encoder processes all source tokens simultaneously, while the decoder processes each token one at a time), we allow controlling their respective computation budgets separately. To elaborate, we permit using a set of computation budgets . Here the first budget of every tuple, , corresponds to the desired encoder budget while the second budget, , corresponds to the desired decoder budget. The control symbol, , is fed as an embedding added to every source and target token. We train the model end-to-end on

We now evaluate our approach on the WMT’14 English-French translation We use newstest13 for validation and newstest14 for test. BLEU scores are computed with tokenized true-cased output and references with Moses

We train a Transformer Big (Vaswani et al., 2017) model as our baseline. For smaller budget baselines, we reduce the capacity of our Transformer following two approaches: (i) Reducing the model depth by reducing the number of layers and (ii) Reducing the model width by reducing the hidden dimension of the feed-forward layers. We compare these baselines against a single CCT model operating at different computation budgets, with a maximum capacity equivalent to Transformer Big. Since decoder computation is typically the bottleneck for Transformer inference, we also train additional baselines where we only reduce the capacity of the decoder, while using a full Transformer Big encoder. These baselines are compared against the same CCT model from the previous comparison, but use full encoder computation while varying the decoder budget.

We use a Transformer learning rate schedule (Vaswani et al., 2017) of (3.0, 40K)555(3.0, 40K) schedule is the shorthand for a learning rate of 3.0, with 40K warm-up steps for the schedule, which is decayed with the inverse square root of the number of training steps after warm-up.

and all dropout probabilities are set to 0.1. For all our models, we use a shared vocabulary Sentence Piece Model

(Kudo & Richardson, 2018) for sub-word tokenization, with a vocabulary size of 32000 tokens. We train each model for 300k steps with batches of 250k tokens. The CCT is trained with the same set of hyper-parameters. In addition to the above hyper-parameters, we set and use a set of computation budgets, , set to . This results in control tasks, one corresponding to each tuple from the above cross-product666We empirically find that allowing half of the control tasks to use their entire computational budgets strikes a good balance between properly training all parameters and learning to operate at reduced budgets.. The noise factor, in Equation 3, is linearly increased during the training process, from at the first step to at 300k steps. For these experiments, we break the feed-forward layer into smaller layers i.e. in Equation 15. All our models are trained on 32 Cloud TPUv3 chips and evaluated at 300k steps.

Figure 4 compares a single CCT against individual Transformer models with different amounts of encoder and decoder capacity. Our results suggest that CCT is competitive with Transformer Big even when operating at half its computation budget. At smaller computation budgets CCT improves over smaller baseline Transformer models by up to 1-1.5 Bleu. Figure 5 depicts the results of our second experiment, comparing CCT against baseline Transformers when using a large encoder (equivalent to Transformer Big) and controlling for decoder computation. We observe a similar trend, with CCT being competitive with Transformers at higher computation budgets, while improving over baselines by almost 1 Bleu at reduced budgets.

5 Experiments with BERT

Figure 6: Comparing the performance of CCT at different encoder computation budgets against Bert baselines of different sizes. The transformer network size is denoted next to each corresponding data-point using the format (model size, number of layers).
Figure 7: Comparing the performance of CCT at different encoder computation budgets against Bert baselines of different sizes. The transformer network size is denoted next to each corresponding data-point using the format (model size, number of layers).

BERT (Devlin et al., 2018) uses a Masked Language Modeling objective, , in order to learn token-level representations of text using a Transformer architecture. Following the pre-training stage, the model is fine-tuned on individual tasks by training on (smaller) task-specific datasets and objectives. To control the amount of computation used by BERT for generating representations of text, we replace the Transformer model in the original BERT implementation with our CCT from Section 3. We train this model following the multi-computation budget recipe described in Section 2.2. The objective function used for training this model is . When fine-tuning on a downstream task we use a different scaled to the new objective.

Figure 8: Comparing the performance of CCT at different encoder computation budgets against Bert baselines of different sizes. The transformer network size is denoted next to each corresponding data-point using the format (model size, number of layers).
Figure 9: Comparing the performance of CCT at different encoder computation budgets against Bert baselines of different sizes. The transformer network size is denoted next to each corresponding data-point using the format (model size, number of layers).

We train a BERT-Large (Devlin et al., 2018) model as our baseline. For smaller budget baselines, we reduce the capacity of BERT following two approaches: (i) Reducing the model depth by reducing the number of layers and (ii) Reducing the width by reducing the model dimension and hidden dimension of the feed-forward layers, maintaining a ratio of 4 between the model dimension and feed-forward hidden dimension. We compare these baselines against a single CCT model operating at different computation budgets, with maximum capacity equivalent to BERT-Large.

We use the same pre-training process used in Devlin et al. (2018), except for one difference: we train on sequences of length 512 for 1M steps with a batch size of 1024, instead of training on shorter sequences for 900k steps and fine-tuning with longer sequences. The CCT is trained with the same set of hyper-parameters. In addition to the above hyper-parameters, we set and use a set of computation budgets . is linearly increased during the training process, going from at the first step to at 300k steps and capping at that value. For these experiments, we break the feed-forward layer into smaller layers i.e. in Equation 15. All our models are trained on 64 Cloud TPUv3 chips.

Figure 10: Comparing the fraction of active (equivalent to the fraction of times feed-forward sub-network was active) for different layers in the encoder at different budgets.
Figure 11: Comparing the performance of CCT when using linearly increasing Gaussian noise ( at the first step, at 300k steps) against using noisy discrete decisions from the beginning of training ( for the entire process).
Figure 12: Comparing the performance of CCT when using different splits for the feed-forward layer ( in Equation 15).

When fine-tuning BERT baselines on downstream tasks, we search over the same grid used in Devlin et al. (2018). We re-use the same fine-tuning parameters as BERT-Large for fine-tuning CCT. The value of used for fine-tuning CCT is different from pre-training, to scale to the downstream task loss. We report validation performance on 4 GLUE benchmark (Wang et al., 2019) tasks over 3 runs: MNLI, SST-2, Squad and CoLA. A comparison of CCT with comparable baselines on MNLI, SST-2, CoLA and Squad tasks is depicted in Figures 6, 7, 8 and 9 respectively. On MNLI and SST-2 we see a trend similar to translation, and the performance of CCT is close to the performance of BERT-Large at the highest computation setting, while improving significantly over baselines at smaller computation budgets. On CoLA we see the reverse trend: CCT improves by a significant margin at the highest computation setting while losing to baselines at smaller computation budgets.

The performance of CCT on Squad is worse than baselines at all computation budgets. It is worth noting that Squad is the only benchmark task that uses token level outputs from the pre-trained representations, while all other tasks act on a pooled representation of the entire sequence. The weak performance on Squad suggests that token-level representations extracted from CCT-BERT, without pooling, might not perform as well as those from a static architecture that uses the same set of operations for every token.

6 Additional Insights

In order to shed more light on the training dynamics and factors effecting the quality and performance of the proposed CCT approach, we conducted further analysis probing various aspects.

Layer usage under a limited budget

To understand the distribution of computation across layers at different budgets, we look at the fraction of times is active for different layers for our machine translation model. Figure 10 depicts the fraction of feed-forward operations applied per-token in different encoder layers at different budgets. Our results suggest that encoders utilize more feed-forward computation at higher levels of the stack. We observe the same trend for the decoder. From analyzing the outputs of control networks for different sub-networks, we also find that lower layers tend to be either active or inactive for most tokens for a given inference budget, while upper layers have different outputs for different tokens. This behaviour could have two possible explanations: (i) Lower level layers perform more general operations applicable to all tokens, while higher level layers apply more input-specific operations, or (ii) Control networks of layers lower in the stack might be difficult to train resulting in learning the trivial solution i.e. ‘on’ for high computation budgets and ‘off’ for smaller budgets.

Importance of the Noise Schedule

We next attempt to understand the role of gradually increasing noise variance when training with discrete decisions. We compare a CCT model trained with linearly increasing

(Equation 3) against one where it is set to its highest value (here ) from the beginning of training. From Figure 11, we notice that the two models are within 0.3 BLEU of each other at , with the difference increasing to 0.5 for and . At , the performance of the discrete model deteriorates much faster with the difference growing to more than 1 BLEU. This suggests that the quality of control network training has a larger effect on model performance at smaller computation budgets.

Importance of Parallel Sub-Networks

Deciding how to divide the model’s computation graph into sub-networks controlled by different control networks can have a significant impact on model quality. For example, for most of our experiments we split each feed-forward layer into 4 smaller, independently controlled feed-forward layers ( in Equation 15). We compare the effect of splitting the feed-forward sub-network at different granularities (setting ). Our results from Figure 12 suggest that having more control on how network computation is utilized, by having control networks for smaller sub-networks, significantly impacts model quality especially at lower computation budgets.

Tricks of the trade

We list some tricks and observations that were empirically found to be useful during the course of this work.

(i) Careful normalization was critical for stable training and good model quality. This includes additional layer normalization applied to every gated sub-network output and using separate layer normalization for the input of every independently gated feed-forward sub-network (i.e. separate layer-normalization for each of the feed-forward layers).

(ii) The range of budgets ( from Section 2.2) used during training affected model quality. We observed significant quality deterioration when one of the values in was too low (). For BERT experiments setting resulted in worse performance on the MLM loss at a budget of while performance at other budgets was not severely impacted. We suspect this is caused by the special role of the MASK token during pre-training.

(iii) Even with all the stabilization approaches, approximately of our runs deteriorated in performance on training further beyond convergence.

(iv) Varying control network capacity did not have a huge effect on model quality within the range of hidden dimensions evaluated by us ().

(v) The proposed computational budget loss (Equation 7) is two-sided and also penalizes the model for under-utilization. While this is counter-intuitive, in practice we found that not penalizing the model for using less computation resulted in under-training certain sub-networks, resulting in sub-optimal downstream quality.

7 Related Work

Activating a sub-network depending on the particular input example has been the focus of conditional computation approaches (Bengio et al., 2013; Davis & Arel, 2014). Following this line of research, Cho & Bengio (2014) studied increasing the capacity of neural networks without increasing required computation by exploiting the bit patterns associated with hidden units. As a majority of conditional computation approaches make use of stochastic binary units that pose trainability challenges, Bengio et al. (2015) cast the problem as a reinforcement learning problem and proposed a policy that maps the activations of layers to Bernoulli masks. Graves (2016)

proposed the first application of conditional computation to neural sequence models, called Adaptive Computation Time (ACT), where a recurrent neural network is trained to learn the lag between reading an input and generating the output.

The recently introduced Transformer architecture (Vaswani et al., 2017) has allowed researchers to train neural networks with billions of parameters, reaffirming the need for more efficient and adaptive models. Universal transformer (Dehghani et al., 2018) addressed the parameter inefficiency problem of Transformers by tying the weights of consecutive layers and utilizing ACT to decide the halting of such recurrence. The recently proposed depth-adaptive Transformer (DAT) (Elbayad et al., 2019)

is perhaps the most similar to our approach. In DAT, decoder layers are equipped with halting classifiers that decide to exit and predict the output or continue processing, extending the ACT framework. DAT requires explicit supervision from oracles (or implicit supervision from multiple softmax computations) to train halting classifiers, restricting their approach to specific applications (like decoders in sequence to sequence models). Our approach trains control networks end-to-end with the rest of the model, allowing us to extend it to a wider range of sub-networks not directly connected with the final classifier (for example, key-value projections in self-attention layers or encoder layers in seq2seq models).

Fan et al. (2019) propose another approach to control the inference time computation budget. Their method applies structured pruning (in the form of layer dropout) to yield varying number of shallower networks that can be used at inference time. Our approach however, results in a single network that can simultaneously adapt to the difficulty of the input example and the computation budget that is available at hand during inference time.

In addition, parallels can be made with approaches utilizing mixture-of-experts (MoE) (Masoudnia & Ebrahimpour, 2014), where different examples are routed to different experts in order to maximize the output diversity (Shen et al., 2019) or device utilization (Shazeer et al., 2017).

8 Conclusion

In this work we present a general framework to adapt neural sequence models (Transformer) for conditional computation and control the amount of computation used at inference. Our proposed approach injects simple control networks into the core computation graph, in order to modulate the information flow through the network. The incorporated control networks are trained end-to-end simultaneously with the model, simulating the binary decisions to be made at inference time. We also introduce a novel multi-task objective that allows the network to operate at multiple computation budgets at inference time efficiently, addressing the need for on-demand computation requirements of large networks. Experiments on large scale machine translation (WMT’14 English-French) and unsupervised representation learning (BERT) demonstrate that our proposed approach is competitive with baseline Transformer models at the same computation budget, and significantly better at smaller computational budgets compared to computationally equivalent baselines.

We believe more analysis is needed to understand the decisions made by control networks and the behavior of the final sub-network on inputs with different levels of complexity.


We would like to thank the Google Translate and Tensorflow Lingvo teams for foundational contributions to the project.