Zero Time Waste: Recycling Predictions in Early Exit Neural Networks

06/09/2021 ∙ by Maciej Wołczyk, et al. ∙ Jagiellonian University 0

The problem of reducing processing time of large deep learning models is a fundamental challenge in many real-world applications. Early exit methods strive towards this goal by attaching additional Internal Classifiers (ICs) to intermediate layers of a neural network. ICs can quickly return predictions for easy examples and, as a result, reduce the average inference time of the whole model. However, if a particular IC does not decide to return an answer early, its predictions are discarded, with its computations effectively being wasted. To solve this issue, we introduce Zero Time Waste (ZTW), a novel approach in which each IC reuses predictions returned by its predecessors by (1) adding direct connections between ICs and (2) combining previous outputs in an ensemble-like manner. We conduct extensive experiments across various datasets and architectures to demonstrate that ZTW achieves a significantly better accuracy vs. inference time trade-off than other recently proposed early exit methods.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

(a) Comparison of the proposed ZTW (bottom) with a conventional early-exit model, SDN (top).
(b) Detailed scheme of the proposed ZTW model architecture.
Figure 1: (a) In both approaches, internal classifiers (s) attached to the intermediate hidden layers of the base network allow us to return predictions quickly for examples that are easy to process. While SDN discards predictions of uncertain s (e.g. below a threshold of 75%), ZTW reuses computations from all previous s, which prevents information loss and waste of computational resources. (b) Backbone network lends its hidden layer activations to s, which share inferred information using cascade connections (red horizontal arrows in the middle row) and give predictions . The inferred predictions are combined using ensembling (bottom row) giving .

Deep learning models achieve tremendous successes across a multitude of tasks, yet their training and inference often yield high computational costs and long processing times [13, 22].

For some applications, however, efficiency remains a critical challenge, e.g.

to deploy a reinforcement learning (RL) system in production the policy inference must be done in real-time 

[8], while the robot performances suffer from the delay between measuring a system state and acting upon it [32]. Similarly, long inference latency in autonomous cars could impact its ability to control the speed [15] and lead to accidents [12, 18].

Typical approaches to reducing the processing complexity of neural networks in latency-critical applications include compressing the model [24, 25, 40] or approximating its responses [21]. For instance, [25] proposes to compress an RL model by policy pruning, while [21]

approximates the responses of LSTM-based modules in self-driving cars to accelerate their inference time. While those methods improve processing efficiency, they still require samples to pass through the entire model. In contrast, biological neural networks leverage simple heuristics to speed up decision making,

e.g. by shortening the processing path even in case of complex tasks [1, 11, 19].

This observation led a way to the inception of the so-called early exit methods, such as Shallow-Deep Networks (SDN) [20] and Patience-based Early Exit (PBEE) [41], that attach simple classification heads, called internal classifiers (s), to selected hidden layers of neural models to shorten processing time. If the prediction confidence of a given is sufficiently high, the response is returned, otherwise, the example is passed to the subsequent classifier. Although these models achieve promising results, they discard the response returned by early s in the evaluation of the next , disregarding potentially valuable information e.g. decision confidence, and wasting computational effort already incurred.

Motivated by the above observation, we postulate to look at the problem of neural model processing efficiency from the information recycling perspective and introduce a new family of zero waste models. More specifically, we investigate how information available at different layers of neural models can contribute to the decision process of the entire model. To that end, we propose Zero Time Waste (ZTW), a method for an intelligent aggregation of the information from previous s. A high-level view of our model is given in Figure 1. Our approach relies on combining ideas from networks with skip connections [37]

, gradient boosting 

[4], and ensemble learning [10, 23]. Skip connections between subsequent s (which we call cascade connections) allow us to explicitly pass the information contained within low-level features to the deeper classifier, which forms a cascading structure of s. In consequence, each improves on the prediction of previous s, as in gradient boosting, instead of generating them from scratch. To give the opportunity for every to explicitly reuse predictions of all previous s, we additionally build an ensemble of shallow s.

We evaluate our approach on standard classification benchmarks, such as CIFAR-100 and Tiny ImageNet, as well as on the more latency-critical applications, such as reinforcement-learned models for interacting with sequential environments. To the best of our knowledge, we are the first to show that early exit methods can be used for cutting computational waste in a reinforcement learning setting.

Results show that ZTW is able to save much more computation while preserving accuracy than current state-of-the-art early exit methods. In order to better understand where the improvements come from, we introduce Hindsight Improvability, a metric for measuring how efficiently the model reuses information from the past. We provide ablation study and additional analysis of the proposed method in the Appendix.

To summarize, the contributions of our work are the following:

  • We introduce a family of zero waste models that quantify neural network network efficiency with the Hindisight Improvability metrics.

  • We propose a instance of zero waste models dubbed Zero Time Waste (ZTW) method which uses cascade connections and ensembling to reuse the responses of previous ICs for the final decision.

  • We show how the state-of-the-art performance of ZTW in the supervised learning scenario generalizes to reinforcement learning.

2 Related Work

The drive towards reducing computational waste in deep learning literature has so far focused on reducing the inference time. Numerous approaches for accelerating deep learning models focus on building more efficient architectures [17], reducing the number of parameters [14] or distilling knowledge to smaller networks [16]. Thus, they decrease inference time by reducing the overall complexity of the model instead of using the conditional computation framework of adapting computational effort to each example. As such we find them orthogonal to the main ideas of our work, e.g. we show that applying our method to architectures designed for efficiency, such as MobileNet [17], leads to even further acceleration. Hence, we focus here on methods that adaptively reduce the inference speed for each example.

Conditional Computation

Conditional computation was first proposed for deep neural networks in [3] and [6] and numerous approaches to the early exit problem appeared since then. In BranchyNet [36]

loss function consisting of a weighted sum of individual head losses is utilized in training, and entropy of the head prediction is used for the early exit criterion. Berestizshevsky & Guy

[5] propose to use confidence (maximum of the softmax output) instead. Other approaches to conditional computation include framing it as a reinforcement learning problem [2], skipping intermediate layers [9, 37] or channels [38]. An overview of early exit methods is available in [30].

Shallow-Deep Networks (SDN) [20] is a conceptually simple yet effective method, where the comparison of confidence with a fixed threshold is used as the exit criterion. The authors attach internal classifiers to layers selected based on the number of compute operations needed to reach them. The answer of each head is independent of the answers of the previous heads, although authors analyze the measure of disagreement between the predictions of final and intermediate heads.

Zhou et al. [41] propose Patience-based Early Exit (PBEE) method, which terminates inference after consecutive unchanged answers, and show that it outperforms SDN on a range of NLP tasks. The idea of checking for agreement in preceding s is connected to our approach of reusing information from the past. However, we find that applying PBEE in our setting does not always work better than SDN. Additionally, in the experiments from the original work, PBEE was trained simultaneously along with the base network, thus making it impossible to preserve the original pre-trained model.

Ensembles

Ensembling is typically used to improve the accuracy of machine learning models 

[7]. Lakshminarayanan et al. [23] showed that it also greatly improves calibration of deep neural networks. There were several attempts to create an ensemble from different layers of a network. Scardapane et al. [29] adaptively exploit outputs of all internal classifiers, albeit not in a conditional computation context. Phuong & Lampert [27] used averaged answers of heads up to the current head for anytime-prediction, where the computational budget is unknown. Besides the method being much more basic, their setup is notably different from ours, as it assumes the same computational budget for all samples no matter how difficult the example is. Finally, none of these methods are designed to work with pre-trained models.

3 Zero Time Waste

Our goal is to reduce computational costs of neural networks by minimizing redundant operations and information loss. To achieve it, we use the conditional computation setting, in which we dynamically select the route of an input example in a neural network. By controlling the computational route, we can decide how the information is stored and utilized within the model for each particular example. Intuitively, difficult examples require more resources to process, but using the same amount of compute for the easy examples is wasteful. Below we describe our Zero Time Waste method in detail.

In order to adapt already trained models to conditional computation setting, we attach and train early exit classifier heads on top of several layers, without changing the parameters of the base network. During inference, the whole model exits through one of them when the response is likely enough, thus saving computational resources.

Formally, we consider a multi-class classification problem, where denotes an input example and is its target class. Let

be a pre-trained neural network with logit output designed for solving the above classification task. The weights

will not be modified.

Model overview

Following typical early exit frameworks, we add shallow Internal Classifiers, , on intermediate layers of . Namely, let , for , be the -th network returning logits, which is attached to hidden layer of the base network . The index is independent of layer numbering. In general, is lower than the overall number of hidden layers since we do not add early exits after every layer (see more details in Appendix A.1).

Although using s to return an answer early can reduce overall computation time [20], in a standard setting each makes its decision independently, ignoring the responses returned by previous s. As we show in Section 4.2, early layers often give correct answers for examples that are misclassified by later classifiers, and hence discarding their information leads to waste and performance drops. To address this issue, we need mechanisms that collect the information from the first s to inform the decision of . For this purpose, we introduce two complementary techniques: cascade connections and ensembling, and show how they help reduce information waste and, in turn, accelerate the model.

Cascade connections directly transfer the already inferred information between consecutive s instead of re-computing it again. Thus, they improve the performance of initial s that lack enough predictive power to classify correctly based on low-level features. Ensembling of individual s improves performance as the number of members increases, thus showing greatest improvements in the deeper part of the network. This is visualized in Figure 1 where cascade connections are used first to pass already inferred information to later s, while ensembling is utilized to conclude the prediction. The details on these two techniques are presented in the following paragraphs.

Cascade connections

Inspired by the gradient boosting, we allow each to improve on the predictions of previous s, instead of inferring them from scratch. The idea of cascade connections is implemented by adding skip connections that combine the output of the base model hidden layer with the logits of and pass it to . The prediction is realized by the softmax function applied to (the -th network):

(1)

where denotes the composition of functions. Formally, , where are trainable parameters of , but we drop these parameters in notation for brevity. uses only the information coming from the layer which does not need to be the first hidden layer of . Figure 1 shows the skip connections as red horizontal arrows.

Each is trained in parallel (with respect to ) to optimize the prediction of all output classes using an appropriate loss function , e.g. cross-entropy for classification. However, during the backward step it is crucial to stop the gradient of a loss function from passing to the previous classifier. Allowing the gradients of loss to affect for leads to a significant performance degradation of earlier layers due to increased focus on the features important for , as we show in Appendix C.3.

Ensembling

Ensembling in machine learning models reliably increases the performance of a model while improving robustness and uncertainty estimation 

[10, 23]. The main drawback of this approach is its wastefulness, as it requires to train multiple models and use them to process the same examples. However, in our setup we can adopt this idea to combine predictions which were already pre-computed in previous s, with near-zero additional computational cost.

To obtain a reliable zero-waste system, we build ensembles that combine outputs from groups of s to provide the final answer of the -th classifier. Since the classifiers we are using vary significantly in predictive strength (later s achieve better performance than early

s) and their predictions are correlated, the standard approach to deep model ensembling does not work in our case. Thus, we introduce weighted geometric mean with class balancing, which allows us to reliably find a combination of pre-computed responses that maximizes the expected result.

Let be the outputs of consecutive predictions (after cascade connections stage) for a given (Figure 1

). We define the probability of the

-th class in the -th ensemble to be:

(2)

where and , for , are trainable parameters, and is a normalization factor, such that . Observe that can be interpreted as our prior belief in predictions of , i.e. large weight indicates less confidence in the predictions of . On the other hand, represents the prior of -th class for . The indices in and are needed as they are trained independently for each subset . Although there are viable potential approaches to setting these parameters by hand, we verified that optimizing them directly by minimizing the cross-entropy loss on the training dataset works best.

Out of additive and geometric ensemble settings, we found the latter to be preferable. In this formulation, a low class confidence of a single would significantly reduce the probability of that class in the whole ensemble. In consequence, in order for the confidence of the given class to be high, we require all s to be confident in that class. Thus, in geometric ensembling, an incorrect although confident answer has less chance of ending calculations prematurely. In the additive setting, the negative impact of a single confident but incorrect is much higher, as we show in Appendix C.2. Hence our choice of geometric ensembling.

Direct calculation of the product in (2) might lead to numerical instabilities whenever the probabilities are close to zero. To avoid this problem we note that

and that log-probabilities can be obtained by running the numerically stable log softmax function on the logits of the classifier.

1 Input: pre-trained model , cross-entropy loss function , training set .;
2 Initialize shallow models at selected layers .;
3 For  do in parallel Cascade connection s
4       Set according to (1).;
5       minimize wrt. by gradient descent;
6      
7For  do Geometric Ensembling
8       Initialize and define according to (2).;
9       minimize wrt. by gradient descent;
10      
Algorithm 1 Zero Time Waste

Both cascade connections and ensembling have different impact on the model. Cascade connections primarily boost the accuracy of early s. Ensembling, on the other hand, improves primarily the performance of later s, which combine the information from many previous classifiers.

This is not surprising, given that the power of the ensemble increases with the number of members, provided they are at least weak in the sense of boosting theory [31]. As such, the two techniques introduced above are complementary. The whole training procedure is presented in Algorithm 1.

Conditional inference

Once a ZTW model is trained, the following question appears: how to use the constructed system at test time? More precisely, we need to dynamically find the shortest processing path for a given input example. For this purpose, we use one of the standard confidence scores given by the probability of the most confident class. If the -th classifier is confident enough about its prediction, i.e. if

(3)

where is the class index, then we terminate the computation and return the response given by this . If this condition is not satisfied, we continue processing and go to the next .

Threshold in (3) is a manually selected value, which controls the acceleration-performance trade-off of the model. A lower threshold leads to a significant speed-up at the cost of a possible drop in accuracy. Observe that for , we recover the original model , since none of the s is confident enough to answer earlier. In practice, to select its appropriate value, we advise using a held-out set to evaluate a range of possible values of .

4 Experiments

In this section we examine the performance of Zero Time Waste and analyze its impact on waste reduction in comparison to two recently proposed early-exit methods: (1) Shallow-Deep Networks (SDN) [20] and (2) Patience-Based Early Exit (PBEE) [41]. In contrast to SDN and PBEE, which train s independently, ZTW reuses the information from past classifiers to improve performance. SDN and ZTW use maximum class probability as the confidence estimator, while PBEE checks the number of classifiers in sequence that gave the same prediction. For example, for PBEE means that if the answer of the current is the same as the answers of the preceding s, we can return that answer, otherwise we continue the computation.

In our experiments, we measure how much computation we can save by re-using precomputed

s responses while keeping good performance, hence obeying the zero waste paradigm. To do this we measure the inference cost as the average number of floating-point operations required to perform the forward pass for a single sample. For the evaluation in supervised learning, we use three datasets: CIFAR-10, CIFAR-100, and Tiny ImageNet, and four commonly used architectures: ResNet-56 

[13], MobileNet [17], WideResNet [39], and VGG-16BN [35] as base networks. We check all combinations of methods, datasets, and architectures, giving models in total, and we additionally evaluate a single architecture on the ImageNet dataset to show that the approach is scalable. Additionally, we examine how Zero Time Waste performs at reducing waste in a reinforcement learning setting of Atari 2600 environments. To the best of our knowledge, we are the first to apply early exit methods to reinforcement learning. Appendix A.1

contains the details about the network architecture, hyperparameters, and training process. We also provide the source code:

https://github.com/gmum/Zero-Time-Waste.

ResNet-56
Data Algo   25% 50% 75% 100% Max
C10 () SDN
PBEE
ZTW
C100 () SDN
PBEE
ZTW
T-IM () SDN
PBEE
ZTW
MobileNet
Data Algo   25% 50% 75% 100% Max
C10 () SDN
PBEE
ZTW
C100 () SDN
PBEE
ZTW
T-IM () SDN
PBEE
ZTW
WideResNet
Data Algo   25% 50% 75% 100% Max
C10 () SDN
PBEE
ZTW
C100 () SDN
PBEE
ZTW
T-IM () SDN
PBEE
ZTW
VGG
Data Algo   25% 50% 75% 100% Max
C10 () SDN
PBEE
ZTW
C100 () SDN
PBEE
ZTW
T-IM () SDN
PBEE
ZTW
Table 1:

Results on four different architectures and three datasets: Cifar-10, Cifar-100 and Tiny ImageNet. Accuracy (in percentages) for time budgets: 25%, 50%, 75%, 100% of the base network, and Max without any time limits. The first column shows the accuracy of the base network. The results represent a mean of three runs and standard deviations are provided in Appendix

B. We bold results within two standard deviations of the best model.

4.1 Time Savings in Supervised Learning

We check what percentage of computation of the base network can be saved by reusing the information from previous layers in a supervised learning setting. To do this, we evaluate how each method behaves at a particular fraction of the computational power (measured in floating point operations) of the base network. We select the highest threshold such that the average inference time is smaller than, for example, of the original time. Then we calculate accuracy for that threshold. Table 1 contains summary of this analysis, averaged over three seeds, with further details (plots for all thresholds, standard deviations) shown in Appendix B. There we also provide an extended description of the ImageNet experiment, results of which are summarized in Table 2.

Looking at the results, we highlight the fact that methods which do not reuse information between s do not always achieve the goal of reducing waste. For example, SDN and PBEE cannot maintain the accuracy of the base network for MobileNet on Tiny ImageNet when using the same computational power, scoring respectively and percentage points lower than the baseline. Adding s to the network and then discarding their predictions when they are not confident enough to return the final answer introduces computational overhead without any gains. By reusing the information from previous s ZTW overcomes this issue and maintains the accuracy of the base network for all considered settings. In particular cases, such as ResNet-56 on Tiny ImageNet or MobileNet on Cifar-100, Zero Time Waste even significantly outperforms the core network.

Algo 25% 50% 75% 100%
SDN 33.8 53.8 69.7 75.8
PBEE 28.3 28.3 62.9 73.3
ZTW 34.9 54.9 70.6 76.3
Table 2: ImageNet results (accuracy in percentage points) show that zero-waste approach scales up to larger datasets.

Similar observation can be made for other inference time limits as well. ZTW consistently maintains high accuracy using less computational resources than the other approaches, for all combinations of datasets and architectures. Although PBEE reuses information from previous layers to decide whether to stop computation or not, this is not sufficient to reduce the waste in the network. PBEE outperforms SDN when given higher inference time limits, it often fails for smaller limits (). We hypothesize, that this is result of the fact that PBEE has smaller flexibility with respect to . While for SDN and ZTW values of are continuous, for PBEE they represent a discrete number of s that must sequentially agree before returning an answer.

Given the performance of ZTW, the results show that paying attention to the minimization of computational waste leads to tangible, practical improvements of the inference time of the network. Therefore, we devote next section to explaining where the empirical gains come from and how to measure information loss in the models.

4.2 Information Loss in Early Exit Models

Figure 2: Hindsight Improvability. For each (horizontal axis) we look at examples it misclassified and we check how many of them were classified correctly by any of the previous s. The lower the number, the better the is at reusing previous information.

Since s in a given model are heavily correlated, it is not immediately obvious why reusing past predictions should improve performance. Later s operate on high-level features for which class separation is much easier than for early s, and hence get better accuracy. Thus, we ask the question — is there something that early s know that the later s do not?

For that purpose, we introduce a metric to evaluate how much a given could improve performance by reusing information from all previous s. We measure it by checking how many examples incorrectly classified by were classified correctly by any of the previous s. An which reuses predictions from the past perfectly would achieve a low score on this metric since it would remember all the correct answers of the previous s. On the other hand, an in a model which trains each classifier independently would have a higher score on this metric, since it does not use past information at all. We call this metric Hindsight Improvability (HI) since it measures how many mistakes we would be able to avoid if we used information from the past efficiently.

Let denote the set of examples correctly classified by , with its complement being the set of examples classified incorrectly. To measure the Hindsight Improvability of we calculate:

Figure 2 compares the values of HI for a method with independent s (SDN in this case) and ZTW which explicitly recycles computations. In the case of VGG16 trained with independent s, over of the mistakes could be avoided if we properly used information from the past, which would translate to improvement from accuracy to . Similarily, for ResNet-56 trained on TinyImageNet, the errors could be cut by around 57%.

ZTW consistently outperforms the baseline, with the largest differences visible at the later s, which can in principle gain the most from reusing previous predictions. Thus, Zero Time Waste is able to efficiently recycle information from the past. At the same time, there is still a room for significant improvements, which shows that future zero waste approaches could offer additional enhancements.

4.3 Time Savings in Reinforcement Learning

Although supervised learning is an important testbed for deep learning, it does not properly reflect the challenges encountered in the real world. In order to examine the impact of waste-minimization methods in a setting that reflects the sequential nature of interacting with the world, we evaluate it in a Reinforcement Learning (RL) setting. In particular, we use the environments from the commonly used suite of Atari 2600 games [26].

Figure 3: Inference time vs. average return of the ZTW policy in an RL setting on Q*bert and Pong Atari 2600 environments. The plot was generated by using different values of confidence threshold hyperparameter. Since the RL envs are stochastic, we plot the return with a standard deviation calculated on 10 runs. ZTW saves a significant amount of computation while preserving the original performance, showcasing that waste can also be minimized in the reinforcement learning domain.

Similarly as in the supervised setting, we start with a pre-trained network, which in this case represents a policy trained with the Proximal Policy Optimization (PPO) algorithm [34]. We attach the s to the network and train it by distilling the knowledge from the core network to the s. We use a behavioral cloning approach, where the states are sampled from the policy defined by s and the labels are provided by the expert. Since actions in Atari 2600 are discrete, we can then use the same confidence threshold-based approach to early exit as in the case of classification. More details about the training process are provided in the Appendix A.2.

In order to investigate the relationship between computation waste reduction and performance, we evaluate Zero Time Waste for different values of confidence threshold . By setting a very high value, we retrieve the performance of the original model (none of the s respond) and by slowly decreasing its value we can reduce the computational cost (s begin to return answers earlier). In Figure 3 we check values of in the interval to show how ZTW is able to control the acceleration-performance balance for Q*Bert and Pong, two popular Atari 2600 environments. By setting lower thresholds for Q*Bert we can save around of computations without score degradation. Similarly, for Pong we can get reduction with minor impact on performance (note that average human score is 9.3 points). This shows that even the small four-layered convolutional architecture commonly used for Atari [26] introduces a noticeable waste of computation which can be mitigated within a zero-waste paradigm. We highlight this fact as the field of reinforcement learning has largely focused on efficiency in terms of number of samples and training time, while paying less attention to the issue of efficient inference.

4.4 Impact & Limitations

Our framework is the cornerstone of an environmental-aware computation where information recycling within a model is cautiously studied to avoid wasting resources. The focus on computational efficiency, however, introduces a natural trade-off between model accuracy and its computational cost. Although in most cases we can carefully adjust the appropriate method hyperparameters to avoid significant accuracy drop, some testing samples remain surprisingly challenging for ZTW, which indicates a need for further investigation of the accuracy vs. computation cost trade-off offered by our method.

Figure 4 contains examples of images for which low-level features in a given image consistently point at a wrong class, while high-level features would allow us to deduce the correct class. Images of birds which contain sharp lines and grayscale silhouettes are interpreted as airplanes by early s which operate on low-level features. If the confidence of these classifiers gets high enough, the answer might be returned before later classifiers can correct this decision. We highlight the problem of dealing with examples which are seemingly easy but turn out difficult as an important future direction for conditional computation methods.

Figure 4: Examples of bird images which were incorrectly classified as airplanes by ZTW. The early s are misled by the low-level features (blue sky, sharp edges, grayscale silhouette) and return a prediction before the later s can detect more subtle high-level features.

5 Conclusion

In this work, we show that discarding predictions of the previous s in early exit models leads to waste of computation resources and a significant loss of information. This result is supported by the introduced Hindsight Improvability metric, as well as empirical result for reducing computations in existing networks. The proposed Zero Time Waste method attempts to solve these issues by incorporating outputs from the past heads by using cascade connections and geometric ensembling. We show that ZTW outperforms other approaches on multiple standard datasets and architectures for supervised learning, as well as in Atari 2600 reinforcement learning suite. At the same time we postulate that focusing on reducing the computational waste in a safe and stable way is an important direction for future research in deep learning.

References

  • [1] D. Ariely and M. I. Norton (2011) From thinking too little to thinking too much: a continuum of decision making. WIREs Cognitive Science 2 (1), pp. 39–46. External Links: https://onlinelibrary.wiley.com/doi/pdf/10.1002/wcs.90 Cited by: §1.
  • [2] E. Bengio, P. Bacon, J. Pineau, and D. Precup (2015) Conditional computation in neural networks for faster models. arXiv:1511.06297. Cited by: §2.
  • [3] Y. Bengio, N. Léonard, and A. Courville (2013)

    Estimating or propagating gradients through stochastic neurons for conditional computation

    .
    arXiv:1308.3432. Cited by: §2.
  • [4] C. Bentéjac, A. Csörgő, and G. Martínez-Muñoz (2020) A comparative analysis of gradient boosting algorithms. Artificial Intelligence Review, pp. 1–31. Cited by: §1.
  • [5] K. Berestizshevsky and G. Even (2019) Dynamically sacrificing accuracy for reduced computation: cascaded inference based on softmax confidence. In Proceedings of the International Conference on Artificial Neural Networks, ICANN, pp. 306–320. Cited by: §2.
  • [6] A. Davis and I. Arel (2013) Low-rank approximations for conditional feedforward computation in deep neural networks. arXiv:1312.4461. Cited by: §2.
  • [7] T. G. Dietterich (2000) Ensemble methods in machine learning. In Proceedings of the International Workshop on Multiple Classifier Systems, pp. 15. Cited by: §2.
  • [8] G. Dulac-Arnold, D. J. Mankowitz, and T. Hester (2019) Challenges of real-world reinforcement learning. CoRR abs/1904.12901. External Links: Link, 1904.12901 Cited by: §1.
  • [9] M. Figurnov, M. D. Collins, Y. Zhu, L. Zhang, J. Huang, D. Vetrov, and R. Salakhutdinov (2017) Spatially adaptive computation time for residual networks. In

    Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition

    ,
    CVPR, pp. 1039–1048. Cited by: §2.
  • [10] S. Fort, H. Hu, and B. Lakshminarayanan (2019) Deep ensembles: a loss landscape perspective. arXiv:1912.02757. Cited by: §1, §3.
  • [11] G. Gigerenzer and W. Gaissmaier (2011) Heuristic decision making. Annual Review of Psychology 62, pp. 451–82. External Links: ISBN 9780199390076 Cited by: §1.
  • [12] S. Grigorescu, B. Trasnea, T. Cocias, and G. Macesanu (2020) A survey of deep learning techniques for autonomous driving. Journal of Field Robotics 37 (3), pp. 362–386. Cited by: §1.
  • [13] K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep residual learning for image recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, CVPR, pp. 770–778. Cited by: §1, §4.
  • [14] Y. He, X. Zhang, and J. Sun (2017) Channel pruning for accelerating very deep neural networks. In Proceedings of the IEEE International Conference on Computer Vision, pp. 1389–1397. Cited by: §2.
  • [15] T. Hester and P. Stone (2013) TEXPLORE: real-time sample-efficient reinforcement learning for robots. Machine Learning 90 (3), pp. 385–429. External Links: Link, Document Cited by: §1.
  • [16] G. Hinton, O. Vinyals, and J. Dean (2015) Distilling the knowledge in a neural network. In Proceedings of the NIPS Workshop on Deep Learning and Representation Learning, Cited by: §2.
  • [17] A. G. Howard, M. Zhu, B. Chen, D. Kalenichenko, W. Wang, T. Weyand, M. Andreetto, and H. Adam (2017)

    Mobilenets: efficient convolutional neural networks for mobile vision applications

    .
    arXiv:1704.04861. Cited by: §2, §4.
  • [18] S. Jung, S. Hwang, H. Shin, and D. H. Shim (2018) Perception, guidance, and navigation for indoor autonomous drone racing using deep learning. IEEE Robotics and Automation Letters 3 (3), pp. 2539–2544. Cited by: §1.
  • [19] D. Kahneman (2017) Thinking, fast and slow. Farrar, Straus and Giroux. Cited by: §1.
  • [20] Y. Kaya, S. Hong, and T. Dumitras (2019) Shallow-deep networks: understanding and mitigating network overthinking. In Proceedings of the International Conference on Machine Learning, ICML, pp. 3301–3310. Cited by: §A.1, §A.1, §1, §2, §3, §4.
  • [21] A. Kouris, S. I. Venieris, M. Rizakis, and C. Bouganis (2019) Approximate LSTMs for time-constrained inference: enabling fast reaction in self-driving cars. External Links: 1905.00689 Cited by: §1.
  • [22] A. Krizhevsky, I. Sutskever, and G. E. Hinton (2017) Imagenet classification with deep convolutional neural networks. Communications of the ACM 60 (6), pp. 84–90. Cited by: §1.
  • [23] B. Lakshminarayanan, A. Pritzel, and C. Blundell (2017) Simple and scalable predictive uncertainty estimation using deep ensembles. In Advances in Neural Information Processing Systems, NIPS, pp. 6402–6413. Cited by: §1, §2, §3.
  • [24] J. Lee, S. Kim, S. Kim, W. Jo, and H. Yoo (2021) GST: group-sparse training for accelerating deep reinforcement learning. External Links: 2101.09650 Cited by: §1.
  • [25] D. Livne and K. Cohen (2020) PoPS: policy pruning and shrinking for deep reinforcement learning. External Links: 2001.05012 Cited by: §1.
  • [26] V. Mnih, K. Kavukcuoglu, D. Silver, A. A. Rusu, J. Veness, M. G. Bellemare, A. Graves, M. Riedmiller, A. K. Fidjeland, G. Ostrovski, et al. (2015) Human-level control through deep reinforcement learning. Nature 518 (7540), pp. 529–533. Cited by: §A.2, §4.3, §4.3.
  • [27] M. Phuong and C. H. Lampert (2019) Distillation-based training for multi-exit architectures. In Proceedings of the IEEE International Conference on Computer Vision, ICCV, pp. 1355–1364. Cited by: §2.
  • [28] A. Raffin, A. Hill, M. Ernestus, A. Gleave, A. Kanervisto, and N. Dormann (2019) Stable baselines3. GitHub. Note: https://github.com/DLR-RM/stable-baselines3 Cited by: §A.2.
  • [29] S. Scardapane, D. Comminiello, M. Scarpiniti, E. Baccarelli, and A. Uncini (2020) Differentiable branching in deep networks for fast inference. In Proceedings of the IEEE International Conference on Acoustics, Speech and Signal Processing, ICASSP, pp. 4167–4171. Cited by: §2.
  • [30] S. Scardapane, M. Scarpiniti, E. Baccarelli, and A. Uncini (2020) Why should we add early exits to neural networks?. arXiv:2004.12814. Cited by: §A.1, §2.
  • [31] R. E. Schapire (1990) The strength of weak learnability. Machine Learning 5 (2), pp. 197–227. Cited by: §3.
  • [32] E. Schuitema, L. Buşoniu, R. Babuška, and P. Jonker (2010) Control delay in reinforcement learning for real-time dynamic systems: a memoryless approach. 2010 IEEE/RSJ International Conference on Intelligent Robots and Systems, pp. 3226–3231. Cited by: §1.
  • [33] J. Schulman, P. Moritz, S. Levine, M. Jordan, and P. Abbeel (2015) High-dimensional continuous control using generalized advantage estimation. arXiv:1506.02438. Cited by: §A.2.
  • [34] J. Schulman, F. Wolski, P. Dhariwal, A. Radford, and O. Klimov (2017) Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347. Cited by: §4.3.
  • [35] K. Simonyan and A. Zisserman (2015) Very deep convolutional networks for large-scale image recognition. In Proceedings of the International Conference on Learning Representations, ICLR. Cited by: §4.
  • [36] S. Teerapittayanon, B. McDanel, and H. Kung (2016) Branchynet: fast inference via early exiting from deep neural networks. In Proceedings of the International Conference on Pattern Recognition, ICPR, pp. 2464–2469. Cited by: §2.
  • [37] X. Wang, F. Yu, Z. Dou, T. Darrell, and J. E. Gonzalez (2018) Skipnet: learning dynamic routing in convolutional networks. In Proceedings of the European Conference on Computer Vision, ECCV, pp. 409–424. Cited by: §1, §2.
  • [38] X. Wang, F. Yu, L. Dunlap, Y. Ma, R. Wang, A. Mirhoseini, T. Darrell, and J. E. Gonzalez (2020) Deep mixture of experts via shallow embedding. In Proceedings of the Uncertainty in Artificial Intelligence, UAI, pp. 552–562. Cited by: §2.
  • [39] S. Zagoruyko and N. Komodakis (2016) Wide residual networks. arXiv:1605.07146. Cited by: §4.
  • [40] H. Zhang, Z. He, and J. Li (2019) Accelerating the deep reinforcement learning with neural network compression. In 2019 International Joint Conference on Neural Networks (IJCNN), pp. 1–8. External Links: Document Cited by: §1.
  • [41] W. Zhou, C. Xu, T. Ge, J. McAuley, K. Xu, and F. Wei (2020) BERT loses patience: fast and robust inference with early exit. arXiv:2006.04152. Cited by: §1, §2, §4.

Appendix A Training Details

All experiments were performed using a single Tesla V100 GPU.

a.1 Supervised Learning

We setup the core networks in our CIFAR-10, CIFAR-100, and Tiny ImageNet experiments following [20] for fair comparison. We use these trained networks and treat them as pre-trained models, i.e. we consider the ,,IC-only” setup, where we do not change the base network.

For CIFAR-10 and CIFAR-100 we train ICs for 50 epochs using the Adam optimizer with learning rate set to

, but lowered by a factor of 10 after 15 epochs. When training on Tiny ImageNet, the learning rate is additionally lowered again by the same factor after epoch 40. On ImageNet (on the pretrained ResNet-50 from the torchvision package), the ICs are trained for epochs, with the initial learning rate of being reduced by a factor of 10 in epochs 20 and 30. To train the ensembling part of our method, we run SGD on the training dataset for 500 epochs. Since both the dataset and the model are very small, we use a high number of epochs to ensure convergence.

Architecture and Placement of ICs

Most common computer vision architectures, including the ones we use, are divided into blocks (e.g. residual blocks in ResNet). Because some blocks change the dimensionality of the features, we take the natural choice of attaching an

after each block, which also considerably simplifies the implementation of our method for any future architectures. Note that the resulting uniform distribution of

s along the base network is not necessarily optimal [30]. However, we focus on this setup for the sake of a fair comparison with SDN and PBEE and consider the exploration of the best placement of s as outside the scope of this work.

Each consists of a single convolutional layer, a pooling layer, and a fully-connected layer, which outputs the class logits. The convolutional layer has a kernel size of 3 with the number of output filters equal to the number of input channels. When applying cascade connections in Zero Time Waste, we use the outputs of the previous as an additional input to the linear classification layer of the current , as shown earlier in Figure 1

. Because Tiny ImageNet has a larger input image size than CIFAR datasets, we use convolutions with stride

instead of to reduce the number of operations of each .

For the pooling layer we reuse the SDN pooling proposed by [20], which is defined as:

where is a learnable scalar parameter. It reduces the size of convolutional maps to .

We keep the architecture and IC placement fixed between experiments, but with small exceptions for Tiny ImageNet and ImageNet. For Tiny ImageNet, we use convolutional layers with stride set to if all dimensions of the input are larger than 8. We do the same for ImageNet, but we additionally reduce the number of output channels of that convolution by a factor of and we place ICs only every third ResNet block. Finally, we apply Layer Normalization to the output of the preceding IC before using it in the final linear layer.

a.2 Reinforcement Learning

We set the Atari environments as follows. Every fourth frame (frame skipping) and the one immediately before it are max-pooled. The resulting frame is then rescaled to size

x and converted into grayscale. At every step the agent has a  probability of taking the previous action irrespective of the policy probabilities (sticky actions). This is added to introduce stochasticity into the environment to avoid cases when the policy converges to a simple strategy that results in the same actions taken in every run. Furthermore, the environment termination flag is set when a life is lost. Finally, the signum function of the reward is taken (reward clipping). The above setup is fairly common and we base our code on the popular Stable Baselines repository [28].

Using that environment setup we use the PPO algorithm to train the policy, and then extract the base network by discarding the value network. We use the following PPO hyperparameters: learning rate , steps to run for each environment per update, batch size , epochs of surrogate loss optimization, clip range () , entropy coefficient , value function coefficient , discount factor ,

as the trade-off of bias vs variance factor for Generalized Advantage Estimator 

[33]

, and the maximum value for the gradient clipping

. The policy is trained for environment time steps in total.

We use the standard ’NatureCNN’ [26] architecture with three convolutional layers and a single fully connected layer. We attach two ICs after the first and the second layer. Similarly as in the supervised setting, each IC has a single convolutional layer, an SDN pooling layer and a fully connected layer. The convolutional layer has stride set to and preserves the number of channels.

To train the ICs, the early-exit policy interacts with the environment. In each step, an IC is chosen uniformly, and the action chosen by that IC is taken. However, the tuple is actually saved to the replay buffer, with and being the observation and the action of the original policy, respectively. After concurrent steps on

environments that buffer is used to train the ICs with behavioral cloning. That is, Kullback–Leibler divergence between the PPO policy actions and the IC actions is used as the cost function. This is done for

epochs with batch size set to and for cascading stage and geometric ensembling stage, respectively. The entire process is repeated until or more steps in total are taken.

Appendix B Additional results

This section contains experimental results which were omitted in the main part of the paper due to page limitations.

b.1 Supervised Learning

For brevity, in the main part of the paper we have only shown a table summarizing the results of acceleration on multiple architectures and dataset. Here, we provide a fuller representation of these results. Figures 1011 and 12 (at the end of the Appendix) show results of the tested methods on CIFAR-10, CIFAR-100 and Tiny ImageNet, respectively. Each figure contains plots for the four considered architectures: ResNet-56, MobileNet, WideResNet and VGG16. Plots show that ZTW outperforms SDN and PBEE in almost all settings, which is consistent with the results summarized earlier. Additionally, in Table 3 we provide summary of the results with standard deviations. Figures 1314 15 show values of Hindsight Improvability for CIFAR-10, CIFAR-100 and Tiny ImageNet, respectively.

ResNet-56
Data Algo 25% 50% 75% 100% Max
CIFAR-10 () SDN
PBEE
ZTW
CIFAR-100 () SDN
PBEE
ZTW
Tiny ImageNet () SDN
PBEE
ZTW
MobileNet
Data Algo 25% 50% 75% 100% Max
CIFAR-10 () SDN
PBEE
ZTW
CIFAR-100 () SDN
PBEE
ZTW
Tiny ImageNet () SDN
PBEE
ZTW
WideResNet
Data Algo 25% 50% 75% 100% Max
CIFAR-10 () SDN
PBEE
ZTW
CIFAR-100 () SDN
PBEE
ZTW
Tiny ImageNet () SDN
PBEE
ZTW
VGG
Data Algo 25% 50% 75% 100% Max
CIFAR-10 () SDN
PBEE
ZTW
CIFAR-100 () SDN
PBEE
ZTW
Tiny ImageNet () SDN
PBEE
ZTW
Table 3: Results on four different architectures and three datasets: Cifar-10, Cifar-100 and Tiny ImageNet. Accuracy (in percentages) obtained using the time budget: 25%, 50%, 75%, 100% of the base network and Max without any limits. The first column shows the accuracy of the base network.

b.2 Results of ImageNet experiments

Figure 5: Inference time vs. accuracy for ResNet-50 trained on ImageNet. Base network achieves accuracy, and given the same inference time constraint SDN obtains , PBEE , and ZTW .

In order to show that the proposed method scales up well to the ImageNet dataset, we use our method on a pre-trained model provided by the torchvision package111https://pytorch.org/vision/stable/index.html. The obtained model allows for significant speed-ups on ImageNet while maintaining the same accuracy for the original inference time limit. The results presented in Figure 5 show that ZTW again outperforms the rest of the methods, with SDN maintaining reasonable, although lower, performance and PBEE generally failing. We want to highlight the fact that the architecture of s used here is very simple and nowhere as intensely investigated as the architecture of ResNet or other common deep learning models. Adjusting the s for this problem could thus improve the results significantly, although we consider this outside the scope of this work.

b.3 Results of Reinforcement Learning experiments

In Figure 6 we show the results for all eight Reinforcement Learning environments that we ran our experiments on. Degree of time savings depends heavily on the environment. For some of the environments, such as AirRaid and Pong, the ICs obtain a similar return to that of the original policy. Because of that the resulting plot is almost flat, allowing for significant inference time reduction without any performance drop. Other environments, such as Seaquest, Phoenix and Riverraid, allow to gradually trade-off performance for inference time just as in the supervised setting.

Figure 6: Mean and standard deviation of returns for multiple confidence thresholds on various Atari 2600 environments. Some environments allow significant computational savings with a negligible or no impact on performance.

Appendix C Ablation Studies

In this section, we present results of experiments which explain our design decisions. In particular, we focus here on three issues: (1) what is the individual impact of cascade connections and geometric ensembling, (2) how performance of additive and geometric ensembles compares in our setting and (3) how stopping the gradient in cascade connections impacts learning dynamics.

c.1 Impact of cascading and ensembling

Figure 7: Ablation studies exhibiting the importance of both techniques proposed in the paper. Although both cascade connections and geometric ensembling seem to help, the exact effect depends on the architecture and chosen threshold . For ResNet56 cascade connections seem to be much more helpful than ensembling, while for VGG16 the opposite is true. As such, both are required to consistently improve results.

An important question is whether we need both components in the proposed model (cascade connections and ensembling), and what role do they play in the final performance of our model. Figure 7 shows the results of independently applied cascade connections and geometric ensembling on a ResNet-56 and VGG-16 trained on CIFAR-100. We observe that depending on the threshold and the architecture, one of these techniques may be more important than the other. However, combining these methods consistently improves the performance each of them achieved independently. Thus we argue that both cascade connections and geometric ensembling are required in Zero Time Waste and using only one of them will lead to significant performance deterioration.

c.2 Geometric vs Additive Ensembles

In this work we proposed geometric ensembles for combining predictions from multiple s. Here, we show how this approach performs in comparison to additive ensemble of the form:

(4)

where and , for , are trainable parameters, and is a normalization value, such that . That is, we use the same approach as in geometric ensembles, but we substitute the product for a sum and change the weighting scheme.

The empirical comparison between an additive ensemble and a geometric ensemble on ResNet-56 is presented in Figure 8. The results show that the geometric ensemble consistently outperforms the additive ensemble, although the magnitude of improvement varies across datasets. While the difference on CIFAR-10 is negligible, it becomes evident on Tiny ImageNet, especially with the later layers. The results suggest that geometric ensembling is more helpful on more complex datasets with a higher number of classes.

c.3 Stop gradients in cascade connections

Figure 8: Comparison of geometric and additive ensembling on ResNet-56 with cascade connections, conducted on CIFAR-10, CIFAR-100, and Tiny ImageNet.

As mentioned in Section 3 of the main paper, we decide to stop gradient from flowing through the cascade connections. We motivate this decision by noticing that the gradients of later layers might destroy the predictive power of the earlier layers. In order to test this hypothesis empirically, we run our experiments on ResNet-56, with and without gradient stopping. As shown in Figure 9, the accuracy of the early s is lower when not using gradient stopping. Performance of later s may vary, as not using stopping gradient allows greater expressivity for later s. Since the second component of our method, ensembling, is able to reuse information from the early s we find it beneficial to use gradient stopping in the final model. This is especially evident on Tiny ImageNet, where on later s cascade connections perform better without gradient stopping, but ZTW is able to reuse s trained with gradient stopping more effectively.

We provide a more in-depth observation of the reason why the gradient of later s might have a detrimental effect on the performance of early s. Observe that in the setting without the detach the parameters of the first will be updated using , where is the gradient of the loss of the -th wrt. parameters of the first

. Experimental investigation showed that the cosine similarity of

and is approximately at the beginning of the training, which means that these gradients point in different directions. Since the gradient represents the best direction for improving the first , using will lead to a non-optimal update of its weights, thus reducing its predictive performance. With detach, and as such the cosine similarity is always . This reasoning can be extended to the rest of s.

Figure 9: Effects of stopping gradient in ResNet-56 trained on CIFAR-10, CIFAR-100, and Tiny ImageNet.
Figure 10: Inference time vs. accuracy obtained on various architectures trained on CIFAR-10.
Figure 11: Inference time vs. accuracy obtained on various architectures trained on CIFAR-100.
Figure 12: Inference time vs. accuracy obtained on various architectures trained on Tiny ImageNet.
Figure 13: Hindsight Improvability of various architectures trained on CIFAR-10.
Figure 14: Hindsight Improvability of various architectures trained on CIFAR-100.
Figure 15: Hindsight Improvability of various architectures trained on Tiny ImageNet.