Increasing the batch size is one of the most appealing ways to accelerate neural network training on data parallel hardware. Larger batch sizes yield better gradient estimates and, up to a point, reduce the number of steps required for training, which reduces the training time. The importance of understanding the benefits of modern parallel hardware has motivated a lot of recent work on training neural networks with larger batch sizes(Goyal et al., 2017; Osawa et al., 2018; McCandlish et al., 2018; Shallue et al., 2018). To date, the most comprehensive empirical study of the effects of batch size on neural network training is Shallue et al. (2018), who confirmed that increasing the batch size initially achieves perfect scaling (i.e. doubling the batch size halves the number of steps needed) up to a problem-dependent critical batch size, beyond which it yields diminishing returns (Balles et al., 2017; Goyal et al., 2017; Jastrzębski et al., 2018; McCandlish et al., 2018). Shallue et al. (2018)
also provided experimental evidence that the critical batch size depends on the optimization algorithm, the network architecture, and the data set. However, their experiments only covered plain SGD, SGD with (heavy-ball) momentum, and SGD with Nesterov momentum, leaving open the enticing possibility that other optimizers might extend perfect scaling to even larger batch sizes.
Empirical scaling curves like those in Shallue et al. (2018) are essential for understanding the effects of batch size, but generating such curves, even for a single optimizer on a single task, can be very expensive. On the other hand, existing theoretical analyses that attempt to analytically derive critical batch sizes (e.g. Ma et al. (2018); Yin et al. (2018); Jain et al. (2018)) do not answer our questions about which optimizers scale the best with batch size. They tend to make strong assumptions, produce parameter-dependent results that are difficult to apply, or are restricted to plain SGD. It would be ideal to find a middle ground between a purely empirical investigation and theoretical analysis by building a model of neural network optimization problems that captures the essential behavior we see in real neural networks, while still being easy to understand. Additionally, we need to study optimizers beyond momentum SGD since they might provide us an approach to exploit speedups from the very largest batch sizes. In this work, we make the following contributions:
We show that a simple noisy quadratic model (NQM) is remarkably consistent with the batch size effects observed in real neural networks, while allowing us to run experiments in seconds, making it a great tool to generate testable predictions about neural network optimization.
We show that the NQM successfully predicts that momentum should speed up training relative to plain SGD at larger batch sizes, but do nothing at small batch sizes.
Through large scale experiments with Adam (Kingma and Ba, 2014) and K-FAC (Martens and Grosse, 2015), we confirm that, as predicted by the NQM, preconditioning extends perfect batch size scaling to larger batch sizes than are possible with momentum SGD alone. Furthermore, unlike momentum, preconditioning can help at small batch sizes as well.
2 Related Work
In a classic paper, Bottou and Bousquet (2008) studied the asymptotics of stochastic optimization algorithms and found SGD to be competitive with fancier approaches. They showed that stochastic optimization involves fundamentally different tradeoffs from full-batch optimization. More recently, several studies have investigated the relationship between batch size and training time for neural networks. Chen et al. (2018) studied the effect of network width on the critical batch size, and showed experimentally that it depends on both the data set and network architecture. Golmant et al. (2018)
studied how various heuristics for adjusting the learning rate as a function of batch size affect the relationship between batch size and training time.Shallue et al. (2018) conducted a comprehensive empirical study on the relationship between batch size and training time with different neural network architectures and data sets using plain SGD, heavy-ball momentum, and Nesterov momentum. Finally, McCandlish et al. (2018) used the average gradient noise over training to predict the critical batch size. All of these studies described a basic relationship between batch size and training steps to a fixed error goal, which is comprised of three regions: perfect scaling initially, then diminishing returns, and finally no benefit for all batch sizes greater than the critical batch size.
Other studies have attempted to characterize the critical batch size analytically in stochastic optimization. Under varying assumptions, Ma et al. (2018); Yin et al. (2018); Jain et al. (2018) all derived analytical notions of critical batch size, but to our knowledge, all for SGD.
Finally, previous studies have shown that SGD and momentum SGD are equivalent for small learning rates (after appropriate rescaling), both for the continuous limit (Leen and Orr, 1994) and discrete settings Yuan et al. (2016). However, they do not explain why momentum SGD (including heavy-ball and Nesterov momentum) sometimes outperforms plain SGD in mini-batch training (as observed by Kidambi et al. (2018) and Shallue et al. (2018)).
3 Analysis of the Noisy Quadratic Model (NQM)
In this section, we work with a noisy quadratic model
(NQM), a stochastic optimization problem whose dynamics can be simulated analytically, in order to reason about various phenomena encountered in training neural networks. In this highly simplified model, we first assume the loss function being optimized is a convex quadratic, with noisy observations of the gradient. For analytic tractability, we further assume the noise covariance is codiagonalizable with the Hessian (an assumption we later test for neural networks). Because we are not interested in modeling overfitting effects, we focus on the online training setting, where the observations are drawn i.i.d. in every training iteration. Under these assumptions, we derive an analytic expression for the risk after any number of steps of SGD with a fixed step size, as well as a dynamic programming method to compute the risk following a given step size schedule.
Convex quadratics may appear an odd model for a complicated nonconvex optimization landscape. However, one obtains a convex quadratic objective by linearizing the network’s function around a given weight vector and taking the second-order Taylor approximation to the loss function (assuming it is smooth and convex). Indeed, recent theoretical works(Jacot et al., 2018; Du et al., 2019; Zhang et al., 2019a) show that for wide enough networks, the weights stay close enough to the initialization for the linearized approximation to remain accurate. Empirically, linearized approximations closely match a variety of training phenomena for large but realistic networks (Lee et al., 2019).
3.1 Problem Setup
We now introduce the noisy quadratic model (Schaul et al., 2013; Martens, 2014; Wu et al., 2018), where the true function being optimized is a convex quadratic. Because we analyze rotation-invariant and translation-invariant optimizers such as SGD and heavy-ball momentum, we assume without loss of generality that the quadratic form is diagonal, and that the optimum is at the origin. Hence, our exact cost function decomposes as a sum of scalar quadratic functions for each coordinate:
Without loss of generality, we assume . We consider a single gradient query to have the form where and
. To reduce the variance of gradient estimation, we can average over multiple independent queries, which corresponds to "mini-batch training" in neural network optimization. We denote the averaged gradient asand the covariance , where is the number of queries (mini-batch size).
For analytical tractability, we make the nontrivial assumption that and are codiagonalizable. (Since is diagonal, this implies that .) See Section 3.4 for justification of this assumption. Under gradient descent with fixed step size , each dimension evolves independently as
where is the learning rate and is zero-mean unit variance iid noise. By treating
as a random variable, we immediately obtain the dynamics of its mean and variance.
Based on eqn. (3), the expected risk after steps in a given dimension is
where we have assumed that . (Note that this can be seen as a special case of the convergence result derived for convex quadratics in Martens (2014).)
Remarkably, each dimension converges exponentially to a steady state risk. Unfortunately, there is a trade-off in the sense that higher learning rates (up to ) give faster convergence to the steady state risk, but also produce higher values of the steady-state risk. The steady state risk also decreases proportionally to increases in batch size; this is important to note because in the following subsections, we will show that traditional acceleration techniques (e.g., momentum and preconditioning) help improve the convergence rate at the expense of increasing the steady state risk. Therefore, the NQM implies that momentum and preconditioning would benefit more from large-batch training compared to plain SGD, as shown in later sections.
3.2 The Role of Momentum
Applied to the same noisy quadratic model as before, the update equations for momentum SGD are:
We show in the following theorem (see Appendix C for proof) that momentum SGD performs similarly to plain SGD in the regime of small batch sizes but helps in the large-batch regime, which can be viewed as a more "deterministically behaving" optimization problem.
Given a dimension index , and with , the expected risk at time associated with that dimension satisfies the upper bound
where and (with ) are the two roots of the quadratic equation .
As with plain SGD (c.f. eqn. (4)), the loss associated with each dimension can be expressed as the sum of two terms, where the first one decays exponentially and corresponds to the behavior of the deterministic version of the algorithm, and the second remains constant.
Following the existing treatment of the deterministic version of the algorithm (Chiang, 1974; Qian, 1999; Yang et al., 2018; Goh, 2017), we divide our analysis two cases: overdamping and underdamping. In the case of overdamping, where , both roots and are real and therefore the convergence rate is determined by the larger one (i.e. ), which has the value
With a fixed learning rate, the steady state risk will be constant, and the best achievable expected risk will be lower bounded by it. Thus, to achieve a certain target loss we must either drive the learning rate down, or the batch size up. Assuming a small batch size and a low target risk, we are forced to pick a small learning rate, in which case one can show111To see this, note that the term in the square root of eqn. (7) for can be written as . Dropping the term and simplifying gives the claimed expression for . that . In Figure 2 we plot the convergence rate as a function of , and we indeed observe that the convergence rate closely matches , assuming a relative small learning rate. We further note that the convergence rate and steady state risk of eqn. (6) are the same as the ones in plain SGD (eqn. (4)), except that they use an "effective learning rate" of . To help validate these predictions, in Appendix D.3 we provide a comparison of momentum SGD with plain SGD using the effective learning rate.
In the case of underdamping where , both and will be complex and have norm . We note that the optimal should be equal to or smaller than , since otherwise all dimensions are under-damped, and we can easily improve the convergence rate and steady state risk by reducing .
Next we observe that the convergence of the total loss will eventually be dominated by the slowest converging dimension (which corresponds to the smallest curvature ), and this will be in the overdamping regime as argued above. By our analysis of the overdamping case, we can achieve the same convergence rate for this dimension by simply replacing the learning rate in the bound for plain SGD (eqn. (4)) with the effective learning rate .
So while momentum gives no long-term training acceleration for very low fixed learning rates (which we are forced to use when the batch size is small), we note that it can help in large-batch training. With , the steady state risk roughly amplifies by a factor of , and we note that steady state risk also decreases proportionally to increases in batch size. Therefore, we expect momentum SGD to exhibit perfect scaling up to larger batch sizes than plain SGD.
3.3 The Role of Preconditioning
Many optimizers, such as Adam and K-FAC, can be viewed as preconditioned gradient descent methods. In each update, the gradient is rescaled by a PSD matrix , called the preconditioner.
In lieu of trying to construct noisy quadratic analogues of particular optimizers, we analyze preconditioners of the form with . Note that remains fixed throughout training since the Hessian is constant in the NQM. We can recover standard SGD by setting .
Conveniently, for our NQM, the dynamics of preconditioned SGD are equivalent to the SGD dynamics in an NQM with Hessian and gradient covariance . Hence, the dynamics can be simulated using eqn. (4), exactly like the non-preconditioned case. We immediately obtain the following bound on the risk:
To qualitatively understand the effect of preconditioning, first consider the first term in eqn. (8). The convergence of this term resembles that of gradient descent on a deterministic quadratic, which (with optimal ) converges exponentially at a rate of approximately , where is the condition number of the transformed problem. Since , this implies a factor of improvement in the rate of convergence. Hence, for near-deterministic objectives where the first term dominates, values of closer to 1 correspond to better preconditioners, and result in much faster convergence. Unfortunately, there is no free lunch, as larger values of will also increase the second term (steady state risk). Assuming an ill-conditioned loss surface (), the steady state risk of each dimension becomes
which is a monotonically increasing function with respect to . Even without this amplification effect, the steady state risk will eventually become the limiting factor in the minimization of the expected risk. One way to reduce the steady state risk, apart from using Polyak averaging (Polyak and Juditsky, 1992) or decreasing the learning rate (which will harm the rate of convergence), is to increase the batch size. This suggests that the benefits of using stronger preconditioners will be more clearly observed for larger batch sizes, which is an an effect that we empirically demonstrate in later sections.
3.4 Choice of and
We’ve found that the qualitative behavior of optimizers in our NQM depends on the choices of and . Therefore, we choose matrices motivated by theoretical and empirical considerations about neural net training. First, we set the diagonal entries of to be for some integer , giving a condition number of . This closely matches the estimated eigenspectrum of the Hessian of a convolutional network (see Figure 7 and Appendix D.4), and is also consistent with recent work finding heavy tailed eigenspectra of neural network Hessians (Ubaru et al., 2017; Ghorbani et al., 2019). We choose , which approximately matches the condition number of the K-FAC Hessian approximation for ResNet8. (Qualitative behaviors were consistent for a wide range of .)
We also set (a nontrivial assumption). This was motivated by theoretical arguments that, under the assumption that the implicit conditional distribution over the network’s output is close to the conditional distribution of targets from the training distribution, the Hessian closely matches the gradient covariance in neural network training (Martens, 2014). Empirically, this relationship appears to hold tightly for a convolutional network and modestly well for a transformer (see Appendix D.2).
3.5 Information Theoretic Lower Bound
Since our NQM assumes the infinite data (online optimization) setting, it’s instructive to compare the performance of optimizers against an information theoretic lower bound. Specifically, under the assumption that
, the NQM is equivalent to maximum likelihood estimation of the mean vector for a multivariate Gaussian distribution with covariance. Hence, the risk obtained by any optimizer can be bounded below by the risk of the maximum likelihood estimator for the Gaussian, which is , where is the dimension and is the total number of training examples visited. We indicate this bound with a dashed black line in our plots.
3.6 Noisy Quadratic Experiments
In this section, we simulate noisy quadratic optimization using the closed-form dynamics. Our aim is to formulate hypotheses for how different optimizers would behave for neural network optimization. Our main metric is the number of steps required to achieve a target risk. For efficiency, rather than explicitly representing all the eigenvalues of, we quantize them into 100 bins and count the number of eigenvalues in each bin. Unless otherwise specified, we initialize as and use a target risk of 0.01. (The results don’t seem to be sensitive to either the initial variance or the target risk; some results with varying target risk thresholds are shown in Appendix D.5).
3.6.1 Effect of Momentum and Preconditioning
We first experiment with momentum and varying preconditioner powers on our NQM. We treat both the (fixed) learning rate and momentum decay parameter
as hyperparameters, which we tune using a fine-grained grid search.
Consistent with the empirical results of Shallue et al. (2018)
, each optimizer shows two distinct regimes: a small-batch (stochastic) regime with perfect linear scaling, and a large-batch (deterministic) regime insensitive to batch size. We call the phase transition between these regimes thecritical batch size. Consistent with the analysis of Section 3.2 and the observations of Smith et al. (2018); Shallue et al. (2018); Kidambi et al. (2018), the performance of momentum-based optimizers matches that of the plain SGD methods in the small-batch regime, but momentum increases the critical batch size and gives substantial speedups in the large batch regime. Preconditioning also increases the critical batch size and gives substantial speedups in the large batch regime, but interestingly, also improves performance by a small constant factor even for very small batches. Combining momentum with preconditioning extends both of these trends.
3.6.2 Optimal Learning Rate and Decay Scheme
In the NQM, we can calculate the optimal constant learning rate given a specific batch size. Figure 12 shows the optimal learning rate as a function of batch size for a target risk of . Notably, the optimal learning rate of plain (preconditioned) SGD (Figure 11(a)) scales linearly with batch size before it hits the critical batch size, matching the scheme used in Goyal et al. (2017). The linear scaling also holds for the effective learning rate of momentum SGD. In the small batch regime, the optimal effective learning rate for momentum SGD matches the optimal plain SGD learning rate, suggesting that the momentum and learning rate are interchangeable in the small batch regime.
While a fixed learning rate often works well for simple problems, good performance on the ImageNet benchmark(Russakovsky et al., 2015) requires a carefully tuned schedule. Here we explicitly optimize a piecewise constant learning rate schedule for SGD (with 50 pieces), in terms of the number of steps to reach the loss threshold.222For a given schedule and number of time steps, we obtain the exact risk using dynamic programming with eqn. (3). For stability, the learning rates are constrained to be at most . For a fixed number of time steps, we minimize this risk using BFGS. We determine the optimal number of time steps using binary search. In Figure 2(b), we show that optimized learning rate schedules help significantly in the small batch regime, consistent with the analysis in Wu et al. (2018). We observe the same linear scaling as with fixed-learning-rate SGD, but with a better constant factor. In fact, optimized schedules nearly achieve the information theoretic optimum. However, learning rate schedules do not improve at all over fixed learning rates in the large batch regime. Figure 2(c) shows optimized schedules for different batch sizes; interestingly, they maintain a large learning rate throughout training followed by a roughly exponential decay, consistent with commonly used neural network training schedules. Additionally, even though the different batch sizes start with the same learning rate, their final learning rates at the end of training scale linearly with batch size (see Figure 13 in Appendix D.7).
|MNIST||55,000||Simple CNN||Same as Shallue et al. (2018) except without dropout regularization.||Constant|
|CIFAR10||45,000||ResNet8 without BN||Same as Shallue et al. (2018).||Constant|
|ResNet32 with BN||Ghost batch norm is used.||Linear Decay|
|VGG11 with BN||Ghost batch norm is used.||Linear Decay|
|LM1B||30M||Two-layer Transformer||Shallow model in Shallue et al. (2018)||Constant|
4 Neural Network Experiments
We investigated whether the predictions made by the NQM hold in practice by running experiments with five neural network architectures across three image classification tasks and one language modeling task (see Table 1). For each model and task, we compared a range of optimizers: SGD, momentum SGD, Adam (with and without momentum), and K-FAC (with and without momentum). See Appendix E for more details about our models and tasks.
The primary quantity we measured is the number of steps required to reach a target accuracy (for image classification tasks) or cross entropy (for language modeling). Unless otherwise specified, we measured steps to target on the validation set. We chose the target metric values based on an initial set of experiments with practical computational budgets. For each model, task, optimizer, and batch size, we independently tuned the learning rate , the parameters governing the learning rate schedule (where applicable), and optimizer-specific metaparameters (see Appendix E.4). We manually chose the search spaces based on our initial experiments, and we verified after each experiment that the optimal metaparameter values were far from the search space boundaries. We used quasi-random search (Bousquet et al., 2017) to tune the metaparameters with fixed budgets of non-divergent333We discarded trials with a divergent training loss, which occurred when the learning rate was too high. trials (100 for Simple CNN, ResNet8, and Transformer, and 200 for ResNet32 and VGG11). We chose the trial that reached the target metric value using the fewest number of steps.
4.1 Critical Batch Size Depends on the Optimizer
Figure 4 shows the relationship between batch size and steps to target for each model, task, and optimizer. In each case, as the batch size grows, there is an initial period of perfect scaling where doubling the batch size halves the steps to target, but once the batch size exceeds a problem-dependent critical batch size, there are rapidly diminishing returns, matching the results of (Goyal et al., 2017; McCandlish et al., 2018; Shallue et al., 2018). K-FAC has the largest critical batch size in all cases, highlighting the usefulness of preconditioning. Momentum SGD extends perfect scaling to larger batch sizes than plain SGD, but for batch sizes smaller than the plain SGD critical batch size, momentum SGD requires as many steps as plain SGD to reach the target. This is consistent with both the empirical results of Shallue et al. (2018) and our NQM simulations. By contrast, Adam and K-FAC can reduce the number of steps needed to reach the target compared to plain SGD even for the smallest batch sizes, although neither optimizer does so in all cases. Finally, we see some evidence that the benefit of momentum diminishes with preconditioning (Figures 3(a) and 3(b)), as predicted by our NQM simulations, although we do not see this in all cases (e.g. Figure 3(c) and 3(f)).
4.2 Optimal Learning Rate
The NQM predicts that the optimal constant learning rate for plain SGD (or effective learning rate for momentum SGD) scales linearly with batch size initially, and then levels off after a certain batch size. Figure 5 shows the empirical optimal (effective) learning rate as a function of batch size for simple CNN on MNIST and ResNet8 on CIFAR10. For small batch sizes, the optimal learning rate of plain SGD appears to match the optimal effective learning rate of momentum SGD. However, after a certain batch size, the optimal learning rate for plain SGD saturates while the optimal effective learning rate of momentum SGD keeps increasing. Interestingly, plain SGD and momentum SGD appear to deviate at the same batch size in the optimal effective learning rate and steps to target plots (Figures 4 and 5).
4.3 Steps to Target on the Training Set
Figure 6 shows the empirical relationship between batch size and steps to target, measured on the training set, for ResNet8 and ResNet32 on CIFAR10. For ResNet8, the curves are almost identical to those using validation accuracy (Figure 3(c)), but for ResNet32, the gaps between different optimizers become much smaller than in Figure 3(e) and the effects of momentum and preconditioning appear to become less significant. Nevertheless, the qualitative differences between optimizers are consistent with the validation set measurements.
In this work, we analyzed the interactions between the batch size and the optimization algorithm from two perspectives: experiments with real neural networks, and a noisy quadratic model with parameters chosen based on empirical observations about neural networks. Despite its simplicity, the noisy quadratic model agrees remarkably well with a variety of neural network training phenomena, including learning rate scaling, critical batch sizes, and the effects of momentum and preconditioning. More importantly, the noisy quadratic model allows us to run experiments in seconds, while it can take weeks, or even months, to conduct careful large-scale experiments with real neural networks. Therefore, the noisy quadratic model is a convenient and powerful way to quickly formulate testable predictions about neural network optimization.
- Ba et al. (2017) Jimmy Ba, Roger Grosse, and James Martens. Distributed second-order optimization using Kronecker-factored approximations. In International Conference on Learning Representations, 2017.
Bae et al. (2018)
Juhan Bae, Guodong Zhang, and Roger Grosse.
Eigenvalue corrected noisy natural gradient.
Workshop of Bayesian Deep Learning, Advances in neural information processing systems, 2018.
Balles et al. (2017)
Lukas Balles, Javier Romero, and Philipp Hennig.
Coupling adaptive batch sizes with learning rates.
Conference on Uncertainty in Artificial Intelligence (UAI) 2017. AUAI Press, 2017.
- Bottou and Bousquet (2008) Léon Bottou and Olivier Bousquet. The tradeoffs of large scale learning. In Advances in neural information processing systems, pages 161–168, 2008.
- Bousquet et al. (2017) Olivier Bousquet, Sylvain Gelly, Karol Kurach, Olivier Teytaud, and Damien Vincent. Critical hyper-parameters: No random, no cry. arXiv preprint arXiv:1706.03200, 2017.
- Chen et al. (2018) Lingjiao Chen, Hongyi Wang, Jinman Zhao, Dimitris Papailiopoulos, and Paraschos Koutris. The effect of network width on the performance of large-batch training. In Advances in Neural Information Processing Systems, pages 9302–9309, 2018.
- Chiang (1974) A.C. Chiang. Fundamental Methods of Mathematical Economics. International student edition. McGraw-Hill, 1974. ISBN 9780070107809.
- Du et al. (2019) Simon S. Du, Xiyu Zhai, Barnabas Poczos, and Aarti Singh. Gradient descent provably optimizes over-parameterized neural networks. In International Conference on Learning Representations, 2019. URL https://openreview.net/forum?id=S1eK3i09YQ.
- George et al. (2018) Thomas George, César Laurent, Xavier Bouthillier, Nicolas Ballas, and Pascal Vincent. Fast approximate natural gradient descent in a Kronecker-factored eigenbasis. In Advances in Neural Information Processing Systems, pages 9550–9560, 2018.
Ghorbani et al. (2019)
Behrooz Ghorbani, Shankar Krishnan, and Ying Xiao.
An investigation into neural net optimization via hessian eigenvalue
Proceedings of the 36th International Conference on Machine Learning, pages 2232–2241, 2019.
- Goh (2017) Gabriel Goh. Why momentum really works. Distill, 2(4):e6, 2017.
- Golmant et al. (2018) Noah Golmant, Nikita Vemuri, Zhewei Yao, Vladimir Feinberg, Amir Gholami, Kai Rothauge, Michael W Mahoney, and Joseph Gonzalez. On the computational inefficiency of large batch sizes for stochastic gradient descent. arXiv preprint arXiv:1811.12941, 2018.
- Goodfellow et al. (2016) I. Goodfellow, Y. Bengio, and A. Courville. Deep Learning. MIT Press, 2016. http://www.deeplearningbook.org.
- Goyal et al. (2017) Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He. Accurate, large minibatch SGD: Training Imagenet in 1 hour. arXiv preprint arXiv:1706.02677, 2017.
- Grosse and Martens (2016) Roger Grosse and James Martens. A kronecker-factored approximate fisher matrix for convolution layers. In International Conference on Machine Learning, pages 573–582, 2016.
- He et al. (2016) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In
- Hoffer et al. (2017) Elad Hoffer, Itay Hubara, and Daniel Soudry. Train longer, generalize better: Closing the generalization gap in large batch training of neural networks. In Advances in Neural Information Processing Systems, pages 1731–1741, 2017.
- Ioffe and Szegedy (2015) Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International Conference on Machine Learning, pages 448–456, 2015.
- Jacot et al. (2018) Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: Convergence and generalization in neural networks. In Advances in neural information processing systems, pages 8571–8580, 2018.
- Jain et al. (2018) Prateek Jain, Sham M Kakade, Rahul Kidambi, Praneeth Netrapalli, and Aaron Sidford. Parallelizing stochastic gradient descent for least squares regression: mini-batching, averaging, and model misspecification. Journal of Machine Learning Research, 18(223):1–42, 2018.
- Jastrzębski et al. (2018) Stanisław Jastrzębski, Zachary Kenton, Devansh Arpit, Nicolas Ballas, Asja Fischer, Yoshua Bengio, and Amos Storkey. Three factors influencing minima in SGD. In International Conference on Artificial Neural Networks, 2018.
- Kidambi et al. (2018) Rahul Kidambi, Praneeth Netrapalli, Prateek Jain, and Sham Kakade. On the insufficiency of existing momentum schemes for stochastic optimization. In 2018 Information Theory and Applications Workshop (ITA), pages 1–9. IEEE, 2018.
- Kingma and Ba (2014) Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In International Conference on Learning Representations, 2014.
- Lee et al. (2019) Jaehoon Lee, Lechao Xiao, Samuel S Schoenholz, Yasaman Bahri, Jascha Sohl-Dickstein, and Jeffrey Pennington. Wide neural networks of any depth evolve as linear models under gradient descent. arXiv preprint arXiv:1902.06720, 2019.
- Leen and Orr (1994) Todd K. Leen and Genevieve B. Orr. Optimal stochastic search and adaptive momentum. In J. D. Cowan, G. Tesauro, and J. Alspector, editors, Advances in Neural Information Processing Systems 6, pages 477–484. Morgan-Kaufmann, 1994. URL http://papers.nips.cc/paper/772-optimal-stochastic-search-and-adaptive-momentum.pdf.
Ma et al. (2018)
Siyuan Ma, Raef Bassily, and Mikhail Belkin.
The power of interpolation: Understanding the effectiveness of SGD in modern over-parametrized learning.In International Conference on Machine Learning, pages 3331–3340, 2018.
- Martens (2014) James Martens. New insights and perspectives on the natural gradient method. arXiv preprint arXiv:1412.1193, 2014.
- Martens and Grosse (2015) James Martens and Roger Grosse. Optimizing neural networks with Kronecker-factored approximate curvature. In International Conference on Machine Learning, pages 2408–2417, 2015.
- McCandlish et al. (2018) Sam McCandlish, Jared Kaplan, Dario Amodei, and OpenAI Dota Team. An empirical model of large-batch training. arXiv preprint arXiv:1812.06162, 2018.
- Osawa et al. (2018) Kazuki Osawa, Yohei Tsuji, Yuichiro Ueno, Akira Naruse, Rio Yokota, and Satoshi Matsuoka. Second-order optimization method for large mini-batch: Training resnet-50 on imagenet in 35 epochs. arXiv preprint arXiv:1811.12019, 2018.
- Polyak and Juditsky (1992) Boris T Polyak and Anatoli B Juditsky. Acceleration of stochastic approximation by averaging. SIAM Journal on Control and Optimization, 30(4):838–855, 1992.
- Qian (1999) Ning Qian. On the momentum term in gradient descent learning algorithms. Neural networks, 12(1):145–151, 1999.
- Russakovsky et al. (2015) Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, et al. ImageNet large scale visual recognition challenge. International Journal of Computer Vision, 115(3):211–252, 2015.
- Sagun et al. (2016) Levent Sagun, Leon Bottou, and Yann LeCun. Eigenvalues of the hessian in deep learning: Singularity and beyond. arXiv preprint arXiv:1611.07476, 2016.
- Schaul et al. (2013) Tom Schaul, Sixin Zhang, and Yann LeCun. No more pesky learning rates. In International Conference on Machine Learning, pages 343–351, 2013.
- Schraudolph (2002) Nicol N Schraudolph. Fast curvature matrix-vector products for second-order gradient descent. Neural computation, 14(7):1723–1738, 2002.
- Shallue et al. (2018) Christopher J Shallue, Jaehoon Lee, Joe Antognini, Jascha Sohl-Dickstein, Roy Frostig, and George E Dahl. Measuring the effects of data parallelism on neural network training. arXiv preprint arXiv:1811.03600, 2018.
- Simonyan and Zisserman (2015) Karen Simonyan and Andrew Zisserman. Very deep convolutional networks for large-scale image recognition. In International Conference on Learning Representations, 2015.
- Smith et al. (2018) Samuel L. Smith, Pieter-Jan Kindermans, and Quoc V. Le. Don’t decay the learning rate, increase the batch size. In International Conference on Learning Representations, 2018. URL https://openreview.net/forum?id=B1Yy1BxCZ.
- Szegedy et al. (2016) Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jon Shlens, and Zbigniew Wojna. Rethinking the inception architecture for computer vision. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 2818–2826, 2016.
- Ubaru et al. (2017) Shashanka Ubaru, Jie Chen, and Yousef Saad. Fast estimation of tr(f(a)) via stochastic lanczos quadrature. SIAM Journal on Matrix Analysis and Applications, 38(4):1075–1099, 2017.
- van Laarhoven (2017) Twan van Laarhoven. L2 regularization versus batch and weight normalization. arXiv preprint arXiv:1706.05350, 2017.
- Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in neural information processing systems, pages 5998–6008, 2017.
- Wang et al. (2019) Chaoqi Wang, Roger Grosse, Sanja Fidler, and Guodong Zhang. Eigendamage: Structured pruning in the Kronecker-factored eigenbasis. In Proceedings of the 36th International Conference on Machine Learning, pages 6566–6575, 2019.
- Wu et al. (2018) Yuhuai Wu, Mengye Ren, Renjie Liao, and Roger Grosse. Understanding short-horizon bias in stochastic meta-optimization. In International Conference on Learning Representations, 2018. URL https://openreview.net/forum?id=H1MczcgR-.
- Yang et al. (2018) Lin Yang, Raman Arora, Tuo Zhao, et al. The physical systems behind optimization algorithms. In Advances in Neural Information Processing Systems, pages 4372–4381, 2018.
- Yin et al. (2018) Dong Yin, Ashwin Pananjady, Max Lam, Dimitris Papailiopoulos, Kannan Ramchandran, and Peter Bartlett. Gradient diversity: a key ingredient for scalable distributed learning. In International Conference on Artificial Intelligence and Statistics, pages 1998–2007, 2018.
- Yuan et al. (2016) Kun Yuan, Bicheng Ying, and Ali H. Sayed. On the influence of momentum acceleration on online learning. Journal of Machine Learning Research, 17(192):1–66, 2016. URL http://jmlr.org/papers/v17/16-157.html.
- Zhang et al. (2019a) Guodong Zhang, James Martens, and Roger Grosse. Fast convergence of natural gradient descent for overparameterized neural networks. arXiv preprint arXiv:1905.10961, 2019a.
- Zhang et al. (2019b) Guodong Zhang, Chaoqi Wang, Bowen Xu, and Roger Grosse. Three mechanisms of weight decay regularization. In International Conference on Learning Representations, 2019b. URL https://openreview.net/forum?id=B1lz-3Rct7.
Appendix A Kronecker-factored Approximate Curvature (K-FAC)
Kronecker-factored approximate curvature (K-FAC) [Martens and Grosse, 2015] uses a Kronecker-factored approximation to the curvature matrix to perform efficient approximate natural gradient updates. Considering the -th layer in a neural network whose input activations are , weight matrix , and outputs , we have . Therefore, the weight gradient is . With this formula, K-FAC decouples this layer’s Fisher matrix using an independence assumption:
where and . Decomposing into and not only avoids the quadratic storage cost of the exact Fisher, but also enables tractable computation of the approximate natural gradient:
As shown by eqn. (12), computing natural gradient using K-FAC only consists of matrix transformations comparable to size of , making it very efficient.
Later, Grosse and Martens  further extended K-FAC to convolutional layers under additional assumptions of spatial homogeneity (SH) and spatially uncorrelated derivatives (SUD). Suppose the input and the output , then the gradient of the reshaped weight is , and the corresponding Fisher matrix is:
where is the set of spatial locations, is the patch extracted from , is the gradient to each spatial location in and .
a.1 K-FAC for Transformer
K-FAC has been implemented on the autoencoder[Martens and Grosse, 2015] and various convolutional networks [Grosse and Martens, 2016, Ba et al., 2017]
before. To our knowledge, this is the first time K-FAC is implemented on the Transformer model. What is different from the previous models is the shared weight matrix between the embedding layer and the pre-softmax linear transformation[Vaswani et al., 2017]
. In particular, the weight matrix is transposed at the pre-softmax layer:and . With the same assumptions as the non-transposed case, we get
i.e. the positions of the two Kronecker factors are swapped. If we name the two Kronecker factors "input factor" and "output factor" respectively, i.e. , then for the weight matrix that is shared between the embedding layer and the pre-softmax layer, the input_factor has contributions from both the embedding inputs and the gradients of pre-softmax layer outputs; and the output_factor has contributions from both the pre-softmax layer inputs and the gradients of the embedding outputs. In practice, when computing a Kronecker factor, we treat contribution from multiple sources as an equivalent situation as contribution from multiple training examples from a mini-batch. Also note that because of the high dimensionality of the embedding weight matrix (with a vocabulary size of 32,768), the dense input factor would have size . In order to save memory, we use a diagonal matrix to estimate the input_factor. The output_factor is still estimated with a dense matrix.
Appendix B Dynamics of momentum SGD on noisy quadratic model
Similar to plain SGD, by treating as a random variable, we can explicitly write down the dynamics of its expectation and variance. But due to the use of momentum, we need to take into account and its correlation with . Because each dimension evolves independently, we drop the the dimension subscripts. We first calculate the expectation of the parameter and velocity:
We then calculate the variance:
where evolves as
Because the expected risk is totally decided by , we define and . We can then simplify the dynamics as follows
The convergence rate is determined by the transition matrix which has the characteristic polynomial
With the momentum value , all eigenvalues of the transition matrix are equal to each other with the value , giving the fastest convergence.
Appendix C Proof of Theorem 1
To analyze the dynamics, we can perform a change of basis so that three different dimensions evolve independently. To achieve that, we first take the eigendecomposition444Note that we implicitly assume , otherwise the transition matrix is not diagonalizable. of the transition matrix . Then the dynamics can be reformulated as follows:
We first analyze the stochastic term alone. By the identity , we have
In eqn. (22), we append zero vector for convenience. To compute the infinite sum, we first focus on a single term. We have the following update:
Since we only care which totally decide the loss, so we get rid of by merging two updates, which yields a second-order difference equation:
with initial conditions and
. To solve the second-order difference equation, we leverage the Z-transform to get the analytical form. Based on basic manipulation of the Z-transform, we have the Z-domain function
where and are two roots of equation . Then, we use the inverse Z-transform to get :
Now, we are ready to compute the infinite sum :
Because and are two roots with , , we have
Now, we analyze the deterministic term. We have
Similar to the analysis of stochastic term, we have the same second-order difference equation
except the initial conditions become . According to Z-transform, we have
Along with eqn. (29), we have
Appendix D More results on the NQM
d.1 Eigenspectra of Neural Networks
The main objective of this section is to examine the loss surface of modern neural networks in different stages of training in order to justify the assumptions made in NQM. Nevertheless, it is hard to visualize such a high dimensional space. Following recent work [Sagun et al., 2016, Ghorbani et al., 2019], we instead focus on analyzing the eigenspectrum of the Hessian/Fisher matrices. The Hessian/Fisher of the training loss (with respect to the parameters) is crucial in determining many behaviors of neural networks. The eigenvalues of the Hessian/Fisher characterize the local curvature of the loss surface which determines many training behaviors, including first-order methods optimization rates (at least for convex problems.)
To construct the eigenspectrum of the true Fisher matrix, we first leverage the Kronecker-factored approximation of the Fisher to get an estimation of the eigenspectrum, which may shed light upon the true eigenspectrum. Specifically, we train the network with K-FAC and then perform eigen-decomposition on saved Kronecker factors of the Fisher to calculate the eigenvalues.
The eigenspectra are plotted in Figure 7. One interesting observation is that there are only a few large eigenvalues and a few small eigenvalues in the approximate Fisher matrices; the bulk of eigenvalues are in the middle of the spectrum. We also note that after 200 iterations of training the eigenspectrum remains mostly unchanged.
Scatter plots of second moment v.s. variance of gradients.
The gradients are projected onto the Kronecker-factored eigenbasis, which approximates the eigenbasis of the true Fisher. Each point compares the gradient variance and the second moment of the gradient in the direction of an eigenvector of the K-FAC approximated Fisher.
d.2 Gradient Covariance in the Kronecker-Factored Eigenbasis
To verify the assumption in Section 3.4 that and are codiagonalizable, we test it on practical neural networks by comparing the gradient variance to the curvature. This assumption is motivated by theoretical considerations that suggest for neural network training [Martens, 2014]. Ideally, we would like to compare the gradient variance and the curvature of the Fisher in the directions of the eigenvectors of the true Fisher. However, it is typically infeasible to get all these eigenvectors, especially for low curvature directions. To resolve this we instead use the Kronecker-factored eigenbasis [George et al., 2018, Bae et al., 2018, Wang et al., 2019], which is obtained from the K-FAC approximation. For this experiment, we are not relying on this basis being an accurate approximation to the eigendecomposition of the true Fisher; rather, we use the eigenbasis only as a way to obtain a diverse set of directions with both high and low curvature. For a given eigenvector , we project the gradients of each training example onto and compute the gradient variance , as well as the curvature . (The latter quantity can be obtained using matrix-vector products [Schraudolph, 2002].) As shown in Figure 8, the gradient variances closely match the curvature (especially for the ResNet8 model on CIFAR10), validating our assumption that .
d.3 Plots for the Evolution of the First Term in Eqn. (6)
In Section 3.2, we claim that the convergence of momentum SGD for a single dimension is very close to that of plain SGD with an adjusted learning rate (note that we already verified that the steady state risk of momentum SGD matches plain SGD using effective learning rate in Figure 2). Here we verify this argument by comparing them in the NQM. The total risk consists of two terms (eqn. (6)): the first term determines convergence, while the second term (steady state risk) stays constant throughout training. Given that the second stays unchanged, we only plot the first term of eqn. (6) in Figure 9. Note that the values are normalized in the figures. We observe that the convergence dynamics of the two update rules closely match each other. For this experiment we set , but the results are not sensitive to this value.
d.4 Verification of Eigenspectrum
In Section 3.6, we assume the diagonal entries of are . To justify this choice, we compare the K-FAC eigenspectra of ResNet8 to this distribution in Figure 10. The distribution of eigenvalues we chose for in the NQM very closely matches the eigenspectra of the real neural network, validating the assumption that the diagonal entries of are in Section 3.4.
d.5 Effect of Loss Threshold
Recall that a main objective of this work is to characterize the effects of increasing the batch size on training time, as measured in the number of steps necessary to reach a goal target error/loss. Here we experiment with different loss thresholds to study the relationship between batch size and number of training steps. To obtain the minimal training steps for a given batch size, we do grid search over constant learning rates. Figure 11 shows that increasing the batch size initially decreases the required number of training steps proportionally, but eventually there are diminishing returns, which matches the empirical findings [Golmant et al., 2018, Shallue et al., 2018]. The shape of the curves is characteristically the same for different loss thresholds, though the critical batch size seems to increase for more difficult thresholds.
d.6 Results of Optimal Learning Rate on NQM
d.7 Final Learning Rate of Different Batch Sizes for PWC Learning Rate Scheme
In Section 3.6.2, we study the piecewise constant learning rate scheme. The optimal scheme starts with a high learning rate which drops later in training (Figure 2(c)). Recall that for fixed learning rates, we observed that the optimal learning rate scaled linearly with the batch size for small batch sizes, but it is unclear whether there is a similar phenomenon for learning rate decay. In Figure 13, we plot the final learning rate as a function of batch size and show that it also scales linearly with batch size.
Appendix E More Details for Experiments
e.1 Data Sets
e.2 Model Details
This section provides details of models in Table 1. The models are very similar (and some identical) to those used in Shallue et al.  (described in their Appendix B). Any modifications from them are highlighted in this section.
, we did not use any dropout regularization (while they used dropout with probability 0.4 in the fully connected layer). We used 32 and 64 filters in the convolutional layers and 1,024 units in the fully connected layer. This corresponds to the “base” configuration inShallue et al. .
ResNet8 [He et al., 2016]
consists of 7 convolutional layers with residual connections followed by 1 fully connected hidden layer. We used the identical architecture asShallue et al. . In particular, we did not use batch normalization. The only difference is that we used data augmentation in our experiments.
ResNet32 [He et al., 2016] consists of 31 convolutional layers with residual connections followed by 1 fully connected hidden layer (see Section 4.2 of He et al. ). We replaced batch normalization [Ioffe and Szegedy, 2015] with ghost batch normalization to keep the training objective fixed between batch sizes and to avoid possible negative effects from computing batch normalization statistics over a large number of examples [Hoffer et al., 2017]. We used a ghost batch size of 32 for all experiments. We also applied label smoothing [Szegedy et al., 2016] to regularize the model at training time, which was helpful for larger batch sizes. We set the label smoothing parameter to 0.1 in all experiments. Instead of using weight decay, we applied channel-wise weight normalization by constraining the Frobenius norm of each convolutional channel to be exactly 1, which controls the effective learning rate [Zhang et al., 2019b, van Laarhoven, 2017].
VGG11 [Simonyan and Zisserman, 2015] consists of 8 convolutional layers followed by 1 fully connected hidden layers. as in ResNet32, we used Ghost batch normalization, label smoothing, and channel-wise weight normalization.
e.3 Learning Rate Schedules
This section describes two learning rate schedules mentioned in Table 1: constant schedule and linear decay schedule. Constant schedule simply keeps a fixed learning rate throughout training:
where is the training step index. Linear decay schedule is
where is the initial learning rate, is the rate of decay, and is the number of steps taken to reach the final learning rate. Shallue et al.  experimented with various learning rate schedules and found that linear decay matched performance of the other schedules with fewer hyperparameters to tune. Therefore, we also chose the linear decay schedule, for which we tuned , and .
e.4 Optimizer-Specific Hyperparamters
For momentum SGD, we tuned the momentum . For Adam, we tuned , , and (see Kingma and Ba ). For K-FAC, we tuned damping and the trust region constraint (also known as the KL clipping term) for Transformer, keeping momentum and the moving average parameter for damping ; for all other models, we tuned all four parameters (see Martens and Grosse ).