DeepAI

# A unified theory of adaptive stochastic gradient descent as Bayesian filtering

12/30/2021

Establishing a fast rate of convergence for optimization methods is cruc...
12/22/2020

### Stochastic Gradient Variance Reduction by Solving a Filtering Problem

Deep neural networks (DNN) are typically optimized using stochastic grad...
09/12/2022

### Convergence of Batch Stochastic Gradient Descent Methods with Approximate Gradients and/or Noisy Measurements: Theory and Computational Results

In this paper, we study convex optimization using a very general formula...
05/24/2018

### Nonlinear Acceleration of Deep Neural Networks

Regularized nonlinear acceleration (RNA) is a generic extrapolation sche...
03/15/2018

### On the insufficiency of existing momentum schemes for Stochastic Optimization

Momentum based stochastic gradient methods such as heavy ball (HB) and N...
08/30/2019

10/23/2019

### Autoencoding with XCSF

Autoencoders enable data dimensionality reduction and are a key componen...

## 1 Introduction and Background

Neural network optimization methods fall into two broad classes: non-adaptive and adaptive. The cannonical non-adaptive method is vanilla stochastic gradient descent (SGD) with momentum which updates parameters by multiplying the exponential moving average gradient, , by a learning rate, ,

 ΔwSGD(t) =ηSGD⟨g(t)⟩minibatch size% . (1)

Here, we divide by the minibatch size because we define to be the gradient of the summed loss, whereas common practice is to use the gradient of the mean loss. Further, following the convention established by Adam (Kingma & Ba, 2015), , is computed by debiasing a raw exponential moving average, ,

 m(t) =β1m(t−1)+(1−β1)g(t) ⟨g(t)⟩ =m(t)1−βt1. (2)

where is the raw minibatch gradient, and is usually chosen to be 0.9. These methods are typically found to give excellent generalisation performance, and as such are used to train many state-of-the-art networks (e.g. ResNet (He et al., 2016), DenseNet (Huang et al., 2017), ResNeXt (Xie et al., 2017)).

Adaptive methods change the learning rates as a function of past gradients. These methods were initially introduced with vario-eta (Neuneier & Zimmermann, 1998), and many variants have recently been developed, including AdaGrad (Duchi et al., 2011), RMSprop (Hinton et al., 2012) and Adam (Kingma & Ba, 2015). The cannonical adaptive method, Adam, normalises the exponential moving average gradient by the root mean square of past gradients,

where,

 v(t) =β2v(t−1)+(1−β2)g2(t) ⟨g2(t)⟩ =v(t)1−βt2, (4)

and where is typically chosen to be 0.999. These methods are often observed to converge faster, and hence may be used on problems which are more difficult to optimize (Graves, 2013), but can give worse generalisation performance than non-adaptive methods (Keskar & Socher, 2017; Loshchilov & Hutter, 2017; Wilson et al., 2017; Luo et al., 2019).

Having these two distinct classes of method raises considerable problems both at a practical and theoretical level. At the practical level, it dramatically increases the size of space that we must search over to find the optimal hyperparameters, especially as we may need to consider switching from one method to another

(Keskar & Socher, 2017), or using one method for some parameters and a different method for other parameters. At the theoretical level, there have been attempts to explain the effectiveness of adaptive and non-adaptive methods. For instance, it has been argued that non-adaptive methods approximate Langevin sampling (Mandt et al., 2017) whereas adaptive methods approximate natural-gradient updates (Zhang et al., 2017; Khan et al., 2017, 2018). However, to understand whether to use adaptive or non-adaptive methods, we need new theory that simultaneously addresses both cases.

Here we provide such a theory by reconciling state-of-the-art adaptive SGD algorithms with very early work that used Bayesian (Kalman) filtering to optimize the parameters of neural networks

(Puskorius & Feldkamp, 1991; Sha et al., 1992; Puskorius & Feldkamp, 1994, 2001; Feldkamp et al., 2003; Ollivier, 2017)

. We begin by formulating the problem of neural network optimization as Bayesian inference, which gives our first algorithm, AdaBayes. Next, we develop a second algorithm, AdaBayes-SS, by considering the steady-state learning rates in AdaBayes, this algorithm recovers non-adaptive (SGD) and adaptive methods (AdamW)

(Loshchilov & Hutter, 2017) in the low and high-data limits. Finally, we compare the performance of AdaBayes and AdaBayes-SS to standard baselines including SGD and Adam, and newer methods such as AdamW (Loshchilov & Hutter, 2017) and Ada/AMSBound (Luo et al., 2019).

## 2 Methods

Typically, to set up the problem of neural network optimization as Bayesian inference, we assume that there is some fixed, “true” set of neural network parameters,

, and infer a joint distribution over those parameters conditioned on the training data,

. However, in the domain of neural networks, the large number of parameters forces us to use extremely strong (usually factorised) approximations to the true posterior. To simplify reasoning about factorised approximations, we choose to infer the distribution over a parameter, , conditioned on the current setting of the other parameter,

, which change over time as they are optimized. In this conditional framework, we can use Bayes theorem to update the posterior with an additional minibatch of data,

 P(wi|D(t),w−i(t)) ∝P(d(t)|wi,w−i(t))P(wi|D(t−1),w−i(t)). (5)

where is all minibatches up to time . We can perform this update using reasonably straightforward approximations (see below). However, this form raises a critical question: how to transform from , which conditions on , and is the output of Eq. (5) at the previous time-step, into , which conditions on and is required as input to Eq. (5

) at this time step. These distributions are fundamentally different, as they are conditioned on different quantities, and probability theory does not offer an easy way to transform between them. All we can do is to recompute the likelihood for all past minibatches of data, each time

changes, which is clearly computationally intractable. Instead, we approximate the distribution over conditioned on by noting that it should be similar to, but slightly broader than the distribution conditioned on (as and are similar, but the data is based on earlier values of ). As such, we approximate the distribution conditioned on by convolving the distribution conditioned on with a narrow Gaussian kernel,

 P(wi|D(t−1),w−i(t)) =∫dw′iN(wi;(1−η22σ2)w′i,η2)P(wi|s(t),w−i(t)), (6)

where is the p.d.f. of

under a Gaussian distribution with mean

and . Note that we set up the distribution with a small decay in the mean, such that, after many applications of this operator, the distribution over

converges to a Gaussian with zero mean and variance

, which constitutes a prior for ,

 P(wi)=limT→∞P(wi|D(t−1),w−i(T)) =N(wi;0,σ2). (7)

We can do inference under these updates by applying Eq. (5) and Eq. (6). In particular, as we will use a Gaussian approximation to the likelihood, we have,

 P(wi|D(t−1),w−i(t)) =N(wi;μprior(t),σ2% prior(t)), (8) P(wi|D(t),w−i(t)) =N(wi;μpost(t),σ2% post(t)). (9)

where the updates for and can be computed from Eq. (6),

 μprior(t) =(1−η22σ2)2μpost(t−1), (10a) σ2prior(t) =(1−η22σ2)2σ2post(t−1)+η2, (10b)

and the updates for and come from approximating Eq. (5). In particular, we approximate the log-likelihood using a second-order Taylor expansion,

 logP(d(t)|wi,w−i) ≈−12Λlike(w−μlike)2 (11)

where we identify the coefficients and using the minibatch gradient computed as usual by our automatic differentiation framework,

 g =∂∂wilogP(d(t)|wi,w−i)=Λlike(μlike−wi), (12)

and we use a Fisher-Information based estimate of

, (as in Zhang et al., 2017; Khan et al., 2017, 2018),

 Λlike ≈g2. (13)

Thus, we obtain,

 σ2post(t) =11σ2prior(t)+Λ% like≈11σ2prior(t)+g2(t), (14a) μpost(t) =μprior(t)+σ2post(t)g(t). (14b)

The full uppdates are now specified by iteratively applying Eq. (10) and Eq. (14). Note that the updates are analogous to Kalman filtering updates (Kalman, 1960), but in our case effective dynamics emerge implicitly through changes in the distribution over induced by changes in the other parameters, , rather than being built-in in the problem definition.

Next, we make two changes to the updates for the mean, to match current best practice for optimizing neural networks. First, we allow more flexibility in weight decay, by replacing the term in Eq. (10a) with a new parameter, . Second, we incorporate momentum, by using an exponential moving average gradient, , instead of the raw minibatch gradient in Eq. (14b). In combination, the updates for the mean become,

 μprior(t) =(1−λ)2μpost(t−1), (15a) μpost(t) =μprior(t)+σ2post(t)⟨g(t)⟩. (15b)

Our complete Bayesian updates are now given by using Eq. (15) to update and , and using Eq. (10b) and Eq. (14a) to update and (see Algo. 1).

Finally, this derivation highlights the need to reason about the effective dynamics that emerge from changes in the other parameters, . In particular, there were no effective dynamics (i.e. if , and ), then the learning rates would shrink too rapidly (as ),

 σ2post(t) ≈11σ2+∑tg2(t)≈1t⟨g2⟩. (16)

 0 ≈σ2⎛⎝1σ2post⎞⎠2−1σ2post−⟨g2⟩σ2η2. (17)

Solving for , we obtain,

 1σ2post ≈12σ2⎛⎜ ⎜ ⎜⎝1+   ⎷1+4⎛⎜ ⎜⎝σ2\nicefracη√⟨g2⟩⎞⎟ ⎟⎠2⎞⎟ ⎟ ⎟⎠. (18)

To confirm that this steady-state expression is correct, we plotted the values for given by AdaBayes (Fig. 1 points) and found a good match to the values suggested by this steady-state expression (green line).

#### 2.1.1 Recovering SGD in the low-data limit

In the low-data regime where , the Bayesian filtering learning rate (and equivalently uncertainty) is constant (Fig. 1; orange line),

 σ2post ≈σ2, (19)

so, the Bayesian filtering updates (Eq. 15b) become equivalent to vanilla SGD (Eq. 1). We can leverage this equivalence to set using standard values of the SGD learning rate,

 σ2=ηSGDminibatch size. (20)

Setting in this way would suggest 111here we use as in Physics to denote “ has the same order of magnitude as ”, see Acklam and Weisstein “Tilde” MathWorld. http://mathworld.wolfram.com/Tilde.html, as , and the . It is important to sanity check that this value of corresponds to Bayesian filtering in a sensible generative model. In particular, note that is the variance of the prior over (Eq. 7), and as such should correspond to typical initialization schemes (e.g. He et al., 2015) which ensure that input and output activations have roughly the same scale. These schemes use , and if we consider that there are typically input channels, and we convolve over a pixel patch, we obtain , matching the previous value.

Finally, this equivalence implies that — at least under this interpretation of the learning rate as the prior variance — SGD can be viewed as dimensionally consistent, in contrast to some claims from e.g. the natural gradient literature (Kakade, 2002).

#### 2.1.2 Recovering Adam(W) in the high-data limit

In the high-data regime, where , the Bayesian filtering learning rate becomes (Fig. 1; blue line),

 σ2post ≈η√⟨g2⟩ (21)

As such, we are able to use past experience with good values for the Adam learning rate , to set : in our case we use .

Furthermore, when we consider the form of regularisation implied by our updates, we recover a state-of-the-art variant of Adam, known as AdamW (Loshchilov & Hutter, 2017)

. In standard Adam, weight-decay regularization is implemented by incorporating an L2 penalty on the weights in the loss function, so the gradient of the loss and regularizer are both normalized by the root-mean-square gradient. In contrast, AdamW “decouples” weight decay from the loss, such that the gradient of the loss is normalized by the root-mean-square gradients, but the weight decay is not. To see that our updates correspond to AdamW, we combine Eq. (

15a) and Eq. (15b), and substitute for (Eq. 21),

 μpost(t) ≈λμpost(t−1)+η√⟨g2(t)⟩⟨g(t)⟩. (23)

Indeed, the root-mean-square normalization applies only to the gradient of the loss, as in AdamW, and not to the weight decay term, as in standard Adam.

 limσ2→∞1σ2post% =limσ2→∞(12σ2+√14σ4+⟨g2⟩η2)=η√⟨g2⟩, (24)

because we use the standard Adam(W) approach to computing unbiased estimates of

and (see Algo. 2).

## 3 Experiments

For our experiments, we have adapted the code and protocols from a recent paper (Luo et al., 2019) on alternative methods for combining non-adaptive and adaptive behaviour (AdaBound and AMSBound). They considered a 34-layer ResNet (He et al., 2016) and a 121-layer DenseNet on CIFAR-10 (Huang et al., 2017)

, trained for 200 epochs with learning rates that decreased by a factor of

at epoch 150. We used the exact same networks and protocol, except that we run for more epochs, we plot both classification error and the loss, and we use both CIFAR-10 and CIFAR-100. We used their optimized hyperparameter settings for standard baselines (including SGD and Adam), and their choice of hyperparameters for their methods (AdaBound and AMSBound). For AdamW and AdaBayes, we used and set using Eq. (20), and we used (matched to the optimal learning rate for standard Adam). We used decoupled weight decay of (from Luo et al., 2019), and we used the equivalence of SGD with weight decay and SGD with decoupled weight decay to set the decoupled weight decay coefficient to for AdamW, AdaBayes and AdaBayes-SS.

## 4 Discussion

It is also important to note that we have tried to take a pragmatic, rather than purist Bayesian, approach to specifying the update rule. This emerges in several features of our updates. First, we introduced momentum by replacing the raw minibatch gradient with the exponential moving average gradient. That said, there may be scope to give a normative justification for introducing momentum, especially as there is nothing in the Bayesian framework that forces us to evaluate the gradient at the current poosterior mean/mode. Second, the decoupled weight decay coefficient should be fixed by and (Eq. 10a), whereas we allowed it to be a free parameter. It remains to be seen whether a model which obeys the required constraints can equal the performance of those presented here. Third, we assumed that was constant across parameters. While this was necessary to match common practice in neural network optimization, in the Bayesian framework, should be the prior variance, which should change across parameters, as in standard initialization schemes. Fourth, we considered conditioning on the other parameters, rather than, as is typical in variational inference, considering a distribution over those parameters. As such, it will be interesting to consider whether the same reasoning can be extended to the case of variational inference. Finally, our derivations apply in the online setting where each minibatch of data is assumed to be new. As such, it might be possible to slightly improve our updates by taking into account repeated data in the batch setting.

Finally, Bayesian filtering presents a novel approach to neural network optimization, and as such, there are variety of directions for future work. First, it should be possible to develop a variant of the method for which represents the covariance of full weight matrix, by exploiting Kronecker factorisation (Martens & Grosse, 2015; Grosse & Martens, 2016; Zhang et al., 2017). These schemes are potentially of interest, because they give a principled description of how to compute the relevant covariance matricies even when we have seen only a small number of minibatches. Second, stochastic regularization has been shown to be extremely effective at reducing generalization error in neural networks. This Bayesian interpretation of neural network optimization presents opportunities for new stochastic regularization schemes (Kutschireiter et al., 2015).

## References

For the steady-state covariance, it is slightly more convenient to work with the inverse variance,

 λpost =1σ2post, (25)

though the same results can be obtained through either route. Substituting Eq. (10b) into Eq. (14a) and taking , we obtain an update from to ,

 λpost(t+1) =11−η2/σ2λpost(t)+η2+g2(t) (26) assuming λpost has reached steady-state, we have λpost=λpost(t)=λpost(t+1), λpost =11−η2/σ2λpost+η2+⟨g2⟩. (27) Rearranging, λpost =λpost1−η2/σ2+η2λpost+⟨g2⟩. (28) Assuming that the magnitude of the update to λpost is small, we can take a first-order Taylor of the first term, λpost ≈λpost(1+η2σ2−η2λpost)+⟨g2⟩. (29) cancelling, 0 ≈η2λ2post−η2σ2λpost−⟨g2⟩, (30) and rearranging, 0 ≈σ2λ2post−λpost−⟨g2⟩σ2η2, (31)

and finally substituting for gives the expression in the main text.

## Appendix B Additional data figures

Here, we replot Fig. 2 to clarify particular comparisons. In particular, we compare AdaBayes(-SS) with standard baselines (Fig. A1), Adam AdamW and AdaBayes-SS (Fig. A2), AdaBayes(-SS) and Ada/AMSBound (Fig. A3), and Ada/AMSBound and SGD (Fig. A4). Finally, we plot the training error and loss for all methods (Fig. A5; note the loss does not include the regularizer, so it may go up without this being evidence of overfitting).