1 Introduction and Background
Neural network optimization methods fall into two broad classes: nonadaptive and adaptive. The cannonical nonadaptive method is vanilla stochastic gradient descent (SGD) with momentum which updates parameters by multiplying the exponential moving average gradient, , by a learning rate, ,
(1) 
Here, we divide by the minibatch size because we define to be the gradient of the summed loss, whereas common practice is to use the gradient of the mean loss. Further, following the convention established by Adam (Kingma & Ba, 2015), , is computed by debiasing a raw exponential moving average, ,
(2) 
where is the raw minibatch gradient, and is usually chosen to be 0.9. These methods are typically found to give excellent generalisation performance, and as such are used to train many stateoftheart networks (e.g. ResNet (He et al., 2016), DenseNet (Huang et al., 2017), ResNeXt (Xie et al., 2017)).
Adaptive methods change the learning rates as a function of past gradients. These methods were initially introduced with varioeta (Neuneier & Zimmermann, 1998), and many variants have recently been developed, including AdaGrad (Duchi et al., 2011), RMSprop (Hinton et al., 2012) and Adam (Kingma & Ba, 2015). The cannonical adaptive method, Adam, normalises the exponential moving average gradient by the root mean square of past gradients,
(3) 
where,
(4) 
and where is typically chosen to be 0.999. These methods are often observed to converge faster, and hence may be used on problems which are more difficult to optimize (Graves, 2013), but can give worse generalisation performance than nonadaptive methods (Keskar & Socher, 2017; Loshchilov & Hutter, 2017; Wilson et al., 2017; Luo et al., 2019).
Having these two distinct classes of method raises considerable problems both at a practical and theoretical level. At the practical level, it dramatically increases the size of space that we must search over to find the optimal hyperparameters, especially as we may need to consider switching from one method to another
(Keskar & Socher, 2017), or using one method for some parameters and a different method for other parameters. At the theoretical level, there have been attempts to explain the effectiveness of adaptive and nonadaptive methods. For instance, it has been argued that nonadaptive methods approximate Langevin sampling (Mandt et al., 2017) whereas adaptive methods approximate naturalgradient updates (Zhang et al., 2017; Khan et al., 2017, 2018). However, to understand whether to use adaptive or nonadaptive methods, we need new theory that simultaneously addresses both cases.Here we provide such a theory by reconciling stateoftheart adaptive SGD algorithms with very early work that used Bayesian (Kalman) filtering to optimize the parameters of neural networks
(Puskorius & Feldkamp, 1991; Sha et al., 1992; Puskorius & Feldkamp, 1994, 2001; Feldkamp et al., 2003; Ollivier, 2017). We begin by formulating the problem of neural network optimization as Bayesian inference, which gives our first algorithm, AdaBayes. Next, we develop a second algorithm, AdaBayesSS, by considering the steadystate learning rates in AdaBayes, this algorithm recovers nonadaptive (SGD) and adaptive methods (AdamW)
(Loshchilov & Hutter, 2017) in the low and highdata limits. Finally, we compare the performance of AdaBayes and AdaBayesSS to standard baselines including SGD and Adam, and newer methods such as AdamW (Loshchilov & Hutter, 2017) and Ada/AMSBound (Luo et al., 2019).2 Methods
Typically, to set up the problem of neural network optimization as Bayesian inference, we assume that there is some fixed, “true” set of neural network parameters,
, and infer a joint distribution over those parameters conditioned on the training data,
. However, in the domain of neural networks, the large number of parameters forces us to use extremely strong (usually factorised) approximations to the true posterior. To simplify reasoning about factorised approximations, we choose to infer the distribution over a parameter, , conditioned on the current setting of the other parameter,, which change over time as they are optimized. In this conditional framework, we can use Bayes theorem to update the posterior with an additional minibatch of data,
(5) 
where is all minibatches up to time . We can perform this update using reasonably straightforward approximations (see below). However, this form raises a critical question: how to transform from , which conditions on , and is the output of Eq. (5) at the previous timestep, into , which conditions on and is required as input to Eq. (5
) at this time step. These distributions are fundamentally different, as they are conditioned on different quantities, and probability theory does not offer an easy way to transform between them. All we can do is to recompute the likelihood for all past minibatches of data, each time
changes, which is clearly computationally intractable. Instead, we approximate the distribution over conditioned on by noting that it should be similar to, but slightly broader than the distribution conditioned on (as and are similar, but the data is based on earlier values of ). As such, we approximate the distribution conditioned on by convolving the distribution conditioned on with a narrow Gaussian kernel,(6) 
where is the p.d.f. of
under a Gaussian distribution with mean
and . Note that we set up the distribution with a small decay in the mean, such that, after many applications of this operator, the distribution overconverges to a Gaussian with zero mean and variance
, which constitutes a prior for ,(7) 
We can do inference under these updates by applying Eq. (5) and Eq. (6). In particular, as we will use a Gaussian approximation to the likelihood, we have,
(8)  
(9) 
where the updates for and can be computed from Eq. (6),
(10a)  
(10b) 
and the updates for and come from approximating Eq. (5). In particular, we approximate the loglikelihood using a secondorder Taylor expansion,
(11) 
where we identify the coefficients and using the minibatch gradient computed as usual by our automatic differentiation framework,
(12) 
and we use a FisherInformation based estimate of
, (as in Zhang et al., 2017; Khan et al., 2017, 2018),(13) 
Thus, we obtain,
(14a)  
(14b) 
The full uppdates are now specified by iteratively applying Eq. (10) and Eq. (14). Note that the updates are analogous to Kalman filtering updates (Kalman, 1960), but in our case effective dynamics emerge implicitly through changes in the distribution over induced by changes in the other parameters, , rather than being builtin in the problem definition.
Next, we make two changes to the updates for the mean, to match current best practice for optimizing neural networks. First, we allow more flexibility in weight decay, by replacing the term in Eq. (10a) with a new parameter, . Second, we incorporate momentum, by using an exponential moving average gradient, , instead of the raw minibatch gradient in Eq. (14b). In combination, the updates for the mean become,
(15a)  
(15b) 
Our complete Bayesian updates are now given by using Eq. (15) to update and , and using Eq. (10b) and Eq. (14a) to update and (see Algo. 1).
Finally, this derivation highlights the need to reason about the effective dynamics that emerge from changes in the other parameters, . In particular, there were no effective dynamics (i.e. if , and ), then the learning rates would shrink too rapidly (as ),
(16) 
2.1 AdaBayesSS, and recovering SGD and Adam
While we can (and do) run AdaBayes directly, we also consider a steadystate variant denoted AdaBayesSS. We find that this variant has interesting performance characteristics in and of itself (see Experiments), and it allows us to recover standard adaptive and nonadaptive algorithms. This variant uses the same updates for the mean (Eq. 15), but uses a different value for . In particular, instead of running the usual updates for (Eq. 10b and Eq. 14a), AdaBayesSS (Algo. 2) uses the steadystate value of implied by those equations (i.e. the expected value as ). To obtain the steadystate, we substitute the update for (Eq. 14a) into the update for (Eq. 10b), and neglect small terms (see Appendix A), which tells us that is given by the solution of a quadratic equation,
(17) 
Solving for , we obtain,
(18) 
To confirm that this steadystate expression is correct, we plotted the values for given by AdaBayes (Fig. 1 points) and found a good match to the values suggested by this steadystate expression (green line).
The full listings for the AdaBayes and AdaBayesSS algorithms. We compute the raw first and second moment exponential moving averages in
and , and debias them to obtain and , as in Adam (Kingma & Ba, 2015).2.1.1 Recovering SGD in the lowdata limit
In the lowdata regime where , the Bayesian filtering learning rate (and equivalently uncertainty) is constant (Fig. 1; orange line),
(19) 
so, the Bayesian filtering updates (Eq. 15b) become equivalent to vanilla SGD (Eq. 1). We can leverage this equivalence to set using standard values of the SGD learning rate,
(20) 
Setting in this way would suggest ^{1}^{1}1here we use as in Physics to denote “ has the same order of magnitude as ”, see Acklam and Weisstein “Tilde” MathWorld. http://mathworld.wolfram.com/Tilde.html, as , and the . It is important to sanity check that this value of corresponds to Bayesian filtering in a sensible generative model. In particular, note that is the variance of the prior over (Eq. 7), and as such should correspond to typical initialization schemes (e.g. He et al., 2015) which ensure that input and output activations have roughly the same scale. These schemes use , and if we consider that there are typically input channels, and we convolve over a pixel patch, we obtain , matching the previous value.
Finally, this equivalence implies that — at least under this interpretation of the learning rate as the prior variance — SGD can be viewed as dimensionally consistent, in contrast to some claims from e.g. the natural gradient literature (Kakade, 2002).
2.1.2 Recovering Adam(W) in the highdata limit
In the highdata regime, where , the Bayesian filtering learning rate becomes (Fig. 1; blue line),
(21) 
so the updates (Eq.15b) become equivalent to Adam updates if we take,
(22) 
As such, we are able to use past experience with good values for the Adam learning rate , to set : in our case we use .
Furthermore, when we consider the form of regularisation implied by our updates, we recover a stateoftheart variant of Adam, known as AdamW (Loshchilov & Hutter, 2017)
. In standard Adam, weightdecay regularization is implemented by incorporating an L2 penalty on the weights in the loss function, so the gradient of the loss and regularizer are both normalized by the rootmeansquare gradient. In contrast, AdamW “decouples” weight decay from the loss, such that the gradient of the loss is normalized by the rootmeansquare gradients, but the weight decay is not. To see that our updates correspond to AdamW, we combine Eq. (
15a) and Eq. (15b), and substitute for (Eq. 21),(23) 
Indeed, the rootmeansquare normalization applies only to the gradient of the loss, as in AdamW, and not to the weight decay term, as in standard Adam.
Finally, note that AdaBayesSS becomes exactly AdamW when we set ,
(24) 
because we use the standard Adam(W) approach to computing unbiased estimates of
and (see Algo. 2).3 Experiments
For our experiments, we have adapted the code and protocols from a recent paper (Luo et al., 2019) on alternative methods for combining nonadaptive and adaptive behaviour (AdaBound and AMSBound). They considered a 34layer ResNet (He et al., 2016) and a 121layer DenseNet on CIFAR10 (Huang et al., 2017)
, trained for 200 epochs with learning rates that decreased by a factor of
at epoch 150. We used the exact same networks and protocol, except that we run for more epochs, we plot both classification error and the loss, and we use both CIFAR10 and CIFAR100. We used their optimized hyperparameter settings for standard baselines (including SGD and Adam), and their choice of hyperparameters for their methods (AdaBound and AMSBound). For AdamW and AdaBayes, we used and set using Eq. (20), and we used (matched to the optimal learning rate for standard Adam). We used decoupled weight decay of (from Luo et al., 2019), and we used the equivalence of SGD with weight decay and SGD with decoupled weight decay to set the decoupled weight decay coefficient to for AdamW, AdaBayes and AdaBayesSS.The results are given in Table 1 and Fig. 2. The best adaptive method is always one of our methods (AdaBayes or AdaBayesSS), though SGD is frequently superior to all adaptive methods tested. To begin, we compare our methods (AdaBayes and AdaBayesSS) to the canonical nonadaptive (SGD) and adaptive (Adam) method (see Fig. A1 for a cleaner figure, including other baselines). Note that AdaBayes and AdaBayesSS improve their accuracy and loss more rapidly than baseline methods during the initial part of learning (with the exception of Adam for CIFAR100). Our algorithms give better test error and loss than Adam, for all networks and datasets, they give better test error than SGD for CIFAR10, and perform similarly to SGD in the other cases, with AdaBayesSS often giving better performance than AdaBayes. Next, we see that AdaBayesSS improves considerably over AdaBayes (see Fig. A2 for a cleaner figure), except in the case of CIFAR10 classification error, where the difference is minimal. Next, we compared AdaBayes and AdaBayesSS with AdaBound and AMSBound (Luo et al., 2019) (Fig. A3). We found that the learning curves for AdaBayes and AdaBound/AMSBound were very similar with AdaBayes usually giving better test error. The behaviour of the loss curves were very different, with AdaBayesSS giving considerably testloss than either AdaBayes or AdaBound/AMSBound. Finally, we compared AdaBound/AMSBound with SGD (Fig. A4), finding that while AdaBound/AMSBound perform slightly better in terms of CIFAR10 classification error, SGD performs better in all other test domains.
CIFAR10  CIFAR100  

ResNet  DenseNet  ResNet  DenseNet  
optimizer  error (%)  loss  error (%)  loss  error (%)  loss  error (%)  loss 
SGD  5.170  0.174  5.580  0.177  22.710  0.833  21.290  0.774 
Adam  7.110  0.239  6.690  0.230  27.590  1.049  26.640  1.074 
AdaGrad  6.840  0.307  7.490  0.338  30.350  1.347  30.110  1.319 
AMSGrad  6.720  0.239  6.170  0.234  27.430  1.033  25.850  1.103 
AdaBound  5.140  0.220  4.850  0.210  23.060  1.004  22.210  1.050 
AMSBound  4.940  0.210  4.960  0.219  23.000  1.003  22.360  1.017 
AdamW  5.080  0.239  5.190  0.214  24.850  1.142  23.480  1.043 
AdaBayesSS  5.230  0.187  4.910  0.176  23.120  0.935  22.600  0.934 
AdaBayes  4.840  0.229  4.560  0.222  22.920  0.969  22.090  1.079 
4 Discussion
Here, we showed that Bayesian filtering recovers standard adaptive (AdamW) and nonadaptive (SGD) neural network optimization algorithms in the high and low data limits. Based on this insight, we provided two new algorithms, AdaBayes and AdaBayesSS, which adaptively transition between SGDlike and Adam(W)like behaviour, and converge more rapidly than standard adaptive methods such as Adam, despite giving generalisation performance similar to SGD. In our experiments, either AdaBayes or AdaBayesSS outperformed other adaptive methods, including AdamW (Loshchilov & Hutter, 2017), and Ada/AMSBound (Luo et al., 2019), though SGD frequently outperformed all adaptive methods.
It is also important to note that we have tried to take a pragmatic, rather than purist Bayesian, approach to specifying the update rule. This emerges in several features of our updates. First, we introduced momentum by replacing the raw minibatch gradient with the exponential moving average gradient. That said, there may be scope to give a normative justification for introducing momentum, especially as there is nothing in the Bayesian framework that forces us to evaluate the gradient at the current poosterior mean/mode. Second, the decoupled weight decay coefficient should be fixed by and (Eq. 10a), whereas we allowed it to be a free parameter. It remains to be seen whether a model which obeys the required constraints can equal the performance of those presented here. Third, we assumed that was constant across parameters. While this was necessary to match common practice in neural network optimization, in the Bayesian framework, should be the prior variance, which should change across parameters, as in standard initialization schemes. Fourth, we considered conditioning on the other parameters, rather than, as is typical in variational inference, considering a distribution over those parameters. As such, it will be interesting to consider whether the same reasoning can be extended to the case of variational inference. Finally, our derivations apply in the online setting where each minibatch of data is assumed to be new. As such, it might be possible to slightly improve our updates by taking into account repeated data in the batch setting.
Finally, Bayesian filtering presents a novel approach to neural network optimization, and as such, there are variety of directions for future work. First, it should be possible to develop a variant of the method for which represents the covariance of full weight matrix, by exploiting Kronecker factorisation (Martens & Grosse, 2015; Grosse & Martens, 2016; Zhang et al., 2017). These schemes are potentially of interest, because they give a principled description of how to compute the relevant covariance matricies even when we have seen only a small number of minibatches. Second, stochastic regularization has been shown to be extremely effective at reducing generalization error in neural networks. This Bayesian interpretation of neural network optimization presents opportunities for new stochastic regularization schemes (Kutschireiter et al., 2015).
References

Duchi et al. (2011)
Duchi, J., Hazan, E., and Singer, Y.
Adaptive subgradient methods for online learning and stochastic
optimization.
Journal of Machine Learning Research
, 12:2121–2159, 2011.  Feldkamp et al. (2003) Feldkamp, L. A., Prokhorov, D. V., and Feldkamp, T. M. Simple and conditioned adaptive behavior from kalman filter trained recurrent networks. Neural Networks, 16(56):683–689, 2003.
 Graves (2013) Graves, A. Generating sequences with recurrent neural networks. arXiv preprint arXiv:1308.0850, 2013.
 Grosse & Martens (2016) Grosse, R. and Martens, J. A kroneckerfactored approximate fisher matrix for convolution layers. In International Conference on Machine Learning, pp. 573–582, 2016.

He et al. (2015)
He, K., Zhang, X., Ren, S., and Sun, J.
Delving deep into rectifiers: Surpassing humanlevel performance on imagenet classification.
InProceedings of the IEEE international conference on computer vision
, pp. 1026–1034, 2015. 
He et al. (2016)
He, K., Zhang, X., Ren, S., and Sun, J.
Deep residual learning for image recognition.
In
Proceedings of the IEEE conference on computer vision and pattern recognition
, pp. 770–778, 2016.  Hinton et al. (2012) Hinton, G., Srivastava, N., and Swersky, K. Overview of minibatch gradient descent. COURSERA: Neural Networks for Machine Learning: Lecture 6a, 2012.
 Huang et al. (2017) Huang, G., Liu, Z., Van Der Maaten, L., and Weinberger, K. Q. Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 4700–4708, 2017.
 Kakade (2002) Kakade, S. M. A natural policy gradient. In Advances in neural information processing systems, pp. 1531–1538, 2002.
 Kalman (1960) Kalman, R. E. A new approach to linear filtering and prediction problems. Journal of basic Engineering, 82(1):35–45, 1960.
 Keskar & Socher (2017) Keskar, N. S. and Socher, R. Improving generalization performance by switching from adam to sgd. arXiv preprint arXiv:1712.07628, 2017.
 Khan et al. (2017) Khan, M. E., Liu, Z., Tangkaratt, V., and Gal, Y. Vprop: Variational inference using rmsprop. arXiv preprint arXiv:1712.01038, 2017.
 Khan et al. (2018) Khan, M. E., Nielsen, D., Tangkaratt, V., Lin, W., Gal, Y., and Srivastava, A. Fast and scalable bayesian deep learning by weightperturbation in adam. arXiv preprint arXiv:1806.04854, 2018.
 Kingma & Ba (2015) Kingma, D. and Ba, J. Adam: A method for stochastic optimization. In International Conference on Learning Representations, 2015.
 Kutschireiter et al. (2015) Kutschireiter, A., Surace, S. C., Sprekeler, H., and Pfister, J.P. The neural particle filter. arXiv preprint arXiv:1508.06818, 2015.
 Loshchilov & Hutter (2017) Loshchilov, I. and Hutter, F. Fixing weight decay regularization in adam. arXiv preprint arXiv:1711.05101, 2017.
 Luo et al. (2019) Luo, L., Xiong, Y., Liu, Y., and Sun, X. Adaptive gradient methods with dynamic bound of learning rate. arXiv preprint arXiv:1902.09843, 2019.
 Mandt et al. (2017) Mandt, S., Hoffman, M. D., and Blei, D. M. Stochastic gradient descent as approximate bayesian inference. The Journal of Machine Learning Research, 18(1):4873–4907, 2017.
 Martens & Grosse (2015) Martens, J. and Grosse, R. Optimizing neural networks with kroneckerfactored approximate curvature. In International conference on machine learning, pp. 2408–2417, 2015.
 Neuneier & Zimmermann (1998) Neuneier, R. and Zimmermann, H. G. How to train neural networks. In Neural networks: tricks of the trade, pp. 373–423. Springer, 1998.
 Ollivier (2017) Ollivier, Y. Online natural gradient as a kalman filter. arXiv preprint arXiv:1703.00209, 2017.
 Puskorius & Feldkamp (1991) Puskorius, G. V. and Feldkamp, L. A. Decoupled extended kalman filter training of feedforward layered networks. In Neural Networks, 1991., IJCNN91Seattle International Joint Conference on, volume 1, pp. 771–777. IEEE, 1991.
 Puskorius & Feldkamp (1994) Puskorius, G. V. and Feldkamp, L. A. Neurocontrol of nonlinear dynamical systems with kalman filter trained recurrent networks. IEEE Transactions on neural networks, 5(2):279–297, 1994.
 Puskorius & Feldkamp (2001) Puskorius, G. V. and Feldkamp, L. A. Parameterbased kalman filter training: theory and implementation. In Kalman filtering and neural networks. 2001.
 Reddi et al. (2018) Reddi, S. J., Kale, S., and Kumar, S. On the convergence of adam and beyond. ICLR, 2018.
 Sha et al. (1992) Sha, S., Palmieri, F., and Datum, M. Optimal filtering algorithms for fast learning in feedforward neual networks. Neural Networks, 1992.
 Wilson et al. (2017) Wilson, A. C., Roelofs, R., Stern, M., Srebro, N., and Recht, B. The marginal value of adaptive gradient methods in machine learning. In Advances in Neural Information Processing Systems, pp. 4148–4158, 2017.
 Xie et al. (2017) Xie, S., Girshick, R., Dollár, P., Tu, Z., and He, K. Aggregated residual transformations for deep neural networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1492–1500, 2017.
 Zhang et al. (2017) Zhang, G., Sun, S., Duvenaud, D., and Grosse, R. Noisy natural gradient as variational inference. arXiv preprint arXiv:1712.02390, 2017.
Appendix A Steadystate covariance
For the steadystate covariance, it is slightly more convenient to work with the inverse variance,
(25) 
though the same results can be obtained through either route. Substituting Eq. (10b) into Eq. (14a) and taking , we obtain an update from to ,
(26)  
assuming has reached steadystate, we have ,  
(27)  
Rearranging,  
(28)  
Assuming that the magnitude of the update to is small, we can take a firstorder Taylor of the first term,  
(29)  
cancelling,  
(30)  
and rearranging,  
(31) 
and finally substituting for gives the expression in the main text.
Appendix B Additional data figures
Here, we replot Fig. 2 to clarify particular comparisons. In particular, we compare AdaBayes(SS) with standard baselines (Fig. A1), Adam AdamW and AdaBayesSS (Fig. A2), AdaBayes(SS) and Ada/AMSBound (Fig. A3), and Ada/AMSBound and SGD (Fig. A4). Finally, we plot the training error and loss for all methods (Fig. A5; note the loss does not include the regularizer, so it may go up without this being evidence of overfitting).