The large learning rate phase of deep learning: the catapult mechanism

03/04/2020 ∙ by Aitor Lewkowycz, et al. ∙ 0

The choice of initial learning rate can have a profound effect on the performance of deep networks. We present a class of neural networks with solvable training dynamics, and confirm their predictions empirically in practical deep learning settings. The networks exhibit sharply distinct behaviors at small and large learning rates. The two regimes are separated by a phase transition. In the small learning rate phase, training can be understood using the existing theory of infinitely wide neural networks. At large learning rates the model captures qualitatively distinct phenomena, including the convergence of gradient descent dynamics to flatter minima. One key prediction of our model is a narrow range of large, stable learning rates. We find good agreement between our model's predictions and training dynamics in realistic deep learning settings. Furthermore, we find that the optimal performance in such settings is often found in the large learning rate phase. We believe our results shed light on characteristics of models trained at different learning rates. In particular, they fill a gap between existing wide neural network theory, and the nonlinear, large learning rate, training dynamics relevant to practice.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 2

page 16

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Deep learning has shown remarkable success across a variety of machine learning tasks. At the same time, our theoretical understanding of deep learning methods remains limited. In particular, the interplay between training dynamics, properties of the learned network, and generalization remains a largely open problem.

In this work we take a step toward addressing these questions. We present a dynamical mechanism that allows deep networks trained using SGD to find flat minima and achieve superior performance. Our theoretical predictions agree well with empirical results in a variety of deep learning settings. In many cases we are able to predict the regime of learning rates where optimal performance is achieved. Figure 1 summarizes our main results. This work builds on several existing results, which we now review.

[width=3in,trim=0 158 0 0,clip]figures/fig1.pdf (a)[width=3in,trim=0 -5 0 158,clip]figures/fig1.pdf (b)

Figure 1: A summary of our main results. (a) A visualization of gradient descent dynamics derived in our theoretical setup. A 2D slice of parameter space is shown, where lighter color indicates higher loss and dots represents points visited during optimization. Initially, the loss grows rapidly while local curvature decreases. Once curvature is sufficiently low, gradient descent converges to a flat minimum. We call this the catapult effect. See Figures 5 and S1 for more details. (b) Confirmation of our theoretical predictions in a practical deep learning setting. Line shows the test accuracy of a Wide ResNet trained on CIFAR-10 as a function of learning rate, each trained for a fixed number of steps. Dashed lines show our predictions for the boundaries of the large learning rate regime (the catapult phase), where we expect optimal performance to occur. Maximal performance is achieved between the dashed lines, confirming our predictions. See Section 3 for details.

1.1 Large learning rate SGD improves generalization

SGD training with large initial learning rates often leads to improved performance over training with small initial learning rates (see Li et al. (2019); Leclerc & Madry (2020); Xie et al. (2020); Frankle et al. (2020); Jastrzebski et al. (2020)

for recent discussions). It has been suggested that one of the mechanisms underlying the benefit of large learning rates is that noise from stochastic gradient descent leads to flat minima, and that flat minima generalize better than sharp minima

(Hochreiter & Schmidhuber, 1997; Keskar et al., 2016; Smith & Le, 2018; Jiang et al., 2020; Park et al., 2019) (though see Dinh et al. (2017) for discussion of some caveats). According to this suggestion, training with a large learning rate (or with a small batch size) can improve performance because it leads to more stochasticity during training (Mandt et al., 2017; Smith et al., 2017; Smith & Le, 2018; Smith et al., 2018).

We will develop a connection between large learning rate and flatness of minima in models trained via SGD. Unlike the relationship explored in most previous work though, this connection is not driven by SGD noise, but arises solely as a result of training with a large initial learning rate, and holds even for full batch gradient descent.

1.2 The existing theory of infinite width networks is insufficient to describe large learning rates

A recent body of work has investigated the gradient descent dynamics of deep networks in the limit of infinite width (Daniely, 2017; Jacot et al., 2018; Lee et al., 2019; Du et al., 2019; Zou et al., 2018; Allen-Zhu et al., 2019; Li & Liang, 2018; Chizat et al., 2019; Mei et al., 2018; Rotskoff & Vanden-Eijnden, 2018; Sirignano & Spiliopoulos, 2018; Woodworth et al., 2019; Naveh et al., ). Of particular relevance is the work by Jacot et al. (2018) showing that gradient flow in the space of functions is governed by a dynamical quantity called the Neural Tangent Kernel (NTK) which is fixed at its initial value in this limit. Lee et al. (2019) showed this result is equivalent to training the linearization of a model around its initialization in parameter space. Finally, moving away from the strict limit of infinite width by working perturbatively, Dyer & Gur-Ari (2020); Huang & Yau (2019) introduced an approach to computing the finite-width corrections to network evolution.

Despite this progress, it seems these results are insufficient to capture the full dynamics of deep networks, as well as their superior performance, in regimes applicable to practice. Prior work has focused on comparisons between various infinite-width kernels associated with deep networks and their finite-width, SGD-trained counterparts (Lee et al., 2018; Novak et al., 2019; Arora et al., 2019)

. Specific findings vary depending on precise choices for architecture and hyperparameters. However, dramatic performance gaps are consistently observed between non-linear CNNs and their limiting kernels, implying that the theory is not sufficient to explain the performance of deep networks in this realistic setup. Furthermore, some hyperparameter settings in finite-width models have no known analogue in the infinite width limit, and it is these settings that often lead to optimal performance.

In particular, finite width networks are often trained with large learning rates that would cause divergence for infinite width linearized models. Further, these large learning rates cause finite width networks to converge to flat minima. For infinite width linearized models, trained with MSE loss, all minima have the same curvature, and the notion of flat minima does not apply. We argue that the reduction in curvature during optimization, and support for learning rates that are infeasible for infinite width linearized models, may thus partially explain performance gaps observed between linear and non-linear models.

1.3 Our contribution: three learning rate regimes

In this work, we identify a dynamical mechanism which enables finite-width networks to stably access large learning rates. We show that this mechanism causes training to converge to flatter minima and is associated with improved generalization. We further show that this same mechanism can describe the behavior of infinite width networks, if training time is increased along with network width.

This new mechanism enables a characterization of gradient descent training in terms of three learning rate regimes, or phases: the lazy phase, the catapult phase, and the divergent phase. In Section 2 we analytically derive the behavior in these three learning rate regimes for one hidden layer linear networks with large but finite width, trained with MSE loss. We confirm experimentally in Section 3 that these phases also apply to deep nonlinear fully- connected, convolutional, and residual architectures. In Section 4 we study additional predictions of the analytic solution.

We now summarize all three phases, using to indicate the learning rate, and to indicate the initial curvature (defined precisely in Section 2.1). The phase is determined by the curvature at initialization and by the learning rate, despite the fact that the curvature may change significantly during training. Based on the experimental evidence we expect the behavior described below to apply in typical deep learning settings, when training sufficiently wide networks using SGD.

Lazy phase: .

For sufficiently small learning rate, the curvature at training step remains constant during the initial part of training. The model behaves (loosely) as a model linearized about its initial parameters (Lee et al., 2019); this becomes exact in the infinite width limit, where these dynamics are sometimes called lazy training (Jacot et al., 2018; Lee et al., 2019; Du et al., 2019; Li & Liang, 2018; Zou et al., 2018; Allen-Zhu et al., 2019; Chizat et al., 2019; Dyer & Gur-Ari, 2020). For a discussion of trainability and the connection to the NTK in the lazy phase see Xiao et al. (2019).

Catapult phase: .

In this phase, the curvature at initialization is too high for training to converge to a nearby point, and the linear approximation quickly breaks down. Optimization begins with a period of exponential growth in the loss, coupled with a rapid decrease in curvature, until curvature stabilizes at a value . Once the curvature drops below , training converges, ultimately reaching a minimum that is flatter than those found in the lazy phase. This initial period lasts for a number of training steps that is of order , where

is the network width, and is therefore quite short for realistic networks (often lasting less than a single epoch). Optimal performance is often achieved when the initial learning rate is in this range. The gradient descent dynamics in this phase are visualized in SM Figure 

S1 and in Figure 1.

The maximum learning rate is approximately given by , where

is an architecture-dependent constant. Empirically, we find that this constant depends strongly on the non-linearity but only weakly on other aspects of the architecture. For networks with ReLU non-linearity we find empirically that

. For the theoretical model, we show that .

Divergent phase: .

When the learning rate is above the maximum learning rate of the model, the loss diverges and the model does not train.

2 Theoretical results

We now present our main theoretical result, an analysis of gradient descent dynamics for a neural network with large but finite width.

Given a network function with model parameters , and a training set , the MSE loss is

(1)

The NTK is defined by

(2)

We denote by

the maximum eigenvalue of the kernel. In large width models,

provides a local measure of the loss landscape curvature that is similar to the top eigenvalue of the Hessian (Dyer & Gur-Ari, 2020).

In this section, we will consider a network with one hidden layer and linear activations, where the network function is given by

(3)

Here

is the width (number of neurons in the hidden layer),

and are the model parameters (collectively denoted ), and is the training input. At initialization, the weights are drawn from .

2.1 Warmup: a simplified model

Before analyzing the dynamics of the model, we analyze a simpler setting which captures the most important aspects of the full solution. Consider a dataset with 1D inputs, and with a single training sample with label . The network function evaluated on this input is then , with , and the loss is . The gradient descent equations at training step are

(4)

Next, consider the update equations in function space. These can be written in terms of the Neural Tangent Kernel. For this model, the kernel evaluated on the training set is a scalar which is equal to , its top eigenvalue, and is given by

(5)

At initialization, both and scale as with width. The following update equations for and at step can be derived from (4).

(6)
(7)

It is important to note that these are the exact update equations for this model, and that no higher-order terms were neglected. We now analyze these dynamical equations assuming the width is large. Two learning rates that will be important in the analysis are and . In terms of the notation introduced above, the architecture-dependent constant that determines that maximum learning rate in this model is .

2.1.1 Lazy phase

Taking the strict infinite width limit, equations (6) and (7) become

(8)

When , remains constant throughout training. This is a special case of NTK dynamics, where the kernel is constant and the network evolves as a linear model (Lee et al., 2019). The function and the loss both shrink to zero because the multiplicative factor obeys . This convergence happens in steps.

2.1.2 Catapult phase

When , the loss diverges in the infinite width limit. Indeed, from (8) we see that the kernel is constant in the limit, while receives multiplicative updates where . This is the well known instability of gradient descent dynamics for linear models with MSE loss. However, the underlying model is not linear in its parameters, and finite width contributions turn out to be important. We therefore relax the infinite width limit and analyze equations (6,7) for large but finite width, .

First, note that by assumption, and therefore the (additive) kernel updates are negative for all . During early training, grows (as in the infinite width limit) while remains constant up to small updates. After steps, grows to order . At this point, the kernel updates are no longer negligible because is of order . The kernel receives negative, non-negligible updates while both and the loss continue to grow (for now, we ignore the term in (6) with an explicit dependence). This continues until the kernel is sufficiently small that the condition is met.111The bound is not exact because of the term we neglected. We call this curvature-reduction effect the catapult effect. Beyond this point, holds, shrinks, and the loss converges to a global minimum. The dependence of the steps until optimization converges is .

It remains to show that the term in (6) with an explicit dependence does not affect these conclusions. Once grows to order , this term is no longer negligible and can cause the multiplicative factor in front of to become smaller than 1 in absolute value, causing to start shrinking. However, once shrinks sufficiently this term again becomes negligible. Therefore, the loss will not converge to zero unless the curvature eventually drops below . Conversely, notice that this term cannot cause to diverge for learning rates below . Indeed, if this were to happen then equation (7) would drive to negative values, leading to a contradiction. This completes the analysis in this phase.

Let us make a few comments about the catapult phase.

It is important for the analysis that we take a modified large width limit, in which the number of training steps grows like as becomes large. This is different than the large width limit commonly studied in the literature, in which the number of steps is kept fixed as the width is taken large. When using this modified limit, the analysis above holds even in the limit. Note as well that the catapult effect takes place over steps, and for practical networks will occur within the first 100 steps or so of training.

In the catapult phase, the kernel at the end of training is smaller by an order amount compared with its value at initialization. The kernel provides a local measure of the loss curvature. Therefore, the minima that SGD finds in the catapult phase are flatter than those it finds in the lazy phase. Contrast this situation, in which the kernel receives non-negligible updates, with the conclusions of Jacot et al. (2018) where the kernel is constant throughout training. The difference is due to the large learning rate, which leads to a breakdown of the linearized approximation even at large width.

Figure 5 illustrates the dynamics in the catapult phase. For learning rates we observe the catapult effect: the loss goes up before converging to zero. The curvature exhibits the expected sharp transitions as a function of the learning rate: it is constant in the lazy phase, decreases in the catapult phase, and diverges for .

(a)
(b)
(c)
Figure 5: Empirical results for the gradient descent dynamics of the warmup model with , for which . (a) Training loss for different learning rates. (b) Maximum NTK eigenvalue as a function of time. For , decreases rapidly to a fixed value. (c) Maximum NTK eigenvalue at . The shaded area indicates learning rates for which training diverges empirically. The results are presented as a function of (rather than ) for convenience.

2.1.3 Divergent phase

Completing the analysis of this model, when the loss diverges because the kernel receives positive updates, accelerating the rate of growth of the function. Therefore, is the maximum learning rate of the model.

2.2 Full model

We now turn to analyzing the model presented at the beginning of this section, with -dimensional inputs and training samples with general labels. The full analysis is presented in SM Section D.1; here we summarize the argument. The conclusions are essentially the same as those of the warmup model.

We introduce the notation for the function evaluated on a training sample, for the error, and for the kernel elements. We will treat

evaluated on the training set as vectors in

, whose elements are . Consider the following update equation for the error, which can be derived from the update equations for the parameters. Note that this is the exact update equation for this model; no higher-order terms were neglected.

(9)

Here, , and all variables are implicitly evaluated at step unless specified otherwise.

We again take the modified large width limit , allowing the number of steps to scale logarithmically in the width. At initialization, , , and are all of order . We now analyze the gradient descent dynamics as a function of the learning rate.

The maximum eigenvalue of the kernel at step is . When , the norm shrinks to zero in time while the kernel receives corrections. Therefore, in the limit the kernel remains constant until convergence. This is a special case of the NTK result (Jacot et al., 2018), and the model evolves as a linear model.

Next, suppose that . Early during training

grows, with the fastest growth taking place along the direction of the top kernel eigenvector,

. During this part of training the kernel receives updates, and so does not change much. As a result, becomes aligned with . In addition, becomes close to because grows while the label is constant. We therefore consider the following approximate update equations for and for the maximum eigenvalue , which can be approximated by .

(10)
(11)

We note in passing the similarity between these equations and (6), (7). We see that once and become of order , receives non-negligible negative corrections of order . This evolution continues until , after which the error converges to zero. Finally, if , the error grows while receives positive updates, and the loss diverges. This concludes the discussion of the theoretical model; further details can be found in Section 4 and in SM Section D.1.

3 Experimental results

In this section we test the extent to which the behavior of our theoretical model describes the dynamics of deep networks in practical settings. The theoretical results of Section 2, describing distinct learning rate phases, are not guaranteed to hold beyond the model analyzed there. We treat these results as predictions to be tested empirically, including the values and of the learning rates that separate the three phases.

In a variety of deep learning settings, we find clear evidence of the different phases predicted by the model. The experiments all use MSE loss, sufficiently wide networks, and SGD222While our theoretical framework focused on (full-batch) gradient descent, we expect these the phases to happen at similar points for SGD as long as evolution is not noise dominated, in which case we expect all phases to be shifted towards smaller learning rates.. Parameters such as network architecture, choice of non-linearity, weight parameterization, and regularization, do not significantly affect this conclusion.

In terms of the learning rates that determine the location of the transitions, the only modification needed to obtain good agreement with experiment is to replace the theoretical maximum learning rate, , with a 1-parameter function , where is an architecture-dependent constant. We find that for all network that use ReLU non-linearity, and it seems this parameter depends only weakly on other details of the architecture. We find the level of agreement with the experiments surprising, given that our theoretical model involves a shallow network without non-linearities.

Building on the observed correlation between lower curvature and generalization performance (Keskar et al., 2016; Jiang et al., 2020), we conjecture that optimal performance occurs in the large learning rate (catapult) phase, where the loss converges to a flatter minimum. For a fixed amount of computational budget, we find that this conjecture holds in all cases we tried. Even when comparing different learning rates trained for a fixed amount of physical time , we find that performance of models trained in the catapult phase either matches or exceeds that of models trained in the lazy phase.

3.1 Early time curvature dynamics

Our theoretical model makes detailed predictions for the gradient descent evolution of , the top eigenvalue of the NTK. Here we test these predictions against empirical results in a variety of deep learning models (see the Supplement for additional experimental results).

Figure 12 shows during the early part of training for two deep learning settings. The results are compared against the theoretical predictions of a phase transition at , and a maximum learning rate of . Here is the top eigenvalue of the empirical NTK at initialization.

For learning rates , we find that is independent of the learning rate and constant throughout training, as expected in the lazy phase. For we find that decreases during training to below , matching the predicted behavior in the catapult phase (note that in the Wide ResNet example, initially increases before reaching its stable value).

The large learning rate behavior predicted by the model appears to persist up to the maximum learning rate, which is larger in these experiments than in the theoretical model. In these and other experiments involving ReLU networks, we find that is a good predictor of the maximum learning rate (in the SM C.4 we discuss other nonlinearities). We conjecture that this is the typical maximum learning rate of networks with ReLU non-linearities.

Figure 12 also shows the loss initially increasing before converging in the catapult phase, confirming another prediction of the model. This transient behavior is very short, taking less than 10 steps to complete.

(a)
(b)
(c)
(d)
(e)
(f)
Figure 12: Early time dynamics. (a,b,c) A 3 hidden layer fully-connected network with ReLU non-linearity trained on MNIST (). (d,e,f) Wide ResNet 28-10 trained on CIFAR-10 (). Both networks are trained with vanilla SGD; for more experimental details see SM Section A. (a,d) Early time dynamics of the training loss for learning rates in the linear and catapult phases. (b,e) Early time dynamics of the curvature for learning rates in the linear and catapult phase. (c,f) measured at (for FC) and (for WRN), as a function of learning rate, compared with theoretical predictions for the locations of phase transitions. Training diverges for learning rates in the shaded region.

3.2 Generalization performance

We now consider the performance of trained models in the different phases discussed in this work. Keskar et al. (2016) observed a correlation between the flatness of a minimum found by SGD and the generalization performance (see Jiang et al. (2020) for additional empirical confirmation of this correlation). In this work, we showed that the minima SGD finds are flatter in the catapult phase, as measured by the top kernel eigenvalue. Our measure of flatness differs from that of Keskar et al. (2016), but we expect that these measures are correlated.

We therefore conjecture that optimal performance is often obtained for learning rates above and below the maximum learning rate.

In this section we test this conjecture empirically. We find that performance in the large learning rate range always matches or exceeds the performance when . For a fixed compute budget, we find that the best performance is always found in the catapult phase.

Figure 14 shows the accuracy as a function of the learning rate for a fully-connected ReLU network trained on a subset of MNIST. We find that the optimal performance is achieved above and close to , the expected maximum learning rate.

(a)
Figure 14: Final accuracy versus learning rate for a fully-connected 1 hidden layer ReLU network, trained on 512 samples of MNIST with full-batch gradient descent until training accuracy reaches 1 or 700k physical steps (see SM Section A for details). We used a subset of samples to accentuate the performance difference between phases. The optimal performance is obtained when the learning rate is above , and close to .

Next, Figure 19 shows the performance of a convolutional network and a Wide ResNet (WRN) trained on CIFAR-10. The experimental setup, which we now describe, was chosen to ensure a fair comparison of the performance across different learning rates. The network is trained with different initial learning rates, followed by a decay at a fixed physical time to the same final learning rate. This schedule is introduced in order to ensure that all experiments have the same level of SGD noise toward the end of training.

(a)
(b)
(c)
(d)
Figure 19: Test accuracy vs learning rate for (a,b) a CNN trained on CIFAR-10 using SGD with batch size 256 and regularization () and (c,d) WRN28-10 trained on CIFAR-10 using SGD with batch size 1024, regularization, and data augmentation (); see SM A for details. (a,c) have a fixed compute budget: (a) 437k steps and (b) 12k steps. (b,d) have been evolved for a fixed amount of physical time: (b) was evolved for 475 steps (purple) and evolved for 50k more steps at learning rate (red) and (d) was evolved for steps with learning rate (purple) and then evolved for 4800 more steps at learning rate (red). In all cases, optimal performance is achieved above and close to the expected maximum learning rate, in agreement with our predictions.

We present results using two different stopping conditions. In Figure (a)a, (c)c, all models were trained for a fixed number of training steps. We find a significant performance gap between small and large learning rates, with the optimal learning rate above and close to . Beyond this learning rate, performance drops sharply.

The fixed compute stopping condition, while of practical interest, biases the results in favor of large learning rates. Indeed, in the limit of small learning rate, training for a fixed number of steps will keep the model close to initialization. To control for this, in Figure (b)b,(d)d models were trained for the same amount of physical time . For the CNN of figure (b)b, decaying the learning rate does not have a significant effect on performance and we observe that performance is flat up to , and there is no correlation between our measure of curvature and generalization performance. Figure (d)d shows the analogous experiment for WRN. When decaying the learning rate toward the end of training to control for SGD noise, we find that optimal performance is achieved above . In all these cases, is a good predictor of the maximal learning rate, despite significant differences in the architectures. Notice that by tuning the learning rate to the catapult phase, we are able to achieve performance using MSE loss, and without momentum, that is competitive with the best reported results for this model (Zagoruyko & Komodakis, 2016).

In SM B.1, we present additional results for WRN on CIFAR-100, with similar conclusions as those for WRN on CIFAR-10.

4 Additional properties of the model

So far we have focused on the generalization performance and curvature of the large learning rate phase. Here we investigate additional predictions made by our model.

4.1 Restoration of linear dynamics

One striking prediction of the model is that after a period of excursion, the logit differences settle back to

values, the NTK stops changing, and evolution is again well approximated by a linear model with constant kernel at large width.

We speculate that the return to linearity and constancy of the kernel may hold asymptotically in width for more general models for a range of learning rates above . We test this by evolving the model for order steps until the catapult effect is over, linearizing the model, and comparing the evolution of the two models beyond this point. Figure 20 shows an example of this. At fixed width, the accuracy of the linear and non-linear networks match for a range of learning rates above the transition up to . We present additional evidence for this asymptotic linearization behavior in the Supplement.

Figure 20: Evidence for linear dynamics after the catapult effect is over. Here we show the same model as in Figure 14 with the addition of models linearized at step and another linearized at step . We observe that the model linearized after steps tracks the non-linear performance in the catapult phase up to .

4.2 Non-perturbative phase transition

The large width analysis of the small learning rate phase has been the subject of much work. In this phase, at infinite width, the network map evolves as a linear random features model, , where is the function of the linearized model. At large but finite width, corrections to this linear evolution can be systematically incorporated via a perturbative expansion (Taylor expansion) around infinite width (Dyer & Gur-Ari, 2020; Huang & Yau, 2019).

(12)

The evolution equations (10) and (11) of the solvable model are an example of this. At large width and in the small learning rate phase, the terms are suppressed for all times. In contrast, the leading order dynamics of diverge when , and so the true evolution cannot be described by the linear model. Indeed, the logits grow to and thus all terms in (10) and (11) are of the same order. Similarly, the growth observed empirically in the catapult phase for more general models cannot be described by truncating the series (12) at any order, because the terms all become comparable.

5 Discussion

In this work we took a step toward understanding the role of large learning rates in deep learning. We presented a dynamical mechanism that allows deep networks to be trained at larger learning rates than those accessible to their linear counterparts. For MSE loss, linear model training diverges when the learning rate is above the critical value , where is the curvature at initialization. We showed that deep networks can train for larger learning rates by navigating to an area of the landscape that has sufficiently low curvature. Perhaps counterintuitively, training in this regime involves an initial period during which the loss increases before converging to its final, small value. We call this the catapult effect.

5.1 A tractable model illustrating catapult dynamics

These observations are made concrete in our theoretical model, where we fully analyze the gradient descent dynamics as a function of the learning rate. The analysis involves a modified large width limit, in which both the width and training time are taken to be large. Sweeping the learning rate from small to large, and working in the limit, we find sharp transitions from a lazy phase where linearized model training is stable, to a catapult phase in which only the full model converges, and finally to a divergent phase in which training is unstable. These transitions have the hallmarks of phase transitions that commonly appear in physical systems such as ferromagnets or water, as one changes parameters such as temperature. In particular, these transitions are non-perturbative: a Taylor series expansion of the linearized model that takes into account finite width corrections is not sufficient to describe the behavior beyond the critical learning rate.

We derive the learning rates at which these transitions occur as a function of the curvature at initialization. We then treat these theoretical results as predictions, to be tested beyond the regime where they are guaranteed to hold, and find good quantitative agreement with empirical results across a variety of realistic deep learning settings.

We find it striking that a relatively simple theoretical model can correctly predict the behavior of realistic deep learning models. In particular, we conjecture that the maximum learning rate is typically a simple function of the curvature at initialization, with a single parameter that seems to depend only on the non-linearity. For ReLU networks, we conjecture that the maximum learning rate is approximately , which we confirm in many cases.

5.2 Reducing misalignment of activations and gradients

The catapult dynamics for the simplified model in Section 2.1 reduce curvature by shrinking the component of the first layer weights which is orthogonal to the second layer weights , and shrinking the component of the second layer weights which is orthogonal to the first layer weights . We can rewrite the simplified model in terms of a hidden layer , where . The gradient with respect to this hidden layer is . These hidden layer gradients thus point in the same direction as , while the hidden activations point in the same direction as . An alternative interpretation of the catapult dynamics is then that they reduce the components of and which are orthogonal to each other. The catapult dynamics thus serve, in this simplified model, to reduce the misalignment between feedforward activations

, and backpropagated gradients

. We hypothesize that this reduction of misalignment between activations and gradients may be a feature of large learning rates and catapult dynamics in deep, as well as shallow, networks. We further hypothesize that it may play a directly beneficial role in generalization, for instance by making the model output less sensitive to orthogonal, out-of-distribution, perturbations of activations.

5.3 Catapult dynamics often improve generalization

Our results shed light on the regularizing effect of training at large learning rates. The effect presented here is independent of the regularizing effect of stochastic gradient noise, which has been studied extensively. Building on previous works, we noted the observed correlation between flatness and generalization performance. Based on these observations, we expect the optimal performance to often occur for learning rates larger than , where the linearized model is unstable. Observing this effect required controlling for several confounding factors that affect the comparison of performance between different learning rates. Under a fair comparison, and also for a fixed compute budget, we find that this expectation holds in practice.

5.4 Beyond infinite linear models

One outcome of our work is to address the performance gap between ordinary neural networks, and linear models inspired by the theory of wide networks. Optimal performance is often obtained at large learning rates which are inaccessible to linearized models. In such cases, we expect the performance gap to persist even at arbitrarily large widths. We hope our work can further improve the understanding of deep learning methods.

5.5 Other open questions

There are several remaining open questions. While the model predicts a maximum learning rate of , for models with ReLU activations we find that the maximum learning rate is consistently higher. This may be due to a separate dynamical curvature-reduction mechanism that relies on ReLU. In addition, we do not explore the degree to which our results extend to softmax classification. While we expect qualitatively similar behavior there, the non-constant Hessian of the softmax cross entropy makes controlled experiments more challenging. Similarly, behavior for other optimizers such as SGD with momentum may differ. For example, the maximum learning rate when training a linear model is larger for gradient descent with momentum than for vanilla gradient descent, and therefore the transition to the catapult phase (if it exists) will occur at a higher learning rate. We leave these questions to future work.

Acknowledgements

The authors would like to thank Kyle Aitken, Dar Gilboa, Justin Gilmer, Boris Hanin, Tengyu Ma, Andrea Montanari, and Behnam Neyshabur for useful discussions. We would also like to thank Jaehoon Lee for early discussions about empirical properties of the lazy phase.

References

Supplementary materials

Appendix A Experimental details

We are using JAX (Bradbury et al., 2018) and the Neural Tangents Library for our experiments (Novak et al., 2020).

All the models have been trained with Mean Squared Error normalized as , where is the number of classes and are one-targets.

In a similar way, we have normalized the NTK as so that the eigenvalues of the NTK are the same as the non-zero eigenvalues of the Fisher information: .

In our experiments we measure the top eigenvalue of the NTK using Lanczos’ algorithm. We construct the NTK on a small batch of data, typically several hundred samples, compute the top eigenvalue, and then average over batches. In this work, we do not focus on precision aspects such as fluctuations in the top eigenvalue across batches.

All experiments that compare different learning rates use the same seed for the weights at initialization and we consider only one such initialization (unless otherwise stated) although we have not seen much variance in the phenomena described. We let

denote the constant (width-independent) coefficient of the standard deviation of the weight and bias initializations, respectively.

Here we describe experimental settings specific to a figure.

Figure (a)a,(b)b,(c)c. Fully connected, three hidden layers , ReLU non-linearity trained using SGD (no momentum) on MNIST. Batch size, using NTK normalization, .

Figures (d)d,(e)e,(f)f. Wide ResNet 28-18 trained on CIFAR10 with SGD (no momentum). Batch size of , LeCun initialization with , .

Figures 14,20 Fully connected network with one hidden layer and ReLU non-linearity trained on 512 samples of MNIST with SGD (no momentum). Batch size of , NTK initialization with .

Figures (a)a,(b)b. The convolutional network has the following architecture: . denotes a fully-connected layer with output dimension .

denote convolutional layers with ’SAME’ or ’VALID’ padding and

filters, respectively; all convolutional layers use

filters. MaxPool((2,2), ’VALID’) performs max pooling with ’VALID’ padding and a (2,2) window size. LeCun initialization is used, with the standard deviation of the weights and biases drawn as

, , respectively. Trained on CIFAR-10 with SGD, batch size of 256 and L2 regularization = 0.001.

Figures 1, (c)c,(d)d. Wide ResNet on CIFAR10 using SGD (no momentum). Training on v3-8 TPUs with a total batch size of (and per device batch size of ). They all use regularization, LeCun initialization with . There is also data augmentation: we use flip, crop and mixup. With softmax classification, these models can get test accuracy of if one uses cosine decay, so we don’t observe a big performance decay due to using MSE. Furthermore, we are using JAX’s implementation of Batch Norm which doesn’t keep track of training batch statistics for test mode evaluation. We have not hyperparameter tuned for learning rates nor regularization parameter.

Figures S4,S6. Wide ResNet on CIFAR100 using SGD (no momentum). Same setting as figure (c)c, (d)d except for the different dataset, different L2 regularization and label smoothing (we have subtracted from the target one-hot labels).

Figure S18. Two hidden layer, ReLU network for one data point .

Figure S27. Fully connected network with two hidden layers and tanh non-linearity trained on MNIST with SGD (no momentum). Batch size of , LeCun initialization with .

Figure (a)a. Two-hidden layer fully connected network trained on MNIST with batch size , NTK normalization with . Trained using both momenta and vanilla SGD for three different non-linearities: tanh, ReLU and identity (no non-linearity). The learning rate for each non-linearity was chosen to correspond to .

Rest of SM figures. Small modifications of experiments in previous figures, specified in captions.

[width=0.75]figures/loss_and_three_phases.pdf (a)(b)(c)(d)(e)

Figure S1: Visualization of training dynamics in all three phases. In the lazy phase, the network is approximately linear in its parameters, and converges exponentially to a global minimum. In the catapult phase, the loss initially grows, while the weight norm and curvature decrease. Once the curvature is low enough, optimization converges. In the divergent phase, both the loss and parameter magnitudes diverge. (a)-(d) Loss surface and training dynamics visualized in a 2d linear subspace. The network has a single hidden layer with width , linear activations, and is trained with MSE loss on a single 1D sample with label . The parameter subspace is defined by , where and

are orthonormal vectors,

are the weight vectors, and are the coordinates in the subspace. If initialized in this 2d subspace, and remain in the subspace throughout training, and so training dynamics can be fully visualized with a two dimensional plot. (e) Visualization of the loss surface and training dynamics in terms of a nonlinear reparameterization, providing interpretable properties: x-axis correlation between weight vectors, y-axis curvature . The trajectory shown is identical to that in (c), and in Figure 1.

Appendix B Experimental results: Late time performance

b.1 CIFAR-100 performance

We can also repeat the performance experiments for CIFAR-100 and the same Wide ResNet 28-10 setup. In this case, using MSE and SGD we require to evolve the system for longer times, which requires a smaller regularization. We didn’t tune for it, but found that works. With only one decay we can get within of the Zagoruyko & Komodakis (2016) performance that used softmax classification and two learning rate decays. However, evolution for longer time is needed: we found that different learning rates converge at physical epochs. Similar to the main text experiments, we observe that if we decay after evolving for the same amount of physical epochs, larger learning rates do better. See figure S4.

(a)
(b)
Figure S4: Test accuracy vs learning rate for WRN28-10 and CIFAR100 with vanilla SGD, regularization, data augmentation, label smoothing and batch size 1024. The critical learning rate is . (a) Evolved for 38400 steps. (b) Evolved for 96000 steps with learning rate (blue) and then evolved for 7200 more steps at learning rate (red).

b.2 Different learning rates converge at the same physical time

We can also plot the test accuracy versus physical time for different learning rates to show that for vanilla SGD, the performance curves of different learning rates are basically on top of each other if we plot them in physical time, which is why we find that the fair comparison between learning rates should be at the same physical time.

We have picked a subset of learning rates of the previous WRN28-18 CIFAR100 experiment of SM B.1. In figure S6, we see how even if the curves are slightly different they converge to roughly the same accuracy. The only curve which is slightly different is which is a rather high learning rate (close to ).

(a)
Figure S6: Test accuracy vs physical time for different learning rates in the WRN CIFAR100 experiment of the previous section B.1

b.3 Comparison of learning rates for different regularization for WRN28-10 on CIFAR10

Even if in the main section we have considered a model with fixed regularization, we can study the effect without or with a different value. In these two examples, we will be considering the same setup as figures (c)c,(d)d.

Without regularization, we see that the larger learning rate does better even in the absence of learning rate decay, although training takes a really long time. In our experience, comparing this setup with state of the art, regularization makes the experiment take longer before convergence but does not influence performance much.

Figure S7: WRN28-10 on CIFAR10 without . Same setup as (d)d but evolved for longer times.

In the presence of regularization we picked the particular value in order to make sure that our conclusion is not dependent on the choice of , the only hyperparameter (other than ), we have considered a larger . We see that the optimal performance in physical time is also peaked in the catapult phase, although the difference here is smaller.

(a)
(b)
Figure S10: Test accuracies for a larger CIFAR10 experiment like that of the main section. (a) WRN CIFAR-10 7200 steps as in figure (c)c. (b) WRN CIFAR10 2400 physical steps and then 4800 more steps at learning rate as in figure (d)d.

b.4 Training accuracy plots

The training accuracies of the previous experiments are shown in figure S15.

(a)
(b)
(c)
(d)
Figure S15: Training accuracies for the performance experiments. Smaller learning rates have higher training accuracy when compared in physical time. However, they still perform worse for a fixed number of steps. (a) WRN CIFAR-10 12000 steps as in figure (c)c. (b) WRN CIFAR10 3360 physical steps as in figure (d)d. (c) WRN CIFAR100 steps as in figure (a)a.(d) WRN CIFAR100 physical steps as in figure (b)b.

Appendix C Experimental results: Early time dynamics

c.1 ReLU activations for the simple model

In the main text we have been using ReLU non-linearities. Compared with the simple model with no non-linearities, ReLU networks have a broader trainability regime after . It looks like these networks generically well train until . This is a generic feature of deep ReLU networks and can be already observed for the model of section with a target , two hidden layers and a ReLU non-linearity: , as shown in figure S18). In this single sample context for , the loss doesn’t diverge but the neurons die and end up giving the trivial function. For deep networks with more than one hidden layer and multiple samples, as discussed in the main text, we observe that the loss diverges after .

(a)
(b)
Figure S18: Simple model ReLU non-linearity (). (b) is evaluated at physical time .

c.2 Momenta

The effect of the optimizer also affects these dynamics. If we consider a similar setup with momenta, first we expect that a linear model converges in a broader range . For smooth non-linearities, we observe that for , the is constant. However this is not true for ReLU, see figure (a)a. In fact, for ReLu networks, we observe that there is a small learning rate, roughly , below which the time dynamics of is similar (but non-constant). However, for , there are strong time dynamics, we illustrate this in figure (b)b with a 3 hidden layer ReLu network.

(a)
(b)
Figure S21: (a) Evolution of the normalized curvature for FC connected networks evolved with momenta (same networks with SGD with dashed line for reference) evolved for . We observe that ReLU networks evolved with momenta doesn’t have a constant kernel in the naive ‘lazy’ phase. (b) Same setup as the FC network of figure 12 with momenta : fully connected, three hidden layers , ReLU non-linearity. is slightly different due to variations at initialization.

c.3 Effect of regularization to early time dynamics

We don’t expect regularization to affect the early time dynamics, but because of the strong rearrangement that goes on in the first steps, it could potentially have a non-trivial effect; among other things, the Hessian spectrum necessarily is decaying. We can see how the dynamics that drives the rearrangement is roughly the same, even in the maximum eigenvalue at early times is decreasing slowly.

(a)
(b)
Figure S24: Same WRN as figure 12d,f with regularization. Dynamics in physical steps of the and vs . a) , b) at physical time

c.4 Tanh activations

We observe that for Tanh activation, is closer to the simple model expectation , see figure S27.

(a)
(b)
Figure S27: Maximum NTK eigenvalue at early times for a 2 hidden layer fully connected network with tanh non-linearity trained on MNIST, with . (a) Early time dynamics of the curvature for learning rates in the linear and catapult phase. (b) measured at .

c.5 WRN NTK Normalization

As illustrated in the text in figures we also see this behaviour for NTK normalization. For completeness we include the WRN model with NTK normalization. From the linearized intuition, we expect the phases to also be determined by the quantity , independently of the normalization. Figure S30 has the same setup as in figure 12.

(a)
(b)
Figure S30: Same as figures (e)e,(f)f but with NTK normalization. a,b) Wide Resnet 28-10. , vs at physical time

Appendix D Theoretical details

d.1 Full model analysis

Here we provide additional details on the theoretical analysis of the full model in Section 2.2. The gradient descent update equations are

(S1)

and

(S2)

The update equations for the error and kernel evaluated on training set inputs are

(S3)
(S4)
(S5)

Where . We now consider the dynamics of the kernel projected onto the direction, which is given by

(S6)

Let us now analyze the phase structure of (S3) and (S6). For now, we neglect the last term on the right-hand side of (S3) (at initialization this term is of order and is negligible at large width). Let be the maximal eigenvalue of the kernel at initialization, and let be the corresponding eigenvector. Notice that projected onto the top eigenvector evolves as

(S7)
Lazy phase.

When , we see that shrinks during training. The kernel updates are of order , while convergence happens in order steps. Therefore the kernel does not change by much during training. This is a special case of the NTK result (Jacot et al., 2018). Effectively, the model evolves as a linear model in this phase.

Catapult phase.

When , grows exponentially fast, and it grows fastest in the direction. Therefore, the vector becomes aligned with after a number of steps that is of order . Also, itself grows quickly while the label is constant, and so we find that after a similar number of steps. When these approximations hold, notice that . From equation (S6) we can then derive an approximate equation for the evolution of the top NTK eigenvalue.

(S8)

While grows exponentially fast, so will . When becomes of order , the updates to the top eigenvalue become of order (and negative), causing to decrease by a non-negligible amount. This will continue until , at which point will start converging to zero. Eventually, after a number of steps of order , gradient descent will converge to a global minimum that has a lower curvature than the curvature at initialization.

The justification for dropping the order term in (S7) was explained in the warmup model: While this term may affect the details of the dynamics, eventually the maximum kernel eigenvalue must drop below for the component of the error (and therefore for the loss) to converge to zero.

Divergent phase.

When , both and will grow, and optimization will diverge. Therefore, is the maximum learning rate for this model.

Appendix E Model dynamics close to the critical learning rate

Here we consider the gradient descent dynamics of the model analyzed in Section 2, for learning rates that are close to the critical point . The analysis reveals that the gradient descent dynamics of the model are qualitatively different above and below this point. For example, the loss decreases monotonically during training when , but not when . In this section we show that the transition from small to large learning rate becomes sharp once we take the modified large width limit, in the following sense: certain functions of the learning rate become non-analytic at in the limit. This sharp transition bears close resemblance to phase transitions of the kind found in physical systems, such as the transition between the liquid and gaseous phases of water. In particular, our case involves a dynamical system, where the dynamics are governed by the gradient descent equations. These dynamics undergo a phase transition as a function of the learning rate — an external parameter. We point to the logistic map (May, 1976) as a well-known example of a dynamical system that undergoes phase transitions as a function of an external parameter.

e.1 Non-perturbative dynamics

A phase transition is a drastic change in a system’s behavior incurred under a small change in external parameters. Mathematically, it is a non-analyticity in some property of the system as a function of these parameters. For example, consider the property , the curvature of the model at the end of training as a function of the learning rate. In the modified large width limit, is constant for , but not for . Therefore, this function is not analytic at . Notice that this statement is true in the limit but not necessarily at finite width, where the final curvature may be an analytic function of the learning rate even at . It is well known in physics that phase transitions only occur in a limit where the number of dynamical variables (in this case the number of model parameters) is taken to infinity. One immediate consequence of the non-analyticity at is that the large learning rate phase is inaccessible from the small learning rate phase via a perturbative expansion. In other words, we cannot describe all properties of the model for some by doing a Taylor expansion around a point and keeping a finite number of terms.

Dyer & Gur-Ari (2020); Huang & Yau (2019) developed a formalism that allows one to compute finite-width corrections to various properties of deep networks, using a perturbative expansion around the infinite width limit. We have argued that the usual infinite width approximation to the training dynamics is not valid for learning rates above , and that a full analysis must account for large finite-width effects. One may have hoped that including the perturbative finite-width corrections discussed in Dyer & Gur-Ari (2020); Huang & Yau (2019) would allow us to regain analytic control over the dynamics. The results presented here suggest that this is not the case: For , we expect that the perturbative expansion will not provide a good approximation to the gradient descent dynamics at any finite order in inverse width.

e.2 Critical exponents

When the external parameters are close to a phase transition, one often finds that the dynamical properties of the system obey power law behavior. The exponents of these power laws (called critical exponents) are of interest because they are often found to be universal, in the sense that the same set of exponents is often found to describe the phase transitions of completely different physical systems.

Here we consider , the number of steps until convergence, as a function of the learning rate. We will now show that exhibits power-law behavior when is close to . For simplicity we consider the warmup model studied in Section 2. First, suppose that we are below the transition, setting for some small . From the update equation, we see that will converge to some fixed small value after time . Here we assumed that is constant in , which is true as long as is independent of (namely we fix and then take large). Therefore, the convergence time below the transition scales as , and the critical exponent is -1.

Next, suppose that with . Now the update equation reads . This approximation holds early during training, when the curvature updates are small. Initially, will grow until it is of order , at which point the updates to become of order . This will happen in time . Following this, the optimizer will converge. At this point is no longer tuned to be close to the transition, and so the convergence time measured from this point on will not be sensitive to . Therefore, for small the convergence time will be dominated by the early part of training, namely . The critical exponent is again -1. Figure S32 show an empirical verification of this behavior.

(a)
Figure S32: The convergence time diverges when the learning rate is close to the critical value , indicated by the solid green line. The measured exponents (shown in parentheses) are close to the predicted value of -1. Experiment involves the warmup model of Section 2 with width .

Appendix F Additional evidence for linearization in the catapult phase.

Here we present some more detailed evidence for the re-emergence of linear dynamics in the catapult phase. Figure S35 show results for models trained on subsets of MNIST with learning rates . In figure Figure (a)a we see that for a one-hidden-layer fully connected model trained on 512 MNIST images, the performance of the full non-linear model and model linearized after 10 steps track closely. Models evolve as linear models when the NTK is constant. In Figure (b)b we give evidence that as networks become wider, the change in the kernel decreases.

(a)
(b)
Figure S35: Evidence for a return of linear dynamics after . (a,b) Show the same model as in figure 14 with the addition of linearized models at step and . We observe that the linearized model after 10 steps tracks the non-linear performance in the ‘catapult’ phase up to (c) The change in the NTK between steps and steps decreases as the width increases. Here we consider 2-class MNIST with 100 samples per class.