On the training dynamics of deep networks with L_2 regularization

06/15/2020
by   Aitor Lewkowycz, et al.
Google
0

We study the role of L_2 regularization in deep learning, and uncover simple relations between the performance of the model, the L_2 coefficient, the learning rate, and the number of training steps. These empirical relations hold when the network is overparameterized. They can be used to predict the optimal regularization parameter of a given model. In addition, based on these observations we propose a dynamical schedule for the regularization parameter that improves performance and speeds up training. We test these proposals in modern image classification settings. Finally, we show that these empirical relations can be understood theoretically in the context of infinitely wide networks. We derive the gradient flow dynamics of such networks, and compare the role of L_2 regularization in this context with that of linear models.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 20

03/04/2020

The large learning rate phase of deep learning: the catapult mechanism

The choice of initial learning rate can have a profound effect on the pe...
10/31/2020

DL-Reg: A Deep Learning Regularization Technique using Linear Regression

Regularization plays a vital role in the context of deep learning by pre...
08/16/2021

Towards Efficient and Data Agnostic Image Classification Training Pipeline for Embedded Systems

Nowadays deep learning-based methods have achieved a remarkable progress...
08/03/2020

Implicit Regularization in Deep Learning: A View from Function Space

We approach the problem of implicit regularization in deep learning from...
07/20/2021

Edge of chaos as a guiding principle for modern neural network training

The success of deep neural networks in real-world problems has prompted ...
01/07/2020

Discovering Nonlinear Relations with Minimum Predictive Information Regularization

Identifying the underlying directional relations from observational time...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Machine learning models are commonly trained with regularization. This involves adding the term

to the loss function, where

is the vector of model parameters and

is a hyperparameter. In some cases, the theoretical motivation for using this type of regularization is clear. For example, in the context of linear regression,

regularization increases the bias of the learned parameters while reducing their variance across instantiations of the training data; in other words, it is a manifestation of the bias-variance tradeoff. In statistical learning theory, a “hard” variant of

regularization, in which one imposes the constraint , is often employed when deriving generalization bounds.

In deep learning, the use of regularization is prevalent and often leads to improved performance in practical settings [Hinton, 1986], although the theoretical motivation for its use is less clear. Indeed, it well known that overparameterized models overfit far less than one may expect [Zhang et al., 2016], and so the classical bias-variance tradeoff picture does not apply [Neyshabur et al., 2017, Belkin et al., 2018, Geiger et al., 2020]

. There is growing understanding that this is caused, at least in part, by the (implicit) regularization properties of stochastic gradient descent (SGD)

[Soudry et al., 2017]. The goal of this paper is to improve our understanding of the role of regularization in deep learning.

1.1 Our contribution

We study the role of regularization when training over-parameterized deep networks, taken here to mean networks that can achieve training accuracy 1 when trained with SGD. Specifically, we consider the early stopping performance of a model, namely the maximum test accuracy a model achieves during training, as a function of the parameter . We make the following observations based on the experimental results presented in the paper.

  1. The number of SGD steps until a model achieves maximum performance is , where is a coefficient that depends on the data, the architecture, and all other hyperparameters. We find that this relationship holds across a wide range of values.

  2. If we train with a fixed number of steps, model performance peaks at a certain value of the parameter. However, if we train for a number of steps proportional to then performance improves with decreasing . In such a setup, performance becomes independent of for sufficiently small . Furthermore, performance with a small, non-zero is often better than performance without any regularization.

Figure (a)a shows the performance of an overparameterized network as a function of the parameter . When the model is trained with a fixed steps budget, performance is maximized at one value of . However, when the training time is proportional to , performance improves and approaches a constant value as we decrease .

(a) independence
(b) Optimal prediction
(c) Auto schedule
Figure 1: Wide ResNet 28-10 trained on CIFAR-10 with momentum and data augmentation. (a) Final test accuracy vs. the parameter

. When the network is trained for a fixed amount of epochs, optimal performance is achieved at a certain value of

. But when trained for a time proportional to , performance plateaus and remains constant down to the lowest values of tested. This experiment includes a learning rate schedule. (b) Test accuracy vs. training epochs for predicted optimal parameter compared with the tuned parameter. (c) Training curves with our dynamical schedule, compared with a tuned, constant parameter.

As we demonstrate in the experimental section, these observations hold for a variety of training setups which include different architectures, data sets, and optimization algorithms. In particular, when training with vanilla SGD (without momentum), we observe that the number of steps until maximum performance depends on the learning rate and on as . The performance achieved after this many steps depends only weakly on the choice of learning rate.

Applications.

We present two practical applications of these observations. First, we propose a simple way to predict the optimal value of the parameter, based on a cheap measurement of the coefficient . Figure (b)b compares the performance of models trained with our predicted parameter with that of models trained with a tuned parameter. In this realistic setting, we find that our predicted parameter leads to performance that is within 0.4% of the tuned performance on CIFAR-10, at a cost that is marginally higher than a single training run. As shown below, we also find that the predicted parameter is consistently within an order of magnitude of the optimal, tuned value.

As a second application we propose Auto, a dynamical schedule for the parameter. The idea is that large values achieve worse performance but also lead to faster training. Therefore, in order to speed up training one can start with a large value and decay it during training (this is similar to the intuition behind learning rate schedules). In Figure (c)c we compare the performance of a model trained with Auto against that of a tuned but constant parameter, and find that Auto outperforms the tuned model both in speed and in performance.

Theoretical contribution.

Finally, we turn to a theoretical investigation of the empirical observations made above. As a first attempt at explaining these effects, consider the following argument based on the loss landscape. For overparameterized networks, the Hessian spectrum evolves rapidly during training [Sagun et al., 2017, Gur-Ari et al., 2018, Ghorbani et al., 2019]. After a small number of training steps with no

regularization, the minimum eigenvalue is found to be close to zero. In the presence of a small

term, we therefore expect that the minimal eigenvalue will be approximately . In quadratic optimization, the convergence time is inversely proportional to the smallest eigenvalue of the Hessian.111 In linear regression with regularization, optimization is controlled by a linear kernel , where is the sample matrix and

is the identity matrix in parameter space. Optimization in each kernel eigendirection evolves as

where is the corresponding eigenvalue. When and the model is overparameterized, the lowest eigenvalue of the kernel will be typically close to , and therefore the time to convergence will be proportional to .
Based on this intuition, we may then expect that convergence time will be proportional to . The fact that performance is roughly constant for sufficiently small can then be explained if overfitting can be mostly attributed to optimization in the very low curvature directions [Rahaman et al., 2018, Wadia et al., 2020]. Now, our empirical finding is that the time it takes the network to reach maximum accuracy is proportional to . In some cases this is the same as the convergence time, but in other cases (see for example Figure (a)a) we find that performance decays after peaking and so convergence happens later. Therefore, the loss landscape-based explanation above is not sufficient to fully explain the effect.

To gain a better theoretical understanding, we consider the setup of an infinitely wide neural network trained using gradient flow. We focus on networks with positive-homogeneous activations, which include deep networks with ReLU activations, fully-connected or convolutional layers, and other common components. By analyzing the gradient flow update equations of such networks, we are able to show that the performance peaks at a time of order

and deteriorates thereafter. This is in contrast to the performance of linear models with regularization, where no such peak is evident. These results are consistent with our empirical observations, and may help shed light on the underlying causes of these effects.

According to known infinite width theory, in the absence of explicit regularization, the kernel that controls network training is constant [Jacot et al., 2018]. Our analysis extends the known results on infinitely wide network optimization, and indicates that the kernel decays in a predictable way in the presence of regularization. We hope that this analysis will shed further light on the observed performance gap between infinitely wide networks which are under good theoretical control, and the networks trained in practical settings [Arora et al., 2019, Novak et al., 2019, Wei et al., 2018, Lewkowycz et al., 2020].

Related works.

regularization in the presence of batch-normalization

[Ioffe and Szegedy, 2015] has been studied in [van Laarhoven, 2017, Hoffer et al., 2018, Zhang et al., 2018]. These papers discussed how the effect of on scale invariant models is merely of having an effective learning rate (and no ). This was made precise in Li and Arora [2019] where they showed that this effective learning rate is (at small learning rates). Our theoretical analysis of large width networks will have has the same behaviour when the network is scale invariant.

2 Experiments

Performance and time scales.

We now turn to an empirical study of networks trained with regularization. In this section we present results for a fully-connected network trained on MNIST, a Wide ResNet [Zagoruyko and Komodakis, 2016] trained on CIFAR-10, and CNNs trained on CIFAR-10. The experimental details are in SM A. The empirical findings discussed in section 1.1 hold across this variety of overparameterized setups.

Figure 2 presents experimental results on fully-connected and Wide ResNet networks. Figure 3 presents experiments conducted on CNNs. We find that the number of steps until optimal performance is achieved (defined here as the minimum time required to be within of the maximum test accuracy) scales as , as discussed in Section 1.1. Our experiments span decades of (larger won’t train at all and smaller would take too long to train). Moreover, when we evolved the networks until they have reached optimal performance, the maximum test accuracy for smaller parameters did not get worse. We compare this against the performance of a model trained with a fixed number of epochs, reporting the maximum performance achieved during training. In this case, we find that reducing beyond a certain value does hurt performance.

While here we consider the simplified set up of vanilla SGD and no data augmentation, our observations also hold in the presence of momentum and data augmentation, see SM C.2 for more experiments. We would like to emphasize again that while the smaller models can reach the same test accuracy as its larger counterparts, models like WRN28-10 on CIFAR-10 need to be trained for a considerably larger number of epochs to achieve this.222The longer experiments ran for 5000 epochs while one usually trains these models for 300 epochs.

(a) FC
(b) FC
(c) FC
(d) WRN
(e) WRN
(f) WRN
Figure 2: Sweep over and illustrating how smaller ’s require longer times to achieve the same performance. In the left, middle plots, the learning rates are logarithmically spaced between the values displayed in the legend, the specific values are in the SM A. Left: Epochs to maximum test accuracy (within ), Middle: Maximum test accuracy (the line denotes the maximum test accuracy achieved among all learning rates), Right: Maximum test accuracy for a fixed time budget. (a,b,c) Fully connected 3-hidden layer neural network evaluated in MNIST samples, evolved for . in (c). (d,e,f) A Wide Residual Network 28-10 trained on CIFAR-10 without data augmentation, evolved for . In (f), . The line was evolved for longer than the smallest but there is still a gap.
(a) CNN No BN
(b) CNN No BN
(c) CNN No BN
(d) CNN BN
(e) CNN BN
(f) CNN BN
Figure 3: CNNs trained with and without batch-norm with learning rate . Presented results follow the same format as Figure 2.

Learning rate schedules.

So far we considered training setups that do not include learning rate schedules. Figure (a)a shows the results of training a Wide ResNet on CIFAR-10 with a learning rate schedule, momentum, and data augmentation. The schedule was determined as follows. Given a total number of epochs , the learning rate is decayed by a factor of at epochs . We compare training with a fixed against training with . We find that training with a fixed budget leads to an optimal value of , below which performance degrades. On the other hand, training with leads to improved performance at smaller , consistent with our previous observations.

3 Applications

We now discuss two practical applications of the empirical observations made in the previous section.

Optimal .

We observed that the time to reach maximum test accuracy is proportional to , which we can express as . This relationship continues to hold empirically even for large values of . When is large, the network attains its (significantly degraded) maximum performance after a relatively short amount of training time. We can therefore measure the value of by training the network with a large parameter until its performance peaks, at a fraction of the cost of a normal training run.

Based on our empirical observations, given a training budget we predict that the optimal parameter can be approximated by . This is the smallest parameter such that model performance will peak within training time . Figure (b)b shows the result of testing this prediction in a realistic setting: a Wide ResNet trained on CIFAR-10 with momentum , learning rate and data augmentation. The model is first trained with a large parameter for 2 epochs in order to measure , and we find , see figure (a)a. We then compare the tuned value of against our prediction for training budgets spanning close to two orders of magnitude, and find excellent agreement: the predicted ’s have a performance which is rather close to the optimal one. Furthermore, the tuned values are always within an order of magnitude of our predictions see figure (b)b.

We have not studied this thoroughly in the presence of learning rate schedules, but intuitively, one wants to evolve for as long as possible with the initial large learning rate and it seems natural to think that the optimal corresponds to the smallest which makes it to the (constant learning rate) peak before the first decay. Using the learning rate schedule of section 2, this implies . In the setup of epochs of figure (a)a, this predicts (test accuracy ), which is close to the optimal at (with test accuracy ). While the predicted

is close, this has substantial implications for performance, but we can nevertheless use it as a reference for hyperparameter tuning. We leave a more precise estimation of the optimal

for future work.

(a)
(b)
Figure 4: Wide ResNet trained with momentum and data augmentation. (a) We train the model with a large parameter for 2 epochs and measure the coefficient , representing the approximate point along the axis where accuracy is maximized. (b) Optimal (tuned) values compared with the theoretical prediction. The error bars represent the spread of values that achieve within of the optimal test accuracy.

Auto: Automatic schedules.

We now turn to another application, based on the observation that models trained with larger parameters reach their peak performance faster. It is therefore plausible that one can speed up the training process by starting with a large parameter, and decaying it according to some schedule. Here we propose to choose the schedule dynamically by decaying the parameter when performance begins to deteriorate. See SM E for further details.

Auto  is a straightforward implementation of this idea: We begin training with a large parameter, , and we decay it by a factor of 10 if either the empirical loss (the training loss without the term) or the training error increases. To improve stability, immediately after decaying we impose a refractory period during which the parameter cannot decay again. Figure (c)c compares this algorithm against the model with the optimal parameter. We find that Auto trains significantly faster and achieves superior performance.

In other experiments we have found that this algorithm does not yield improved results when the training procedure includes a learning rate schedule. We leave the attempt to effectively combine learning rate schedules with schedules to future work.

4 Theoretical results

We now turn to a theoretical analysis of the training trajectory of networks trained with regularization. We focus on infinitely wide networks with positively-homogeneous activations. Consider a network function with model parameter . The network initialized using NTK parameterization [Jacot et al., 2018]: the initial parameters are sampled i.i.d. from . The model parameters are trained using gradient flow with loss , where is the empirical loss, is the sample loss, and is the training set of size .

We say that the network function is -homogeneous if for any . As an example, a fully-connected network with layers and ReLU or linear activations is

-homogeneous. Networks made out of convolutional, max-pooling or batch-normalization layers are also

-homogeneous.333Batch normalization is often implemented with an parameter meant to prevent numerical instabilities. Such networks are only approximately homogeneous. See Li and Arora [2019] for a discussion of networks with homogeneous activations.

Jacot et al. [2018] showed that when an infinitely wide, fully-connected network is trained using gradient flow (and without regularization), its network function obeys the differential equation , where is the gradient flow time and is the Neural Tangent Kernel (NTK).

Dyer and Gur-Ari [2020] presented a conjecture that allows one to derive the large width asymptotic behavior of the network function, the Neural Tangent Kernel, as well as of combinations involving higher-order derivatives of the network function. The conjecture was shown to hold for networks with polynomial activations [Aitken and Gur-Ari, 2020]

, and has been verified empirically for commonly used activation functions. In what follows, we will assume the validity of this conjecture. The following is our main theoretical result.

Theorem 1.

Consider a -homogeneous network, and assume that the network obeys the correlation function conjecture of Dyer and Gur-Ari [2020]. In the infinite width limit, the network function and the kernel evolve according to the following equations at training time .

(1)
(2)

The proof hinges on the following equation, which holds for -homogeneous functions: . This equation allows us to show that the only effect of regularization at infinite width is to introduce simple terms proportional to in the gradient flow update equations for both the function and the kernel.

We refer the reader to the SM for the proof. We mention in passing that the case corresponds to a scaling-invariant network function which was studied in Li and Arora [2019]. In this case, training with term is equivalent to training with an exponentially increasing learning rate.

For commonly used loss functions, and for , we expect that the solution obeys . We will prove that this holds for MSE loss, but let us first discuss the intuition behind this statement. At late times the exponent in front of the first term in (1) decays to zero, leaving the approximate equation and leading to an exponential decay of the function to zero. Both the explicit exponent in the equation, and the approximate late time exponential decay, suggest that this decay occurs at a time . Therefore, we expect that the minimum of the empirical loss to occur at a time proportional to , after which the bare loss will increase because the function is decaying to zero. We observe this behaviour empirically for wide fully-connected networks and for Wide ResNet in the SM.

Furthermore, notice that if we include the dependence, the decay time scale is approximately . Models with a higher degree of homogeneity (for example deeper fully-connected networks) will converge faster.

We now focus on MSE loss and solve the gradient flow equation (1) for this case.

Theorem 2.

Let the sample loss be , and assume that . Suppose that, at initialization, the kernel

has eigenvectors

with corresponding eigenvalues . Then during gradient flow, the eigenvalues evolve as while the eigenvectors are static. Suppose we treat as a vector defined on the training set. Then each mode of the function, , evolves independently as

(3)

Here, . At late times, on the training set.

The properties of the solution (3) depend on whether the ratio is greater than or smaller than 1, as illustrated in Figure 5. When , the function approaches the label mode at a time that is of order . This behavior is the same as that of a linear model, and represents ordinary learning. Later, at a time of order the mode decays to zero as described above; this late time decay is not present in the linear model. Next, when the mode decays to zero at a time of order , which is the same behavior as that of a linear model.

(a)
(b)
(c)
Figure 5: (a) The theoretical evolution of an infinitely wide 2-layer network with regularization (, ). Two modes are shown, representing small and large ratios . (b) The same, for a linear model (

). (c) Training loss vs. time for a wide network trained on a subset of MNIST with even/odd labels, with

. We compare the kernel evolution with gradient descent for a 2-layer ReLU network. The blue and orange curves are the theoretical predictions when setting and in the solution (3), respectively. The green curve is the result of a numerical experiment where we train a 2-layer ReLU network with gradient descent. We attribute the difference between the green and orange curves at late times to finite width effects.

Generalization of wide networks with .

It is interesting to understand how regularization affects the generalization performance of wide networks. This is well understood for the case of linear models, which correspond to in our notation, to be an instance of the bias-variance tradeoff. In this case, gradient flow converges to the function , where are the training samples, is are the labels, and is any input. When , the solution is highly sensitive to small perturbations in the inputs that affect the flat modes of the kernel, because the kernel is inverted in the solution. In other words, the solution has high variance. Choosing reduces variance by lifting the low kernel eigenvalues and reducing sensitivity on small perturbations, at the cost of biasing the model parameters toward zero. While a linear model is the prototypical case of , the previous late time solution is valid for any

model. In particular, any homogeneous model that has batch-normalization in the pre-logit layer will satisfy this property. It would be interesting to understand the generalization properties of such these models based on these solutions.

Let us now return to infinitely wide networks. These behave like linear models with a fixed kernel when , but as we have seen when the kernel decays exponentially. Nevertheless, we argue that this decay is slow enough such that the training dynamics follow that of the linear model (obtained by setting in eq. (1)) up until a time of order , when the function begins decaying to zero. This can be seen in Figure (c)c, which compares the training curves of a linear and a 2-layer network using the same kernel. We see that the agreement extends until the linear model is almost fully trained, at which point the 2-layer model begins deteriorating due to the late time decay. Therefore, if we stop training the 2-layer network at the loss minimum, we end up with a trained and regularized model. It would be interesting to understand how the generalization properties of this model with decaying kernel differ from those of the linear model.

Finite-width network.

Theorem 1 holds in the strict large width, fixed limit for NTK parameterization. At large but finite width we expect (1) to be a good description of the training trajectory at early times, until the kernel and function because small enough such that the finite-width corrections become non-negligible. Our experimental results imply that this approximation remains good until after the minimum in the loss, but that at late times the function will not decay to zero; see for example Figure (c)c. See the SM for further discussion for the case of deep linear models. We reserve a more careful study of these finite width effects to future work.

5 Discussion

In this work we consider the effect of regularization on overparameterized networks. We make two empirical observations: (1) The time it takes the network to reach peak performance is proportional to , the regularization parameter, and (2) the performance reached in this way is independent of when is not too large. We find that these observations hold for a variety of overparameterized training setups; see the SM for some examples where they do not hold.

Motivated by these observations, we suggest two practical applications. The first is a simple method for predicting the optimal parameter at a given training budget. The performance obtained using this prediction is close to that of a tuned parameter, at a fraction of the training cost. The second is Auto, an automatic parameter schedule. In our experiments, this method leads to better performance and faster training when compared against training with a tuned parameter. We find that these proposals work well when training with a constant learning rate; we leave an extension of these methods to networks trained with learning rate schedules to future work.

We attempt to understand the empirical observations by analyzing the training trajectory of infinitely wide networks trained with regularization. We derive the differential equations governing this trajectory, and solve them explicitly for MSE loss. The solution reproduces the observation that the time to peak performance is of order . This is due to an effect that is specific to deep networks, and is not present in linear models: during training, the kernel (which is constant for linear models) decays exponentially due to the term.

Acknowledgments

The authors would like to thank Yasaman Bahri, Ethan Dyer, Jaehoon Lee and Behnam Neyshabur for useful discussions. We specially thank Behnam for encouraging us to use the scaling to come up with an schedule.

References

Supplementary material

Appendix A Experimental details

We are using JAX [Bradbury et al., 2018].

All the models except for section C.4 have been trained with Softmax loss normalized as , where is the number of classes and are one-hot targets.

All experiments that compare different learning rates and parameters use the same seed for the weights at initialization and we consider only one such initialization (unless otherwise stated) although we have not seen much variance in the phenomena described. We will be using standard normalization with LeCun initialization .

Batch Norm: we are using JAX’s Stax implementation of Batch Norm which doesn’t keep track of training batch statistics for test mode evaluation.

Data augmentation: denotes flip, crop and mixup.

We consider 3 different networks:

  • WRN: Wide Resnet 28-10 [Zagoruyko and Komodakis, 2016] with has batch-normalization and batch size (per device batch size of ), . Trained on CIFAR-10.

  • FC: Fully connected, three hidden layers with width and ReLU activation and batch size ,. Trained on 512 samples of MNIST.

  • CNN: We use the following architecture: . denotes a fully-connected layer with output dimension .

    denote convolutional layers with ’SAME’ or ’VALID’ padding and

    filters, respectively; all convolutional layers use filters. MaxPool((2,2), ’VALID’) performs max pooling with ’VALID’ padding and a (2,2) window size. Act denotes the activation: ‘(Batch-Norm ) ReLU ’ depending on whether we use Batch-Normalization or not. We use batch size , . Trained on CIFAR-10 without data augmentation.

The WRN experiments are run on v3-8 TPUs and the rest on P100 GPUs.

Here we describe the particularities of each figure. Whenever we report performance for a given time budget, we report the maximum performance during training which does not have to happen at the end of training.

Figure (a)a WRN trained using momentum, data augmentation and a learning rate schedule where and then decays at , where is the number of epochs. We compare training with a fixed training budget, against training with . This was chosen so that .

Figures (b)b, 4, S4. WRN trained using momentum, data augmentation and for . The predicted performance of (b)b was computed at for respectively.

Figures (c)c,S9. WRN trained using momentum, data augmentation and , evolved for epochs. The Auto  algorithm is written explicitly in SM E and make measurements every 10 steps.

Figure 2a,b,c. FC trained using SGD epochs with learning rate and regularizations , . The model was evolved for epochs which is more than the smallest .

Figure 2d,e,f. WRN trained using SGD without data augmentation for epochs for the following hyperparameters , as long as the total number of epochs was epochs (except for which was evolved for epochs). We evolved the models for epochs.

Figure S7. Fully connected depth and width trained on CIFAR-10 with batch size , and cross-entropy loss.

Figure S8.

ResNet-50 trained on ImageNet with batch size 8192, using the implementation in

https://github.com/tensorflow/tpu.

Figure 5 (a,b) plots in equation 3 with (for 2-layer) and (for linear), for different values of and . (c) The empirical kernel of a layer ReLU network of width 5,000 was evaluated on -samples of MNIST with even/odd labels. The linear, layer curves come from evolving equation 1 with the previous kernel and setting , respectively . The experimental curve comes from training the 2-layer ReLU network with width and learning rate (the time is ).

Figure 3a,b,c. CNN without BN trained using SGD for epochs for the following hyperparameters . with was evolved for epochs.

Figure 3d,e,f. CNN with BN trained using SGD for a time for the following hyperparameters . The model with was evolved for epochs, which goes beyond where all the other ’s have peaked.

Figure S5. FC trained using SGD and MSE loss for epochs and the following hyperparameters . For , it was trained for epochs.

Rest of SM figures. Small modifications of experiments in previous figures, specified explicitly in captions.

Appendix B Details of theoretical results

In this section we prove the main theoretical results. We begin with two technical lemmas that apply to -homogeneous network functions, namely network functions that obey the equation for any input , parameter vector , and .

Lemma 1.

Let be a -homogeneous network function. Then .

Proof.

We prove by induction on . For , we differentiate the homogeneity equation with respect to .

(S1)

For ,

(S2)
(S3)
(S4)

Lemma 2.

Consider a -homogeneous network function , and a correlation function

that involves derivative tensors of

. Let be a loss function, where is the training set and is the sample loss. We train the network using gradient flow on this loss function, where the update rule is . If the conjecture of Dyer and Gur-Ari [2020] holds, and if the conjecture implies that where is the width, then as well.

Proof.

The cluster graph of has vertices; we denote by () the number of even (odd) components in the graph (we refer the reader to Dyer and Gur-Ari [2020] for a definition of the cluster graph and other terminology used in this proof). By assumption, .

We can write the correlation function as , where is a mean over initializations, is shorthand for a derivative tensor of the form for some , and the sum is over all the free indices of the derivative tensors. Then , where . To bound the asymptotic behavior of it is therefore enough to bound the asymptotics of each .

Notice that each is obtained from by replacing a derivative tensor with inside the expectation value. Let us see how this affects the cluster graph. For any derivative tensor , we have

(S5)
(S6)

In the last step we used lemma 1. We now compute how replacing the derivative tensor by each of the terms in the last line of (S6) affects the cluster graph, and specifically the combination .

The second term is equal to the original derivative tensor up to an -independent factor, and therefore does not change the asymptotic behavior. For the first term, the factor leaves and invariant so it will not affect the asymptotic behavior. The additional factor increases the number of vertices in the cluster graph by 1, namely it changes . In addition, it increases the size of the graph component of by 1, therefore either turning an even sized component into an odd sized one or vice versa. In terms of the number of components, it means we have . Therefore, . Therefore, it follows from the conjecture that for all , and then . ∎

We now turn to the proof of Theorems 1 and 2.

Proof (Theorem 1).

A straightforward calculation leads to the following gradient flow equations for the network function and kernel.

(S7)
(S8)
(S9)

Here . In deriving these we used the gradient flow update and Lemma 1. It was shown in Dyer and Gur-Ari [2020] that . If then follows from Lemma 2 that for all , where the expectation value is taken at initialization. Furthermore, the results of Dyer and Gur-Ari [2020] imply that and therefore .444See appendix D in Dyer and Gur-Ari [2020]. In the strict infinite width limit we can therefore neglect the contribution in the following equation, and write

(S10)

The solution of this set of equations (labelled by ) is the same as for the any-time equation , and the solution is given by

(S11)

Proof (Theorem 2).

The evolution of the kernel eigenvalues, and the fact that its eigenvectors do not evolve, follow immediately from (2). The solution (3) can be verified directly by plugging it into (1) after projecting the equation on the eigenvector . Finally, the fact that the function decays to zero at late times can be seen from (3) as follows. From the assumption , notice that when . Therefore, we can bound each mode as follows.

(S12)

Therefore, . ∎

For completeness we now write down the solution (3) in functional form, for in the training set.

(S13)
(S14)
(S15)

Here, is a matrix exponential, and is a matrix of size .

b.1 Deep linear fixed point analysis

Let’s consider a deep linear model , with for NTK normalization and for standard normalization. The gradient descent equation will be:

(S16)

where we defined:

(S17)

Evolution will stop when the fixed point ( ) is reached:

(S18)
(S19)

Furthermore note that:

(S20)

Now, we would like to show that, at the fixed point

(S21)

This follows from induction:

(S22)
(S23)

Which has a trivial solution if . Let’s assume that it is non-trivial. If we contract the previous equation with we get:

(S24)

We can finally set and simplify:

(S25)

At large , to obtain a non-trivial fixed point should be finite as . From the previous equation, this implies that . In NTK normalization , for , we will get a non-trivial () fixed point. This also implies that these corrections will be important for in standard normalization since there . Note that if , we expect that we get the solution .

We can be very explicit if we consider and one sample with for MSE loss. The fixed point has a logit:

(S26)

which is only different from for fixed .

Appendix C More on experiments

c.1 Training accuracy = 1 scale

We can see how the time it takes to rech training accuracy depends very mildly on , and for small enough learning rates it scales like .

(a) FC
(b) WRN
Figure S1: Training accuracy vs learning rate the setup of figure 2. The specific values for the sweeps are in A.

c.2 More WRN experiments

We can also study the previous in the presence of momentum and data augmentation. These are the experiments that we used in figure 4, evolved until convergence. As discussed before, in the presence of momentum the depends on , so we will fixed the learning rate .

(a)
(b)
(c)
Figure S2: WRN 28-10 with momentum and data augmentation trained with a fixed learning rate.

c.3 More on optimal

Here we give more details about the optimal prediction of section 3. Figure S3 illustrates how performance changes as a function of for different time budgets with the predicted marked with a dashed line. If one wanted to be more precise, from figure 2 we see that while the scaling works across ’s, generally lower ’s have a scaling times higher than the larger ’s. One could try to get a more precise prediction by multiplying by two, , see figure S4. We reserve a more detailed analysis of this more fine-grained prescription for the future.

(a)
Figure S3: WRN trained with momentum and data augmentation. Given a number of epochs, we compare the maximum test accuracy as a function of and compare it with the smallest with the predicted one. We see that this gives us the optimal within an order of magnitude.
(a)
(b)
(c)
Figure S4: Same as previous figure with .

c.4 MSE and the catapult effect

In Lewkowycz et al. [2020] it was argued that, in the absence of when training a network with SGD and MSE loss, high learning rates have a rather different final accuracy, due to the fact that at early times they undergo the "catapult effect". However, this seems to contradict with our story around 1.1 where we argue that performance doesn’t depend strongly on . In figure S5, we can see how, while when stopped at training accuracy , performance depends strongly on the learning rate, this is no longer the case in the presence of if we evolve it for . We also show how the training MSE loss has a minimum after which it increases.

(a)
(b)
(c)
(d)
Figure S5: MSE and catapult effect: we see how even if there is a strong dependence of the test accuracy on the learning rate when the training accuracy is , this dependence flattens out when evolved until convergence in the presence of . The specific values for the sweeps are in A.

c.5 Dynamics of loss and accuracy

In figure S6 we illustrate the training curves of the experiments we have discussed in the main text and SM.

(a) FC
(b) WRN
(c) CNN no BN
(d) FC
(e) WRN
(f) CNN no BN
(g) FC
(h) WRN
(i) CNN no BN
(j) FC
(k) WRN
(l) CNN no BN
Figure S6: