tf-rnn-adaptive
Adaptive Computation Time (Graves, 2016, arXiv:1603.08983) wrapper for TensorFlow RNN cells.
view repo
This paper introduces Adaptive Computation Time (ACT), an algorithm that allows recurrent neural networks to learn how many computational steps to take between receiving an input and emitting an output. ACT requires minimal changes to the network architecture, is deterministic and differentiable, and does not add any noise to the parameter gradients. Experimental results are provided for four synthetic problems: determining the parity of binary vectors, applying binary logic operations, adding integers, and sorting real numbers. Overall, performance is dramatically improved by the use of ACT, which successfully adapts the number of computational steps to the requirements of the problem. We also present character-level language modelling results on the Hutter prize Wikipedia dataset. In this case ACT does not yield large gains in performance; however it does provide intriguing insight into the structure of the data, with more computation allocated to harder-to-predict transitions, such as spaces between words and ends of sentences. This suggests that ACT or other adaptive computation methods could provide a generic method for inferring segment boundaries in sequence data.
READ FULL TEXT VIEW PDFAdaptive Computation Time (Graves, 2016, arXiv:1603.08983) wrapper for TensorFlow RNN cells.
The amount of time required to pose a problem and the amount of thought required to solve it are notoriously unrelated. Pierre de Fermat was able to write in a margin the conjecture (if not the proof) of a theorem that took three and a half centuries and reams of mathematics to solve [35]
. More mundanely, we expect the effort required to find a satisfactory route between two cities, or the number of queries needed to check a particular fact, to vary greatly, and unpredictably, from case to case. Most machine learning algorithms, however, are unable to dynamically adapt the amount of computation they employ to the complexity of the task they perform.
For artificial neural networks, where the neurons are typically arranged in densely connected layers, an obvious measure of computation time is the number of layer-to-layer transformations the network performs. In feedforward networks this is controlled by the network
depth, or number of layers stacked on top of each other. For recurrent networks, the number of transformations also depends on the length of the input sequence — which can be padded or otherwise extended to allow for extra computation. The evidence that increased depth leads to more performant networks is by now inarguable
[5, 4, 19, 9], and recent results show that increased sequence length can be similarly beneficial [31, 33, 25]. However it remains necessary for the experimenter to decide a priori on the amount of computation allocated to a particular input vector or sequence. One solution is to simply make every network very deep and design its architecture in such a way as to mitigate the vanishing gradient problem [13] associated with long chains of iteration [29, 17]. However in the interests of both computational efficiency and ease of learning it seems preferable to dynamically vary the number of steps for which the network ‘ponders’ each input before emitting an output. In this case the effective depth of the network at each step along the sequence becomes a dynamic function of the inputs received so far.The approach pursued here is to augment the network output with a sigmoidal halting unit
whose activation determines the probability that computation should continue. The resulting
halting distributionis used to define a mean-field vector for both the network output and the internal network state propagated along the sequence. A stochastic alternative would be to halt or continue according to binary samples drawn from the halting distribution—a technique that has recently been applied to scene understanding with recurrent networks
[7]. However the mean-field approach has the advantage of using a smooth function of the outputs and states, with no need for stochastic gradient estimates. We expect this to be particularly beneficial when long sequences of halting decisions must be made, since each decision is likely to affect all subsequent ones, and sampling noise will rapidly accumulate (as observed for policy gradient methods
[36]).A related architecture known as Self-Delimiting Neural Networks [26, 30] employs a halting neuron to end a particular update within a large, partially activated network; in this case however a simple activation threshold is used to make the decision, and no gradient with respect to halting time is propagated. More broadly, learning when to halt can be seen as a form of conditional computing, where parts of the network are selectively enabled and disabled according to a learned policy [3, 6].
We would like the network to be parsimonious in its use of computation, ideally limiting itself to the minimum number of steps necessary to solve the problem. Finding this limit in its most general form would be equivalent to determining the Kolmogorov complexity of the data (and hence solving the halting problem) [21]
. We therefore take the more pragmatic approach of adding a time cost to the loss function to encourage faster solutions. The network then has to learn to trade off accuracy against speed, just as a person must when making decisions under time pressure. One weakness is that the numerical weight assigned to the time cost has to be hand-chosen, and the behaviour of the network is quite sensitive to its value.
Consider a recurrent neural network composed of a matrix of input weights , a parametric state transition model , a set of output weights and an output bias . When applied to an input sequence , computes the state sequence and the output sequence by iterating the following equations from to :
(1) | ||||
(2) |
The state is a fixed-size vector of real numbers containing the complete dynamic information of the network. For a standard recurrent network this is simply the vector of hidden unit activations. For a Long Short-Term Memory network (LSTM)
[14], the state also contains the activations of the memory cells. For a memory augmented network such as a Neural Turing Machine (NTM)
[10], the state contains both the complete state of the controller network and the complete state of the memory. In general some portions of the state (for example the NTM memory contents) will not be visible to the output units; in this case we consider the corresponding columns of to be fixed to 0.Adaptive Computation Time (ACT) modifies the conventional setup by allowing to perform a variable number of state transitions and compute a variable number of outputs at each input step. Let be the total number of updates performed at step . Then define the intermediate state sequence and intermediate output sequence at step as follows
(3) | ||||
(4) |
where is the input at time augmented with a binary flag that indicates whether the input step has just been incremented, allowing the network to distinguish between repeated inputs and repeated computations for the same input. Note that the same state function is used for all state transitions (intermediate or otherwise), and similarly the output weights and bias are shared for all outputs. It would also be possible to use different state and output parameters for each intermediate step; however doing so would cloud the distinction between increasing the number of parameters and increasing the number of computational steps. We leave this for future work.
To determine how many updates performs at each input step an extra sigmoidal halting unit is added to the network output, with associated weight matrix and bias :
(5) |
As with the output weights, some columns of may be fixed to zero to give selective access to the network state. The activation of the halting unit is then used to determine the halting probability of the intermediate steps:
(6) |
where
(7) |
the remainder is defined as follows
(8) |
and is a small constant (0.01 for the experiments in this paper), whose purpose is to allow computation to halt after a single update if , as otherwise a minimum of two updates would be required for every input step. It follows directly from the definition that and
, so this is a valid probability distribution. A similar distribution was recently used to define differentiable
push and pop operations for neural stacks and queues [11].At this point we could proceed stochastically by sampling from and setting , . However we will eschew sampling techniques and the associated problems of noisy gradients, instead using to determine mean-field updates for the states and outputs:
(9) |
The implicit assumption is that the states and outputs are approximately linear, in the sense that a linear interpolation between a pair of state or output vectors will also interpolate between the properties the vectors embody. There are several reasons to believe that such an assumption is reasonable. Firstly, it has been observed that the high-dimensional representations present in neural networks naturally tend to behave in a linear way
[32, 20], even remaining consistent under arithmetic operations such as addition and subtraction [22]. Secondly, neural networks have been successfully trained under a wide range of adversarial regularisation constraints, including sparse internal states [23], stochastically masked units [28] and randomly perturbed weights [1]. This leads us to believe that the relatively benign constraint of approximately linear representations will not be too damaging. Thirdly, as training converges, the tendency for both mean-field and stochastic latent variables is to concentrate all the probability mass on a single value. In this case that yields a standard RNN with each input duplicated a variable, but deterministic, number of times, rendering the linearity assumption irrelevant.A diagram of the unrolled computation graph of a standard RNN is illustrated in Figure 1, while Figure 2 provides the equivalent diagram for an RNN trained with ACT.
If no constraints are placed on the number of updates can take at each step it will naturally tend to ‘ponder’ each input for as long as possible (so as to avoid making predictions and incurring errors). We therefore require a way of limiting the amount of computation the network performs. Given a length input sequence , define the ponder sequence ) of as
(10) |
and the ponder cost as
(11) |
Since , is an upper bound on the (non-differentiable) property we ultimately want to reduce, namely the total computation during the sequence^{1}^{1}1For a stochastic ACT network, a more natural halting distribution than the one described in Equations 8, 7 and 6 would be to simply treat as the probability of halting at step , in which case . One could then set — i.e. the expected ponder time under the stochastic distribution. However experiments show that networks trained to minimise expected rather than total halting time learn to ‘cheat’ in the following ingenious way: they set to a value just below the halting threshold, then keep until some when they set high enough to ensure they halt. In this case , so the states and outputs at have much lower weight in the mean field updates (Equation 9) than those at ; however by making the magnitudes of the states and output vectors much larger at than the network can still ensure that the update is dominated by the final vectors, despite having paid a low ponder penalty..
We can encourage the network to minimise by modifying the sequence loss function used for training:
(12) |
where is a time penalty parameter that weights the relative cost of computation versus error. As we will see in the experiments section the behaviour of the network is quite sensitive to the value of , and it is not obvious how to choose a good value. If computation time and prediction error can be meaningfully equated (for example if the relative financial cost of both were known) a more principled technique for selecting should be possible.
To prevent very long sequences at the beginning of training (while the network is learning how to use the halting unit) the bias term can be initialised to a positive value. In addition, a hard limit on the maximum allowed value of can be imposed to avoid excessive space and time costs. In this case Equation 7 is modified to
(13) |
The ponder costs are discontinuous with respect to the halting probabilities at the points where increments or decrements (that is, when the summed probability mass up to some either decreases below or increases above ). However they are continuous away from those points, as remains constant and is a linear function of the probabilities. In practice we simply ignore the discontinuities by treating as constant and minimising everywhere.
Given this approximation, the gradient of the ponder cost with respect to the halting activations is straightforward:
(14) |
and hence
(15) |
The halting activations only influence via their effect on the halting probabilities, therefore
(16) |
Furthermore, since the halting probabilities only influence via their effect on the states and outputs, it follows from Equation 9 that
(17) |
while, from Equations 8 and 6
(18) |
∫ Combining Equations 18, 17 and 15 gives, for
(19) |
while for
(20) |
Thereafter the network can be differentiated as usual (e.g. with backpropagation through time
[36]) and trained with gradient descent.We tested recurrent neural networks (RNNs) with and without ACT on four synthetic tasks and one real-world language processing task. LSTM was used as the network architecture for all experiments except one, where a simple RNN was used. However we stress that ACT is equally applicable to any recurrent architecture.
All the tasks were supervised learning problems with discrete targets and cross-entropy loss. The data for the synthetic tasks was generated online and cross-validation was therefore not needed. Similarly, the character prediction dataset was sufficiently large that the network did not overfit. The performance metric for the synthetic tasks was the
sequence error rate: the fraction of examples where any mistakes were made in the complete output sequence. This metric is useful as it is trivial to evaluate without decoding. For character prediction the metric was the average log-loss of the output predictions, in units of bits per character.Most of the training parameters were fixed for all experiments: Adam [18] was used for optimisation with a learning rate of , the Hogwild! algorithm [24] was used for asynchronous training with 16 threads; the initial halting unit bias mentioned in Equation 5 was 1; the term from Equation 7 was 0.01. The synthetic tasks were all trained for 1M iterations, where an iteration is defined as a weight update on a single thread (hence the total number of weight updates is approximately 16 times the number of iterations). The character prediction task was trained for 10K iterations. Early stopping was not used for any of the experiments.
A logarithmic grid search over time penalties was performed for each experiment, with 20 randomly initialised networks trained for each value of . For the synthetic problems the range of the grid search was from with integer in the range 1–10 and the exponent in the range 1–4. For the language modelling task, which took many days to complete, the range of was limited to 1–3 to reduce training time (lower values of , which naturally induce more pondering, tend to give greater data efficiency but slower wall clock training time).
Unless otherwise stated the maximum computation time (Equation 13) was set to 100. In all experiments the networks converged on learned values of that were far less than , which functions mainly as safeguard against excessively long ponder times early in training.
Determining the parity of a sequence of binary numbers is a trivial task for a recurrent neural network [27], which simply needs to implement an internal switch that changes sign every time a one is received. For shallow feedforward networks receiving the entire sequence in one vector, however, the number of distinct input patterns, and hence difficulty of the task, grows exponentially with the number of bits. We gauged the ability of ACT to infer an inherently sequential algorithm from statically presented data by presenting large binary vectors to the network and asking it to determine the parity. By varying the number of binary bits for which parity must be calculated we were also able to assess ACT’s ability to adapt the amount of computation to the difficulty of the vector.
The input vectors had 64 elements, of which a random number from to were randomly set to or
and the rest were set to 0. The corresponding target was 1 if there was an odd number of ones and 0 if there was an even number of ones. Each training sequence consisted of a single input and target vector, an example of which is shown in
Figure 3. The network architecture was a simple RNN with a single hidden layer containing 128 units and a single sigmoidal output unit, trained with binary cross-entropy loss on minibatches of size 128. Note that without ACT the recurrent connection in the hidden layer was never used since the data had no sequential component, and the network reduced to a feedforward network with a single hidden layer.Figure 4 demonstrates that the network was unable to reliably solve the problem without ACT, with a mean of almost 40% error compared to 50% for random guessing. For penalties of 0.03 and below the mean error was below 5%. Figure 5
reveals that the solutions were both more rapid and more accurate with lower time penalties. It also highlights the relationship between the time penalty, the classification error rate and the average ponder time per input. The variance in ponder time for low
networks is very high, indicating that many correct solutions with widely varying runtime can be discovered. We speculate that progressively higher values lead the network to compute the parities of successively larger chunks of the input vector at each ponder step, then iteratively combine these calculations to obtain the parity of the complete vector.Figure 6 shows that for the networks without ACT and those with overly high time penalties, the error rate increases sharply with the difficulty of the task (where difficulty is defined as the number of bits whose parity must be determined), while the amount of ponder remains roughly constant. For the more successful networks, with intermediate values, ponder time appears to grow linearly with difficulty, with a slope that generally increases as decreases. Even for the best networks the error rate increased somewhat with difficulty. For some of the lowest networks there is a dramatic increase in ponder after about 32 bits, suggesting an inefficient algorithm.
Like parity, the logic task tests if an RNN with ACT can sequentially process a static vector. Unlike parity it also requires the network to internally transfer information across successive input timesteps, thereby testing whether ACT can propagate coherent internal states.
Each input sequence consists of a random number from 1 to 10 of size 102 input vectors. The first two elements of each input represent a pair of binary numbers; the remainder of the vector is divided up into 10 chunks of size 10. The first chunks, where is a random number from 1 to 10, contain one-hot representations of randomly chosen numbers between 1 and 10; each of these numbers correspond to an index into the subset of binary logic gates whose truth tables are listed in Table 1. The remaining chunks were zeroed to indicate that no further binary operations were defined for that vector. The binary target for each input is the truth value yielded by recursively applying the binary gates in the vector to the two initial bits . That is for :
(21) |
where is the truth table indexed by chunk in the input vector.
P | Q | NOR | Xq | ABJ | XOR | NAND | AND | XNOR | if/then | then/if | OR |
---|---|---|---|---|---|---|---|---|---|---|---|
T | T | F | F | F | F | F | T | T | T | T | T |
T | F | F | F | T | T | T | F | F | F | T | T |
F | T | F | T | F | T | T | F | F | T | F | T |
F | F | T | F | F | F | T | F | T | T | T | F |
For the first vector in the sequence, the two input bits were randomly chosen to be false (0) or true (1) and assigned to the first two elements in the vector. For subsequent vectors, only was random, while was implicitly equal to the target bit from the previous vector (for the purposes of calculating the current target bit), but was always set to zero in the input vector. To solve the task, the network therefore had to learn both how to calculate the sequence of binary operations represented by the chunks in each vector, and how to carry the final output of that sequence over to the next timestep. An example input-target sequence pair is shown in Figure 7.
The network architecture was single-layer LSTM with 128 cells. The output was a single sigmoidal unit, trained with binary cross-entropy, and the minibatch size was 16.
Figure 8 shows that the network reaches a minimum sequence error rate of around 0.2 without ACT (compared to 0.5 for random guessing), and virtually zero error for all . From Figure 9 it can be seen that low ACT networks solve the task very quickly, requiring about 10,000 training iterations. For higher values ponder time reduces to 1, at which point the networks trained with ACT behave identically to those without. For lower values, the spread of ponder values, and hence computational cost, is quite large. Again we speculate that this is due to the network learning more or less ‘chunked’ solutions in which composite truth table are learned for multiple successive logic operations. This is somewhat supported by the clustering of the lowest networks around a ponder time of 5–6, which is approximately the mean number of logic gates applied per sequence, and hence the minimum number of computations the network would need if calculating single binary operations at a time.
Figure 10 shows a surprisingly high ponder time for the least difficult inputs, with some networks taking more than 10 steps to evaluate a single logic gate. From 5 to 10 logic gates, ponder gradually increases with difficulty as expected, suggesting that a qualitatively different solution is learned for the two regimes. This is supported by the error rates for the non ACT and high networks, which increase abruptly after 5 gates. It may be that 5 is the upper limit on the number of successive gates the network can learn as a single composite operation, and thereafter it is forced to apply an iterative algorithm.
The addition task presents the network with a input sequence of 1 to 5 size 50 input vectors. Each vector represents a digit number, where is drawn randomly from 1 to 5, and each digit is drawn randomly from 0 to 9. The first
elements of the vector are a concatenation of one-hot encodings of the
digits in the number, and the remainder of the vector is set to 0. The required output is the cumulative sum of all inputs up to the current one, represented as a set of 6 simultaneous classifications for the 6 possible digits in the sum. There is no target for the first vector in the sequence, as no sums have yet been calculated. Because the previous sum must be carried over by the network, this task again requires the internal state of the network to remain coherent. Each classification is modelled by a size 11 softmax, where the first 10 classes are the digits and the is a special marker used to indicate that the number is complete. An example input-target pair is shown in Figure 11.The network was single-layer LSTM with 512 memory cells. The loss function was the joint cross-entropy of all 6 targets at each time-step where targets were present and the minibatch size was 32. The maximum ponder was set to 20 for this task, as it was found that some networks had very high ponder times early in training.
The results in Figure 12 show that the task was perfectly solved by the ACT networks for all values of in the grid search. Unusually, networks with higher solved the problem with fewer training examples. Figure 14 demonstrates that the relationship between the ponder time and the number of digits was approximately linear for most of the ACT networks, and that for the most efficient networks (with the highest values) the slope of the line was close to 1, which matches our expectations that an efficient long addition algorithm should need one computation step per digit.
Figure 15 shows how the ponder time is distributed during individual addition sequences, providing further evidence of an approximately linear-time long addition algorithm.
The sort
task requires the network to sort sequences of 2 to 15 numbers drawn from a standard normal distribution in ascending order. The experiments considered so far have been designed to favour ACT by compressing sequential information into single vectors, and thereby requiring the use of multiple computation steps to unpack them. For the sort task a more natural sequential representation was used: the random numbers were presented one at a time as inputs, and the required output was the sequence of indices into the number sequence placed in sorted order; an example is shown in
Figure 16. We were particularly curious to see how the number of ponder steps scaled with the number of elements to be sorted, knowing that efficient sorting algorithms have computational cost.The network was single-layer LSTM with 512 cells. The output layer was a size 15 softmax, trained with cross-entropy to classify the indices of the sorted inputs. The minibatch size was 16.
Figure 17 shows that the advantage of using ACT is less dramatic for this task than the previous three, but still substantial (from around 12% error without ACT to around 6% for the best value). However from Figure 18 it is clear that these gains come at a heavy computational cost, with the best networks requiring roughly 9 times as much computation as those without ACT. Not surprisingly, Figure 19 shows that the error rate grew rapidly with the sequence length for all networks. It also indicates that the better networks had a sublinear growth in computations per input step with sequence length, though whether this indicates a logarithmic time algorithm is unclear. One problem with the sort task was that the Gaussian samples were sometimes very close together, making it hard for the network to determine which was greater; enforcing a minimum separation between successive values would probably be beneficial.
Figure 20 shows the ponder time during three sort sequences of varying length. As can be seen, there is a large spike in ponder time near (though not precisely at) the end of the input sequence, presumably when the majority of the sort comparisons take place. Note that the spike is much higher for the longer two sequences than the length 5 one, again pointing to an algorithm that is nonlinear in sequence length (the average ponder per timestep is nonetheless lower for longer sequences, as little pondering is done away from the spike.).
The Wikipedia task is character prediction on text drawn from the Hutter prize Wikipedia dataset [15]. Following previous RNN experiments on the same data [8], the raw unicode text was used, including XML tags and markup characters, with one byte presented per input timestep and the next byte predicted as a target. No validation set was used for early stopping, as the networks were unable to overfit the data, and all error rates are recorded on the training set. Sequences of 500 consecutive bytes were randomly chosen from the training set and presented to the network, whose internal state was reset to 0 at the start of each sequence.
LSTM networks were used with a single layer of 1500 cells and a size 256 softmax classification layer. As can be seen from Figures 22 and 21, the error rates are fairly similar with and without ACT, and across values of (although the learning curves suggest that the ACT networks are somewhat more data efficient). Furthermore the amount of ponder per input is much lower than for the other problems, suggesting that the advantages of extra computation were slight for this task.
However Figure 23 reveals an intriguing pattern of ponder allocation while processing a sequence. Character prediction networks trained with ACT consistently pause at spaces between words, and pause for longer at ‘boundary’ characters such as commas and full stops. We speculate that the extra computation is used to make predictions about the next ‘chunk’ in the data (word, sentence, clause), much as humans have been found to do in self-paced reading experiments [16]. This suggests that ACT could be useful for inferring implicit boundaries or transitions in sequence data. Alternative measures for inferring transitions include the next-step prediction loss and predictive entropy, both of which tend to increase during harder predictions. However, as can be seen from the figure, they are a less reliable indicator of boundaries, and are not likely to increase at points such as full stops and commas, as these are invariably followed by space characters. More generally, loss and entropy only indicate the difficulty of the current prediction, not the degree to which the current input is likely to impact future predictions.
Furthermore Figure 24 reveals that, as well as being an effective detector of non-text transition markers such as the opening brackets of XML tags, ACT does not increase computation time during random or fundamentally unpredictable sequences like the two ID numbers. This is unsurprising, as doing so will not improve its predictions. In contrast, both entropy and loss are inevitably high for unpredictable data. We are therefore hopeful that computation time will provide a better way to distinguish between structure and noise (or at least data perceived by the network as structure or noise) than existing measures of predictive difficulty.
This paper has introduced Adaptive Computation time (ACT), a method that allows recurrent neural networks to learn how many updates to perform for each input they receive. Experiments on synthetic data prove that ACT can make otherwise inaccessible problems straightforward for RNNs to learn, and that it is able to dynamically adapt the amount of computation it uses to the demands of the data. An experiment on real data suggests that the allocation of computation steps learned by ACT can yield insight into both the structure of the data and the computational demands of predicting it.
ACT promises to be particularly interesting for recurrent architectures containing soft attention modules [2, 10, 34, 12], which it could enable to dynamically adapt the number of glances or internal operations they perform at each time-step.
One weakness of the current algorithm is that it is quite sensitive to the time penalty parameter that controls the relative cost of computation time versus prediction error. An important direction for future work will be to find ways of automatically determining and adapting the trade-off between accuracy and speed.
The author wishes to thank Ivo Danihleka, Greg Wayne, Tim Harley, Malcolm Reynolds, Jacob Menick, Oriol Vinyals, Joel Leibo, Koray Kavukcuoglu and many others on the DeepMind team for valuable comments and suggestions, as well as Albert Zeyer, Martin Abadi, Dario Amodei, Eugene Brevdo and Christopher Olah for pointing out the discontinuity in the ponder cost, which was erroneously described as smooth in an earlier version of the paper.
Universal artificial intelligence
. Springer, 2005.Hogwild: A lock-free approach to parallelizing stochastic gradient descent.
In Advances in Neural Information Processing Systems, pages 693–701, 2011.