Time Matters in Regularizing Deep Networks: Weight Decay and Data Augmentation Affect Early Learning Dynamics, Matter Little Near Convergence

05/30/2019 ∙ by Aditya Golatkar, et al. ∙ 0

Regularization is typically understood as improving generalization by altering the landscape of local extrema to which the model eventually converges. Deep neural networks (DNNs), however, challenge this view: We show that removing regularization after an initial transient period has little effect on generalization, even if the final loss landscape is the same as if there had been no regularization. In some cases, generalization even improves after interrupting regularization. Conversely, if regularization is applied only after the initial transient, it has no effect on the final solution, whose generalization gap is as bad as if regularization never happened. This suggests that what matters for training deep networks is not just whether or how, but when to regularize. The phenomena we observe are manifest in different datasets (CIFAR-10, CIFAR-100), different architectures (ResNet-18, All-CNN), different regularization methods (weight decay, data augmentation), different learning rate schedules (exponential, piece-wise constant). They collectively suggest that there is a "critical period" for regularizing deep networks that is decisive of the final performance. More analysis should, therefore, focus on the transient rather than asymptotic behavior of learning.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

There is no shortage of literature on what regularizers to use when training deep neural networks and how they affect the loss landscape but, to the best of our knowledge, no work has addressed when

to apply regularization. We test the hypothesis that applying regularization at different epochs of training can yield different outcomes. Our curiosity stems from recent observations suggesting that the early epochs of training are decisive of the outcome of learning with a deep neural network

achille2017critical .

We find that regularization via weight decay or data augmentation has the same effect on generalization when applied only

during the initial epochs of training. Conversely, if regularization is applied only in the latter phase of convergence, it has little effect on the final solution, whose generalization is as bad as if regularization never happened. This suggests that, contrary to classical models, the mechanism by which regularization affects generalization in deep networks is not by changing the landscape of critical points at convergence, but by influencing the early transient of learning. This is unlike convex optimization (linear regression, support vector machines) where the transient is irrelevant.

In short, what matters for training deep networks is not just whether or how, but when to regularize.

In particular, the effect of temporary regularization on the final performance is maximal during an initial “critical period.” This mimics other phenomena affecting the learning process which, albeit temporary, can permanently affect the final outcome if applied at the right time, as observed in a variety of learning systems, from artificial deep neural networks to biological ones. We use the methodology of achille2017critical to regress the most critical epochs for various architectures and datasets.

Specifically, our findings are:

  1. Applying weight decay or data augmentation beyond the initial transient of training does not improve generalization (Figure 1, Left). The transient is decisive of asymptotic performance.

  2. Applying regularization only during the final phases of convergence does not improve, and in some cases degrades generalization. Hence, regularization in deep networks does not work by re-shaping the loss function at convergence (

    Figure 1, Center).

  3. Applying regularization only during a short sliding window shows that its effect is most pronounced during a critical period of few epochs (Figure 1

    , Right). Hence, the analysis of regularization in Deep Learning should focus on the transient, rather than asymptotics.

The explanation for these phenomena is not as simple as the solution being stuck in some local minimum: When turning regularization on or off after the critical period, the value of the weights changes, so the solution moves in the loss landscape. However, test accuracy, hence generalization, does not change. Adding regularization after the critical period does change the loss function, and also changes the final solution, but not for the better. Thus, the role of regularization is not to bias the final solution towards critical points with better generalization. Instead, it is to bias the initial transient towards regions of the loss landscape that contains multiple equivalent solutions with good generalization properties.

In the next section we place our observations in the context of prior related work, then introduce some of the nomenclature and notation (Sect. 3) before describing our experiments in Sect. 4. We discuss the results in Sect. 5.

2 Related Work

There is a considerable volume of work addressing regularization in deep networks, too vast to review here. Most of the efforts are towards analyzing the geometry and topology of the loss landscape at convergence. Work relating the local curvature of the loss around the point of convergence to regularization (“flat minima” hochreiter1997flat ; keskar2016large ; dinh2017sharp ; chaudhari2016entropy ) has been especially influential choromanska2015loss ; li2018visualizing . Other work addresses the topological characteristics of the point of convergence (minima vs. saddles dauphin2014identifying ). (jastrzkebski2017three, ; NIPS2017_6770, )

discuss the effects of the learning rate and batch size on stochastic gradient descent (SGD) dynamics and generalization. At the other end of the spectrum, there is complementary work addressing initialization of deep networks,

(glorot2010understanding, ; henaff16, ). There is limited work addressing the timing of regularization, other than for the scheduling of learning rates (smith2017cyclical, ; loshchilov2016sgdr, ).

Changing the regularizer during training is common practice in many fields, and can be done in a variety of ways, either pre-scheduled – as in homotopy continuation methods mobahi2015theoretical , or in a manner that depends on the state of learning – as in adaptive regularization hong2017adaptive . For example, in variational stereo-view reconstruction, regularization of the reconstruction loss is typically varied during the optimization, starting with high regularization and, ideally, ending with no regularization. This is quite unlike the case of Deep Learning: Stereo is ill-posed, as the object of inference (the disparity field) is infinite-dimensional and not smooth due to occluding boundaries. So, ideally one would not

want to impose regularization, except for wading through the myriad of local mimima due to local self-similarity in images. Imposing regularization all along, however, causes over-smoothing, whereas the ground-truth disparity field is typically discontinuous. So, regularization is introduced initially and then removed to capture fine details. In other words, the ideal loss is not regularized, and regularization is introduced artificially to improve transient performance. In the case of machine learning, regularization is often interpreted as a prior on the solution. Thus, regularization is part of the problem formulation, rather than the mechanics of its solution.

Also related to our work, there have been attempts to interpret the mechanisms of action of certain regularization methods, such as weight decay zhang2018three ; van2017l2 ; loshchilov2017fixing ; hoffer2018norm ; krogh1992simple ; bos1996using , data augmentation vapnik2000vicinal , dropout srivastava2014dropout . It has been pointed out in (zhang2018three, ) that the Gauss-Newton norm correlates with generalization, and with the Fisher Information Matrix (fisher1925theory, ; amari1998natural, ), a measure of the flatness of the minimum, to conclude that the Fisher Information at convergence correlates with generalization. However, there is no causal link proven. In fact, we suggest this correlation may be an epi-phenomenon: Weight decay causes an increase in Fisher information during the transient, which is responsible for generalization (Figure 5), whereas the asymptotic value of the Fisher norm (i.e., sharpness of the minimum) is not causative. In particular, we show that increasing Fisher Information can actually improve generalization.

Figure 1: Critical periods for regularization in DNNs : (Left) Final test accuracy as a function of the epoch in which the regularizer is removed. Applying regularization beyond the initial transient of training (around 100 epochs) produces no appreciable increase in the test accuracy. In some cases, early removal of regularization e.g., at epoch 75 for All-CNN, actually improves generalization. Despite the loss landscape at convergence being un-regularized, the network achieves accuracy comparable to a regularized one. (Center) Final test accuracy as a function of the onset of regularization. Applying regularization after the initial transient changes the convergence point (Fig. 2, B), but does not improve regularization. Thus, regularization does not influence generalization by re-shaping the loss landscape near the eventual solution. Instead, regularization biases the solution towards regions with good generalization properties during the initial transient. Weight decay (blue) shows a more marked time dependency than data augmentation (orange). The dashed line (green) in (Left) and (Center) corresponds to the final accuracy when we regularize throughout the training. (Right)

Sensitivity (change in the final accuracy relative to un-regularized training) as a function of the onset of a 50-epoch regularization window. Initial learning epochs are more sensitive to weight decay compared to the intermediate training epochs for data augmentation. The shape of the sensitivity curve depends on the regularization scheme as well as the network architecture. The error bars indicate thrice the standard deviation across 5 independent trials. For experiments with weight decay (or data augmentation), we apply data augmentation (or weight decay) throughout the training.

3 Preliminaries and notation

Given an observed input (e.g.

, an image) and a random variable

we are trying to infer (e.g., a discrete label), we denote with the output distribution of a deep network parameterized by weights . For discrete , we usually have for some parametric function . Given a dataset , the cross-entropy loss of the network on the dataset is defined as .

When minimizing with stochastic gradient descent (SGD), we update the weights

with an estimate of the gradient computed from a small number of samples (mini-batch). That is,

where is a random subset of indices of size (mini-batch size). In our implementation, weight decay (WD) is equivalent to imposing a penalty to the norm of the weights, so that we minimize the regularized loss .

Data augmentation (DA) expands the training set by choosing a set of random transformations of the data, (e.g., random translations, rotations, reflections of the domain and affine transformations of the range of the images), sampled from a known distribution , to yield .

In our experiments, we choose to be random cropping and horizontal flipping (reflections) of the images; are the CIFAR-10 and CIFAR-100 datasets (krizhevsky2009learning, ), and the class of functions are ResNet-18 (he2016deep, ) and All-CNN (springenberg2014striving, ). For all experiments, unless otherwise noted, we train with SGD with momentum 0.9 and exponentially decaying learning rate with factor per epoch, starting from learning rate (see also Appendix A).

Figure 2: Intermediate application or removal of regularization affects the final solution: (A-C) norm of the weights as a function of the training epoch (corresponding to Figure 1 (Top)). The weights of the network move after application or removal of regularization, which can be seen by the change in their norm. Correlation between the norm of the weights and generalization properties is not as straightforward as lower norm implying better generalization. For instance, (C) applying weight decay only at the beginning (curve 0) reduces the norm only during the critical period, and yields higher norm asymptotically than, for example, curve 25. Yet it has better generalization. This suggests that the having a lower norm mostly help only during the critical period. (D) PCA-projection of the training paths obtained removing weight decay at different times (see Section A.1). Removing WD before the end of the critical period (curves 25, 50) makes the network converge to different regions of the parameter space. Removing WD after the critical period (curves 75 to 200) still sensibly changes the final point (in particular, critical periods are not due the optimization being stuck in a local minimum), but all points lie in a similar area, supporting the Critical Period interpretation of (achille2017critical, ). (E) Same plots, but for DA, which unlike WD does not have a sharp critical period: all training paths converge to a similar area.

4 Experiments

To test the hypothesis that regularization can have different effects when applied at different epochs of training, we perform three kinds of experiments. In the first, we apply regularization up to a certain point, and then switch off the regularizer. In the second, we initially forgo regularization, and switch it on only after a certain number of epochs. In the third, we apply regularization for a short window during the training process. We describe these three experiments in order, before discussing the effect of batch normalization, and analyzing changes in the loss landscape during training using local curvature (Fisher Information).

Regularization interrupted.

We train standard DNN architectures (ResNet-18/All-CNN on CIFAR-10) using weight decay (WD) during the first epochs, then continue without WD. Similarly, we augment the dataset (DA) up to epochs, past which we revert to the original training set. We train both the architectures for 200 epochs. In all cases, the training loss converges to essentially zero for all values of . We then examine the final test accuracy as a function of (Figure 1, Left). We observe that applying regularization beyond the initial transient (around 100 epochs) produces no measurable improvement in generalization (test accuracy). In Figure 3 (Left), we observe similar results for a different data distribution (CIFAR-100). Surprisingly, limiting regularization to the initial learning epochs yields final test accuracy that is as good as that achieved by regularizing to the end, even if the final loss landscapes, and hence the minima encountered at convergence, are different.

It is tempting to ascribe the imperviousness to regularization in the latter epochs of training (Figure 1, Left) to the optimization being stuck in a local minimum. After all, the decreased learning rate, or the shape of the loss around the minimum, could prevent the solution from moving. However, Figure 2 (A, curves 75/100) shows that the norm of the weights changes significantly after switching off the regularizer: the optimization is not stuck. The point of convergence does change, just not in a way that improves test accuracy.

The fact that applying regularization only at the very beginning yields comparable results, suggests that regularization matters not because it alters the shape of the loss function at convergence, reducing convergence to spurious minimizers, but rather because it “directs” the initial phase of training towards regions with multiple extrema with similar generalization properties. Once the network enters such a region, removing regularization causes the solution to move to different extrema, with no appreciable change in test accuracy.

Figure 3: (Top) Critical periods for regularization are independent of the data distribution: We repeat the same experiment as in Figure 1 on CIFAR-100. We observe that the results are consistent with Figure 1. The dashed line (green) in (Left) and (Right) denotes the final accuracy when regularization is applied throughout the training. The dashed line on top corresponds to ResNet-18, while the one below it corresponds to All-CNN. (Bottom) Critical regularization periods with a piecewise constant learning rate schedule: We repeat experiment in Figure 1, but change the learning rate scheduling. Networks trained with piecewise constant learning rate exhibit behavior that is qualitatively similar to the exponentially decaying learning rate. The same experiment with constant learning rate is inconclusive since the network does not converge (see Appendix, Figure 8).

Regularization delayed.

In this experiment, we switch on regularization starting at some epoch , and continue training to convergence. We train the DNNs for 200 epochs, except when regularization is applied late (from epoch 150/175), where we allow the training to continue for an additional 50 epochs to ensure the network’s convergence. Figure 1 (Center) displays the final accuracy as a function of the onset , which shows that there is a “critical period” to perform regularization (around epoch 50), beyond which adding a regularizer yields no benefit.

Absence of regularization can be thought of as a form of learning deficit. The permanent effect of temporary deficits during the early phases of learning has been documented across different tasks and systems, both biological and artificial achille2017critical . Critical periods

thus appear to be fundamental phenomena, not just quirks of biology or the choice of the dataset, architecture, learning rate, or other hyperparameters in deep networks.

In Figure 1 (Top Center), we see that delaying WD by 50 epochs causes a 40% increase in test error, from 5% regularizing all along, to 7% with onset epochs. This is despite the two optimization problems sharing the same loss landscape at convergence. This reinforces the intuition that WD does not improve generalization by modifying the loss function, lest Figure 1 (Center) would show an increase in test accuracy after the onset of regularization.

Here, too, we see that the optimization is not stuck in a local minimum: Figure 2 (B) shows the weights changing even after late onset of regularization. Unlike the previous case, in the absence of regularization, the network enters prematurely into regions with multiple sub-optimal local extrema, seen in the flat part of the curve in Figure 1 (Center).

Note that the magnitude of critical period effects depends on the kind of regularization. Figure 1 (Center) shows that WD exhibits more significant critical period behavior than DA. At convergence, data augmentation is more effective than weight decay. In Figure 3 (Center), we observe critical periods for DNNs trained on CIFAR-100, suggesting that they are independent of the data distribution.

Figure 4: Critical periods for regularization are independent of Batch-Normalization: We repeat the same experiment as in Figure 1, but without Batch-Normalization. The results are largely compatible with previous experiments, suggesting that the effects are not caused by the interaction between batch normalization and regularization. (Left) Notice that, surprisingly, removal of weight decay right after the initial critical period actually improves generalization. (Center) Data augmentation in this setting shows a more marked dependency on timing. (Right) Unlike weight decay which mainly affects initial epochs, data augmentation is critical for the intermediate epochs.

Sliding Window Regularization.

In an effort to regress which phase of learning is most impacted by regularization, we compute the maximum sensitivity against a sliding window of 50 epochs during which WD and DA are applied (Figure 1 Right). The early epochs are the most sensitive, and regularizing for a short 50 epochs yields generalization that is almost as if we had regularized all along. This captures the critical period for regularization. Note that the shape of the sensitivity to critical periods depends on the type of regularization: Data augmentation has essentially the same effect throughout training, whereas weight decay impacts critically only the initial epochs. Similar to the previous experiments, we train the networks for 200 epochs except, when the window onsets late (epoch 125/150/175), where we train for 50 additional epochs after the termination of the regularization window which ensures that the network converges.

Reshaping the loss landscape.

regularization is classically understood as trading classification loss against the norm of the parameters (weights), which is a simple proxy for model complexity. The effects of such a tradeoff on generalization are established in classical models such as linear regression or support-vector machines. However, DNNs need not trade classification accuracy for the norm of the weights, as evident from the fact that the training error can always reach zero regardless of the amount of regularization. Current explanations (goodfellow2016deep, ) are based on asymptotic convergence properties, that is, on the effect of regularization on the loss landscape and the minima to which the optimization converges. In fact, for learning algorithm that reduces to a convex problem, this is the only possible effect. However, Figure 1 shows that for DNNs, the critical role of regularization is to change the dynamics of the initial transient, which biases the model towards regions with good generalization. This can be seen in Figure 1 (Left), where despite halting regularization after 100 epochs, thus letting the model converge in the un-regularized loss landscape, the network achieves around 5% test error. Also in Figure 1 (Top Center), despite applying regularization after 50 epochs, thus converging in the regularized loss landscape, the DNN generalizes poorly (around 7% error). Thus, while there is reshaping of the loss landscape at convergence, this is not the mechanism by which deep networks achieve generalization. It is commonly believed that a smaller norm of the weights at convergence implies better generalization (wilson2017marginal, ; neyshabur2017pac, ). Our experiments show no such causation: Slight changes of the training algorithm can yield solutions with larger norm that generalize better (Figure 2, (C) & Figure 1, Top right: onset epoch 0 vs 25/50).

Effect of Batch-Normalization.

One would expect regularization to be ineffective when used in conjunction with Batch-Normalization (BN) (ioffe2015batch, ), since BN makes the network’s output invariant to changes in the norm of its weights. However, it has been observed that, in practice, WD improves generalization even, or especially, when used with BN. Several authors (zhang2018three, ; hoffer2018norm, ; van2017l2, ) have observed that WD increases the effective learning rate , where is the learning rate at epoch and is the squared-norm of weights at epoch , by decreasing the weight norm, which increases the effective gradient noise, which promotes generalization (neelakantan2015adding, ; jastrzkebski2017three, ; hoffer2017train, ). However, in the sliding window experiment for regularization, we observe that networks with regularization applied around epoch 50, despite having smaller weight norm (Figure 2 (C), compare onset epoch 50 to onset epoch 0) and thus a higher effective learning rate, generalize poorly (Figure 1 Top Right: onset epoch 50 has a mean test accuracy increase of 0.24% compared to 1.92% for onset epoch 0). We interpret the latter (onset epoch 0) as having a higher effective learning rate during the critical period, while for the former (onset epoch 50) it was past its critical period. Thus, previous observations in the literature should be considered with more nuance: we contend that an increased effective learning rate induces generalization only insofar as it modifies the dynamics during the critical period, reinforcing the importance of studying when to regularize, in addition to how. In Figure 9 in the Appendix, we show that the initial effective learning rate correlates better with generalization (Pearson coefficient 0.96, p-value < 0.001) than the final effective learning rate (Pearson coefficient 0.85, p-value < 0.001).

We repeat the experiments in Figure 1 without Batch-Normalization (Figure 4). We observe a similar result, suggesting that the positive effect of weight decay during the transient cannot be due solely to the use of batch normalization and an increased effective learning rate.

Figure 5: Fisher Information and generalization: (Left) Trace of the Fisher Information Matrix (FIM) as a function of the training epochs. Weight decay increases the peak of the FIM during the transient, with negligible effect on the final value (see left plot when regularization is terminated beyond 100 epochs). The FIM trace is proportional to the norm of the gradients of the cross-entropy loss. FIM trace plots for delayed application/sliding window can be found in the Appendix (Figure 7) (Center) & (Right): Peak vs. final Fisher Information correlate differently with test accuracy: Each point in the plot is a ResNet-18 trained on CIFAR-10 achieving 100% training accuracy. Surprisingly, the maximum value of the FIM trace correlates far better with generalization than its final value, which is instead related to the local curvature of the loss landscape (“flat minima”). The Pearson correlation coefficient for the peak FIM trace is 0.92 (p-value < 0.001) compared to 0.29 (p-value > 0.05) for the final FIM trace.

Weight decay, Fisher and flatness.

Generalization for DNNs is often correlated with the flatness of the minima to which the network converges during training (hochreiter1997flat, ; li2018visualizing, ; keskar2016large, ; chaudhari2016entropy, ), where solutions corresponding to flatter minima seem to generalize better. In order to understand if the effect of regularization is to increase the flatness at convergence, we use the Fisher Information Matrix (FIM), which is a semi-definite approximation of the Hessian of the loss function (martens2014new, ) and thus a measure of the curvature of the loss landscape. We recall that the Fisher Information Matrix is defined as:

In Figure 5 (Left) we plot the trace of FIM against the final accuracy. Notice that, contrary to our expectations, weight decay increases the FIM norm, and hence curvature of the convergence point, but this still leads to better generalization. Moreover, the effect of weight decay on the curvature is more marked during the transient (Figure 5).

This suggests that the peak curvature reached during the transient, rather than its final value, may correlate with the effectiveness of regularization. To test this hypothesis, we consider the DNNs trained in Figure 1 (Top) and plot the relationship between peak/final FIM value and test accuracy in Figure 5 (Center, Right): Indeed, while the peak value of the FIM strongly correlates with the final test performance (Pearson coefficient 0.92, p-value < 0.001), the final value of the FIM norm does not (Pearson 0.29, p-value > 0.05). We report plots of the Fisher Norm for delayed/sliding window application of WD in the Appendix (Figure 7).

The FIM was also used to study critical period for changes in the data distribution in achille2017critical

, which however in their setting observe an anti-correlation between Fisher and generalization. Indeed, the relationship between the flatness of the convergence point and generalization established in the literature emerges as rather complex, and we may hypothesize a more complex bias-variance trade-off like a connection between the two, where either too low or too high curvature can be detrimental.

Jacobian norm.

(zhang2018three, ) relates the effect of regularization to the norm of the Gauss-Newton matrix, , where is the Jacobian of w.r.t , which in turn relates to norm of the networks input-output Jacobian. The Fisher Information Matrix is indeed related to the GN matrix (more precisely, it coincides with the generalized Gauss-Newton matrix, , where is the Hessian of w.r.t. ). However, while the GN norm remains approximately constant during training, we found the changes of the Fisher-Norm during training (and in particular its peak) to be informative of the critical period for regularization, allowing for a more detailed analysis.

5 Discussion and Conclusions

We have tested the hypothesis that there exists a “critical period” for regularization in training deep neural networks. Unlike classical machine learning, where regularization trades off the training error in the loss being minimized, DNNs are not subject to this trade-off: One can train a model with sufficient capacity to zero training error regardless of the norm constraint imposed on the weights. Yet, weight decay works, even in the case where it seems it should not, for instance when the network is invariant to the scale of the weights, e.g., in the presence of batch normalization. We believe the reason is that regularization affects the early epochs of training by biasing the solution towards regions that have good generalization properties. Once there, there are many local extrema to which the optimization can converge. Which to is unimportant: Turning the regularizer on or off changes the loss function, and the optimizer moves accordingly, but test error is unaffected, at least for the variety of architectures, training sets, and learning rates we tested.

We believe that there are universal phenomena at play, and what we observe is not the byproduct of accidental choices of training set, architecture, and hyperparameters: One can see the absence of regularization as a learning deficit, and it has been known for decades that deficits that interfere with the early phases of learning, or critical periods, have irreversible effects, from humans to songbirds and, as recently shown by achille2017critical , deep neural networks. Critical periods depend on the type of deficits, the task, the species or architecture. We have shown results for two datasets, two architectures, two learning rate schedules.

While our exploration is by no means exhaustive, it supports the point that considerably more effort should be devoted to the analysis of the transient dynamics of Deep Learning. To this date, most of the theoretical work in Deep Learning focuses on the asymptotics and the properties of the minimum at convergence.

Our hypothesis also stands when considering the interaction with other forms of generalized regularization, such as batch normalization, and explains why weight decay still works, even though batch normalization makes the activations invariant to the norm of the weights, which challenges previous explanation of the mechanisms of action of weight decay.

We note that there is no trade-off between regularization and loss in DNNs, and the effects of regularization cannot (solely) be to change the shape of the loss landscape (WD), or to change the variety of gradient noise (DA) preventing the network from converging to some local minimizers, as without regularization in the end, everything works. The main effect of regularization ought to be on the transient dynamics before convergence.

At present, there is no viable theory on transient regularization. The empirical results we present should be a call to arms for theoreticians interested in understanding Deep Learning. A possible interpretation advanced by achille2017critical is to interpret critical periods as the (irreversible) crossing of narrow bottlenecks in the loss landscape. Increasing the noise – either by increasing the effective learning rate (WD) or by adding variety to the samples (DA) – may help the network cross the right bottlenecks while avoiding those leading to irreversibly sub-optimal solutions. If this is the case, can better regularizers be designed for this task?

References

Appendix A Details of the Experiments

We use the standard ResNet-18 architecture [11] in all the experiments, stated unless otherwise. We train all the networks using SGD (momentum = 0.9) with a batch-size of 128 for 200 epochs, except when regularization is applied late during the training, where we train for an extra 50 epochs, to ensure the convergence of the networks. For experiments with ResNet-18, we use an initial learning rate of 0.1, with learning rate decay factor of 0.97 per epoch and a weight decay coefficient of 0.0005. In the piece-wise constant learning rate experiment, we use an initial learning rate of 0.1 and decay it by a factor of 0.1 every 60 epochs. While in the constant learning experiment we fix it to 0.001. For the All-CNN [31] experiments, we use an initial learning rate of 0.05 with a weight decay coefficient of 0.001 and a learning rate decay of 0.97 per epoch. For All-CNN, we do not use dropout and instead we add Batch-Normalization to all layers.

a.1 Path Plotting

We follow the method proposed by [23] to plot the training trajectories of the DNNs for varying duration of regularization (Figure 2 A). More precisely, we combine the weights of the network (stored at regular intervals during the training) for different duration of regularization into a single matrix and then project them on the first two principal components of .

Figure 6: PCA-projection of the training paths for All-CNN on CIFAR-10.

Appendix B Additional Experiments

Figure 7: Trace of FIM as a function of training epochs for delayed and sliding window application of weight decay for ResNet-18 on CIFAR-10.
Figure 8: We repeat the experiments from Figure 1 and replace the exponential learning rate with a constant learning rate (lr = 0.001). However, those are inconclusive since the error achievable with a constant rate does not saturate performance on the datasets we tested them, with the architectures we used. Test error is almost twice than what was achieved with an exponential or piecewise constant learning rate. It is possible that careful tuning of the constant could achieve baseline performance and also display critical period phenomena.
Figure 9: Effect of learning rate on generalization: (Top) The effective learning rate , where is the learning rate at epoch and is the norm of weights at epoch . [36, 14, 33], shown as a function of the training epoch for the experiments from Figure 1 (Top blue) / Figure 2. (Left) and (Center) The effective learning rate is higher throughout the training (which includes the critical periods) for models which generalize better (Figure 1 Top blue) (Right) When we apply weight decay in a sliding window, the effective learning rate for the dark blue curve (weight decay from 0-50) is higher during the critical period but lower afterwards, compared to light blue curves (weight decay from 25-75, 50-100). However, the dark blue curve yields a higher final accuracy (test accuracy increase of 1.92%) compared to light blue curves (test accuracy increase of 1.39%, 0.24% respectively), despite having a lower effective learning rate for the majority of the training. This suggests that having a higher effective learning rate during the critical periods is more conducive to good generalization when using weight decay. (Bottom) We test this further by plotting the relationship between the effective learning rate and final test accuracy. Mean initial effective learning rate is the average effective learning rate over the first 100 epoch, while mean final effective learning rate is the average computed over the final 100 epochs. The initial effective learning rate correlates better with generalization (Pearson correlation coefficient of 0.96 (p-value < 0.001)) compared to the final effective learning rate (Pearson correlation coefficient of 0.85 (p-value < 0.001)).
Figure 10: norm of the weights as a function of the training epoch. We observe similar results (compared to Figure 2) for different architectures (ResNet-18 and All-CNN) and different datasets (CIFAR-10 and CIFAR-100).