Neural networks can be represented as a graph of computational modules, and training these networks amounts to optimising the weights associated with the modules of this graph to minimise a loss. At present, training is usually performed with first-order gradient descent style algorithms, where the weights are adjusted along the direction of the negative gradient of the loss. In order to compute the gradient of the loss with respect to the weights of a module, one performs backpropagation [williams1986learning]
– sequentially applying the chain rule to compute the exact gradient of the loss with respect to a module. However, this scheme has many potential drawbacks, as well as lacking biological plausibility[marblestone2016toward, bengio2015towards]. In particular, backpropagation results in locking – the weights of a network module can only be updated after a full forwards propagation of the data through the network, followed by loss evaluation, then finally after waiting for the backpropagation of error gradients. This locking constrains us to updating neural network modules in a sequential, synchronous manner.
One way of overcoming this issue is to apply Synthetic Gradients (s) to build Decoupled Neural Interfaces (s) [DNI]. In this approach, models of error gradients are used to approximate the true error gradient. These models of error gradients are local to the network modules they are predicting the error gradient for, so that an update to the module can be computed by using the predicted, synthetic gradients, thus bypassing the need for subsequent forward execution, loss evaluation, and backpropagation. The gradient models themselves are trained at the same time as the modules they are feeding synthetic gradients to are trained. The result is effectively a complex dynamical system composed of multiple sub-networks cooperating to minimise the loss. There is a very appealing potential of using DNIs the potential to distribute and parallelise training of networks across multiple GPUs and machines, the ability to asynchronously train multi-network systems, and the ability to extend the temporal modelling capabilities of recurrent networks. However, it is not clear that introducing s and s into a learning system will not negatively impact the learning dynamics and solutions found. While the empirical evidence in DNI suggests that s do not have a negative impact and that this potential is attainable, this paper will dig deeper and analyse the result of using s to accurately answer the question of the impact of synthetic gradients on learning systems. In particular, we address the following questions, using feed-forward networks as our probe network architecture: Does introducing s change the critical points of the neural network learning system? In Section Document we show that the critical points of the original optimisation problem are maintained when using s. Can we characterise the convergence and learning dynamics for systems that use synthetic gradients in place of true gradients? Section Document gives first convergence proofs when using synthetic gradients and empirical expositions of the impact of s on learning. What is the difference in the representations and functional decomposition of networks learnt with synthetic gradients compared to backpropagation? Through experiments on deep neural networks in Section Document, we find that while functionally the networks perform identically trained with backpropagation or synthetic gradients, the layer-wise functional decomposition is markedly different due to s. In addition, in Section Document we look at formalising the connection between s and other forms of approximate error propagation such as Feedback Alignment [lillicrap2016random], Direct Feedback Alignment [NIPS2016_6441, baldi2016learning], and Kickback [balduzzi2014kickback], and show that all these error approximation schemes can be captured in a unified framework, but crucially only using synthetic gradients can one achieve unlocked training.
DNI using Synthetic Gradients
The key idea of synthetic gradients and DNI is to approximate the true gradient of the loss with a learnt model which predicts gradients without performing full backpropagation. Consider a feed-forward network consisting of layers , each taking an input and producing an output , where is the input data point . A loss is defined on the output of the network where is the given label or supervision for (which comes from some unknown ). Each layer has parameters that can be trained jointly to minimise with the gradient-based update rule
where is the learning rate and is computed with backpropagation. The reliance on means that an update to layer can only occur after every subsequent layer has been computed, the loss has been computed, and the error gradient backpropgated to get . An update rule such as this is update locked as it depends on computing , and also backwards locked as it depends on backpropagation to form . DNI introduces a learnt prediction of the error gradient, the synthetic gradient resulting in the update
This approximation to the true loss gradient allows us to have both update and backwards unlocking – the update to layer can be applied without any other network computation as soon as has been computed, since the module is not a function of the rest of the network (unlike ). Furthermore, note that since the true can be described completely as a function of just and , from a mathematical perspective this approximation is sufficiently parameterised. The synthetic gradient module has parameters which must themselves be trained to accurately predict the true gradient by minimising the L loss . The resulting learning system consists of three decoupled parts: first, the part of the network above the module which minimises wrt. to its parameters , then the module that minimises the wrt. to . Finally the part of the network below the module which uses as the learning signal to train , thus it is minimising the loss modeled internally by .
Assumptions and notation
Throughout the remainder of this paper, we consider the use of a single synthetic gradient module at a single layer and for a generic data sample and so refer to ; unless specified we drop the superscript and subscript . This model is shown in Figure Document (b). We also focus on modules which take the point’s true label/value as conditioning as opposed to . Note that without label conditioning, a module is trying to approximate not but rather since is a function of both input and label. In theory, the lack of label is a sufficient parametrisation but learning becomes harder, since the module has to additionally learn . We also focus most of our attention on models that employ linear modules, . Such modules have been shown to work well in practice, and furthermore are more tractable to analyse. As a shorthand, we denote to denote the subset of the parameters contained in modules up to (and symmetrically ), if is the th layer then .
Synthetic gradients in operation
Consider an -layer feed-forward network with a single module at layer . This network can be decomposed into two sub-networks: the first takes an input and produces an output , while the second network takes as an input, produces an output and incurs a loss based on a label . With regular backpropagation, the learning signal for the first network is , which is a signal that specifies how the input to should be changed in order to reduce the loss. When we attach a linear between these two networks, the first sub-network no longer receives the exact learning signal from , but an approximation , which implies that will be minimising an approximation of the loss, because it is using approximate error gradients. Since the module is a linear model of , the approximation of the true loss that is being optimised for will be a quadratic function of and . Note that this is not what a second order method does when a function is locally approximated with a quadratic and used for optimisation – here we are approximating the current loss, which is a function of parameters with a quadratic which is a function of . Three appealing properties of an approximation based on is that already encapsulates a lot of non-linearities due to the processing of , is usually vastly lower dimensional than which makes learning more tractable, and the error only depends on quantities () which are local to this part of the network rather than which requires knowledge of the entire network. With the module in place, the learning system decomposes into two tasks: the second sub-network tasked with minimising given inputs , while the first sub-network is tasked with pre-processing in such a way that the best fitted quadratic approximator of (wrt. ) is minimised. In addition, the module is tasked with best approximating
. The approximations and changing of learning objectives (described above) that are imposed by using synthetic gradients may appear to be extremely limiting. However, in both the theoretical and empirical sections of this paper we show that models can, and do, learn solutions to highly non-linear problems (such as memorising noise). The crucial mechanism that allows such rich behaviour is to remember that the implicit quadratic approximation to the loss implied by the module is local (per data point) and non-stationary – it is continually trained itself. It is not a single quadratic fit to the true loss over the entire optimisation landscape, but a local quadratic approximation specific to each instantaneous moment in optimisation. In addition, because the quadratic approximation is a function only ofand not , the loss approximation is still highly non-linear w.r.t. . If, instead of a linear module, one uses a more complex function approximator of gradients such as an MLP, the loss is effectively approximated by the integral of the MLP. More formally, the loss implied by the module in hypotheses space is of class 111We mean equality for all points where is defined.
. In particular, this shows an attractive mathematical benefit over predicting loss directly: by modelling gradients rather than losses, we get to implicitly model higher order loss functions.
We now consider the effect has on critical points of the optimisation problem. Concretely, it seems natural to ask whether a model augmented with is capable of learning the same functions as the original model. We ask this question under the assumption of a locally converging training method, such that we always end up in a critical point. In the case of a -based model this implies a set of parameters such that , and . In other words we are trying to establish whether introduces regularisation to the model class, which changes the critical points, or whether it merely introduces a modification to learning dynamics, but retains the same set of critical points. In general, the answer is positive: does induce a regularisation effect. However, in the presence of additional assumptions, we can show families of models and losses for which the original critical points are not affected. Every critical point of the original optimisation problem where can produce has a corresponding critical point of the -based model. Directly from the assumption we have that there exists a set of parameters such that the loss is minimal, thus and also and . The assumptions of this proposition are true for example when (one attains global minimum), when or a network is a deep linear model trained with MSE and is linear. In particular, this shows that for a large enough module all the critical points of the original problem have a corresponding critical point in the -based model. Limiting the space of hypotheses leads to inevitable reduction of number of original critical points, thus acting as a regulariser. At first this might look like a somewhat negative result, since in practice we rarely use a module capable of exactly producing true gradients. However, there are three important observations to make: (1) Our previous observation reflects having an exact representation of the gradient at the critical point, not in the whole parameter space. (2) One does preserve all the critical points where the loss is zero, and given current neural network training paradigms these critical points are important. For such cases even if is linear the critical points are preserved. (3) In practice one rarely optimises to absolute convergence regardless of the approach taken; rather we obtain numerical convergence meaning that is small enough. Thus, all one needs from -based model is to have small enough , implying that the approximation error at a critical point just has to be small wrt to and need not be 0. To recap: so far we have shown that can preserve critical points of the optimisation problem. However, can also introduce new critical points, leading to premature convergence and spurious additional solutions. As with our previous observation, this does not effect modules which are able to represent gradients exactly. But if the hypothesis space does not include a good approximator222In this case, our gradient approximation needs to be reasonable at every point through optimisation, not just the critical ones. of the true gradient, then we can get new critical points which end up being an equilibrium state between modules and the original network. We provide an example of such an equilibrium in the Supplementary Materials Section LABEL:SM:examples.
Having demonstrated that important critical points are preserved and also that new ones might get created, we need a better characterisation of the basins of attraction, and to understand when, in both theory and practice, one can expect convergence to a good solution.
We conduct an empirical analysis of the learning dynamics on easily analysable artificial data. We create 2 and 100 dimensional versions of four basic datasets (details in the Supplementary Materials Section LABEL:SM:experiments) and train four simple models (a linear model and a deep linear one with 10 hidden layers, trained to minimise MSE and log loss) with regular backprop and with a -based alternative to see whether it (numerically) converges to the same solution. For MSE and both shallow and deep linear architectures the -based model converges to the global optimum (exact numerical results provided in Supplementary Material Table LABEL:tab:diffs
). However, this is not the case for logistic regression. This effect is a direct consequence of a linear module being unable to model333 (where is the output of logistic regression), which often approaches the step function (when data is linearly separable), and cannot be well approximated with a linear function . Once one moves towards problems without this characteristic ( random labeling) the problem vanishes, since now can be approximated much better. While this may not seem particularly significant, it illustrates an important characteristic of in the context of the log loss – it will struggle to overfit to training data, since it requires modeling step function type shapes, which is not possible with a linear model. In particular this means that for best performance one should adapt the module architecture to the loss function used —for MSE linear is a reasonable choice, however for log loss one should use architectures including a sigmoid applied pointwise to a linear , such as .
As described in Section 2, using a linear module makes the implicit assumption that loss is a quadratic function of the activations. Furthermore, in such setting we can actually reconstruct the loss being used up to some additive constant since implies that . If we now construct a 2-dimensional dataset, where data points are arranged in a 2D grid, we can visualise the loss implicitly predicted by the module and compare it with the true loss for each point. Figure Artificial Data shows the results of such an experiment when learning a highly non-linear model (5-hidden layer relu network). As one can see, the quality of the loss approximation has two main components to its dynamics. First, it is better in layers closer to the true loss ( the topmost layers), which matches observations from DNI and the intuition that the lower layers solve a more complex problem (since they bootstrap their targets). Second, the loss is better approximated at the very beginning of the training and the quality of the approximation degrades slowly towards the end. This is a consequence of the fact that close to the end of training, the highly non-linear model has quite complex derivatives which cannot be well represented in a space of linear functions. It is worth noting, that in these experiments, the quality of the loss approximation dropped significantly when the true loss was around 0.001, thus it created good approximations for the majority of the learning process. There is also an empirical confirmation of the previous claim, that with log loss and data that can be separated, linear s will have problems modeling this relation close to the end of training (Figure Artificial Data (b) left), while there is no such problem for MSE loss (Figure Artificial Data (a) left).
It is trivial to note that if a module used is globally convergent to the true gradient, and we only use its output after it converges, then the whole model behaves like the one trained with regular backprop. However, in practice we never do this, and instead train the two models in parallel without waiting for convergence of the module. We now discuss some of the consequences of this, and begin by showing that as long as a synthetic gradient produced is close enough to the true one we still get convergence to the true critical points. Namely, only if the error introduced by , backpropagated to all the parameters, is consistently smaller than the norm of true gradient multiplied by some positive constant smaller than one, the whole system converges. Thus, we essentially need the error to vanish around critical points. Let us assume that a module is trained in each iteration in such a way that it -tracks true gradient, i.e. that . If is upper bounded by some and there exists a constant such that in every iteration , then the whole training process converges to the solution of the original problem. Proof follows from showing that, under the assumptions, effectively we are training with noisy gradients, where the noise is small enough for convergence guarantees given by zoutendijk1970nonlinear, gratton2011much to apply. Details are provided in the Supplementary Materials Section LABEL:SM:proofs. As a consequence of Proposition Convergence we can show that with specifically chosen learning rates (not merely ones that are small enough) we obtain convergence for deep linear models. For a deep linear model minimising MSE, trained with a linear module attached between two of its hidden layers, there exist learning rates in each iteration such that it converges to the critical point of the original problem. Proof follows directly from Propositions 1 and 2. Full proof is given in Supplementary Materials Section LABEL:SM:proofs.
For a shallow model we can guarantee convergence to the global solution provided we have a small enough learning rate, which is the main theoretical result of this paper. Let us consider linear regression trained with a linear module attached between its output and the loss. If one chooses the learning rate of the module using line search, then in every iteration there exists small enough, positive learning rate of the main network such that it converges to the global solution. The general idea (full proof in the Supplementary Materials SectionLABEL:SM:proofs) is to show that with assumed learning rates the sum of norms of network error and error decreases in every iteration. Despite covering a quite limited class of models, these are the very first convergence results for -based learning. Unfortunately, they do not seem to easily generalise to the non-linear cases, which we leave for future research.
We now shift our attention to more realistic data. We train deep relu networks of varied depth (up to 50 hidden layers) with batch-normalisation and with two different activation functions on MNIST and compare models trained with full backpropagation to variants that employ a module in the middle of the hidden stack.
Figure Document shows, that -based architectures converge well even if there are many hidden layers both below and above the module. Interestingly, -based models actually seem to converge faster (compare for example 20- or 50 layer deep relu network). We believe this may be due to some amount of loss function smoothing since, as described in Section Document, a linear module effectively models the loss function to be quadratic – thus the lower network has a simpler optimisation task and makes faster learning progress. Obtaining similar errors on MNIST does not necessarily mean that trained models are the same or even similar. Since the use of synthetic gradients can alter learning dynamics and introduce new critical points, they might converge to different types of models. Assessing the representational similarity between different models is difficult, however. One approach is to compute and visualise Representational Dissimilarity Matrices [kriegeskorte2008representational] for our data. We sample a subset of 400 points from MNIST, order them by label, and then record activations on each hidden layer when the network is presented with these points. We plot the pairwise correlation matrix for each layer, as shown in Figure Convergence. This representation is permutation invariant, and thus the emergence of a block-diagonal correlation matrix means that at a given layer, points from the same class already have very correlated representations. Under such visualisations one can notice qualitative differences between the representations developed under standard backpropagation training versus those delivered by a -based model. In particular, in the MNIST model with 20 hidden layers trained with standard backpropagation we see that the representation covariance after 9 layers is nearly the same as the final layer’s representation. However, by contrast, if we consider the same architecture but with a module in the middle we see that the layers before the module develop a qualitatively different style of representation. Note: this does not
mean that layers before do not learn anything useful. To confirm this, we also introduced linear classifier probes[alain2016understanding] and observed that, as with the pure backpropagation trained model, such probes can achieve 100% training accuracy after the first two hidden-layers of the -based model, as shown in Supplementary Material’s Figure LABEL:fig:probes. With 20 modules (one between every pair of layers), the representation is scattered even more broadly: we see rather different learning dynamics, with each layer contributing a small amount to the final solution, and there is no longer a point in the progression of layers where the representation is more or less static in terms of correlation structure (see Figure Convergence). Another way to investigate whether the trained models are qualitatively similar is to examine the norms of the weight matrices connecting consecutive hidden layers, and to assess whether the general shape of such norms are similar. While this does not definitively say anything about how much of the original classification is being solved in each hidden layer, it is a reasonable surrogate for how much computation is being performed in each layer444We train with a small L penalty added to weights to make norm correspond roughly to amount of computation..
According to our experiments (see Figure Document
for visualisation of one of the runs), models trained with backpropagation on MNIST tend to have norms slowly increasing towards the output of the network (with some fluctuations and differences coming from activation functions, random initialisations, etc.). If we now put a in between every two hidden layers, we get norms that start high, and then decrease towards the output of the network (with much more variance now). Finally, if we have a single module we can observe that the behaviour after the module resembles, at least to some degree, the distributions of norms obtained with backpropagation, while before the it is more chaotic, with some similarities to the distribution of weights with s in-between every two layers. These observations match the results of the previous experiment and the qualitative differences observed. When synthetic gradients are used to deliver full unlocking we obtain a very basic model at the lowest layers and then see iterative corrections in deeper layers. For a one-point unlocked model with a single module, we have two slightly separated models where one behaves similarly to backprop, and the other supports it. Finally, a fully locked model ( traditional backprop) solves the task relatively early on, and later just increases its confidence. We note that the results of this section support our previous notion that we are effectively dealing with a multi-agent system, which looks for coordination/equilibrium between components, rather than a single model which simply has some small noise injected into the gradients (and this is especially true for more complex models).
and conspiring networks
is a fixed, random matrix andis a matrix of ones of an appropriate dimension. * In SG+Prop the network is locked if there is a single module, however if we have multiple ones, then propagating error signal only locks a module with the next one, not with the entire network. Direct error means that a model tries to solve classification problem directly at layer .
We now shift our attention and consider a unified view of several different learning principles that work by replacing true gradients with surrogates. We focus on three such approaches: Feedback Alignment (FA) [lillicrap2016random], Direct Feedback Alignment (DFA) [NIPS2016_6441], and Kickback (KB) [balduzzi2014kickback]. FA effectively uses a fixed random matrix during backpropagation, rather than the transpose of the weight matrix used in the forward pass. DFA does the same, except each layer directly uses the learning signal from the output layer rather than the subsequent local one. KB also pushes the output learning signal directly but through a predefined matrix instead of a random one. By making appropriate choices for targets, losses, and model structure we can cast all of these methods in the framework, and view them as comprising two networks with a module in between them, wherein the first module builds a representation which makes the task of the predictions easier. We begin by noting that in the models described thus far we do not backpropagate the error back into the part of the main network preceding the module ( we assume ). However, if we relax this restriction, we can use this signal (perhaps with some scaling factor ) and obtain what we will refer to as a model. Intuitively, this additional learning signal adds capacity to our model and forces both the main network and the module to “conspire” towards a common goal of making better gradient predictions. From a practical perspective, according to our experiments, this additional signal heavily stabilises learning system555 In fact, ignoring the gradients predicted by and only using the derivative of the loss, i.e. , still provides enough learning signal to converge to a solution for the original task in the simple classification problems we considered. We posit a simple rationale for this: if one can predict gradients well using a simple transformation of network activations ( a linear mapping), this suggests that the loss itself can be predicted well too, and thus (implicitly) so can the correct outputs. . However, this comes at the cost of no longer being unlocked. Our main observation in this section is that FA, DFA, and KB can be expressed in the language of “conspiring” networks (see Table Document), of two-network systems that use a module. The only difference between these approaches is how one parametrises and what target we attempt to fit it to. This comes directly from the construction of these systems, and the fact that if we treat our targets as constants (as we do in methods), then the backpropagated error from each module () matches the prescribed update rule of each of these methods (). One direct result from this perspective is the fact that Kickback is essentially DFA with . For completeness, we note that regular backpropagation can also be expressed in this unified view – to do so, we construct a module such that the gradients it produces attempt to align the layer activations with the negation of the true learning signal (). In addition to unifying several different approaches, our mapping also illustrates the potential utility and diversity in the generic idea of predicting gradients.
This paper has presented new theory and analysis for the behaviour of synthetic gradients in feed forward models. Firstly, we showed that introducing does not necessarily change the critical points of the original problem, however at the same time it can introduce new critical points into the learning process. This is an important result showing that does not act like a typical regulariser despite simplifying the error signals. Secondly, we showed that (despite modifying learning dynamics) -based models converge to analogous solutions to the true model under some additional assumptions. We proved exact convergence for a simple class of models, and for more complex situations we demonstrated that the implicit loss model captures the characteristics of the true loss surface. It remains an open question how to characterise the learning dynamics in more general cases. Thirdly, we showed that despite these convergence properties the trained networks can be qualitatively different from the ones trained with backpropagation. While not necessarily a drawback, this is an important consequence one should be aware of when using synthetic gradients in practice. Finally, we provided a unified framework that can be used to describe alternative learning methods such as Synthetic Gradients, FA, DFA, and Kickback, as well as standard Backprop. The approach taken shows that the language of predicting gradients is suprisingly universal and provides additional intuitions and insights into the models.