Which Algorithmic Choices Matter at Which Batch Sizes? Insights From a Noisy Quadratic Model

07/09/2019 ∙ by Guodong Zhang, et al. ∙ UNIVERSITY OF TORONTO 3

Increasing the batch size is a popular way to speed up neural network training, but beyond some critical batch size, larger batch sizes yield diminishing returns. In this work, we study how the critical batch size changes based on properties of the optimization algorithm, including acceleration and preconditioning, through two different lenses: large scale experiments, and analysis of a simple noisy quadratic model (NQM). We experimentally demonstrate that optimization algorithms that employ preconditioning, specifically Adam and K-FAC, result in much larger critical batch sizes than stochastic gradient descent with momentum. We also demonstrate that the NQM captures many of the essential features of real neural network training, despite being drastically simpler to work with. The NQM predicts our results with preconditioned optimizers, previous results with accelerated gradient descent, and other results around optimal learning rates and large batch training, making it a useful tool to generate testable predictions about neural network optimization.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Increasing the batch size is one of the most appealing ways to accelerate neural network training on data parallel hardware. Larger batch sizes yield better gradient estimates and, up to a point, reduce the number of steps required for training, which reduces the training time. The importance of understanding the benefits of modern parallel hardware has motivated a lot of recent work on training neural networks with larger batch sizes 

(Goyal et al., 2017; Osawa et al., 2018; McCandlish et al., 2018; Shallue et al., 2018). To date, the most comprehensive empirical study of the effects of batch size on neural network training is Shallue et al. (2018), who confirmed that increasing the batch size initially achieves perfect scaling (i.e. doubling the batch size halves the number of steps needed) up to a problem-dependent critical batch size, beyond which it yields diminishing returns (Balles et al., 2017; Goyal et al., 2017; Jastrzębski et al., 2018; McCandlish et al., 2018). Shallue et al. (2018)

also provided experimental evidence that the critical batch size depends on the optimization algorithm, the network architecture, and the data set. However, their experiments only covered plain SGD, SGD with (heavy-ball) momentum, and SGD with Nesterov momentum, leaving open the enticing possibility that other optimizers might extend perfect scaling to even larger batch sizes.

Empirical scaling curves like those in Shallue et al. (2018) are essential for understanding the effects of batch size, but generating such curves, even for a single optimizer on a single task, can be very expensive. On the other hand, existing theoretical analyses that attempt to analytically derive critical batch sizes (e.g. Ma et al. (2018); Yin et al. (2018); Jain et al. (2018)) do not answer our questions about which optimizers scale the best with batch size. They tend to make strong assumptions, produce parameter-dependent results that are difficult to apply, or are restricted to plain SGD. It would be ideal to find a middle ground between a purely empirical investigation and theoretical analysis by building a model of neural network optimization problems that captures the essential behavior we see in real neural networks, while still being easy to understand. Additionally, we need to study optimizers beyond momentum SGD since they might provide us an approach to exploit speedups from the very largest batch sizes. In this work, we make the following contributions:

  1. We show that a simple noisy quadratic model (NQM) is remarkably consistent with the batch size effects observed in real neural networks, while allowing us to run experiments in seconds, making it a great tool to generate testable predictions about neural network optimization.

  2. We show that the NQM successfully predicts that momentum should speed up training relative to plain SGD at larger batch sizes, but do nothing at small batch sizes.

  3. Through large scale experiments with Adam (Kingma and Ba, 2014) and K-FAC (Martens and Grosse, 2015), we confirm that, as predicted by the NQM, preconditioning extends perfect batch size scaling to larger batch sizes than are possible with momentum SGD alone. Furthermore, unlike momentum, preconditioning can help at small batch sizes as well.

2 Related Work

In a classic paper, Bottou and Bousquet (2008) studied the asymptotics of stochastic optimization algorithms and found SGD to be competitive with fancier approaches. They showed that stochastic optimization involves fundamentally different tradeoffs from full-batch optimization. More recently, several studies have investigated the relationship between batch size and training time for neural networks. Chen et al. (2018) studied the effect of network width on the critical batch size, and showed experimentally that it depends on both the data set and network architecture. Golmant et al. (2018)

studied how various heuristics for adjusting the learning rate as a function of batch size affect the relationship between batch size and training time.

Shallue et al. (2018) conducted a comprehensive empirical study on the relationship between batch size and training time with different neural network architectures and data sets using plain SGD, heavy-ball momentum, and Nesterov momentum. Finally, McCandlish et al. (2018) used the average gradient noise over training to predict the critical batch size. All of these studies described a basic relationship between batch size and training steps to a fixed error goal, which is comprised of three regions: perfect scaling initially, then diminishing returns, and finally no benefit for all batch sizes greater than the critical batch size.

Other studies have attempted to characterize the critical batch size analytically in stochastic optimization. Under varying assumptions, Ma et al. (2018); Yin et al. (2018); Jain et al. (2018) all derived analytical notions of critical batch size, but to our knowledge, all for SGD.

Finally, previous studies have shown that SGD and momentum SGD are equivalent for small learning rates (after appropriate rescaling), both for the continuous limit (Leen and Orr, 1994) and discrete settings Yuan et al. (2016). However, they do not explain why momentum SGD (including heavy-ball and Nesterov momentum) sometimes outperforms plain SGD in mini-batch training (as observed by Kidambi et al. (2018) and Shallue et al. (2018)).

3 Analysis of the Noisy Quadratic Model (NQM)

In this section, we work with a noisy quadratic model

(NQM), a stochastic optimization problem whose dynamics can be simulated analytically, in order to reason about various phenomena encountered in training neural networks. In this highly simplified model, we first assume the loss function being optimized is a convex quadratic, with noisy observations of the gradient. For analytic tractability, we further assume the noise covariance is codiagonalizable with the Hessian (an assumption we later test for neural networks). Because we are not interested in modeling overfitting effects, we focus on the online training setting, where the observations are drawn i.i.d. in every training iteration. Under these assumptions, we derive an analytic expression for the risk after any number of steps of SGD with a fixed step size, as well as a dynamic programming method to compute the risk following a given step size schedule.

Convex quadratics may appear an odd model for a complicated nonconvex optimization landscape. However, one obtains a convex quadratic objective by linearizing the network’s function around a given weight vector and taking the second-order Taylor approximation to the loss function (assuming it is smooth and convex). Indeed, recent theoretical works 

(Jacot et al., 2018; Du et al., 2019; Zhang et al., 2019a) show that for wide enough networks, the weights stay close enough to the initialization for the linearized approximation to remain accurate. Empirically, linearized approximations closely match a variety of training phenomena for large but realistic networks (Lee et al., 2019).

3.1 Problem Setup

Figure 1: Cartoon of the evolution of risk for different coordinates with and without learning rate decay.

We now introduce the noisy quadratic model (Schaul et al., 2013; Martens, 2014; Wu et al., 2018), where the true function being optimized is a convex quadratic. Because we analyze rotation-invariant and translation-invariant optimizers such as SGD and heavy-ball momentum, we assume without loss of generality that the quadratic form is diagonal, and that the optimum is at the origin. Hence, our exact cost function decomposes as a sum of scalar quadratic functions for each coordinate:


Without loss of generality, we assume . We consider a single gradient query to have the form where and

. To reduce the variance of gradient estimation, we can average over multiple independent queries, which corresponds to "mini-batch training" in neural network optimization. We denote the averaged gradient as

and the covariance , where is the number of queries (mini-batch size).

For analytical tractability, we make the nontrivial assumption that and are codiagonalizable. (Since is diagonal, this implies that .) See Section 3.4 for justification of this assumption. Under gradient descent with fixed step size , each dimension evolves independently as


where is the learning rate and is zero-mean unit variance iid noise. By treating

as a random variable, we immediately obtain the dynamics of its mean and variance.


Based on eqn. (3), the expected risk after steps in a given dimension is


where we have assumed that . (Note that this can be seen as a special case of the convergence result derived for convex quadratics in Martens (2014).)

Remarkably, each dimension converges exponentially to a steady state risk. Unfortunately, there is a trade-off in the sense that higher learning rates (up to ) give faster convergence to the steady state risk, but also produce higher values of the steady-state risk. The steady state risk also decreases proportionally to increases in batch size; this is important to note because in the following subsections, we will show that traditional acceleration techniques (e.g., momentum and preconditioning) help improve the convergence rate at the expense of increasing the steady state risk. Therefore, the NQM implies that momentum and preconditioning would benefit more from large-batch training compared to plain SGD, as shown in later sections.

3.2 The Role of Momentum

Applied to the same noisy quadratic model as before, the update equations for momentum SGD are:


We show in the following theorem (see Appendix C for proof) that momentum SGD performs similarly to plain SGD in the regime of small batch sizes but helps in the large-batch regime, which can be viewed as a more "deterministically behaving" optimization problem.

Theorem 1.

Given a dimension index , and with , the expected risk at time associated with that dimension satisfies the upper bound


where and (with ) are the two roots of the quadratic equation .

As with plain SGD (c.f. eqn. (4)), the loss associated with each dimension can be expressed as the sum of two terms, where the first one decays exponentially and corresponds to the behavior of the deterministic version of the algorithm, and the second remains constant.

Following the existing treatment of the deterministic version of the algorithm (Chiang, 1974; Qian, 1999; Yang et al., 2018; Goh, 2017), we divide our analysis two cases: overdamping and underdamping. In the case of overdamping, where , both roots and are real and therefore the convergence rate is determined by the larger one (i.e. ), which has the value


With a fixed learning rate, the steady state risk will be constant, and the best achievable expected risk will be lower bounded by it. Thus, to achieve a certain target loss we must either drive the learning rate down, or the batch size up. Assuming a small batch size and a low target risk, we are forced to pick a small learning rate, in which case one can show111To see this, note that the term in the square root of eqn. (7) for can be written as . Dropping the term and simplifying gives the claimed expression for . that . In Figure 2 we plot the convergence rate as a function of , and we indeed observe that the convergence rate closely matches , assuming a relative small learning rate. We further note that the convergence rate and steady state risk of eqn. (6) are the same as the ones in plain SGD (eqn. (4)), except that they use an "effective learning rate" of . To help validate these predictions, in Appendix D.3 we provide a comparison of momentum SGD with plain SGD using the effective learning rate.

Figure 2: Convergence rate and steady state risk (SSK) as a function of momentum for a single dimension with and batch size .

In the case of underdamping where , both and will be complex and have norm . We note that the optimal should be equal to or smaller than , since otherwise all dimensions are under-damped, and we can easily improve the convergence rate and steady state risk by reducing .

Next we observe that the convergence of the total loss will eventually be dominated by the slowest converging dimension (which corresponds to the smallest curvature ), and this will be in the overdamping regime as argued above. By our analysis of the overdamping case, we can achieve the same convergence rate for this dimension by simply replacing the learning rate in the bound for plain SGD (eqn. (4)) with the effective learning rate .

So while momentum gives no long-term training acceleration for very low fixed learning rates (which we are forced to use when the batch size is small), we note that it can help in large-batch training. With , the steady state risk roughly amplifies by a factor of , and we note that steady state risk also decreases proportionally to increases in batch size. Therefore, we expect momentum SGD to exhibit perfect scaling up to larger batch sizes than plain SGD.

3.3 The Role of Preconditioning

Many optimizers, such as Adam and K-FAC, can be viewed as preconditioned gradient descent methods. In each update, the gradient is rescaled by a PSD matrix , called the preconditioner.


In lieu of trying to construct noisy quadratic analogues of particular optimizers, we analyze preconditioners of the form with . Note that remains fixed throughout training since the Hessian is constant in the NQM. We can recover standard SGD by setting .

Conveniently, for our NQM, the dynamics of preconditioned SGD are equivalent to the SGD dynamics in an NQM with Hessian and gradient covariance . Hence, the dynamics can be simulated using eqn. (4), exactly like the non-preconditioned case. We immediately obtain the following bound on the risk:


To qualitatively understand the effect of preconditioning, first consider the first term in eqn. (8). The convergence of this term resembles that of gradient descent on a deterministic quadratic, which (with optimal ) converges exponentially at a rate of approximately , where is the condition number of the transformed problem. Since , this implies a factor of improvement in the rate of convergence. Hence, for near-deterministic objectives where the first term dominates, values of closer to 1 correspond to better preconditioners, and result in much faster convergence. Unfortunately, there is no free lunch, as larger values of will also increase the second term (steady state risk). Assuming an ill-conditioned loss surface (), the steady state risk of each dimension becomes


which is a monotonically increasing function with respect to . Even without this amplification effect, the steady state risk will eventually become the limiting factor in the minimization of the expected risk. One way to reduce the steady state risk, apart from using Polyak averaging (Polyak and Juditsky, 1992) or decreasing the learning rate (which will harm the rate of convergence), is to increase the batch size. This suggests that the benefits of using stronger preconditioners will be more clearly observed for larger batch sizes, which is an an effect that we empirically demonstrate in later sections.

3.4 Choice of and

We’ve found that the qualitative behavior of optimizers in our NQM depends on the choices of and . Therefore, we choose matrices motivated by theoretical and empirical considerations about neural net training. First, we set the diagonal entries of to be for some integer , giving a condition number of . This closely matches the estimated eigenspectrum of the Hessian of a convolutional network (see Figure 7 and Appendix D.4), and is also consistent with recent work finding heavy tailed eigenspectra of neural network Hessians (Ubaru et al., 2017; Ghorbani et al., 2019). We choose , which approximately matches the condition number of the K-FAC Hessian approximation for ResNet8. (Qualitative behaviors were consistent for a wide range of .)

We also set (a nontrivial assumption). This was motivated by theoretical arguments that, under the assumption that the implicit conditional distribution over the network’s output is close to the conditional distribution of targets from the training distribution, the Hessian closely matches the gradient covariance in neural network training (Martens, 2014). Empirically, this relationship appears to hold tightly for a convolutional network and modestly well for a transformer (see Appendix D.2).

3.5 Information Theoretic Lower Bound

Since our NQM assumes the infinite data (online optimization) setting, it’s instructive to compare the performance of optimizers against an information theoretic lower bound. Specifically, under the assumption that

, the NQM is equivalent to maximum likelihood estimation of the mean vector for a multivariate Gaussian distribution with covariance

. Hence, the risk obtained by any optimizer can be bounded below by the risk of the maximum likelihood estimator for the Gaussian, which is , where is the dimension and is the total number of training examples visited. We indicate this bound with a dashed black line in our plots.

3.6 Noisy Quadratic Experiments

In this section, we simulate noisy quadratic optimization using the closed-form dynamics. Our aim is to formulate hypotheses for how different optimizers would behave for neural network optimization. Our main metric is the number of steps required to achieve a target risk. For efficiency, rather than explicitly representing all the eigenvalues of

, we quantize them into 100 bins and count the number of eigenvalues in each bin. Unless otherwise specified, we initialize as and use a target risk of 0.01. (The results don’t seem to be sensitive to either the initial variance or the target risk; some results with varying target risk thresholds are shown in Appendix D.5).

3.6.1 Effect of Momentum and Preconditioning

We first experiment with momentum and varying preconditioner powers on our NQM. We treat both the (fixed) learning rate and momentum decay parameter

as hyperparameters, which we tune using a fine-grained grid search.

Consistent with the empirical results of Shallue et al. (2018)

, each optimizer shows two distinct regimes: a small-batch (stochastic) regime with perfect linear scaling, and a large-batch (deterministic) regime insensitive to batch size. We call the phase transition between these regimes the

critical batch size. Consistent with the analysis of Section 3.2 and the observations of Smith et al. (2018); Shallue et al. (2018); Kidambi et al. (2018), the performance of momentum-based optimizers matches that of the plain SGD methods in the small-batch regime, but momentum increases the critical batch size and gives substantial speedups in the large batch regime. Preconditioning also increases the critical batch size and gives substantial speedups in the large batch regime, but interestingly, also improves performance by a small constant factor even for very small batches. Combining momentum with preconditioning extends both of these trends.

(a) Scaling with Constant LR
(b) Scaling with LR Schedules
(c) Optimized LR Schedules
Figure 3: (a) Effects of momentum and preconditioning. Steps required to reach target loss as a function of batch size under different preconditioning power. Solid lines are momentum SGD while dashed lines are plain SGD. The black dashed line is the information theoretic lower bound. (b) Effect of learning rate decay. The solid lines use the optimized piecewise constant scheme, which are shown in (c) for power . The dashed curves in (b) are plain SGD for comparison. We observe that learning rate schedules close most of the gap between the fixed learning rate performance and the information theoretic lower bound.

3.6.2 Optimal Learning Rate and Decay Scheme

In the NQM, we can calculate the optimal constant learning rate given a specific batch size. Figure 12 shows the optimal learning rate as a function of batch size for a target risk of . Notably, the optimal learning rate of plain (preconditioned) SGD (Figure 11(a)) scales linearly with batch size before it hits the critical batch size, matching the scheme used in Goyal et al. (2017). The linear scaling also holds for the effective learning rate of momentum SGD. In the small batch regime, the optimal effective learning rate for momentum SGD matches the optimal plain SGD learning rate, suggesting that the momentum and learning rate are interchangeable in the small batch regime.

While a fixed learning rate often works well for simple problems, good performance on the ImageNet benchmark 

(Russakovsky et al., 2015) requires a carefully tuned schedule. Here we explicitly optimize a piecewise constant learning rate schedule for SGD (with 50 pieces), in terms of the number of steps to reach the loss threshold.222For a given schedule and number of time steps, we obtain the exact risk using dynamic programming with eqn. (3). For stability, the learning rates are constrained to be at most . For a fixed number of time steps, we minimize this risk using BFGS. We determine the optimal number of time steps using binary search. In Figure 2(b), we show that optimized learning rate schedules help significantly in the small batch regime, consistent with the analysis in Wu et al. (2018). We observe the same linear scaling as with fixed-learning-rate SGD, but with a better constant factor. In fact, optimized schedules nearly achieve the information theoretic optimum. However, learning rate schedules do not improve at all over fixed learning rates in the large batch regime. Figure 2(c) shows optimized schedules for different batch sizes; interestingly, they maintain a large learning rate throughout training followed by a roughly exponential decay, consistent with commonly used neural network training schedules. Additionally, even though the different batch sizes start with the same learning rate, their final learning rates at the end of training scale linearly with batch size (see Figure 13 in Appendix D.7).

Data Set Size Model Remarks LR
MNIST 55,000 Simple CNN Same as Shallue et al. (2018) except without dropout regularization. Constant
FMNIST 55,000
CIFAR10 45,000 ResNet8 without BN Same as Shallue et al. (2018). Constant
ResNet32 with BN Ghost batch norm is used. Linear Decay
VGG11 with BN Ghost batch norm is used. Linear Decay
LM1B 30M Two-layer Transformer Shallow model in Shallue et al. (2018) Constant
Table 1: Data sets and models used in our experiments. See Appendix E.2 for full details.
(a) Simple CNN on MNIST
(b) Simple CNN on Fashion MNIST
(c) ResNet8 on CIFAR10
(d) VGG11 on CIFAR10
(e) ResNet32 on CIFAR10
(f) Transformer on LM1B
Figure 4: Empirical relationship between batch size and steps to result. Key observations: 1) momentum SGD has no benefit over plain SGD at small batch sizes, but extends the perfect scaling to larger batch sizes; 2) preconditioning also extends perfect scaling to larger batch sizes, i.e. K-FAC > Adam > momentum SGD. This is most noticeable in the Transformer model; 3) preconditioning (particularly K-FAC) reduces the number of steps needed to reach the target even for small batch sizes. All of these agree with the predictions by NQM.

4 Neural Network Experiments

We investigated whether the predictions made by the NQM hold in practice by running experiments with five neural network architectures across three image classification tasks and one language modeling task (see Table 1). For each model and task, we compared a range of optimizers: SGD, momentum SGD, Adam (with and without momentum), and K-FAC (with and without momentum). See Appendix E for more details about our models and tasks.

The primary quantity we measured is the number of steps required to reach a target accuracy (for image classification tasks) or cross entropy (for language modeling). Unless otherwise specified, we measured steps to target on the validation set. We chose the target metric values based on an initial set of experiments with practical computational budgets. For each model, task, optimizer, and batch size, we independently tuned the learning rate , the parameters governing the learning rate schedule (where applicable), and optimizer-specific metaparameters (see Appendix E.4). We manually chose the search spaces based on our initial experiments, and we verified after each experiment that the optimal metaparameter values were far from the search space boundaries. We used quasi-random search (Bousquet et al., 2017) to tune the metaparameters with fixed budgets of non-divergent333We discarded trials with a divergent training loss, which occurred when the learning rate was too high. trials (100 for Simple CNN, ResNet8, and Transformer, and 200 for ResNet32 and VGG11). We chose the trial that reached the target metric value using the fewest number of steps.

4.1 Critical Batch Size Depends on the Optimizer

Figure 4 shows the relationship between batch size and steps to target for each model, task, and optimizer. In each case, as the batch size grows, there is an initial period of perfect scaling where doubling the batch size halves the steps to target, but once the batch size exceeds a problem-dependent critical batch size, there are rapidly diminishing returns, matching the results of (Goyal et al., 2017; McCandlish et al., 2018; Shallue et al., 2018). K-FAC has the largest critical batch size in all cases, highlighting the usefulness of preconditioning. Momentum SGD extends perfect scaling to larger batch sizes than plain SGD, but for batch sizes smaller than the plain SGD critical batch size, momentum SGD requires as many steps as plain SGD to reach the target. This is consistent with both the empirical results of Shallue et al. (2018) and our NQM simulations. By contrast, Adam and K-FAC can reduce the number of steps needed to reach the target compared to plain SGD even for the smallest batch sizes, although neither optimizer does so in all cases. Finally, we see some evidence that the benefit of momentum diminishes with preconditioning (Figures 3(a) and 3(b)), as predicted by our NQM simulations, although we do not see this in all cases (e.g. Figure 3(c) and 3(f)).

4.2 Optimal Learning Rate

Figure 5: Optimal learning rates for plain SGD and momentum SGD. Left: Simple CNN on MNIST; Right: ResNet8 on CIFAR10

The NQM predicts that the optimal constant learning rate for plain SGD (or effective learning rate for momentum SGD) scales linearly with batch size initially, and then levels off after a certain batch size. Figure 5 shows the empirical optimal (effective) learning rate as a function of batch size for simple CNN on MNIST and ResNet8 on CIFAR10. For small batch sizes, the optimal learning rate of plain SGD appears to match the optimal effective learning rate of momentum SGD. However, after a certain batch size, the optimal learning rate for plain SGD saturates while the optimal effective learning rate of momentum SGD keeps increasing. Interestingly, plain SGD and momentum SGD appear to deviate at the same batch size in the optimal effective learning rate and steps to target plots (Figures 4 and 5).

4.3 Steps to Target on the Training Set

Figure 6: Steps to training accuracy versus batch size on CIFAR10. Left: ResNet8; Right: ResNet32.

Figure 6 shows the empirical relationship between batch size and steps to target, measured on the training set, for ResNet8 and ResNet32 on CIFAR10. For ResNet8, the curves are almost identical to those using validation accuracy (Figure 3(c)), but for ResNet32, the gaps between different optimizers become much smaller than in Figure 3(e) and the effects of momentum and preconditioning appear to become less significant. Nevertheless, the qualitative differences between optimizers are consistent with the validation set measurements.

5 Conclusion

In this work, we analyzed the interactions between the batch size and the optimization algorithm from two perspectives: experiments with real neural networks, and a noisy quadratic model with parameters chosen based on empirical observations about neural networks. Despite its simplicity, the noisy quadratic model agrees remarkably well with a variety of neural network training phenomena, including learning rate scaling, critical batch sizes, and the effects of momentum and preconditioning. More importantly, the noisy quadratic model allows us to run experiments in seconds, while it can take weeks, or even months, to conduct careful large-scale experiments with real neural networks. Therefore, the noisy quadratic model is a convenient and powerful way to quickly formulate testable predictions about neural network optimization.


Appendix A Kronecker-factored Approximate Curvature (K-FAC)

Kronecker-factored approximate curvature (K-FAC) [Martens and Grosse, 2015] uses a Kronecker-factored approximation to the curvature matrix to perform efficient approximate natural gradient updates. Considering the -th layer in a neural network whose input activations are , weight matrix , and outputs , we have . Therefore, the weight gradient is . With this formula, K-FAC decouples this layer’s Fisher matrix using an independence assumption:


where and . Decomposing into and not only avoids the quadratic storage cost of the exact Fisher, but also enables tractable computation of the approximate natural gradient:


As shown by eqn. (12), computing natural gradient using K-FAC only consists of matrix transformations comparable to size of , making it very efficient.

Later, Grosse and Martens [2016] further extended K-FAC to convolutional layers under additional assumptions of spatial homogeneity (SH) and spatially uncorrelated derivatives (SUD). Suppose the input and the output , then the gradient of the reshaped weight is , and the corresponding Fisher matrix is:


where is the set of spatial locations, is the patch extracted from , is the gradient to each spatial location in and .

a.1 K-FAC for Transformer

K-FAC has been implemented on the autoencoder 

[Martens and Grosse, 2015] and various convolutional networks [Grosse and Martens, 2016, Ba et al., 2017]

before. To our knowledge, this is the first time K-FAC is implemented on the Transformer model. What is different from the previous models is the shared weight matrix between the embedding layer and the pre-softmax linear transformation 

[Vaswani et al., 2017]

. In particular, the weight matrix is transposed at the pre-softmax layer:

and . With the same assumptions as the non-transposed case, we get


i.e. the positions of the two Kronecker factors are swapped. If we name the two Kronecker factors "input factor" and "output factor" respectively, i.e. , then for the weight matrix that is shared between the embedding layer and the pre-softmax layer, the input_factor has contributions from both the embedding inputs and the gradients of pre-softmax layer outputs; and the output_factor has contributions from both the pre-softmax layer inputs and the gradients of the embedding outputs. In practice, when computing a Kronecker factor, we treat contribution from multiple sources as an equivalent situation as contribution from multiple training examples from a mini-batch. Also note that because of the high dimensionality of the embedding weight matrix (with a vocabulary size of 32,768), the dense input factor would have size . In order to save memory, we use a diagonal matrix to estimate the input_factor. The output_factor is still estimated with a dense matrix.

Appendix B Dynamics of momentum SGD on noisy quadratic model

Similar to plain SGD, by treating as a random variable, we can explicitly write down the dynamics of its expectation and variance. But due to the use of momentum, we need to take into account and its correlation with . Because each dimension evolves independently, we drop the the dimension subscripts. We first calculate the expectation of the parameter and velocity:


We then calculate the variance:


where evolves as


Because the expected risk is totally decided by , we define and . We can then simplify the dynamics as follows


or equivalently


The convergence rate is determined by the transition matrix which has the characteristic polynomial


With the momentum value , all eigenvalues of the transition matrix are equal to each other with the value , giving the fastest convergence.

Appendix C Proof of Theorem 1

To analyze the dynamics, we can perform a change of basis so that three different dimensions evolve independently. To achieve that, we first take the eigendecomposition444Note that we implicitly assume , otherwise the transition matrix is not diagonalizable. of the transition matrix . Then the dynamics can be reformulated as follows:


We first analyze the stochastic term alone. By the identity , we have


In eqn. (22), we append zero vector for convenience. To compute the infinite sum, we first focus on a single term. We have the following update:


Since we only care which totally decide the loss, so we get rid of by merging two updates, which yields a second-order difference equation:


with initial conditions and

. To solve the second-order difference equation, we leverage the Z-transform to get the analytical form. Based on basic manipulation of the Z-transform, we have the Z-domain function


where and are two roots of equation . Then, we use the inverse Z-transform to get :


and therefore


Now, we are ready to compute the infinite sum :


Because and are two roots with , , we have


Now, we analyze the deterministic term. We have


Similar to the analysis of stochastic term, we have the same second-order difference equation


except the initial conditions become . According to Z-transform, we have


Along with eqn. (29), we have


Appendix D More results on the NQM

d.1 Eigenspectra of Neural Networks

Figure 7: Eigenspectra of the K-FAC approximate Fisher matrix of ResNet8 at different training iterations. The model is trained on CIFAR-10 with batch size 3000.

The main objective of this section is to examine the loss surface of modern neural networks in different stages of training in order to justify the assumptions made in NQM. Nevertheless, it is hard to visualize such a high dimensional space. Following recent work [Sagun et al., 2016, Ghorbani et al., 2019], we instead focus on analyzing the eigenspectrum of the Hessian/Fisher matrices. The Hessian/Fisher of the training loss (with respect to the parameters) is crucial in determining many behaviors of neural networks. The eigenvalues of the Hessian/Fisher characterize the local curvature of the loss surface which determines many training behaviors, including first-order methods optimization rates (at least for convex problems.)

To construct the eigenspectrum of the true Fisher matrix, we first leverage the Kronecker-factored approximation of the Fisher to get an estimation of the eigenspectrum, which may shed light upon the true eigenspectrum. Specifically, we train the network with K-FAC and then perform eigen-decomposition on saved Kronecker factors of the Fisher to calculate the eigenvalues.

The eigenspectra are plotted in Figure 7. One interesting observation is that there are only a few large eigenvalues and a few small eigenvalues in the approximate Fisher matrices; the bulk of eigenvalues are in the middle of the spectrum. We also note that after 200 iterations of training the eigenspectrum remains mostly unchanged.

(a) ResNet8
(b) Transformer
Figure 8:

Scatter plots of second moment v.s. variance of gradients.

The gradients are projected onto the Kronecker-factored eigenbasis, which approximates the eigenbasis of the true Fisher. Each point compares the gradient variance and the second moment of the gradient in the direction of an eigenvector of the K-FAC approximated Fisher.

d.2 Gradient Covariance in the Kronecker-Factored Eigenbasis

To verify the assumption in Section 3.4 that and are codiagonalizable, we test it on practical neural networks by comparing the gradient variance to the curvature. This assumption is motivated by theoretical considerations that suggest for neural network training [Martens, 2014]. Ideally, we would like to compare the gradient variance and the curvature of the Fisher in the directions of the eigenvectors of the true Fisher. However, it is typically infeasible to get all these eigenvectors, especially for low curvature directions. To resolve this we instead use the Kronecker-factored eigenbasis [George et al., 2018, Bae et al., 2018, Wang et al., 2019], which is obtained from the K-FAC approximation. For this experiment, we are not relying on this basis being an accurate approximation to the eigendecomposition of the true Fisher; rather, we use the eigenbasis only as a way to obtain a diverse set of directions with both high and low curvature. For a given eigenvector , we project the gradients of each training example onto and compute the gradient variance , as well as the curvature . (The latter quantity can be obtained using matrix-vector products [Schraudolph, 2002].) As shown in Figure 8, the gradient variances closely match the curvature (especially for the ResNet8 model on CIFAR10), validating our assumption that .

d.3 Plots for the Evolution of the First Term in Eqn. (6)

In Section 3.2, we claim that the convergence of momentum SGD for a single dimension is very close to that of plain SGD with an adjusted learning rate (note that we already verified that the steady state risk of momentum SGD matches plain SGD using effective learning rate in Figure 2). Here we verify this argument by comparing them in the NQM. The total risk consists of two terms (eqn. (6)): the first term determines convergence, while the second term (steady state risk) stays constant throughout training. Given that the second stays unchanged, we only plot the first term of eqn. (6) in Figure 9. Note that the values are normalized in the figures. We observe that the convergence dynamics of the two update rules closely match each other. For this experiment we set , but the results are not sensitive to this value.

Figure 9: Comparison in convergence between momentum SGD and SGD with adjusted learning rate. This plot shows values for the first term in eqn. (6) as a function of , which is the scaling between the “effective learning rate” and the true learning rate for momentum SGD. The red curves show the first term when using momentum, while the blue curves show the first term when using plain SGD with the learning rate set to the effective learning rate of momentum.

d.4 Verification of Eigenspectrum

In Section 3.6, we assume the diagonal entries of are . To justify this choice, we compare the K-FAC eigenspectra of ResNet8 to this distribution in Figure 10. The distribution of eigenvalues we chose for in the NQM very closely matches the eigenspectra of the real neural network, validating the assumption that the diagonal entries of are in Section 3.4.

Figure 10: Comparison between K-FAC Fisher eigenspectra and the distribution used in the NQM.

d.5 Effect of Loss Threshold

Figure 11: Number of training steps required to reach a target loss as a function of batch size for different loss threshold values.

Recall that a main objective of this work is to characterize the effects of increasing the batch size on training time, as measured in the number of steps necessary to reach a goal target error/loss. Here we experiment with different loss thresholds to study the relationship between batch size and number of training steps. To obtain the minimal training steps for a given batch size, we do grid search over constant learning rates. Figure 11 shows that increasing the batch size initially decreases the required number of training steps proportionally, but eventually there are diminishing returns, which matches the empirical findings [Golmant et al., 2018, Shallue et al., 2018]. The shape of the curves is characteristically the same for different loss thresholds, though the critical batch size seems to increase for more difficult thresholds.

d.6 Results of Optimal Learning Rate on NQM

(a) Without Momentum
(b) Fixed Momentum
(c) Tuned Momentum
Figure 12: Optimal learning rate v.s. batch size for different preconditioning powers. (a) When momentum is not used, the learning rate increases with batch size until it is limited by the maximum stable learning rate. Larger preconditioning powers reduce the optimal learning rate for the same batch size, thus extending the batch size where the learning rate levels off. (b, c) Fixed (0.9) and tuned momentum values. In (b) and (c), we plot the effective learning rate for momentum, defined as . The dashed lines are the same plots from (a) for easier comparison.

d.7 Final Learning Rate of Different Batch Sizes for PWC Learning Rate Scheme

Figure 13: Final learning rate of the piecewise-constant learning rate scheme v.s. batch size.

In Section 3.6.2, we study the piecewise constant learning rate scheme. The optimal scheme starts with a high learning rate which drops later in training (Figure 2(c)). Recall that for fixed learning rates, we observed that the optimal learning rate scaled linearly with the batch size for small batch sizes, but it is unclear whether there is a similar phenomenon for learning rate decay. In Figure 13, we plot the final learning rate as a function of batch size and show that it also scales linearly with batch size.

Appendix E More Details for Experiments

e.1 Data Sets

The data sets in Table 1 (MNIST, Fashion MNIST, CIFAR10, ImageNet and LM1B) are identical to those of Shallue et al. [2018] (described in their Appendix A.1). For CIFAR10 we used data augmentation (including horizontal flip and random crop), but they did not.

e.2 Model Details

This section provides details of models in Table 1. The models are very similar (and some identical) to those used in Shallue et al. [2018] (described in their Appendix B). Any modifications from them are highlighted in this section.

Simple CNN

consists of 2 convolutional layers with max-pooling followed by 1 fully connected hidden layer. The convolutional layers use 5×5 filters with stride length 1, “same” padding

[Goodfellow et al., 2016]

, and ReLU activation function. Max pooling uses 2×2 windows with stride length 2. Unlike in 

Shallue et al. [2018]

, we did not use any dropout regularization (while they used dropout with probability 0.4 in the fully connected layer). We used 32 and 64 filters in the convolutional layers and 1,024 units in the fully connected layer. This corresponds to the “base” configuration in 

Shallue et al. [2018].

ResNet8 [He et al., 2016]

consists of 7 convolutional layers with residual connections followed by 1 fully connected hidden layer. We used the identical architecture as 

Shallue et al. [2018]. In particular, we did not use batch normalization. The only difference is that we used data augmentation in our experiments.

ResNet32 [He et al., 2016] consists of 31 convolutional layers with residual connections followed by 1 fully connected hidden layer (see Section 4.2 of He et al. [2016]). We replaced batch normalization [Ioffe and Szegedy, 2015] with ghost batch normalization to keep the training objective fixed between batch sizes and to avoid possible negative effects from computing batch normalization statistics over a large number of examples [Hoffer et al., 2017]. We used a ghost batch size of 32 for all experiments. We also applied label smoothing [Szegedy et al., 2016] to regularize the model at training time, which was helpful for larger batch sizes. We set the label smoothing parameter to 0.1 in all experiments. Instead of using weight decay, we applied channel-wise weight normalization by constraining the Frobenius norm of each convolutional channel to be exactly 1, which controls the effective learning rate [Zhang et al., 2019b, van Laarhoven, 2017].

VGG11 [Simonyan and Zisserman, 2015] consists of 8 convolutional layers followed by 1 fully connected hidden layers. as in ResNet32, we used Ghost batch normalization, label smoothing, and channel-wise weight normalization.

Transformer Vaswani et al. [2017]

is a self-attention model. We chose the Transformer model identical to the “base” model described in 

Vaswani et al. [2017], except with only two hidden layers instead of six. This is identical to the “Transformer Shallow” model in Shallue et al. [2018].

e.3 Learning Rate Schedules

This section describes two learning rate schedules mentioned in Table 1: constant schedule and linear decay schedule. Constant schedule simply keeps a fixed learning rate throughout training:

where is the training step index. Linear decay schedule is

where is the initial learning rate, is the rate of decay, and is the number of steps taken to reach the final learning rate. Shallue et al. [2018] experimented with various learning rate schedules and found that linear decay matched performance of the other schedules with fewer hyperparameters to tune. Therefore, we also chose the linear decay schedule, for which we tuned , and .

e.4 Optimizer-Specific Hyperparamters

For momentum SGD, we tuned the momentum . For Adam, we tuned , , and (see Kingma and Ba [2014]). For K-FAC, we tuned damping and the trust region constraint (also known as the KL clipping term) for Transformer, keeping momentum and the moving average parameter for damping ; for all other models, we tuned all four parameters (see Martens and Grosse [2015]).