Log In Sign Up

A unified theory of adaptive stochastic gradient descent as Bayesian filtering

by   Laurence Aitchison, et al.

There are a diverse array of schemes for adaptive stochastic gradient descent for optimizing neural networks, from fully factorised methods with and without momentum (e.g. RMSProp and ADAM), to Kronecker factored methods that consider the Hessian for a full weight matrix. However, these schemes have been derived and justified using a wide variety of mathematical approaches, and as such, there is no unified theory of adaptive stochastic gradients descent methods. Here, we provide such a theory by showing that many successful adaptive stochastic gradient descent schemes emerge by considering a filtering-based inference in a Bayesian optimization problem. In particular, we use backpropagated gradients to compute a Gaussian posterior over the optimal neural network parameters, given the data minibatches seen so far. Our unified theory is able to give some guidance to practitioners on how to choose between the large number of available optimization methods. In the fully factorised setting, we recover RMSProp and ADAM under different priors, along with additional improvements such as Nesterov acceleration and AdamW. Moreover, we obtain new recommendations, including the possibility of combining RMSProp and ADAM updates. In the Kronecker factored setting, we obtain a adaptive natural gradient adaptation scheme that is derived specifically for the minibatch setting. Furthermore, under a modified prior, we obtain a Kronecker factored analogue of RMSProp or ADAM, that preconditions the gradient by whitening (i.e. by multiplying by the square root of the Hessian, as in RMSProp/ADAM). Our work raises the hope that it is possible to achieve unified theoretical understanding of empirically successful adaptive gradient descent schemes for neural networks.


page 1

page 2

page 3

page 4


Local Quadratic Convergence of Stochastic Gradient Descent with Adaptive Step Size

Establishing a fast rate of convergence for optimization methods is cruc...

Stochastic Gradient Variance Reduction by Solving a Filtering Problem

Deep neural networks (DNN) are typically optimized using stochastic grad...

Nonlinear Acceleration of Deep Neural Networks

Regularized nonlinear acceleration (RNA) is a generic extrapolation sche...

On the insufficiency of existing momentum schemes for Stochastic Optimization

Momentum based stochastic gradient methods such as heavy ball (HB) and N...

Partitioned integrators for thermodynamic parameterization of neural networks

Stochastic Gradient Langevin Dynamics, the "unadjusted Langevin algorith...

Autoencoding with XCSF

Autoencoders enable data dimensionality reduction and are a key componen...

1 Introduction and Background

Neural network optimization methods fall into two broad classes: non-adaptive and adaptive. The cannonical non-adaptive method is vanilla stochastic gradient descent (SGD) with momentum which updates parameters by multiplying the exponential moving average gradient, , by a learning rate, ,


Here, we divide by the minibatch size because we define to be the gradient of the summed loss, whereas common practice is to use the gradient of the mean loss. Further, following the convention established by Adam (Kingma & Ba, 2015), , is computed by debiasing a raw exponential moving average, ,


where is the raw minibatch gradient, and is usually chosen to be 0.9. These methods are typically found to give excellent generalisation performance, and as such are used to train many state-of-the-art networks (e.g. ResNet (He et al., 2016), DenseNet (Huang et al., 2017), ResNeXt (Xie et al., 2017)).

Adaptive methods change the learning rates as a function of past gradients. These methods were initially introduced with vario-eta (Neuneier & Zimmermann, 1998), and many variants have recently been developed, including AdaGrad (Duchi et al., 2011), RMSprop (Hinton et al., 2012) and Adam (Kingma & Ba, 2015). The cannonical adaptive method, Adam, normalises the exponential moving average gradient by the root mean square of past gradients,




and where is typically chosen to be 0.999. These methods are often observed to converge faster, and hence may be used on problems which are more difficult to optimize (Graves, 2013), but can give worse generalisation performance than non-adaptive methods (Keskar & Socher, 2017; Loshchilov & Hutter, 2017; Wilson et al., 2017; Luo et al., 2019).

Having these two distinct classes of method raises considerable problems both at a practical and theoretical level. At the practical level, it dramatically increases the size of space that we must search over to find the optimal hyperparameters, especially as we may need to consider switching from one method to another

(Keskar & Socher, 2017), or using one method for some parameters and a different method for other parameters. At the theoretical level, there have been attempts to explain the effectiveness of adaptive and non-adaptive methods. For instance, it has been argued that non-adaptive methods approximate Langevin sampling (Mandt et al., 2017) whereas adaptive methods approximate natural-gradient updates (Zhang et al., 2017; Khan et al., 2017, 2018). However, to understand whether to use adaptive or non-adaptive methods, we need new theory that simultaneously addresses both cases.

Here we provide such a theory by reconciling state-of-the-art adaptive SGD algorithms with very early work that used Bayesian (Kalman) filtering to optimize the parameters of neural networks

(Puskorius & Feldkamp, 1991; Sha et al., 1992; Puskorius & Feldkamp, 1994, 2001; Feldkamp et al., 2003; Ollivier, 2017)

. We begin by formulating the problem of neural network optimization as Bayesian inference, which gives our first algorithm, AdaBayes. Next, we develop a second algorithm, AdaBayes-SS, by considering the steady-state learning rates in AdaBayes, this algorithm recovers non-adaptive (SGD) and adaptive methods (AdamW)

(Loshchilov & Hutter, 2017) in the low and high-data limits. Finally, we compare the performance of AdaBayes and AdaBayes-SS to standard baselines including SGD and Adam, and newer methods such as AdamW (Loshchilov & Hutter, 2017) and Ada/AMSBound (Luo et al., 2019).

2 Methods

Typically, to set up the problem of neural network optimization as Bayesian inference, we assume that there is some fixed, “true” set of neural network parameters,

, and infer a joint distribution over those parameters conditioned on the training data,

. However, in the domain of neural networks, the large number of parameters forces us to use extremely strong (usually factorised) approximations to the true posterior. To simplify reasoning about factorised approximations, we choose to infer the distribution over a parameter, , conditioned on the current setting of the other parameter,

, which change over time as they are optimized. In this conditional framework, we can use Bayes theorem to update the posterior with an additional minibatch of data,


where is all minibatches up to time . We can perform this update using reasonably straightforward approximations (see below). However, this form raises a critical question: how to transform from , which conditions on , and is the output of Eq. (5) at the previous time-step, into , which conditions on and is required as input to Eq. (5

) at this time step. These distributions are fundamentally different, as they are conditioned on different quantities, and probability theory does not offer an easy way to transform between them. All we can do is to recompute the likelihood for all past minibatches of data, each time

changes, which is clearly computationally intractable. Instead, we approximate the distribution over conditioned on by noting that it should be similar to, but slightly broader than the distribution conditioned on (as and are similar, but the data is based on earlier values of ). As such, we approximate the distribution conditioned on by convolving the distribution conditioned on with a narrow Gaussian kernel,


where is the p.d.f. of

under a Gaussian distribution with mean

and . Note that we set up the distribution with a small decay in the mean, such that, after many applications of this operator, the distribution over

converges to a Gaussian with zero mean and variance

, which constitutes a prior for ,


We can do inference under these updates by applying Eq. (5) and Eq. (6). In particular, as we will use a Gaussian approximation to the likelihood, we have,


where the updates for and can be computed from Eq. (6),


and the updates for and come from approximating Eq. (5). In particular, we approximate the log-likelihood using a second-order Taylor expansion,


where we identify the coefficients and using the minibatch gradient computed as usual by our automatic differentiation framework,


and we use a Fisher-Information based estimate of

, (as in Zhang et al., 2017; Khan et al., 2017, 2018),


Thus, we obtain,


The full uppdates are now specified by iteratively applying Eq. (10) and Eq. (14). Note that the updates are analogous to Kalman filtering updates (Kalman, 1960), but in our case effective dynamics emerge implicitly through changes in the distribution over induced by changes in the other parameters, , rather than being built-in in the problem definition.

Figure 1: The learning rate for AdaBayes (points) compared against the predicted steady-state value (green line), . The plot displays the low-data limit (orange line), which is valid when the value on the x-axis, , is much greater than (purple line), and the high-data limit (blue line), which is valid when the value on the x-axis is much smaller than (purple line).

Next, we make two changes to the updates for the mean, to match current best practice for optimizing neural networks. First, we allow more flexibility in weight decay, by replacing the term in Eq. (10a) with a new parameter, . Second, we incorporate momentum, by using an exponential moving average gradient, , instead of the raw minibatch gradient in Eq. (14b). In combination, the updates for the mean become,


Our complete Bayesian updates are now given by using Eq. (15) to update and , and using Eq. (10b) and Eq. (14a) to update and (see Algo. 1).

Finally, this derivation highlights the need to reason about the effective dynamics that emerge from changes in the other parameters, . In particular, there were no effective dynamics (i.e. if , and ), then the learning rates would shrink too rapidly (as ),


2.1 AdaBayes-SS, and recovering SGD and Adam

While we can (and do) run AdaBayes directly, we also consider a steady-state variant denoted AdaBayes-SS. We find that this variant has interesting performance characteristics in and of itself (see Experiments), and it allows us to recover standard adaptive and non-adaptive algorithms. This variant uses the same updates for the mean (Eq. 15), but uses a different value for . In particular, instead of running the usual updates for (Eq. 10b and Eq. 14a), AdaBayes-SS (Algo. 2) uses the steady-state value of implied by those equations (i.e. the expected value as ). To obtain the steady-state, we substitute the update for (Eq. 14a) into the update for (Eq. 10b), and neglect small terms (see Appendix A), which tells us that is given by the solution of a quadratic equation,


Solving for , we obtain,


To confirm that this steady-state expression is correct, we plotted the values for given by AdaBayes (Fig. 1 points) and found a good match to the values suggested by this steady-state expression (green line).

4:while not converged do
13:end while
Algorithm 1 AdaBayes
4:while not converged do
13:end while
Algorithm 2 AdaBayes-SS

The full listings for the AdaBayes and AdaBayes-SS algorithms. We compute the raw first and second moment exponential moving averages in

and , and debias them to obtain and , as in Adam (Kingma & Ba, 2015).

2.1.1 Recovering SGD in the low-data limit

In the low-data regime where , the Bayesian filtering learning rate (and equivalently uncertainty) is constant (Fig. 1; orange line),


so, the Bayesian filtering updates (Eq. 15b) become equivalent to vanilla SGD (Eq. 1). We can leverage this equivalence to set using standard values of the SGD learning rate,


Setting in this way would suggest 111here we use as in Physics to denote “ has the same order of magnitude as ”, see Acklam and Weisstein “Tilde” MathWorld., as , and the . It is important to sanity check that this value of corresponds to Bayesian filtering in a sensible generative model. In particular, note that is the variance of the prior over (Eq. 7), and as such should correspond to typical initialization schemes (e.g. He et al., 2015) which ensure that input and output activations have roughly the same scale. These schemes use , and if we consider that there are typically input channels, and we convolve over a pixel patch, we obtain , matching the previous value.

Finally, this equivalence implies that — at least under this interpretation of the learning rate as the prior variance — SGD can be viewed as dimensionally consistent, in contrast to some claims from e.g. the natural gradient literature (Kakade, 2002).

2.1.2 Recovering Adam(W) in the high-data limit

In the high-data regime, where , the Bayesian filtering learning rate becomes (Fig. 1; blue line),


so the updates (Eq.15b) become equivalent to Adam updates if we take,


As such, we are able to use past experience with good values for the Adam learning rate , to set : in our case we use .

Furthermore, when we consider the form of regularisation implied by our updates, we recover a state-of-the-art variant of Adam, known as AdamW (Loshchilov & Hutter, 2017)

. In standard Adam, weight-decay regularization is implemented by incorporating an L2 penalty on the weights in the loss function, so the gradient of the loss and regularizer are both normalized by the root-mean-square gradient. In contrast, AdamW “decouples” weight decay from the loss, such that the gradient of the loss is normalized by the root-mean-square gradients, but the weight decay is not. To see that our updates correspond to AdamW, we combine Eq. (

15a) and Eq. (15b), and substitute for (Eq. 21),


Indeed, the root-mean-square normalization applies only to the gradient of the loss, as in AdamW, and not to the weight decay term, as in standard Adam.

Finally, note that AdaBayes-SS becomes exactly AdamW when we set ,


because we use the standard Adam(W) approach to computing unbiased estimates of

and (see Algo. 2).

3 Experiments

For our experiments, we have adapted the code and protocols from a recent paper (Luo et al., 2019) on alternative methods for combining non-adaptive and adaptive behaviour (AdaBound and AMSBound). They considered a 34-layer ResNet (He et al., 2016) and a 121-layer DenseNet on CIFAR-10 (Huang et al., 2017)

, trained for 200 epochs with learning rates that decreased by a factor of

at epoch 150. We used the exact same networks and protocol, except that we run for more epochs, we plot both classification error and the loss, and we use both CIFAR-10 and CIFAR-100. We used their optimized hyperparameter settings for standard baselines (including SGD and Adam), and their choice of hyperparameters for their methods (AdaBound and AMSBound). For AdamW and AdaBayes, we used and set using Eq. (20), and we used (matched to the optimal learning rate for standard Adam). We used decoupled weight decay of (from Luo et al., 2019), and we used the equivalence of SGD with weight decay and SGD with decoupled weight decay to set the decoupled weight decay coefficient to for AdamW, AdaBayes and AdaBayes-SS.

The results are given in Table 1 and Fig. 2. The best adaptive method is always one of our methods (AdaBayes or AdaBayes-SS), though SGD is frequently superior to all adaptive methods tested. To begin, we compare our methods (AdaBayes and AdaBayes-SS) to the canonical non-adaptive (SGD) and adaptive (Adam) method (see Fig. A1 for a cleaner figure, including other baselines). Note that AdaBayes and AdaBayes-SS improve their accuracy and loss more rapidly than baseline methods during the initial part of learning (with the exception of Adam for CIFAR-100). Our algorithms give better test error and loss than Adam, for all networks and datasets, they give better test error than SGD for CIFAR-10, and perform similarly to SGD in the other cases, with AdaBayes-SS often giving better performance than AdaBayes. Next, we see that AdaBayes-SS improves considerably over AdaBayes (see Fig. A2 for a cleaner figure), except in the case of CIFAR-10 classification error, where the difference is minimal. Next, we compared AdaBayes and AdaBayes-SS with AdaBound and AMSBound (Luo et al., 2019) (Fig. A3). We found that the learning curves for AdaBayes and AdaBound/AMSBound were very similar with AdaBayes usually giving better test error. The behaviour of the loss curves were very different, with AdaBayes-SS giving considerably test-loss than either AdaBayes or AdaBound/AMSBound. Finally, we compared AdaBound/AMSBound with SGD (Fig. A4), finding that while AdaBound/AMSBound perform slightly better in terms of CIFAR-10 classification error, SGD performs better in all other test domains.

Figure 2: Test loss and classification error for CIFAR-10 and CIFAR-100 for a Resnet-34 and a DenseNet-121, for multiple update algorithms.
ResNet DenseNet ResNet DenseNet
optimizer error (%) loss error (%) loss error (%) loss error (%) loss
SGD 5.170 0.174 5.580 0.177 22.710 0.833 21.290 0.774
Adam 7.110 0.239 6.690 0.230 27.590 1.049 26.640 1.074
AdaGrad 6.840 0.307 7.490 0.338 30.350 1.347 30.110 1.319
AMSGrad 6.720 0.239 6.170 0.234 27.430 1.033 25.850 1.103
AdaBound 5.140 0.220 4.850 0.210 23.060 1.004 22.210 1.050
AMSBound 4.940 0.210 4.960 0.219 23.000 1.003 22.360 1.017
AdamW 5.080 0.239 5.190 0.214 24.850 1.142 23.480 1.043
AdaBayes-SS 5.230 0.187 4.910 0.176 23.120 0.935 22.600 0.934
AdaBayes 4.840 0.229 4.560 0.222 22.920 0.969 22.090 1.079
Table 1: A table displaying the minimal test error and test loss for a ResNet and DenseNet applied to CIFAR-10 and CIFAR-100 for different optimizers. The table displays the best adaptive algorithm (bold), which is always one of our methods: either AdaBayes or AdaBayes-SS. We also display the instances where SGD (gray) beats all adaptive methods (in which case we also embolden the SGD value).

4 Discussion

Here, we showed that Bayesian filtering recovers standard adaptive (AdamW) and non-adaptive (SGD) neural network optimization algorithms in the high and low data limits. Based on this insight, we provided two new algorithms, AdaBayes and AdaBayes-SS, which adaptively transition between SGD-like and Adam(W)-like behaviour, and converge more rapidly than standard adaptive methods such as Adam, despite giving generalisation performance similar to SGD. In our experiments, either AdaBayes or AdaBayes-SS outperformed other adaptive methods, including AdamW (Loshchilov & Hutter, 2017), and Ada/AMSBound (Luo et al., 2019), though SGD frequently outperformed all adaptive methods.

It is also important to note that we have tried to take a pragmatic, rather than purist Bayesian, approach to specifying the update rule. This emerges in several features of our updates. First, we introduced momentum by replacing the raw minibatch gradient with the exponential moving average gradient. That said, there may be scope to give a normative justification for introducing momentum, especially as there is nothing in the Bayesian framework that forces us to evaluate the gradient at the current poosterior mean/mode. Second, the decoupled weight decay coefficient should be fixed by and (Eq. 10a), whereas we allowed it to be a free parameter. It remains to be seen whether a model which obeys the required constraints can equal the performance of those presented here. Third, we assumed that was constant across parameters. While this was necessary to match common practice in neural network optimization, in the Bayesian framework, should be the prior variance, which should change across parameters, as in standard initialization schemes. Fourth, we considered conditioning on the other parameters, rather than, as is typical in variational inference, considering a distribution over those parameters. As such, it will be interesting to consider whether the same reasoning can be extended to the case of variational inference. Finally, our derivations apply in the online setting where each minibatch of data is assumed to be new. As such, it might be possible to slightly improve our updates by taking into account repeated data in the batch setting.

Finally, Bayesian filtering presents a novel approach to neural network optimization, and as such, there are variety of directions for future work. First, it should be possible to develop a variant of the method for which represents the covariance of full weight matrix, by exploiting Kronecker factorisation (Martens & Grosse, 2015; Grosse & Martens, 2016; Zhang et al., 2017). These schemes are potentially of interest, because they give a principled description of how to compute the relevant covariance matricies even when we have seen only a small number of minibatches. Second, stochastic regularization has been shown to be extremely effective at reducing generalization error in neural networks. This Bayesian interpretation of neural network optimization presents opportunities for new stochastic regularization schemes (Kutschireiter et al., 2015).


Appendix A Steady-state covariance

For the steady-state covariance, it is slightly more convenient to work with the inverse variance,


though the same results can be obtained through either route. Substituting Eq. (10b) into Eq. (14a) and taking , we obtain an update from to ,

assuming has reached steady-state, we have ,
Assuming that the magnitude of the update to is small, we can take a first-order Taylor of the first term,
and rearranging,

and finally substituting for gives the expression in the main text.

Appendix B Additional data figures

Here, we replot Fig. 2 to clarify particular comparisons. In particular, we compare AdaBayes(-SS) with standard baselines (Fig. A1), Adam AdamW and AdaBayes-SS (Fig. A2), AdaBayes(-SS) and Ada/AMSBound (Fig. A3), and Ada/AMSBound and SGD (Fig. A4). Finally, we plot the training error and loss for all methods (Fig. A5; note the loss does not include the regularizer, so it may go up without this being evidence of overfitting).

Figure A1: Test loss and classification error for CIFAR-10 and CIFAR-100 for a Resnet-34 and a DenseNet-121, comparing our methods (AdaBayes and AdaBayes-SS) with standard baselines (SGD, Adam, AdaGrad and AMSGrad (Reddi et al., 2018)).
Figure A2: Test loss and classification error for CIFAR-10 and CIFAR-100 for a Resnet-34 and a DenseNet-121, comparing Adam, AdamW and AdaBayes-SS.
Figure A3: Test loss and classification error for CIFAR-10 and CIFAR-100 for a Resnet-34 and a DenseNet-121, comparing our methods (AdaBayes and AdaBayes-SS) with AdaBound/AMSBound Luo et al. (2019).
Figure A4: Test loss and classification error for CIFAR-10 and CIFAR-100 for a Resnet-34 and a DenseNet-121, comparing AdaBound/AMSBound Luo et al. (2019) and SGD.
Figure A5: Train loss and classification error for CIFAR-10 and CIFAR-100 for a Resnet-34 and a DenseNet-121, for all methods in Fig. 2.