1 Introduction
Optimization of functions involving large datasets and high dimensional models finds today large applicability in several datadriven fields in science and the industry. Given the growing role of deep learning, in this paper we look at optimization problems arising in the training of neural networks. The training of these models can be cast as the minimization or maximization of a certain objective function with respect to the model parameters. Because of the complexity and computational requirements of the objective function, the data and the models, the common practice is to resort to iterative training procedures, such as gradient descent. Among the iterative methods that emerged as the most effective and computationally efficient is stochastic gradient descent (SGD)
[61]. SGD owes its performance gains to the adoption of an approximate version of the objective function at each iteration step, which, in turn, yields an approximate or noisy gradient.While SGD seems to benefit greatly (, in terms of rate of convergence) from such approximation, it has also been shown that too much noise hurts the performance [83, 6]. This suggests that, to further improve over SGD, one could attempt to model the noise of the objective function. We consider the iterationtime varying loss function used in SGD as a stochastic process obtained by adding the expected risk to zero mean Gaussian noise. A powerful approach designed to handle estimation with such processes is Kalman filtering [36]. The idea of using Kalman filtering to train neural networks is not new [24]. However, the way to apply it to address this task can vary vastly. Indeed, in our approach, which we call KaFiStO, we introduce a number of novel ideas that result in a practical and effective training algorithm. Firstly, we introduce drastic approximations of the estimated covariance of Kalman’s dynamical state so that the corresponding matrix depends on only up to a matrix of parameters. Secondly, we approximate intermediate Kalman filtering calculations so that more accuracy can be achieved. Thirdly, because of the way we model the objective function, we can also define a schedule for the optimization that behaves similarly to learning rate schedules used in SGD and other iterative methods [37].
We highlight the following contributions: 1) KaFiStO is designed to handle highdimensional data and models, and large datasets; 2) The tuning of the algorithm is automated, but it is also possible to introduce a learning rate schedule similar to those in existing methods, albeit with a very different interpretation; 3) KaFiStO adapts automatically to the noise in the loss, which might vary depending on the settings of the training (, the minibatch size), and to the variation in the estimated weights over iteration time; 4) It can incorporate iterationtime dynamics of the model parameters, which are analogous to momentum
[74]; 5) It is a framework that can be easily extended (we show a few variations of KaFiStO); 6) As shown in our experiments, KaFiStO is on par with state of the art optimizers and can yield better minima in a number of problems ranging from image classification to generative adversarial networks (GAN) and natural language processing (NLP).
2 Prior Work
In this section, we review optimization methods that found application in machine learning, and, in particular, for large scale problems. Most of the progress in the last decades aimed at improving the efficiency and accuracy of the optimization algorithms.
FirstOrder Methods.
Firstorder methods exploit only the gradient of the objective function. The main advantage of these methods lies in their speed and simplicity. Robbins and Monro [61] introduced the very first stochastic optimization method (SGD) in early 1951. Since then, the SGD method has been thoroughly analyzed and extended [42, 67, 32, 73]. Some considered restarting techniques for optimization purposes [46, 82].
However, a limitation of SGD is that the learning rate must be manually defined and the approximations in the computation of the gradient hurt the performance.
SecondOrder Methods. To address the manual tuning of the learning rates in firstorder methods and to improve the convergence rate, secondorder methods rely on the Hessian matrix.
However, this matrix becomes very quickly unmanageable as it grows quadratically with the number of model parameters.
Thus, most work reduces the computational complexity by approximating the Hessian with a blockdiagonal matrix [20, 5, 39]. A number of methods looked at combining the secondorder information in different ways. For example, Roux and Fitzgibbon [63] combined Newton’s method and natural gradient. SohlDickstein [70] combined SGD with the secondorder curvature information leveraged by quasiNewton methods. Yao [89] dynamically incorporated the curvature of the loss via adaptive estimates of the Hessian. Henriques [27] proposed a method that does not even require to store the Hessian at all. In contrast with these methods, KaFiStO does not compute secondorder derivatives, but focuses instead on modeling noise in the objective function.
Adaptive.
An alternative to using secondorder derivatives is to design methods that automatically adjust the stepsize during the optimization process.
The adaptive selection of the update stepsize has been based on several principles, including: the local sharpness of the loss function [91], incorporating a line search approach [80, 53, 49], the gradient change speed [15], the BarzilaiBorwein method [76], a “belief” in the current gradient direction [100], the linearization of the loss [62], the percomponent unweighted mean of all historical gradients [10], handling noise by preconditioning based on a covariance matrix [34]
, the adaptive and momental bounds
[13], decorrelating the second moment and gradient terms [99], the importance weights [40], the layerwise adaptation strategy [90], the gradient scale invariance [55], multiple learning rates [66], controlling the increase in effective learning [93], learning the updatestep size [87], looking ahead at the sequence of fast weights generated by another optimizer [98]. Among the widely adopted methods is the work of Duchi [16], who presented a new family of subgradient methods called AdaGrad. AdaGrad dynamically incorporates knowledge of the geometry of the data observed in earlier iterations. Tieleman [77]introduced RmsProp, further extended by Mukkama
[52] with logarithmic regret bounds for strongly convex functions. Zeiler [94] proposed a perdimension learning rate method for gradient descent called AdaDelta. Kingma and Ba [37] introduced Adam, based on adaptive estimates of lowerorder moments. A wide range of variations and extensions of the original Adam optimizer has also been proposed [44, 60, 28, 47, 78, 86, 14, 72, 8, 33, 43, 48, 84, 85, 41]. Recent work proposed to decouple the weight decay [22, 17]. Chen [7]introduced a partially adaptive momentum estimation method. Some recent work also focused on the role of gradient clipping
[97, 96]. Another line of research focused on reducing the memory overhead for adaptive algorithms [1, 69, 56]. In most prior work, adaptivity comes from the introduction of extra hyperparameters that also require taskspecific tuning. In our case, this property is a direct byproduct of the Kalman filtering framework.Kalman filtering. The use of Kalman filtering theory and methods for the training of neural networks is not new. Haykin [35] edited a book collecting a wide range of techniques on this topic. More recently, Shashua [68]
incorporates Kalman filtering for Value Approximation in Reinforcement Learning. Ollivier
[54] recovers the exact extended Kalman filter equations from first principles in statistical learning: the Extended Kalman filter is equal to Amari’s online natural gradient, applied in the space of trajectories of the system. Vilmarest [11]applies the Extended Kalman filter to linear and logistic regressions. Takenga
[75] compared GD to methods based on either Kalman filtering or the decoupled Kalman filter. To summarize, all of these prior Kalman filtering approaches either focus on a specific nongeneral formulation or face difficulties when scaling to highdimensional parameter spaces of largescale neural models.3 Modeling Noise in Stochastic Optimization
In machine learning, we are interested in minimizing the expected risk
(1) 
with respect to some loss that is a function of both the data , with the data dimensionality, and the model parameters (, the weights of a neural network), where is the number of parameters in the model. We consider the case, which is of common interest today, where both and are very large (
). For notational simplicity, we do not distinguish the supervised and unsupervised learning cases by concatenating all data into a single vector
(, in the case of image classification we stack in both the input image and the output label). In practice, we have access to only a finite set of samples and thus resolve to optimize the empirical risk(2) 
where , for , are our training dataset samples. Because of the nonconvex nature of the loss function with respect to the model parameters, this risk is then optimized via an iterative method such as gradient descent.
Since in current datasets can be very large, the computation of the gradient of the empirical risk at each iteration is too demanding. To address this issue, the stochastic gradient descent (SGD) method [61] minimizes instead the following risk approximation
(3) 
where is a sample set of the dataset indices that changes over iteration time . SGD then iteratively builds a sequence of parameters by recursively updating the parameters with a step in the opposite direction of the gradient of , , with some random initialization for and for
(4) 
where denotes the gradient of with respect to and computed at , and is commonly referred to as the learning rate and regulates the speed of convergence.
While this approach is highly efficient, it is also affected by the training set sampling at each iteration. The gradients are computed on the timevarying objectives and can be seen as noisy versions of the gradient of the expected risk . Due to the aleatoric nature of this optimization, it is necessary to apply a learning rate decay [4] to achieve convergence. There are also several methods to reduce noise in the gradients, which work on a dynamic sample size or a gradient aggregation strategy [6].
Rather than directly modeling the noise of the gradient, we adopt a different perspective on the minimization of the expected risk. Let us denote with the optimal set of model parameters and let us denote with the expected risk at . Then, we model
(5) 
where
is a Gaussian random variable with the optimal parameters
as the mean and covariance , and we associate both the stochasticity of the sampling and of to the scalar noise variable, which we assume to be zeromean Gaussian with variance
. With eq. (5) we implicitly assume that and are statistically dependent since is constant. Often, we know the value of up to some approximation. In the next sections we additionally show that it is possible to obtain an online estimate of .The task now can be posed as that of identifying the parameters such that the observations (5) are satisfied. A natural way to tackle the identification of parameters given their noisy observations is to use Kalman filtering. As discussed in the prior work, there is an extensive literature on the application of Kalman filtering as a stochastic gradient descent algorithm. However, these methods differ from our approach in several ways. For instance, Vuckovic [81] uses the gradients as measurements. Thus, this method requires large matrix inversions, which are not scalable to the settings we consider in this paper and that are commonly used in deep learning. As we describe in the next section, we work instead directly with the scalar risks and introduce a number of computational approximations that make the training with large datasets and high dimensional data feasible with our method.
3.1 Kalman Filtering for Stochastic Optimization
We assumed that is a random variable capturing the optimum up to some zeromean Gaussian error (which represents our uncertainty about the parameters). Then, the values of the timevarying loss at samples of will be scattered close to (see eq. (5)). Thus, a possible system of equations for a sequence of samples of is
(6)  
(7) 
Here, is modeled as a zeromean Gaussian variable with covariance . The dynamical model implies that (at ) the state does not change on average.
The equations (6) and (7) fit very well the equations used in Kalman filtering [36]. For completeness, we briefly recall here the general equations for an Extended Kalman filter
(8)  
(9) 
where are also called the hidden state, are the observations, and are functions that describe the state transition and the measurement dynamics respectively. The Extended Kalman filter infers optimal estimates of the state variables from the previous estimates of and the last observation . Moreover, it also estimates the a posteriori covariance matrix of the state. This is done in two steps: Predict and Update, which we recall in Table 1.
Predict:  

Update:  
with:  
If we directly apply the equations in Table 1 to our equations (6) and (7), we would immediately find that the posterior covariance is an matrix, which would be too large to store and update for values used in practice. Hence, we approximate
as a scaled identity matrix. Since the update equation for the posterior covariance requires the computation of
, we need to approximatealso with a scaled identity matrix. We do this by using its largest eigenvalue, ,
(10) 
where denotes the identity matrix. Because we work with a scalar loss , the innovation covariance is a scalar and thus it can be easily inverted. We call this first parameter estimation method the Vanilla Kalman algorithm, and summarize it in Algorithm 1.
(11)  
(12) 
(13)  
(14) 
3.2 Incorporating Momentum Dynamics
The framework introduced so far, which we call KaFiStO (as a shorthand notation for Kalman Filtering for Stochastic Optimization), is very flexible and allows several extensions. A first important change we introduce is the incorporation of Momentum [74]. Within our notation, this method could be written as
(15)  
(16) 
where are so called momentums or velocities, that accumulate the gradients from the past. The parameter , commonly referred to as momentum rate, controls the tradeoff between current and past gradients. Such updates claim to stabilize the training and prevent the parameters from getting stuck at local minima.
To incorporate the idea of Momentum within the KaFiStO framework, one can simply introduce the state velocities and define the following dynamics
(17)  
(18)  
(19) 
where and is a zerocentered Gaussian random variable.
One can rewrite these equations again as Kalman filter equations by combining the parameters and the velocities into one state vector and similarly for the state noise . This results in the following dynamical system
(20)  
(21) 
where and . Similarly to the Vanilla Kalman algorithm, we also aim to drastically reduce the dimensionality of the posterior covariance, which now is a matrix. We approximate with the following form , where , , are scalars. In this formulation we have that and thus our approximation for the Kalman update of the posterior covariance will use
(22) 
The remaining equations follow directly from the application of Table 1. We call this method the Kalman Dynamics algorithm.
3.3 Estimation of the Measurement and State Noise
In the KaFiStO framework we model the noise in the observations and the state transitions with zeromean Gaussian variables with covariances and respectively. So far, we assumed that these covariances were constant. However, they can also be estimated online, and lead to more accurate state and posterior covariance estimates. For we use the following running average
(23) 
where we set . Similarly, for the covariance , the online update for the component is
(24)  
(25) 
where we set . This adaptivity of the noise helps both to reduce the number of hyperparameters and to stabilize the training and convergence.
3.4 Learning Rate Scheduling
In both the Vanilla Kalman and the Kalman Dynamics algorithms, the update equation for the state estimate needs (see , eq. (13)). This term can be in many cases set to , when we believe that this value can be achieved (, in some image classification problems). Also, we have the option to change progressively with the iteration time . For instance, we could set . By substituting this term in eq. (13) we obtain a learning rate that is times the learning rate with . By varying over the iteration time, we thus can define a learning rate schedule as in current SGD implementations [19, 46]. Notice, however, the very different interpretation of the schedule in the case of KaFiStO, where we are gradually decreasing the target expected risk.
3.5 Layerwise Approximations
Let us consider the optimization problem specifically for large neural networks. Let us denote with the number of layers in a network. Next, we consider the observations to be dimensional vectors. The th entry in this observation vector is obtained by considering that only the parameters of the th layer are varying. Under these assumptions, the update equation (13) for both the Vanilla Kalman and the Kalman Dynamics algorithm will split into layerwise equations, where each separate equation incorporates only the gradients with respect to the parameters of a specific layer. Additionally to this, now the matrix also yields separate blocks (one per observation), each of which gets approximated by the corresponding largest block eigenvalue. Finally, the maximum of these approximations gives us the approximation of the whole matrix . That is
(26) 
where is the subset of parameters corresponding to the th layer and is the innovation covariance corresponding to only the th measurement. We observe that this procedure induces additional stability in training.
4 Ablations
In this section we ablate the following features and parameters of both Vanilla Kalman and Kalman Dynamics algorithms: the dynamics of the weights and velocities, the initialization of the posterior covariance matrix and the adaptivity of the measurement and state noise estimators. In some ablations we also separately test the Kalman Dynamics algorithm with adaptive , since it usually gives a large boost to performance. Furthermore, we show that our algorithm is relatively insensitive to different batch sizes and weight initialization techniques.
We evaluate our optimization methods by computing the test performance achieved by the model obtained with the estimated parameters. Although such performance may not uniquely correlate to the performance of our method, as it might be affected also by the data, model and regularization, it is a useful indicator. In all the ablations, we choose the classification task on CIFAR100
[38] with ResNet18 [25]. We train all the models for epochs and decrease the learning rate by a factor of every epochs.For the last two ablations and in the Experiments section, we use the Kalman Dynamics algorithm with , adaptive and , initial posterior covariance parameters and . We refer to this configuration as KaFiStO and have no need in tuning it further.
Impact of the state dynamics. We compare the Vanilla Kalman algorithm (, constant dynamics) to the Kalman Dynamics (, with velocities). Additionally, we ablate the , , the decay rate of the velocities. The results are shown in Table 2. We observe that the use of velocities with a calibrated moment has a positive impact on the estimated parameters, and that the adaptive state noise estimation provides a further substantial gain.
KaFiStO Variant  Top1 Error  Top5 Error  

Vanilla    34.16  12.02 
Dynamics  0.50  33.52  11.34 
Dynamics  0.90  33.63  12.13 
Dynamics  0.99  diverge  diverge 
Dynamics (adapt. )  0.50  28.25  8.91 
Dynamics (adapt. )  0.90  23.39  6.50 
Dynamics (adapt. )  0.99  33.63  9.82 
Posterior covariance initialization. The KaFiStO framework requires to initialize the matrix . In the case of the Vanilla Kalman algorithm, we approximate the posterior covariance with a scaled identity matrix, , , where . In the case of Kalman Dynamics, we approximate with a block diagonal matrix, and we initialize it with
(27) 
where . In this section we ablate , and to show that the method quickly adapts to the observations and the initialization of does not have a significant impact on the final accuracy achieved with the estimated parameters. The results are given in Table 3 and in Figure 1.
Parameter  Value  KaFiStO Variant  Top1 Error  Top5 Error 

0.01  Vanilla  33.97  12.83  
0.10  Vanilla  34.16  12.02  
1.00  Vanilla  33.52  12.03  

0.01  Dynamics  33.42  12.28 
0.10  Dynamics  33.63  12.13  
1.00  Dynamics  33.16  12.16  
0.01  Dynamics (adapt. )  23.67  6.81  
0.10  Dynamics (adapt. )  23.39  6.50  
1.00  Dynamics (adapt. )  23.82  6.53  

0.01  Dynamics  33.28  11.88 
0.10  Dynamics  33.63  12.13  
1.00  Dynamics  34.73  13.36  
0.01  Dynamics (adapt. )  23.37  7.13  
0.10  Dynamics (adapt. )  23.39  6.50  
1.00  Dynamics (adapt. )  24.24  7.40  

Noise adaptivity. We compare the performance obtained with a fixed measurement variance to the one with an online estimate based on the th minibatch. Similarly, we ablate the adaptivity of the process noise . The results are shown in Table 4.
KaFiStO Variant  Top1 Error  Top5 Error  

Vanilla  adaptive    34.16  12.02 
Vanilla  constant    diverge  diverge 
Dynamics  adaptive  constant  33.63  12.13 
Dynamics  constant  constant  diverge  diverge 
Dynamics  adaptive  adaptive  23.39  6.50 
Dynamics  constant  adaptive  diverge  diverge 
We observe that the adaptivity of is essential for the model to converge and the adaptivity of helps to further improve the performance of the trained model. Moreover, with adaptive noises there is no need to set some initial values for them, which reduces the number of hyperparameters to tune.
Batch size. Usually one needs to adapt the learning rate to the chosen minibatch size. In this experiment, we change the batch size in the range and show that KaFiStO adapts to it naturally. Table 5 shows that the accuracy of the model does not vary significantly with a varying batch size, which is a sign of stability.
Batch Size  Top1 Error  Top5 Error 

32  24.59  7.13 
64  23.11  6.93 
128  23.39  6.50 
256  24.34  7.59 
Weight initialization. Similarly to the batch size, here we use different initialization techniques to show that the algorithm is robust to them. We apply the same initializations to SGD for comparison. We test Kaiming Uniform [26], Orthogonal [65], Xavier Normal [18], Xavier Uniform [18]. The results are shown in Table 6.
Initialization  Optimizer  Top1 Error  Top5 Error 

XavierNormal  SGD  26.71  7.59 
KaFiStO  23.34  6.78  
XavierUniform  SGD  26.90  7.97 
KaFiStO  23.40  6.85  
KaimingUniform  SGD  27.82  7.95 
KaFiStO  23.35  6.76  
Orthogonal  SGD  26.83  7.59 
KaFiStO  23.27  6.63 
5 Experiments
In order to assess the efficiency of KaFiStO, we evaluate it on different tasks, including image classification (on CIFAR10, CIFAR100 and ImageNet
[64]), generative learning and language modeling. For all these tasks, we report the quality metrics on the validation sets to compare KaFiStO to the optimizers commonly used in the training of existing models. We find that KaFiStO outperforms or is on par with the existing methods, while requiring fewer hyperparameters to tune.100epochs  200epochs  

Error  
Dataset  Architecture  Method  Top1  Top5  Top1  Top5 
CIFAR10  ResNet18  SGD  5.60  0.16  7.53  0.29 
Adam  6.58  0.28  6.46  0.28  
KaFiStO  5.69  0.21  5.46  0.25  
ResNet50  SGD  6.37  0.19  8.10  0.27  
Adam  6.28  0.24  5.97  0.28  
KaFiStO  7.29  0.24  6.31  0.13  
WResNet502  SGD  6.08  0.15  7.60  0.24  
Adam  6.02  0.19  5.90  0.26  
KaFiStO  6.83  0.19  5.36  0.12  
CIFAR100  ResNet18  SGD  23.50  6.48  22.44  5.99 
Adam  26.30  7.85  25.61  7.74  
KaFiStO  23.38  6.70  22.22  6.13  
ResNet50  SGD  25.05  6.74  22.06  5.71  
Adam  24.95  6.96  24.44  6.81  
KaFiStO  22.34  5.96  21.03  5.33  
WResNet502  SGD  23.83  6.35  22.47  5.96  
Adam  23.73  6.64  24.04  7.06  
KaFiStO  21.25  5.35  20.73  5.08  
ImageNet32  ResNet50  SGD  34.07  13.38     
KaFiStO  34.99  14.06     
5.1 Classification
CIFAR10/100. We first evaluate KaFiStO on CIFAR10 and CIFAR100 using the popular ResNets [25] and WideResNets [92] for training. We compare our results with the ones obtained with commonly used existing optimization algorithms, such as SGD with Momentum and Adam. For SGD we set the momentum rate to , which is the default for many popular networks, and for Adam we use the default parameters . In all experiments on CIFARs, we use a batch size of
and basic data augmentation (random horizontal flipping and random cropping with padding by
pixels). For each configuration we have two runs for and epochs respectively. For SGD we start with a learning rate equal to , for Adam to and for KaFiStO. For the epochs run we decrease the learning rate by a factor of every epochs. For epochs on CIFAR10 we decrease the learning rate only once at epoch by the same factor. For the epoch training on CIFAR100 the learning rate is decreased by a factor of at epochs , and . For all the algorithms, we additionally use a weight decay of .To show the benefit of using KaFiStO for training on classification tasks, we report the Top1 and Top5 errors on the validation set. For both the epoch and epoch configurations, we report the mean error among runs with different random seeds. The results are reported in Table 7. Figure 2 shows the behavior of the training loss, the validation loss and the Top1 error on the validation set, as well as the adaptive evolution of KaFiStO’s “learning rate”, , the step size that scales the gradient in the update eq. (13).
ImageNet. Following [47], we train a ResNet50 [25] on downscaled images with the most common settings: epochs of training with learning rate decrease of after every epochs and a weight decay of . We use random cropping and random horizontal flipping during training and we report the validation accuracy on single center crop images. As shown in Table 7, our model achieves a comparable accuracy to SGD, but without any taskspecific hyperparameter tuning.
5.2 Generative Adversarial Networks Training
Generative Adversarial Networks (GAN) [21] are generative models trained to generate new samples from a given data distribution. GAN consists of two networks: generator and discriminator, which are trained in adversarial manner. The training alternates between these networks in a minimax game, which tends to be difficult to train. Algorithms like SGD struggle to find a good solution, and a common practice for training GANs is to use adaptive methods like Adam or RMSProp. Thus, a good performance on the training of GANs is a good indicator of stability and the ability to handle complex loss functions.
Following [100], we test our method with one of the most popular models, Wasserstein GAN [2] with gradient penalty (WGANGP) [23]. The objectives for both the generator and the discriminator in WGANGP are unbounded from below, which makes it difficult to apply our model directly. Indeed, our algorithm works under the assumption that the expected risk at the optimum is some given finite value. However, we can control the measurements equations in KaFiStO by adjusting the , as was done to obtain learning rate schedules. The simplest way to deal with unbounded losses is to set below the current estimation of the loss . That is, for a given minibatch the target should be equal to , for some constant , which we set before the training. In our experiments, we fix . We also set similarly to a common choice of for Adam in GAN training.
We use a DCGAN [59] architecture with WGANGP loss. For each optimizer, we train the model for epochs on CIFAR10 and evaluate the FID score [29]
, which captures the quality and diversity of the generated samples. We report the mean and standard deviation among
runs with different random seeds. Usually GANs are trained in an alternating way, that is the generator is updated every iterations. This should make the generator compete with a stronger discriminator and achieve convergence. We test KaFiStO on two settings: . On both settings KaFiStO outperforms Adam in terms of FID score and is more stable. The results are shown in Table 8 and images sampled from the trained generator are reported in the Supplementary material.Optimizer  FID  FID 

Adam  78.2115.91  79.057.89 
KaFiStO  67.1714.93  75.655.38 
5.3 Language modeling
Given the recent success of transfer learning in NLP with pretrained language models
[57, 88, 12, 45, 9, 31], we trained both an LSTM [30] and also a Transformer [79]for language modeling on the Penn TreeBank dataset
[50]and Wikitext2
[51]. We use the default data splits for training and validation and report the perplexity (lower is better) on the test set in Table 9. We used a two layer LSTM withhidden neurons and input embedding size of
for our tinyLSTM (the same settings as Bernstein [3]) and increased the sizes to for the largerLSTM experiment (the same as Zhang [95]). The learning rate for Adam and SGD were picked based on a grid search and we used the default learning rate of for our optimizer. In order to prevent overfitting we used an aggressive dropout [71] rate of and tied the input and output embeddings [58], which is a common practice in NLP. Since we are using small datasets, we use only a twolayer masked multihead selfattention transformer with two heads, which performs worse than LSTM. We find that even in these settings, KaFiStO is onpar with other optimizers.Model  Optimizer  PTB ppl  WikiText2 ppl 

tinyLSTM  SGD  207.7  216.62 
Adam  123.6  166.49  
Dynamics (adapt. )  110.1  124.69  
largerLSTM 
SGD  112.19  145.7 
Adam  81.27  94.39  
Dynamics (adapt. )  81.57  89.64  
tinyTransformer 
SGD  134.41  169.04 
Adam  140.13  189.66  
Dynamics (adapt. )  129.45  179.81  

6 Conclusions
We have introduced KaFiStO, a novel Kalman filteringbased approach to stochastic optimization. KaFiStO is suitable to train modern neural network models on current large scale datasets with highdimensional data. The method can selftune and is quite robust to wide range of training settings. Moreover, we design KaFiStO so that it can incorporate optimization dynamics such as those in Momentum and Adam, and learning rate schedules. The efficacy of this method is demonstrated on several experiments in image classification, image generation and language processing.
References
 [1] (2019) Memory efficient adaptive optimization. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §2.
 [2] (2017) Wasserstein generative adversarial networks. In International conference on machine learning, Cited by: §5.2.
 [3] (2020) On the distance between two neural networks and the stability of learning. In Advances in Neural Information Processing Systems, H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin (Eds.), External Links: Link Cited by: §5.3.
 [4] (2000) Gradient convergence in gradient methods with errors. SIAM Journal on Optimization (3). Cited by: §3.
 [5] (2017) Practical gaussnewton optimisation for deep learning. In ICML, External Links: Link Cited by: §2.
 [6] (2018) Optimization methods for largescale machine learning. Siam Review (2). Cited by: §1, §3.

[7]
(202007)
Closing the generalization gap of adaptive gradient methods in training deep neural networks.
In
Proceedings of the TwentyNinth International Joint Conference on Artificial Intelligence, IJCAI20
, Note: Main track External Links: Link Cited by: §2.  [8] (2019) On the convergence of a class of adamtype algorithms for nonconvex optimization. In International Conference on Learning Representations, External Links: Link Cited by: §2.
 [9] (2020) ELECTRA: pretraining text encoders as discriminators rather than generators. ArXiv abs/2003.10555. Cited by: §5.3.
 [10] (2021) Expectigrad: fast stochastic optimization with robust convergence properties. External Links: Link Cited by: §2.
 [11] (2020) Stochastic online optimization using kalman recursion. External Links: 2002.03636 Cited by: §2.
 [12] (201906) BERT: pretraining of deep bidirectional transformers for language understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), Minneapolis, Minnesota. External Links: Link, Document Cited by: §5.3.
 [13] (2019) An adaptive and momental bound method for stochastic learning. arXiv preprint arXiv:1910.12249. Cited by: §2.

[14]
(2016)
Incorporating nesterov momentum into adam
. Cited by: §2. 
[15]
(2020)
DiffGrad: an optimization method for convolutional neural networks
. IEEE Transactions on Neural Networks and Learning Systems (11). External Links: Document Cited by: §2.  [16] (2011) Adaptive subgradient methods for online learning and stochastic optimization. Journal of Machine Learning Research (61). External Links: Link Cited by: §2.
 [17] (2020) Stochastic gradient methods with layerwise adaptive moments for training of deep networks. External Links: 1905.11286 Cited by: §2.
 [18] (2010) Understanding the difficulty of training deep feedforward neural networks. In In Proceedings of the International Conference on Artificial Intelligence and Statistics (AISTATS’10). Society for Artificial Intelligence and Statistics, Cited by: §4.
 [19] (1977) On convergence rates of subgradient optimization methods.. Mathematical Programming. 13, pp. 329–347. External Links: Document Cited by: §3.4.
 [20] (2020) Practical quasinewton methods for training deep neural networks. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §2.
 [21] (2014) Generative adversarial nets. Advances in Neural Information Processing Systems. Cited by: §5.2.
 [22] (2021) Beyond sgd: iterate averaged adaptive gradient method. External Links: 2003.01247 Cited by: §2.
 [23] (2017) Improved training of wasserstein gans. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §5.2.
 [24] (2004) Kalman filtering and neural networks. John Wiley & Sons. Cited by: §1.

[25]
(2016)
Deep residual learning for image recognition.
In
2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR)
, External Links: Document Cited by: §4, §5.1, §5.1.  [26] (2015) Delving deep into rectifiers: surpassing humanlevel performance on imagenet classification. In Proceedings of the 2015 IEEE International Conference on Computer Vision (ICCV), ICCV ’15. Cited by: §4.
 [27] (201910) Small steps and giant leaps: minimal newton solvers for deep learning. In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), Cited by: §2.
 [28] (2021) AdamP: slowing down the slowdown for momentum optimizers on scaleinvariant weights. In International Conference on Learning Representations (ICLR), Cited by: §2.
 [29] (2017) GANs trained by a two timescale update rule converge to a local nash equilibrium. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §5.2.
 [30] (1997) Long shortterm memory. Neural Computation. Cited by: §5.3.
 [31] (2018) Universal language model finetuning for text classification. In ACL, Cited by: §5.3.
 [32] (2020) Biased stochastic firstorder methods for conditional stochastic optimization and applications in meta learning. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §2.
 [33] (201907) Nostalgic adam: weighting more of the past gradients when designing the adaptive learning rate. In Proceedings of the TwentyEighth International Joint Conference on Artificial Intelligence, IJCAI19, External Links: Document, Link Cited by: §2.
 [34] (2017) Adaptive learning rate via covariance matrix based preconditioning for deep neural networks. In Proceedings of the TwentySixth International Joint Conference on Artificial Intelligence, IJCAI17, External Links: Document, Link Cited by: §2.
 [35] (2001) Kalman filtering and neural networks. Adaptive and learning systems for signal processing, communications, and control, Wiley, New York. External Links: ISBN 9780471369981 Cited by: §2.
 [36] (1960) A new approach to linear filtering and prediction problems. Transactions of the ASME – Journal of Basic Engineering (Series D). Cited by: §1, §3.1.
 [37] (2015) Adam: a method for stochastic optimization. In ICLR (Poster), External Links: Link Cited by: §1, §2.
 [38] (2009) Learning multiple layers of features from tiny images. Cited by: §4.
 [39] (2012) Efficient backprop. In Neural Networks: Tricks of the Trade: Second Edition, External Links: ISBN 9783642352898, Document, Link Cited by: §2.
 [40] (2018) Online adaptive methods, universality and acceleration. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §2.
 [41] (2020) AdaX: adaptive gradient descent with exponential long term memory. External Links: Link Cited by: §2.
 [42] (2020) PAGE: a simple and optimal probabilistic gradient estimator for nonconvex optimization. External Links: 2008.10898 Cited by: §2.
 [43] (2020) On the variance of the adaptive learning rate and beyond. In International Conference on Learning Representations, External Links: Link Cited by: §2.
 [44] (2020) Adam with bandit sampling for deep learning. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §2.
 [45] (2019) RoBERTa: a robustly optimized bert pretraining approach. ArXiv abs/1907.11692. Cited by: §5.3.
 [46] (201704) SGDR: stochastic gradient descent with warm restarts. In International Conference on Learning Representations (ICLR) 2017 Conference Track, Cited by: §2, §3.4.
 [47] (2019) Decoupled weight decay regularization. In International Conference on Learning Representations, External Links: Link Cited by: §2, §5.1.
 [48] (2019) Adaptive gradient methods with dynamic bound of learning rate. In International Conference on Learning Representations, External Links: Link Cited by: §2.
 [49] (2015) Probabilistic line searches for stochastic optimization. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §2.
 [50] (1993) Building a large annotated corpus of English: the Penn Treebank. Computational Linguistics. External Links: Link Cited by: §5.3.
 [51] (2017) Pointer sentinel mixture models. ArXiv abs/1609.07843. Cited by: §5.3.
 [52] (201706–11 Aug) Variants of RMSProp and Adagrad with logarithmic regret bounds. In Proceedings of the 34th International Conference on Machine Learning, Proceedings of Machine Learning Research, International Convention Centre, Sydney, Australia. External Links: Link Cited by: §2.
 [53] (2020) Parabolic approximation line search for dnns. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §2.
 [54] (2019) The extended kalman filter is a natural gradient descent in trajectory space. External Links: 1901.00696 Cited by: §2.
 [55] (2015) Scalefree algorithms for online linear optimization. In Algorithmic Learning Theory, Cham. External Links: ISBN 9783319244860 Cited by: §2.
 [56] (202022–25 Jul) The role of memory in stochastic optimization. In Proceedings of The 35th Uncertainty in Artificial Intelligence Conference, Proceedings of Machine Learning Research, Tel Aviv, Israel. External Links: Link Cited by: §2.
 [57] (201806) Deep contextualized word representations. In Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers), New Orleans, Louisiana. External Links: Link, Document Cited by: §5.3.
 [58] (201704) Using the output embedding to improve language models. In Proceedings of the 15th Conference of the European Chapter of the Association for Computational Linguistics: Volume 2, Short Papers, Valencia, Spain, pp. 157–163. External Links: Link Cited by: §5.3.
 [59] (2015) Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434. Cited by: §5.2.
 [60] (2018) On the convergence of adam and beyond. In International Conference on Learning Representations, External Links: Link Cited by: §2.
 [61] (1951) A Stochastic Approximation Method. The Annals of Mathematical Statistics (3). External Links: Document, Link Cited by: §1, §2, §3.
 [62] (2018) L4: practical lossbased stepsize adaptation for deep learning. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §2.
 [63] (2010) A fast natural newton method. In ICML, External Links: Link Cited by: §2.
 [64] (2015) ImageNet large scale visual recognition challenge. Int. J. Comput. Vision 115 (3). Cited by: §5.
 [65] (2014) Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. In 2nd International Conference on Learning Representations, ICLR 2014, Cited by: §4.
 [66] (2013) No more pesky learning rates. In ICML (3), External Links: Link Cited by: §2.
 [67] (2018) VRsgd: a simple stochastic variance reduction method for machine learning. External Links: 1802.09932 Cited by: §2.
 [68] (2019) Trust region value optimization using kalman filtering. External Links: 1901.07860 Cited by: §2.
 [69] (2018) Adafactor: adaptive learning rates with sublinear memory cost. In ICML, External Links: Link Cited by: §2.
 [70] (201422–24 Jun) Fast largescale optimization by unifying stochastic gradient and quasinewton methods. In Proceedings of the 31st International Conference on Machine Learning, Proceedings of Machine Learning Research, Bejing, China. External Links: Link Cited by: §2.
 [71] (201401) Dropout: a simple way to prevent neural networks from overfitting. J. Mach. Learn. Res., pp. 1929–1958. External Links: ISSN 15324435 Cited by: §5.3.

[72]
(201906)
Adathm: adaptive gradient method based on estimates of thirdorder moments.
In
2019 IEEE Fourth International Conference on Data Science in Cyberspace (DSC)
, Los Alamitos, CA, USA. External Links: ISSN , Document, Link Cited by: §2.  [73] (2020) Ssgd: symmetrical stochastic gradient descent with weight noise injection for reaching flat minima. External Links: 2009.02479 Cited by: §2.
 [74] (201317–19 Jun) On the importance of initialization and momentum in deep learning. In Proceedings of the 30th International Conference on Machine Learning, Proceedings of Machine Learning Research, Atlanta, Georgia, USA. External Links: Link Cited by: §1, §3.2.
 [75] (2004) Comparison of gradient descent method, kalman filtering and decoupled kalman in training neural networks used for fingerprintbased positioning. In IEEE 60th Vehicular Technology Conference, 2004. VTC2004Fall. 2004, External Links: Document Cited by: §2.
 [76] (2016) Barzilaiborwein step size for stochastic gradient descent. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §2.
 [77] (2012) Lecture 6.5rmsprop: divide the gradient by a running average of its recent magnitude. Cited by: §2.
 [78] (2019) On the convergence proof of amsgrad and a new version. IEEE Access (). External Links: Document Cited by: §2.
 [79] (2017) Attention is all you need. ArXiv abs/1706.03762. Cited by: §5.3.

[80]
(2019)
Painless stochastic gradient: interpolation, linesearch, and convergence rates
. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §2.  [81] (2018) Kalman gradient descent: adaptive variance reduction in stochastic optimization. External Links: 1810.12273 Cited by: §3.
 [82] (2020) Scheduled restart momentum for accelerated stochastic gradient descent. arXiv preprint arXiv:2002.10583. Cited by: §2.
 [83] (2013) Variance reduction for stochastic gradient optimization. Advances in Neural Information Processing Systems. Cited by: §1.
 [84] (2019) SignADAM++: learning confidences for deep neural networks. In 2019 International Conference on Data Mining Workshops (ICDMW), External Links: Document Cited by: §2.
 [85] (2020) SAdam: a variant of adam for strongly convex functions. In International Conference on Learning Representations, External Links: Link Cited by: §2.
 [86] (2019Jul.) HyperAdam: a learnable taskadaptive adam for network training. External Links: Link, Document Cited by: §2.
 [87] (2020) WNGrad: learn the learning rate in gradient descent. External Links: 1803.02865 Cited by: §2.
 [88] (2019) XLNet: generalized autoregressive pretraining for language understanding. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §5.3.
 [89] (2021) ADAHESSIAN: an adaptive second order optimizer for machine learning. AAAI. Cited by: §2.
 [90] (2020) Large batch optimization for deep learning: training bert in 76 minutes. In International Conference on Learning Representations, External Links: Link Cited by: §2.
 [91] (2020) SALR: sharpnessaware learning rates for improved generalization. External Links: 2011.05348 Cited by: §2.
 [92] (201605) Wide residual networks. Cited by: §5.1.
 [93] (2018) Adaptive methods for nonconvex optimization. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §2.
 [94] (2012) ADADELTA: an adaptive learning rate method. CoRR. External Links: Link Cited by: §2.
 [95] (2017) YellowFin and the art of momentum tuning. arXiv preprint arXiv:1706.03471. Cited by: §5.3.

[96]
(2020)
Why are adaptive methods good for attention models?
. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §2.  [97] (2020) Why are adaptive methods good for attention models?. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §2.
 [98] (2019) Lookahead optimizer: k steps forward, 1 step back. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §2.
 [99] (2019) AdaShift: decorrelation and convergence of adaptive learning rate methods. In International Conference on Learning Representations, External Links: Link Cited by: §2.
 [100] (2020) AdaBelief optimizer: adapting stepsizes by the belief in observed gradients. In Advances in Neural Information Processing Systems, External Links: Link Cited by: §2, §5.2.