1 Introduction
The time required to solve a problem is a function of more than just the size of the inputs. Commonly problems also have an inherent complexity that is independent of the input size: it is faster to add two numbers than to divide them. Most machine learning algorithms do not adjust their computational budget based on the complexity of the task they are learning to solve, or arguably, such adaptation is done manually by the machine learning practitioner. This adaptation is known as pondering. In prior work, Adaptive Computation Time
(ACT; Graves, 2016)automatically learns to scale the required computation time via a scalar halting probability. This halting probability modulates the number of computational steps, called the “ponder time”, needed for each input. Unfortunately ACT is notably unstable and sensitive to the choice of a hyperparameter that tradesoff accuracy and computation cost. Additionally, the gradient for the cost of computation can only backpropagate through the last computational step, leading to a biased estimation of the gradient. Another approach is represented by Adaptive Early Exit Networks
(Bolukbasi et al., 2017) where the forward pass of an existing network is terminated at evaluation time if it is likely that the part of the network used so far already predicts the correct answer. More recently, work has investigated the use of REINFORCE (Williams, 1992)to perform conditional computation. A discrete latent variable is used to dynamically adjust the number of computation steps. This approach has been applied to recurrent neural networks
(Chung et al., 2016; Banino et al., 2020), but has the downside that the estimated gradients have high variance, requiring large batch sizes to train them. A parallel line of research has explored using similar techniques to reduce the computation by skipping elements from a sequence of processed inputs
(Yu et al., 2017; Campos Camunez et al., 2018).In this paper we present PonderNet that builds on these previous ideas. PonderNet is fully differentiable which allows for lowvariance gradient estimates (unlike REINFORCE). It has unbiased gradient estimates (unlike ACT). We achieve this by reformulating the halting policy as a probabilistic model. This has consequences in all aspects of the model:

Architecture: in PonderNet, the halting node predicts the probability of halting conditional on not having halted before. We exactly compute the overall probability of halting at each step as a geometric distribution.

Loss: we don’t regularize PonderNet to explicitly minimize the number of computing steps, but incentivize exploration instead. The pressure of using computation efficiently happens naturally as a form of Occam’s razor.

Inference: PonderNet is probabilistic both in terms of number of computational steps and the prediction produced by the network.
2 Methods
2.1 Problem setting
We consider a supervised setting, where we want to learn a function from data , with and
. We propose a new general architecture for neural networks that modifies the forward pass, as well as a novel loss function to train it.
2.2 Step recurrence and halting process
The PonderNet architecture requires a step function of the form , as well as an initial state ^{1}^{1}1Alternatively, one can consider a step function of the form together with an encoder of the form .. The output and are respectively the network’s prediction and scalar probability of halting at step . The step function can be any neural network, such as MLPs, LSTMs, or encoderdecoder architectures such as transformers. We apply the step function recurrently up to times.
The output is a learned prediction conditioned on the dynamic number of steps . We rely on the value of to learn the optimal value of
. We define a Bernoulli random variable
in order to represent a Markov process for the halting with two states “continue” () and “halt” (). The decision process starts from state “continue” (). We set the transition probability:(1) 
that is the conditional probability of entering state “halt” at step conditioned that there has been no previous halting. Note that “halt” is a terminal state. We can then estimate the unconditioned probability that the halting happened in steps where
is the maximum number of steps allowed before halting. We derive this probability distribution
as a generalization of the geometric distribution:(2) 
which is a valid probability distribution if we integrate over an infinite number of possible computation steps ().
The prediction made by PonderNet is sampled from a random variable with probability distribution . In other words, the prediction of PonderNet is the prediction made at the step at which it halts. This is in contrast with ACT, where model predictions are always weighted averages across steps. Additionally, PonderNet is more generic in this regard: if one wishes to do so, it is straightforward to calculate the expected prediction across steps, similar to how it is done in ACT.
2.3 Maximum number of pondering steps
Since in practice we can only unroll the step function for a limited number of iterations, we must correct for this so that the sum of probabilities sums to . We can do this in two ways. One option here is to normalize the probabilities so that they sum up to 1 (this is equivalent to conditioning the probability of halting under the knowledge that ). Alternatively, we could assign any remaining halting probability to the last step, so that instead of as previously defined.
In our experiments, we specify the maximum number of steps using two different criteria. In evaluation, and under known temporal or computational limitations, can be set naively as a constant (or not set any limit, i.e. ). For training, we found that a more effective (and interpretable) way of parameterizing is by defining a minimum cumulative probability of halting. is then the smallest value of such that , with the hyperparameter positive near (in our experiments ).
2.4 Training loss
The total loss is composed of reconstruction and regularization terms:
(3) 
where is a predefined loss for the prediction (usually mean squared error, or crossentropy); and is a hyperparameter that defines a geometric prior distribution on the halting policy (truncated at N). is the expectation of the predefined reconstruction loss across halting steps. is the KL divergence between the distribution of halting probabilities and the prior (a geometric distribution truncated at N, parameterized by ). This hyperparameter defines a prior on how likely it is that the network will halt at each step. This regularisation serves two purposes. First, it biases the network towards the expected prior number of steps . Second, it provides an incentive to give a nonzero probability to all possible number of steps, thus promoting exploration.
2.5 Evaluation sampling
At evaluation, the network samples on a step basis from the halting Bernoulli random variable to decide whether to continue or to halt. This process is repeated on every step until a “halt” outcome is sampled, at which point the output becomes the final prediction of the network. If a maximum number of steps is reached, the network is automatically halted and produces a prediction .
3 Results
3.1 Parity
In this section we are reporting results on the parity task as introduced in the original ACT paper (Graves, 2016)
. Out of the four tasks presented in that paper we decided to focus on parity as it was the one showing greater benefit from adaptive compute. In our instantiation of the parity problem the input vectors had 64 elements, of which a random number from
to were randomly set to or and the rest were set to . The corresponding target wasif there was an odd number of ones and 0 if there was an even number of ones. We refer the reader to the original ACT paper for specific details on the tasks
(Graves, 2016). Also, please refer to Appendix B for further training and evaluation details.In figure 1a we can see that PonderNet achieved better accuracy than ACT on the parity task and it did so with a more efficient use of thinking time (1a at the bottom). Moreover, if we consider the total computation time during training (figure 1c) we can see that, in comparison to ACT, PonderNet employed less computation and achieved higher score.
Another analysis we performed on this version of the parity task was to look at the effect of the prior probability on performance. In figure
2b we show that the only case where PonderNet could not solve the task is when the prior () was set to 0.9, that is when the average number of thinking steps given as prior was roughly . Interestingly, when the prior () was set to 0.1, hence starting with a prior average thinking time of 10 steps , the network managed to overcome this and settled to a more efficient average thinking time of roughly 3 steps (figure 2c). These results are important as they show that our method is particularly robust with respect to the prior, and a clear advancement in comparison to ACT, where the parameter is difficult to set and it is a source of training instability, as explained in the original paper and confirmed by our results. Indeed, Fig. 2a shows that only for few configuration of ACT is able to solve the task and even when it does so there is a great variance across seeds. Finally, one advantage of setting a prior probability is that this parameter is easy to interpret as the inverse of the “number of ponder steps”, whereas the parameter does not have any straightforward interpretation, which makes it harder to define a priori.Next we moved to test the ability of PonderNet to allow extrapolation. To do this we consider input vectors of elements instead. We train the network on input vectors up from integers ranging from to elements and we then evaluate on integers between and . Figure 1b shows that PonderNet was able to achieve almost perfect accuracy on this hard extrapolation task, whereas ACT remained at chance level. It is interesting to see how PonderNet increased its thinking time to 5 steps, which is almost twice as much as the ones in the interpolation set (see Fig. 1a), showing the capability of our method to adapt its computation to the complexity of the task.
3.2 bAbI
We then turn our attention to the bAbI question answering dataset (Weston et al., 2015), which consists of 20 different tasks. This task was chosen as it proved to be difficult for standard neural network architecture that do not employ adaptive computation (Dehghani et al., 2018). In particular we trained our model on the joint 10k training set. Also, please see Appendix C for further training and evaluation details.
Table 1 reports the averaged accuracy of our model and the other baselines on bAbI. Our model is able to match state of the art results, but it achieves them faster and with a lower average error. The comparison with Universal transformer (Dehghani et al., 2018, UT) is interesting as it uses the same transformer architecture as PonderNet, but the compute time is optimised with ACT. Interestingly, to solve 20 tasks, Universal Transformer takes 10161 steps, whereas our methods 1658, hence confirming that approach uses less compute than ACT.
Architecture  Average Error  Tasks Solved 

Memory Networks (Sukhbaatar et al., 2015)  4.2 0.2  17 
DNC (Graves, 2016)  3.8 0.6  18 
Universal Transformer (Dehghani et al., 2018)  0.29 1.4  20 
Transformer+PonderNet  0.15 0.9  20 
3.3 Paired associative inference
Finally, we tested PonderNet on the Paired associative inference task (PAI) (Banino et al., 2020). This task is thought to capture the essence of reasoning – the appreciation of distant relationships among elements distributed across multiple facts or memories and it has been shown to benefit from the addition of adaptive computation. Please refer to Appendix D for further details on the task and the training regime.
Length  UT  MEMO  PonderNet 
3 items (trained on: ABC  accuracy on AC)  85.60  98.26(0.67)  97.86(3.78) 
Table 2 reports the averaged accuracy of our model and the other baselines on PAI. Our model is able to match the results of MEMO, which was specifically designed with this task in mind. More interestingly, our model although is using the same architecture as UT (Dehghani et al., 2018) is able to achieve higher accuracy. For the complete set of results please see Table 7 in Appendix D.
4 Discussion
We introduced PonderNet, a new algorithm for learning to adapt the computational complexity of neural networks. It optimizes a novel objective function that combines prediction accuracy with a regularization term that incentivizes exploration over the pondering time. We demonstrated on the parity task that a neural network equipped with PonderNet can increase its computation to extrapolate beyond the data seen during training. Also, we showed that our methods achieved the highest accuracy in complex domains such as question answering and multistep reasoning. Finally, adapting existing recurrent architectures to work with PonderNet is very easy: it simply requires to augment the step function with an additional halting unit, and to add an extra term to the loss. Critically, we showed that this extra loss term is robust to the choice of , the hyperparameter that defines a prior on how likely is that the network will halt, which is an important advancement over ACT.
References
 MEMO: a deep network for flexible combination of episodic memories. In International Conference on Learning Representations, Cited by: §C.1, §D.1, §D.1, §D.3, §1, §3.3, Table 2.
 Adaptive neural networks for efficient inference. In Proceedings of the 34th International Conference on Machine LearningVolume 70, pp. 527–536. Cited by: §1.
 Skip RNN: learning to skip state updates in recurrent neural networks. In Sixth International Conference on Learning Representations: Monday April 30Thursday May 03, 2018, Vancouver Convention Center, Vancouver:[proceedings], pp. 1–17. Cited by: §1.
 Hierarchical multiscale recurrent neural networks. arXiv preprint arXiv:1609.01704. Cited by: §1.

Transformerxl: attentive language models beyond a fixedlength context
. arXiv preprint arXiv:1901.02860. Cited by: §D.2.  Universal transformers. arXiv preprint arXiv:1807.03819. Cited by: §C.1, §C.2, §D.2, §3.2, §3.2, §3.3, Table 1.
 ImageNet: A LargeScale Hierarchical Image Database. In CVPR09, Cited by: §D.1.
 Adaptive computation time for recurrent neural networks. arXiv preprint arXiv:1603.08983. Cited by: Appendix A, §B.1, §1, §3.1, Table 1.

Deep residual learning for image recognition.
In
Proceedings of the IEEE conference on computer vision and pattern recognition
, pp. 770–778. Cited by: §D.1.  Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980. Cited by: §B.1, §C.1, §D.2.
 Endtoend memory networks. In Advances in Neural Information Processing Systems, pp. 2440–2448. Cited by: Table 1.
 Attention is all you need. In Advances in neural information processing systems, pp. 5998–6008. Cited by: §A.1, Table 3, Table 5.
 Neural execution of graph algorithms. arXiv preprint arXiv:1910.10593. Cited by: §A.1.
 Towards AIcomplete question answering: a set of prerequisite toy tasks. arXiv preprint arXiv:1502.05698. Cited by: §C.1, §3.2.

Simple statistical gradientfollowing algorithms for connectionist reinforcement learning
. Machine learning 8 (34), pp. 229–256. Cited by: §1.  Learning to skim text. arXiv preprint arXiv:1704.06877. Cited by: §1.
Appendix A Comparison to ACT
PonderNet builds on the ideas introduced in Adaptive Computation Time (ACT; Graves, 2016). The main contribution of this paper is to reformulate how the network learns to halt in a probabilistic way. This has consequences in all aspects of the model, including: the architecture and forward computation; the loss used to train the network; the deployment of the model; and the limitation of how multiple pondering modules can be combined. We explain in more detail all these differences below.
a.1 Forward computation
PonderNet’s step function (that is computed on every step) is identical to the one proposed in ACT. They both assume a mapping . The main difference between ACT and PonderNet’s forward computation is how the halting node is used.
In ACT, the network is unrolled for a number of steps . ACT’s halting nodes learn to predict the overall probability that the network halted at step , so that . The value of the halting node in the last step is replaced with a remainder quantity . In ACT it would not make sense to unroll the network for a larger number of steps than because the sum of probabilities of halting would be . When training ACT, higher values of are not necessarily better, and is being determined (and learnt) via the halting node . In PonderNet, any sufficiently high value of can be used, and the unroll length of the network at training is distinguished from the learning of the halting policy (which is most critical for saving computation when deployed at evaluation).
The output of ACT is not treated probabilistically but as a weighted average over the outputs at each step. The halting, as well as the output, are computed identically for training and evaluation. In PonderNet, the output is probabilistic. In training, we compute the output and halting probabilities across many steps so that we can compute a weighted average of the loss. In evaluation, the network returns its prediction as soon as a halt state is sampled.
Finally, ACT considers the case of sequential data, where the step function can ponder dynamically for each new item in the input sequence. Given the introduction of attention mechanisms in the recent years (e.g. Transformers; Vaswani et al., 2017) that can process arrays with dynamic shapes, we suggest that pondering should be done holistically instead of independently for each item in the sequence. This can be useful in learning e.g. how many messagepassing steps to do in a graph network (Veličković et al., 2019).
a.2 Training loss
ACT proposes a heuristic training loss that combines two intuitive costs: the accuracy of the model, and the cost of computation. These two costs are in different units, and not easily comparable. Since
is not differentiable with respect to , ACT utilizes the remainder as a proxy for minimizing the total number of computational steps. This is unlike in PonderNet, where the expected number of steps can be computed (and differentiated) exactly as .In PonderNet, however, we propose that naively minimizing the number of steps (subject to good performance) is not necessarily a good objective. Instead, we propose that matching a prior halting distribution has multiple benefits: a) it provides an incentive for exploring alternative halting strategies; b) it provides robustness of the learnt step function, which may improve generalization; c) the KL is in same units as informationtheoretic losses such as crossentropy; and d) it provides an incentive to not ponder for longer than the prior.
Note that in PonderNet, we compute the loss for every possible number of computational steps, and then minimize the expectation (weighted average) over those. This is unlike in ACT where the expectation is taken over the predictions, and a loss is computed by comparing the average prediction with the target. This has the consequence that combining multiple networks is easier in ACT than in PonderNet. One could easily chain multiple ACT modules next to each other, and the size of the network during training would grow linearly with the number of modules. However, the network size when chaining PonderNet modules grows exponentially because the loss would need to be estimated conditioned on each PonderNet module halting at each step.
In PonderNet we have introduced two loss hyperparameters and , in comparison to a single hyperparameter in ACT that tradesoff accuracy with computational complexity. We note that, while and are superficially similar (they both apply a weight to the regularization term), their effect is not equivalent because the regularization of ACT and PonderNet have different interpretation.
a.3 Evaluation
ACT’s predictions are computed identically during training and evaluation. In both contexts, the maximum number of steps is determined based on the inputs, and the prediction is computed as a weighted average over the predictions in all steps. In PonderNet, training and evaluation are performed differently. During evaluation, the network halts probabilistically by sampling , and either outputs the current prediction or performs an additional computational step. During training, we are not interested in the predictions per se but in the expected loss over steps, and so estimate this up to a maximum number of steps (the higher the better). This estimate will improve with higher probability that the network has halted at some point during the first steps (i.e. the cumulative probability of halting).
Appendix B Parity.
b.1 Training and evaluation details
For this experiment we used the Parity task as explained by Graves (2016).
All the models used the same architecture, a simple RNN with a single hidden layer containing tanh units and a single logistic sigmoid output unit. All models were optimized using Adam (Kingma and Ba, 2014), with learning rate fixed to . The networks were trained with binary crossentropy loss to predict the corresponding target, 1 if there was an odd number of ones and 0 if there was an even number of ones. We used minibatches of size 128. For PonderNet we sampled uniformly 10 values of in the range (0, 1]. For ACT we sampled uniformly 19 values of in the range [2e4, 2e2] and we added also 0, which correspond to not penalising the halting unit at all. For both ACT and Ponder, was set to 20. For PonderNet was fixed to
Appendix C bAbI.
c.1 Training and evaluation details
For this experiment we used the English Question Answer dataset Weston et al. (2015). We use the training and test datasets that they provide with the following preprocessing:

All text is converted to lowercase.

Periods and interrogation marks were ignored.

Blank spaces are taken as word separation tokens.

Commas only appear in answers, and they are not ignored. This means that, e.g. for the path finding task, the answer ’n,s’ has its own independent label from the answer ’n,w’. This also implies that every input (consisting of ’query’ and ’stories’) corresponds to a single answer throughout the whole dataset.

All the questions are stripped out from the text and put separately (given as ”queries” to our system).
At training time, we sample a minibatch of queries from the training dataset, as well as its corresponding stories (which consist of the text prior to the question). As a result, the queries are a matrix of tokens, and sentences are of size , where is the batch size, is the max number of stories, and
is the max sentence size. We pad with zeros every query and group of stories that do not reach the max sentence and stories size. For PonderNet, stories and query are used as their naturally corresponding inputs in their architecture. The details of the network architecture are described in Section
C.2.After that minibatch is sampled, we perform one optimization step using Adam Kingma and Ba (2014)
. We also performed a search on hyperparameters to train on bAbI, with ranges reported on Table
4. The network was trained for epochs, each one formed by batch updates.For evaluation, we sample a batch of elements from the dataset and compute the forward pass in the same fashion as done in training. With that, we compute the mean accuracy over those examples, as well as the accuracy per task for each of the
tasks of bAbI. We report average values and standard deviation over the best
hyper parameters we used.c.2 Transformer architecture and hyperparameters
We use the same architecture as described in Dehghani et al. (2018). More concretely, we use the implementation and hyperparameters described as ’universal_transformer_small’ that is available at https://bit.ly/3frofUI. For completeness, we describe the hyperparameters used on Table 3.
We also performed a search on hyperparameters to train on our tasks, with ranges reported on Table 4.
Parameter name  Value 

Optimizer algorithm  Adam 
Learning rate  3e4 
Input embedding size  128 
Attention type  as in Vaswani et al. (2017) 
Attention hidden size  512 
Attention number of heads  8 
Transition function  MLP(1 Layer) 
Transition hidden size  128 
Attention dropout rate  0.1 
Activation function  RELU 
N  10 
0.01 
Parameter name  Value 

Attention hidden size  {128, 512} 
Transition hidden size  {128, 512} 
uniform(0, 1.0] 
Appendix D Paired Associative Inference
d.1 PAI  Task details
For this task we used the dataset published in Banino et al. (2020), also the task is available at https://github.com/deepmind/deepmindresearch/tree/master/memo
To build the dataset, Banino et al. (2020)
started with raw images from the ImageNet dataset
(Deng et al., 2009), which were embedded using a pretrained ResNet (He et al., 2016), resulting in embeddings of size . Here we are focusing on the dataset with sequences of length three (i.e. ) items, which is composed of training images, evaluation images and testing images.A single entry in the batch is built by selecting sequences from the relevant pool (e.g. training) and it’s composed by three items:

a memory,

a query,

a target.
Each memory content is created by storing all the possible pair wise association between the items in the sequence, e.g. AB and BC, AB and BC, …, AB and BC. With , this process results in a memory with rows each one with embeddings of size .
Each query is composed of 3 images, namely:

the cue

the match

the lure
The cue (e.g. A) and the match (e.g. C) are images extracted from the sequence; whereas the lure is an image from the same memory content but from a different sequence (e.g. C). There are two types of queries  “direct” and “indirect”. In “direct” queries the cue and the match are sampled from the same memory slot. For example, if the sequence is A  B  C, then an example of direct query would be, A (cue)  B (match)  B (lure). More of interests here is the case of “indirect” queries, as they require an inference across multiple facts stored at different location in memory. For instance, if we consider again the previous example sequence: A  B  C, then an example of inference trial would be A (cue)  C (match)  C (lure).
The queries are presented to the network as a concatenation of three image embedding vectors (the cue, the match and the lure), that is a dimensional vector. The cue is always placed in the first position in the concatenation, but to avoid any trivial solution, the position of the match and lure are randomized. It is worth noting that the lure image always has the same position in the sequence (e.g. if the match image is a C the lure is also a C) but it is randomly drawn from a different sequence that is also present in the current memory. This way the task can only be solved by appreciating the correct connection between the images, and this need to be done by avoiding the interference coming for other items in memory. For each entry in the batch we generated all possible queries that the current memory store could support and then one was selected at random. Finally, the batch was balanced, i.e. half of the elements were direct queries and the other half was indirect. Finally, the targets represent the ImageNet classID of the matches.
To summarize, for each entry in each batch:

Memory was of size

Queries were of size

Target was of size
d.2 PAI  Architecture details
We used an architecture similar to Universal Transformers (Dehghani et al., 2018, UT), but we augmented the transformer with a memory as in Dai et al. (2019). The number of layers in the encoder and the decoder was learnt, but constrained to be the same. This number was identified as the “pondering time” in our PonderNet architecture. Also, we set an upper bound to the number of layers. The initial state was a learnt embedding of the input. On each step, the state was updated by applying the encoder layer once, that is: . Note that in this case PonderNet only received information about the inputs through its state. The prediction was computed by applying the decoder layer an equal number of times to the pondering step, that is . With this architecture, PonderNet was able to optimize how many times to apply the encoder and the decoder layers to improve its performance in this task.
The weights were optimised using Adam (Kingma and Ba, 2014), using polynomial weight decay with a maximum learning rate of and learning rate linear warmup for the first epoch. The minibatch size was of size . For completeness, we describe the hyperparameters used on Table 5. We also performed a search on hyperparameters to train on our tasks, with ranges reported on Table 6.
Parameter name  Value 

Optimizer algorithm  Adam 
Input embedding size  256 
Attention type  as in Vaswani et al. (2017) 
Attention hidden size  512 
Attention number of heads  8 
Transition function  MLP(2 Layers) 
Transition hidden size  128 
Attention dropout rate  0.1 
0.01 
Parameter name  Value 

Attention hidden size  {256, 512} 
Transition hidden size  {128, 1024} 
uniform(0, 0.5]  
N  [7, 10] 
d.3 PAI  Results based on query type
The result reported below in Table 7 are from the evaluation set at the end of training. Each evaluation set contains 600 items.

MEMO  UT  PonderNet  

AB  99.82(0.30)  97.43  98.01(2.39)  
BC  99.76(0.38)  98.28  97.43(1.97)  
AC  98.26(0.67)  85.60  97.86(3.78) 
For MEMO and for Universal transformer the results were taken from Banino et al. (2020).
Appendix E Broader impact statement
In this work we introduced PonderNet, a new method that enables neural networks to adapt their computational complexity to the task they are trying to solve. Neural networks achieve state of the art in a wide range of applications, including natural language processing, reinforcement learning, computer vision and more. Currently, they require much time, expensive hardware and energy to train and to deploy. They also often fail to generalize and to extrapolate to conditions beyond their training.
PonderNet expands the capabilities of neural networks, by letting them decide to ponder for an indefinite amount of time (analogous to how both humans and computers think). This can be used to reduce the amount of compute and energy at inference time, which makes it particularly well suited for platforms with limited resources such as mobile phones. Additionally, our experiments show that enabling neural networks to adapt their computational complexity has also benefits for their performance (beyond the computational requirements) when evaluating outside of the training distribution, which is one of the limiting factors when applying neural networks for realworld problems.
We encourage other researchers to pursue the questions we have considered on this work. We believe that biasing neural network architectures to behave more like algorithms, and less like “flat” mappings, will help develop deep learning methods to their the full potential.