How to decay your learning rate

03/23/2021
by   Aitor Lewkowycz, et al.
0

Complex learning rate schedules have become an integral part of deep learning. We find empirically that common fine-tuned schedules decay the learning rate after the weight norm bounces. This leads to the proposal of ABEL: an automatic scheduler which decays the learning rate by keeping track of the weight norm. ABEL's performance matches that of tuned schedules and is more robust with respect to its parameters. Through extensive experiments in vision, NLP, and RL, we show that if the weight norm does not bounce, we can simplify schedules even further with no loss in performance. In such cases, a complex schedule has similar performance to a constant learning rate with a decay at the end of training.

READ FULL TEXT VIEW PDF
02/17/2021

Training Aware Sigmoidal Optimizer

Proper optimization of deep neural networks is an open research question...
10/16/2019

An Exponential Learning Rate Schedule for Deep Learning

Intriguing empirical evidence exists that deep learning can work well wi...
06/20/2022

When Does Re-initialization Work?

Re-initializing a neural network during training has been observed to im...
04/27/2019

Forget the Learning Rate, Decay Loss

In the usual deep neural network optimization process, the learning rate...
05/07/2021

Network Pruning That Matters: A Case Study on Retraining Variants

Network pruning is an effective method to reduce the computational expen...
03/29/2021

FixNorm: Dissecting Weight Decay for Training Deep Neural Networks

Weight decay is a widely used technique for training Deep Neural Network...
10/01/2020

Bag of Tricks for Adversarial Training

Adversarial training (AT) is one of the most effective strategies for pr...

1 Introduction

Learning rate schedules play a crucial role in modern deep learning. They were originally proposed with the goal of reducing noise to ensure the convergence of SGD in convex optimization [Bottou, 1998]

. A variety of tuned schedules are often used, some of the most common being step-wise, linear or cosine decay. Each schedule has its own advantages and disadvantages and they all require hyperparameter tuning. Given this heterogeneity, it would be desirable to have a coherent picture of when schedules are useful and to come up with good schedules with minimal tuning.

While we do not expect dependence on the initial learning rate in convex optimization, large learning rates behave quite different from small learning rates in deep learning [Li et al., 2020a, Lewkowycz et al., 2020]. We expect the situation to be similar for learning rate schedules, the non-convex landscape makes it desirable to reduce the learning rate as we evolve our models. The goal of this paper is to study empirically (a) in which situations schedules are beneficial and (b) when during training one should decay the learning rate. Given that stochastic gradients are used in deep learning, we will use a simple schedule as our baseline: a constant learning rate with one decay close to the end of training. Training with this smaller learning rate for a short time is expected to reduce the noise without letting the model explore the landscape too much. This is corroborated by the fact that the minimum test error often occurs almost immediately after decaying the learning rate. Part of the paper focuses on comparing the simple schedule with standard complex schedules used in the literature, studying the situations in which these complex schedules are advantageous. We find that complex schedules are considerably helpful whenever the weight norm bounces, which happens often in the usual, optimal setups. In the presence of a bouncing weight norm, we propose an automatic scheduler which performs as well as fine tuned schedules.

1.1 Our contribution

The goal of the paper is to study the benefits of learning rate schedules and when the learning rate should be decayed. We focus on the dynamics of the weight norm , defined as the sum over the squared -norm of the weight in each layer  .

We make the following observations based on extensive empirical results which include Vision, NLP and RL tasks.

  1. We observe that tuned step-wise schedules decay the learning rate after the weight norm bouncesaaaWe define bouncing by the monotonic decrease of the weight norm followed by a monotonic increase which occurs for a fixed learning rate as can be seen in figure 3., when the weight norm starts to converge. Towards the end of training, a last decay decreases the noise. See figure 3.

  2. We propose an Automatic, Bouncing into Equilibration Learning rate scheduler (ABEL). ABEL is competitive with fine tuned schedules and needs less tuning (see table 1 and discussion in section 2).

  3. A bouncing weight norm seems necessary for non-trivial learning rate schedules to outperform the simple decay baseline. regularization is required for the weight norm to bounce and in its absence (which is common in NLP and RL) we don’t see a benefit from complex schedules. This is explored in detail in section 3 and the results are summarized in table 2.

(a)

Resnet-50 on ImageNet

(b)

WRN28-10 on CIFAR-100

Figure 3: Evolution of the weight norm when training with step-wise decay (decay times marked by black dashed lines). The learning rate is decayed after the weight norm bounces, towards its convergence. Models were evolved in optimal settings whose tuning did not use the weight norm as input.

The origin of weight bouncing.

There is a simple heuristic for why the weight norm bounces. Without

regularization, the weight norm usually increases for learning rate values used in practice. In the presence of regularization, we expect the weight norm to decrease initially. As the weight norm decrease slows down, the natural tendency for the weight norm to increase in the absence of regularization will eventually dominate. This is explained in more detail in section 4.

Weight bouncing and performance.

Generally speaking, weight bouncing occurs when we have non-zero regularization and large enough learning rates. While

regularization is crucial in vision tasks, it is not found to be that beneficial in NLP or Reinforcement Learning tasks (for example

Vaswani et al. [2017] does not use ). If the weight norm does not bounce, ABEL yields the "simple" learning rate schedule that we expect naively from the noise reduction picture: decay the learning rate once towards the end of training. We confirm that, in the absence of bouncing, such a simple schedule is competitive with more complicated ones across a variety of tasks and architectures, see table 2

. We also see that the well known advantage of momentum compared with Adam for image classification in ImageNet (see

Agarwal et al. [2020] for example) seems to disappear in the absence of bouncing, when we turn off regularization. Weight norm bouncing thus seems empirically a necessary condition for non-trivial schedules to provide a benefit, but it is not sufficient: we observe that when the datasets are easy enough that simple schedules can get zero training error, schedules do not make a difference.

1.2 Related works

We do not know of any explicit discussion of weight bouncing in the literature. The dynamics of deep networks with regularization has drawn recent attention, see for example van Laarhoven [2017], Lewkowycz and Gur-Ari [2020], Li et al. [2020b], Wan et al. [2020], Kunin et al. [2020]. The recent paper Wan et al. [2020] observes that the weight norm equilibration is a dynamical process (the weight norm still changes even if the equilibrium conditions are approximately satisfied) which happens soon after the bounce.

The classic justification for schedules comes from reducing the noise in a quadratic potential [Bottou, 1998]. Different schedules do not provide an advantage in convex optimization unless there is a substantial mismatch between train and test landscape [Nakkiran, 2020], however this is not the effect that we are observing in our setup: when schedules are beneficial, their training performance is substantially different, see for example figure S21. The work of Li et al. [2020a] could be helpful to understand better the theory behind our phenomena, although it is not clear to us how their mechanism can generalize to multiple decays (or other complex schedules). There has been lots of empirical work trying to learn schedules/optimizers see for example Maclaurin et al. [2015], Li and Malik [2016], Li et al. [2017], Wichrowska et al. [2017], Rolinek and Martius [2018], Qi et al. [2020]. Our approach does not have an outer loop: the learning rate is decayed depending on the weight norm, which is conceptually similar to the ReduceLROnPlateau scheduler, where the learning rate is decayed when the loss plateaus which is present in most deep learning libraries. However, ReduceLROnPlateau does not perform well across our tasks. A couple of papers which thoroughly compare the performance of learning rate schedules are Shallue et al. [2019], Kaplan et al. [2020].

2 An automatic learning rate schedule based on the weight norm

ABEL and its motivation

From the two setups in figure 3 it seems that optimal schedules tend to decay the learning rate after bouncing, when the weight norm growth slows down. We can use this observation to propose ABEL (Automatic Bouncing into Equilibration Learning rate scheduler): a schedule which implements this behaviour automatically, see algorithm 1

. In words, we keep track of the changes in weight norm between subsequent epochs,

. When the sign of flips, it necessarily means that it has gone through a local minimum: because initially) if , then is a minimum: . After this, the weight norm grows and slows down until at some point is noise dominated. In this regime, will become negative, which we will take as our decaying condition. In order to reduce SGD noise, near the end of training we decay it one last time. In practice we do this last decay at around of the total training time and as we can see in the SM B.1, this particular value does not really matter.

  if  then
     if reached_minimum then
        learning_rate = decay_factor learning_rate
        reached_minimum=False
     else
        reached_minimum=True
     end if
  end if
  if t = last_decay_epoch then
     learning_rate = decay_factor learning_rate
  end if
Algorithm 1 ABEL Scheduler

Algorithm 1 is an implementation of the idea with the base learning rate and the decay factor as the main hyperparameters. While alternative implementations could be more explicit about the weight norm slowing down after reaching the minimum, they would likely require more hyperparameters.

We have decided to focus on the total weight norm, but one might ask what happens with the layer-wise weight norm. In SM B.2, we study the evolution of the weight norm in different layers. We focus on the layers which contribute the most to the weight norm (these layers account for of the weight norm). We see that most layers exhibit the same bouncing plus slowing down pattern as the total weight norm and this happens at roughly the same time scale.

Performance comparison across setups.

We have run a variety of experiments comparing learning rate schedules with ABEL, see table 1 for a summary and figure 10 for some selected training curves ( rest of the training curves are in SM C.1). We use ABEL without hyperparameter tuning it: we are plugging the base learning rate and the decay factor of the reference step-wise schedule (these reference decay factors are for CIFAR and for other datasets). We see that ABEL is competitive with existing fine tuned learning rate schedules and slightly outperforms step-wise decay on ImageNet. Cosine often beats step-wise schedules, however as we will discuss shortly, such decay has several drawbacks.

Setup Test error
Dataset Architecture Step-wise ABEL Cosine
ImageNet Resnet-50 24.0 23.8 23.2
CIFAR-10 WRN 28-10 3.7 3.8 3.5
CIFAR-10 VGG-16 - 7.1 6.9
CIFAR-100 WRN 28-10 18.5 18.7 18.4
CIFAR-100 PyramidNet - 10.8 10.8
SVHN WRN 16-8 1.77 1.79 1.89
Table 1: Comparison of test error at the end of training for different setups and learning rate schedules. We see that ABEL has very similar performance to the fine tuned step-wise schedule without the need to tune when to decay. ABEL uses the baseline values of learning rates and decay factors and we have not fine tuned these. The cells denoted by - refer to setups for which we do not have reference step-wise decays. The experimental details can be found in the SM.
(a) Resnet-50 on ImageNet
(b) WRN28-10 on CIFAR-100
(c) Resnet-50 on ImageNet
(d) WRN28-10 on CIFAR-100
(e) Resnet-50 on ImageNet
(f) WRN28-10 on CIFAR-100
Figure 10: Training curves of two experiments from table 1.

Robustness of ABEL

ABEL is quite robust with respect to the learning rate and the decay factor. Since it depends implicitly on the natural time scales of the system, it will adapt to when to decay the learning rate. We can illustrate this by repeating the ImageNet experiment with different base learning rates or decay factors. The results are shown in figure 13.

We would like to highlight the mild dependence of performance in the learning rate: if the learning rate is too high, the weight norm will bounce faster and ABEL will adapt to this by quickly decaying the learning rate. This can be seen quite clearly in the learning rate training curves, see SM C.3.

ABEL also has the ‘last_decay_epoch’ hyperparameter, which determines when to perform the last decay in order to reduce noise. Performance depends very weakly on this hyperparameter (see SM more) and for all setups in table 1 we have chosen it to be at of the total training time. The most natural way to think about this would be to run ABEL for a fixed amount of time and after decay the learning rate for a small number of epochs in order to get the performance with less SGD noise.

Comparison of ABEL with other schedules

It is very natural to compare ABEL with step-wise decay. Step-wise decay is complicated to use in new settings because on top of the base learning rate and the decay factor, one has to determine when to decay the learning rate. ABEL, takes care of the ‘when’ automatically without hurting performance. Because when to decay depends strongly on the system and its current hyperparameters, ABEL is much more robust to the choices of base learning rate and decay factor.

A second class of schedules are those which depend explicitly in the number of training epochs (), like cosine or linear decay. This strongly determines the decay profile: with cosine decay, the learning rate will not decay by a factor of with respect to its initial value until of training! Having as a determining hyperparameter is problematic: it takes a long time for these schedules to have comparable error rates to step-wise decays, as can be seen in figures 10, S43. This implies that until very late in training one can not tell whether is too short, in which case there is no straightforward way to resume training (if we want to evolve the model with the same decay for a longer time, we have to start training from the beginning). This is part of the reason why large models in NLP and vision use schedules which can be easily resumed like rsqrt decay [Vaswani et al., 2017], "clipped" cosine decay [Kaplan et al., 2020, Brown et al., 2020] or exponential decay [Tan and Le, 2020]. In contrast, for ABEL the learning rate at any given time is independent of the total training budget ( while there is the last_decay_epoch parameter, it can easily be evolved for longer if we load the model before the last decay).

We have decided to compare ABEL with the previous two schedules because they are the most commonly used ones. There are a lot of automatic/learnt learning rate schedules (or optimizers), see Maclaurin et al. [2015], Li and Malik [2016], Li et al. [2017], Wichrowska et al. [2017], Yaida [2018], Rolinek and Martius [2018], Qi et al. [2020] and to our knowledge most of them require either significant change in the code (like the addition of non-trivial measurements) or outer loops and also add hyperparameter of their own, so these are never completely hyperparameter free. Compared with these algorithms ABEL is simple, interpretable (it can be easily compared with fine tuned step-wise decays) and performs as well as tuned schedules. It is also quite robust compared with other automatic methods because it relies in the weight norm which is mostly noise free through training (compared with other batched quantities like gradients or losses).

An algorithm similar in simplicity and interpretability is ReduceLROnPlateau

which is one of the basic optimizers of PyTorch or TensorFlow and decays the learning rate whenever the loss equilibrates. We train a Resnet-50 ImageNet model and a WRN 28-10 CIFAR-10 model with this algorithm , see SM

B.3 for details. We use the default hyperparameters and for the ImageNet experiment, the learning rate does not decay at all, yielding a test error of . For CIFAR-10, ReduceLROnPlateau does fairly well, test error of , however the learning rate decays without bound rather fast. These two experiments suggest that ReduceLROnPlateau can not really compete with the schedules described above.

(a)
(b)
Figure 13: ResNet-50 trained on ImageNet for different learning rates and decay factors. (a) ABEL beats others schedules when using non-optimal learning rates. At learning rate , only ABEL converges. (b) ABEL is robust with respect to changes in the decay factor, its performance does not depend too much on the decay factor because it adjusts the number of decays accordingly. Note: when the decay factor is , epochs is too short of a time for ABEL to adapt properly: the weight norm is still is bouncing, so we evolved that point for epochs. If evolved for epochs like the other points, the standard decay performance does not change much, but ABEL has an error closer to .

ABEL does not require a fixed train budget

From the empirical studies, the drop in the test error after decaying the learning rate is upper bounded by the previous drops in the test error, a reason for this is that this drop can be attributed to a reduction of the SGD noise and smaller learning rates have less SGD noise. This provides an automatic way of prescribing the train budget: if the improvement of accuracy after a decay is smaller than some threshold, exit training by after a small number of epochs (to process the last decay). This approach does not have a manual decay at the end of training. Such a training setup would not possible for cosine/linear decay by construction since they depend on the training budget. This seems hard for step-wise decay since there is no way to predetermine how to decay the learning rate automatically.

3 Schedules and performance in the absence of a bouncing weight norm

In this section, we study settings where the weight norm does not bounce to understand the impact of learning rate schedules on performance. Setups without regularization are the most common situation with no bouncing, these setups often present a monotonically increasing weight norm. It is not clear to us what characteristics of a task make regularization beneficial but it seems that Vision benefits considerably more from it than NLP or RL.

We conduct an extensive set of experiments in the realms of vision, NLP and RL and the results are summarized in table 2. In these experiments, we compare complex learning rate schedules with a simple schedule where the model is evolved with a constant learning rate and decayed once towards the end of training. This simple schedule mainly reduces noise: the error decreases considerably immediately after decaying and it does not change much afterwards (it often increases). Across these experiments, we observe that complicated learning rate schedules are not significantly better than the simple ones. For a couple of tasks (like ALBERT finetuning or WRN on CIFAR-100), complex schedules are slightly better () than the simple decay but this small advantage is nothing compared with the substantial advantage that schedules have in vision tasks with . Another situation where there is no bouncing weight norm is for small learning rates, for example VGG-16 with learning rate , in such case there is also no benefit from using complex schedules, see SM B.5 for more details. Note that in this paper we are using regularization and weight decay interchangeably: what matters is that there is explicit weight regularization. These experiments also show that the well known advantage of momentum versus Adam for vision tasks is only significant in the presence of . In the absence of a benefit from regularization/ weight decay it seems like Adam is a better optimizer, Agarwal et al. [2020] suggested that this is because it can adjusts the learning rate of each layer appropriately and it would be interesting to understand whether there is any connection between that and bouncing.

These experiments have a growing weight norm as can be seen in SM C.2. While the weight norm does not have to be always increasing in the absence of regularization, this is a function of the learning rate (see section 4.1) , and the learning rates used in practice exhibit this property. Homogeneous networks with cross entropy loss will have an increasing weight norm at late times, see Lyu and Li [2020]. Even if a simple schedule is competitive this does not imply that other features of convex optimization like the independence of performance in the learning rate carry over. We repeat the CIFAR-100 experiments for a fixed small learning rate of (the same as the final learning rate for the simple schedule) and the error with is while with is , we see that while there is a performance gap between a small and large learning rates, this gap is much smaller if there is no bouncing (difference in error rate of for vs for ). For a fair comparison with small learning rates, we evolved these experiments for times longer than the large learning rates, but this did not give any benefit.

While the experiments presented in table 2 do not have regularization, some NLP architectures like Devlin et al. [2019], Brown et al. [2020] have weight decays of respectively. We tried adding weight decay to our translation models and while performance did not change substantially, we were not able to get a bouncing weight norm.

The effect of different learning rate schedules in NLP was also studied thoroughly in appendix D.6 of Kaplan et al. [2020] with the similar conclusion that as long as the learning rate is not small and is decayed near the end of training, performance stays roughly the same.

The presence of a bouncing weight norm does not guarantee that schedules are beneficial.

From this section, a bouncing weight norm seems to be a necessary condition for learning rate schedules to matter, but it is not a sufficient condition. Learning rate schedules seem only advantageous if the training task is hard enough. In our experience, if the training data can be memorized with a simple learning rate schedule before the weight norm has bounced, then more complex schedules are not useful. This can be seen by removing data augmentation in our Wide Resnet CIFAR experiments, see figure 16. In the presence of data augmentation, simple schedules can not reach training error even when evolved for epochs, see SM.

(a) Wide Resnet on CIFAR-10
(b) Wide Resnet on CIFAR-10
Figure 16: Wide Resnet on CIFAR-10 without data augmentation evolved for epochs (left) and epochs (right). In this setup, the weight norm bounces at around epochs. a) Both schedules reach training error and their performance is the same (error of ). b) If we evolve the model for only epochs, both schedules can still get training error without a weight norm bounce, we think this is the reason why there is no performance difference in a).
Setup Performance for different schedules
Type Task and metric Architecture Complex Decay Simple Decay
NLP EN-DE, BLEU Transformer 29.0 28.9
EN-FR, BLEU Transformer 43.0 43.0
GLUE, Average score ALBERT finetuning 83.1 82.9
RL Qbert, Score PPO 1750 1850
Seaquest, Score PPO 21.0 20.7
Pong, Score PPO 22300 23000
Vision ImageNet, Test Error Resnet-50 28.1 27.8
ImageNet, Test Error Resnet-50 + Adam 28.2 28.9
CIFAR-10, Test Error Wide Resnet 28-10 5.0 5.0
CIFAR-100, Test Error Wide Resnet 28-10 21.8 22.1
ImageNet, Test Error Resnet-50 23.2 28.5
Vision ImageNet, Test Error Resnet-50 + Adam 24.7 26.0
(has bounce) CIFAR-10, Test Error Wide Resnet 28-10 3.5 4.9
CIFAR-100, Test Error Wide Resnet 28-10 17.8 22.2
Table 2: Comparison of performance between a simple learning rate decay and a "complex" decay among tasks: "complex" means cosine decay for vision tasks and linear decay for NLP and RL. For NLP and RL tasks higher metrics imply better performance, while for vision tasks, lower error denotes better performance. None of these tasks (except for the vision task with used as a reference) have weight norm bouncing nor an advantage from non-simple schedules. We have averaged the RL tasks over runs and their difference is compatible with noise. See S1 for the individual GLUE scores, as it is common we have omitted the problematic WNLI.

4 Understanding weight norm bouncing

In this section, we will pursue some first steps towards understanding the mechanism behind the phenomena that we found empirically in the previous sections.

4.1 Intuition behind bouncing behaviour

We can build intuition about the dynamics of the weight norm by studying its dynamics under SGD updates:

(1)

where are the learning rate and regularization coefficient, is the gradient with respect to the loss (in the absence of the term) and we have used that empirically . This equation holds layer by layer, see SM for more details about it. In the absence of regularization, for large enough learning rates (), this suggests that the weight norm will be increasing.

Equation 1 can be further simplified for scale invariant networks, which satisfy , see for example van Laarhoven [2017] bbbScale invariant networks are defined by network functions which are independent of the weight norm: . The weight norm of these functions still affects its dynamics [Li and Arora, 2019].. In the absence of such term, we see that the updates of the weight norm are determined by the relative values of the gradient and weight norm. If or is very small, the weight norm updates will have a fixed sign and thus there will not be bouncing. More generally, we expect that in the initial stages of training, the weight norm is large and its dynamics are dominated by the decay term. As it shrinks, the relative value of the gradient norm term becomes larger and it seems natural that at some point, it will dominate, making the weight norm bounce. This is also studied in Wan et al. [2020], where it is shown that after the bounce, the two terms in equation 1 are the same order and the weight norm "dynamically equilibrates" (although it can not stay constant because the gradient norm changes with time). While we expect the to be non-zero in our setups, only layers which are not scale invariant would contribute to this term and roughly any layer before a BatchNorm layer is scale invariant so we expect this term to be smaller than the other two.

In our experience, the only necessary condition for a model to have a bouncing weight norm is that is has regularization (or weight decay) and the learning rate is large enough. We expect the previous intuition to apply to other optimizers with weight decay. Empirically, we have seen that different optimizers, losses and batch sizes can have a bouncing weight norm.

4.2 Towards understanding the benefits of bouncing and schedules

While it is still unclear why bouncing is correlated with the benefit of schedules, we would like to point in some directions which could help provide a more complete picture.

To better understand this phenomenon it would be useful to distill its properties and find the simplest model that captures it. The bouncing of the weight norm appears to be generic, as long as we have regularization and the learning rate is big enough. We believe that learning rate schedules being only advantageous for hard tasks (as we discussed in section 3) is the principal roadblock to find theoretically tractable models of this phenomena.

For bouncing setups, decaying the learning rate when the weight norm is equilibrating allows the weight decay term in equation 1 to dominate, causing the weight norm to bounce again. However, from equation 1, in the absence of decaying the learning rate can only slow down the weight norm equilibration process and this implies that the weights change more slowly, see SM. It seems like the combination of weight bouncing and decaying the learning rate might be beneficial because it allows the model to explore a larger portion of the landscape. Exploring this direction further might yield better insights to this phenomenon, perhaps building on the results of Wan et al. [2020], Kunin et al. [2020].

We now conduct several experiments with models with bouncing weight norms in order to understand better its properties.

The disadvantage of decaying too early or too late.

Waiting for the weight norm to bounce seems key to get good performance. Decaying too late might be harmful because the weight norm does not have enough time to bounce again, but it is not clear if it is bad by itself. In this section, we run a VGG-5 experiment on CIFAR-100 for epochs and we decay the learning rate once (by a factor of 10) at different times and compare the minimum test error, see fig (a)a. We see that decaying too early significantly hurts performance and the best time to decay is after the weight norm has started slowing down its growth, before it is fully equilibrated. We can compare this sweep over decay times with ABEL: ABEL would be equivalent to a simple schedule decayed at the time marked by the red line (here there is no benefit from decaying it again at the end of training). Given the limitation of the experiment, we can not conclude that decaying to late is hurtful and from the success of cosine decay we expect it is not bad.

Dependence on initialization scale.

One could wonder if the bounce would disappear if we change the initialization of the weights so that the initial weight norm is smaller than the minimum of the bounce with the original normalization. We studied this in figure (b)b (see SM for more details) and we see how even for very small initialization scales, there is a bouncing weight norm. If the initialization scale is too small, the bouncing weight norm disappears and the performance gets significantly degraded.

(a)
(b)
Figure 19: Experiments exploring features of weight bouncing. a) Minimum test error for VGG5 models decayed at different times (all are evolved for epochs) and weight norm for fixed learning rate. Models with the best performance are decayed after the bounce, soon before the weight norm dramatically slows down its growth. b) WRN 28-10 models trained with different weight initialization scales () and cosine decay. The weight norm keeps bouncing even if the initialization scale is small and when it stops bouncing, performance is degraded.

5 Conclusions

In this work we have studied the connections between learning rate schedules and the weight norm. We have made the empirical observation that a bouncing weight norm is a necessary condition for complex learning rate schedules to be beneficial, and we have observed that the step-wise schedules tend to decay the learning rates when the weight norm equilibrates after bouncing. We have checked these observations across architectures and datasets and have proposed ABEL: a learning rate scheduler which automatically decays the learning rate depending on the weight norm, performs as well as fine tuned schedules and is more robust than standard schedules with respect to its initial learning rate and the decay factor. In the absence of weight bouncing, complex schedules do not seem to matter too much.

Now we would like to briefly mention on some important points and future directions:

Practical implications.

In vision tasks with regularization, learning rate schedules have a substantial impact on performance and tuning schedules or using ABEL should be beneficial. In other setups, training with a constant learning rate and decay it at the end of training should not hurt performance and might be a preferable, simpler method.

ABEL’s hyperparameters.

The main hyperparameters of ABEL are the base learning rate and decay factor: while our schedule is not hyperparameter free, ABEL is more robust than other schedules to these parameters because of its adaptive nature.

Warmup.

Our discussion focuses on learning rate decays and we have not studied the effects of warmup. This is an interesting problem but seems unrelated to bouncing weight norms and the benefit of learning rate schedules at late times. In this way, whenever warmup is beneficial and widely used (such as large batch Imagenet or Transformer models) we added it on top of our schedule.

Comparison between ABEL and other schedules.

We have mainly compared ABEL with step-wise decays and cosine decay. It can easily be compared with step-wise decay and is objectively better because it does not require to fine tune when to decay the learning rate. Compared with cosine decay, the learning rate in ABEL does not depend explicitly on the total training budget and it is thus preferred for situations where we do not know for how long we want to train or might want to pick up training where it was left. Compared with other automatic methods, it does not require an outer loop nor adding complicated measurements to training, it only depends on the weight norm which is pretty stable through training (compared with quantities like the training loss or gradients).

The role of simple schedules is to reduce noise.

We used simple schedules as a baseline to reduce SGD noise. This picture is justified by the fact that the most dramatic drop in test error for such schedules is always immediately after the decay, as can be seen in the training curves of the SM.

Memory and compute considerations.

ABEL only depends on the current and past weight norm (two scalars) so it does not add any significant compute nor memory cost.

Weight norm bounce with Adam.

Despite of Adam having an implicit schedule, it still exhibits weight norm bouncing as can be seen in SM .

What is the behaviour of the layerwise weight norm?

As discussed previously and in the SM B.2, most layers exhibit the same pattern as the total weight norm.

Understanding the source of the generalization advantage of learning rate schedules.

It would be nice to understand if the bouncing of the weight norm is a proxy for some other phenomena. While we have tried tracking other simple quantities, the weight norm seems the best predictor for when to decay the learning rate. In order to make theoretical progress the two significant roadblocks that we identify are that this phenomena requires large learning rates and hard datasets, both of which are complicated to study theoretically.

Acknowledgments

The authors would like to thank Anders Andreassen, Yasaman Bahri, Ethan Dyer, Orhan Firat, Pierre Foret, Guy Gur-Ari, Jaehoon Lee, Behnam Neyshabur and Vinay Ramasesh for useful discussions.

References

  • Agarwal et al. [2020] Naman Agarwal, Rohan Anil, Elad Hazan, Tomer Koren, and Cyril Zhang. Disentangling adaptive gradient methods from learning rates, 2020.
  • Bottou [1998] Léon Bottou. Online learning and stochastic approximations, 1998.
  • Bradbury et al. [2018] James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, and Skye Wanderman-Milne. JAX: composable transformations of Python+NumPy programs. 2018. URL http://github.com/google/jax.
  • Brown et al. [2020] Tom B. Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh, Daniel M. Ziegler, Jeffrey Wu, Clemens Winter, Christopher Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish, Alec Radford, Ilya Sutskever, and Dario Amodei. Language models are few-shot learners, 2020.
  • Devlin et al. [2019] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding, 2019.
  • Foret et al. [2020] Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization, 2020.
  • Kaplan et al. [2020] Jared Kaplan, Sam McCandlish, Tom Henighan, Tom B. Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. Scaling laws for neural language models, 2020.
  • Kunin et al. [2020] Daniel Kunin, Javier Sagastuy-Brena, Surya Ganguli, Daniel L. K. Yamins, and Hidenori Tanaka. Neural mechanics: Symmetry and broken conservation laws in deep learning dynamics, 2020.
  • Lewkowycz and Gur-Ari [2020] Aitor Lewkowycz and Guy Gur-Ari. On the training dynamics of deep networks with regularization, 2020.
  • Lewkowycz et al. [2020] Aitor Lewkowycz, Yasaman Bahri, Ethan Dyer, Jascha Sohl-Dickstein, and Guy Gur-Ari. The large learning rate phase of deep learning: the catapult mechanism, 2020.
  • Li and Malik [2016] Ke Li and Jitendra Malik. Learning to optimize, 2016.
  • Li et al. [2020a] Yuanzhi Li, Colin Wei, and Tengyu Ma.

    Towards explaining the regularization effect of initial large learning rate in training neural networks, 2020a.

  • Li et al. [2017] Zhenguo Li, Fengwei Zhou, Fei Chen, and Hang Li. Meta-sgd: Learning to learn quickly for few-shot learning, 2017.
  • Li and Arora [2019] Zhiyuan Li and Sanjeev Arora. An exponential learning rate schedule for deep learning, 2019.
  • Li et al. [2020b] Zhiyuan Li, Kaifeng Lyu, and Sanjeev Arora. Reconciling modern deep learning with traditional optimization analyses: The intrinsic learning rate, 2020b.
  • Lyu and Li [2020] Kaifeng Lyu and Jian Li. Gradient descent maximizes the margin of homogeneous neural networks, 2020.
  • Maclaurin et al. [2015] Dougal Maclaurin, David Duvenaud, and Ryan P. Adams. Gradient-based hyperparameter optimization through reversible learning, 2015.
  • Nakkiran [2020] Preetum Nakkiran. Learning rate annealing can provably help generalization, even for convex problems, 2020.
  • Qi et al. [2020] Xiaoman Qi, PanPan Zhu, Yuebin Wang, Liqiang Zhang, Junhuan Peng, Mengfan Wu, Jialong Chen, Xudong Zhao, Ning Zang, and P. Takis Mathiopoulos.

    Mlrsnet: A multi-label high spatial resolution remote sensing dataset for semantic scene understanding, 2020.

  • Rolinek and Martius [2018] Michal Rolinek and Georg Martius. L4: Practical loss-based stepsize adaptation for deep learning, 2018.
  • Shallue et al. [2019] Christopher J. Shallue, Jaehoon Lee, Joseph Antognini, Jascha Sohl-Dickstein, Roy Frostig, and George E. Dahl. Measuring the effects of data parallelism on neural network training, 2019.
  • Tan and Le [2020] Mingxing Tan and Quoc V. Le.

    Efficientnet: Rethinking model scaling for convolutional neural networks, 2020.

  • van Laarhoven [2017] Twan van Laarhoven. L2 regularization versus batch and weight normalization, 2017.
  • Vaswani et al. [2017] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need, 2017.
  • Wan et al. [2020] Ruosi Wan, Zhanxing Zhu, Xiangyu Zhang, and Jian Sun. Spherical motion dynamics: Learning dynamics of neural network with normalization, weight decay, and sgd, 2020.
  • Wichrowska et al. [2017] Olga Wichrowska, Niru Maheswaranathan, Matthew W. Hoffman, Sergio Gomez Colmenarejo, Misha Denil, Nando de Freitas, and Jascha Sohl-Dickstein. Learned optimizers that scale and generalize, 2017.
  • Yaida [2018] Sho Yaida.

    Fluctuation-dissipation relations for stochastic gradient descent, 2018.

Supplementary material

Appendix A Experimental settings

We are using Flax ( https://github.com/google/flax) which is based on JAX [Bradbury et al., 2018]. The implementations are based on the Flax examples and the repository https://github.com/google-research/google-research/tree/master/flax_models/cifar [Foret et al., 2020]. Our simple Flax implementation can be found in SM E. All datasets are downloaded for tensorflow-datasets and we used their training/test split.

All experiments use the same seed for the weights at initialization and we consider only one such initialization unless otherwise stated. We have not seen much variance in the phenomena described, see Table

description. All experiments use cross-entropy loss.

We run all experiments in v3-8 TPUs. The WideResnet 28-10 CIFAR experiments take roughly h per run, the ResNet-50 Imagenet experiments take roughly h per run, the PyramidNet CIFAR experiments take roughly h per run, the Transformer experiments take roughly h per run, the ALBERT finetuning experiments take around minutes per task.

Here we describe the particularities of each experiment in the figure/tables.

Table 1. All models use momentum with a momentum parameter of and none of them uses dropout. ABEL considerations: learning rate decay for ImageNet and SVHN and for CIFAR-100. We force ABEL to decay at

of the total epoch budget by the same decay factor. We run the CIFAR-10 and CIFAR-100 WideResnet experiments with three different seeds and report the average accuracy, the standard deviation across experiments for all schedules and CIFAR-10, CIFAR-100 is

, except for ABEL and CIFAR-100 which has a standard deviation of

(the standard error of the mean is

, from these observations we only consider a single run for the other experiments. For the CIFAR and WRN experiments, the gradient norm is clipped to .

  • Resnet-50 on Imagenet: learning rate of , regularization of (we do not decay the batch-norm parameters here), label smoothing of and epochs. All experiments include a linear warmup of epochs. Standard data augmentation: horizontal flips and random crops. Step-wise decay: multiply learning rate by at epochs.

  • Wide Resnet 16-8 on SVHN: learning rate of , regularization of , batch size of , no dropout and evolved for epochs. Training data includes the "extra" training data and no data augmentation. Step-wise decay: multiply learning rate by at epochs.

  • Wide Resnet 28-10 on CIFAR10/100: learning rate of , regularization of , batch size of and evolved for epochs. Standard data augmentation: horizontal flips and random crops. Step-wise decay: multiply learning rate by at epochs.

  • Shake-Drop PyramidNet on CIFAR100: learning rate of , regularization of , batch size , evolved for epochs. Uses AutoAugment and cutout as data augmentation. Because training this model takes a long time and weight bounces generally take longer at smaller learning rates, we found it convenient to average the weight norm every five epochs when using ABEL in order to avoid the decays from being noise dominated (without this, the third decay happens too early and the test error increases by ) .

  • VGG-16 on CIFAR10: learning rate of , regularization of , batch size , evolved for epochs. Basic data augmentation and no batch norm.

Table 2 NLP, RL. All these experiments use ADAM, no dropout nor weight decay.

  • Base Transformer trained for translation: learning rate: , learning rate warmup of k steps, evolved for 100k steps, batch size uses reverse translation. Simple decay: decay by at k steps. The translation tasks correspond to WMT’17.

  • ALBERT and Glue finetuning. Fine tuned from the base ALBERT model of https://github.com/google-research/albert. Tasks: All tasks were trained for 10k steps (including 1k linear warmup steps) for four learning rates: , reported best learning rate, batch size . See table S1 for the specific scores. Simple decay: decay by at of training. The individual scores are summarized in S1.

  • PPO: learning rate , batch size , agents. Evolved for 40M frames. Simple decay: decay by at of training. Rest of configuration is the default configuration of https://github.com/google/flax/examples. Reported score is the average score over the last

    episodes. We have three runs per task and schedule. The scores with a 95% confidence interval are: Seaquest: Linear:

    , Simple: ; Pong: Linear , Simple ; Qbert: Linear , Simple .

Table 2 Vision.

These experiments are the same as table 1 but without . The simple decay is defined by evolving with a fixed learning rate during of the time and then decay it (with decay factor for ImageNet and for CIFAR). In the case of ImageNet+Adam, the learning rate was reduced to . For simple decay with , the test error would often increase after decaying, in such case, we choose the test error at the minimum of the training loss, see fig S50a for an experiemnt with this behaviour.

Fig (a)a: VGG-5 trained with and decayed by a factor of at the respective time.

Rest of figures. Small modifications or further plots of previous experiments, changes are specified explicitly.

Task MNLI(m/mm) QQP QNLI SST-2 CoLA STS-B MRPC RTE Average
Linear Decay 82.6/83.4 88.1 92.1 92.8 58.5 91.0 88.7 70.8 83.1
Simple Decay 82.9/83.5 87.6 91.1 92.1 57.0 91.2 88.2 73.3 82.9
Table S1: Evaluation metric is accuracy except for CoLA (Matthew correlation), MRPC (F1) and STS-B (Pearson correlation).

Appendix B Experimental details

b.1 The mild dependence on the last decay epoch

In figure S1 we run again the CIFAR experiments with ABEL for different values of ‘last_decay_epoch‘. We see how the final performance does not depend too much on this value.

Figure S1: Models trained with ABEL with different values for ‘last_decay_epoch‘. Difference of test error with respect to of training time ( epochs). We see that the test error does not change too much, most changes are less than .

b.2 Layer-wise dependence dynamics of weight norm.

In this section, we study the dynamics of the layer-wise weight norm. We consider the Resnet-50 Imagenet setup of table 1 with step-wise decay. For better visualization, we focuse on the top layers, we rank layers by their maximum weight norm during training. These top layers are mostly intermediate but there are also early and late layers. In figure S4a, we plot the contribution the quotient between the weight norm of each layer (or the sum of the top layers) and the total weight norm. We see these top layers account for of the total weight norm and they exhibit the same dynamics: the red line is straight. At the individual layer level, we see that several layers become constant really fast while there are which have a deeper bounce but then have the same dynamics as the total weight norm. We can illustrate this in fig S4b, where we compare the evolution of the different weights (normalized by their value at initialization). Most layers have the behavour of the total weight norm: a bounce and slow down of growth before decaying.

(a) Resnet-50 on Imagenet
(b) Resnet-50 on Imagenet
Figure S4: a) Evolution of the quotient between the weight norm of different layers and the total weight norm. b) Evolution of the weight norm of different layers, normalized by its value at initialization.

b.3 Decay LR on Loss Plateau

We implemented a simple version of this scheduler in FLAX with the default values for the hyperparameters from the pytorch implementation , except for the decay factor (called factor) which we set to for ImageNet and for CIFAR. The other relevant hyperparameters which we did not change are patience (Number of epochs with no improvement after which learning rate will be reduced), threshold (Relative Threshold for measuring the new optimum, to only focus on significant changes.). See the https://pytorch.org/docs/stable/optim.html for more details. The test errors for WRN 28-10 on CIFAR-10 are and for Resnet-50 on Imagenet. For Imagenet, the loss still decreases slightly during training so it does not seem like the logic behind this schedule could yield good results.

For CIFAR-10, the first decay is caused by a real plateau, but after that the loss increases for more than epochs, so the learning rate decays again. One decay after that, the loss basically stays in a plateau and the learning rate decays without bound. It is surprising that this schedule does so well on CIFAR-10 despite of the learning rate becoming so low so early.

These two examples exhibit the opposite issues: too much vs too few decay. It seems unrealistic that proper hyperparameter tuning can fix both problems, among other things because the ImageNet Resnet loss is slowly decaying.

(a) Wide Resnet on CIFAR-10
(b) Wide Resnet on CIFAR-10
(c) Wide Resnet on CIFAR-10
(d) Resnet-50 on Imagenet
(e) Resnet-50 on Imagenet
(f) Resnet-50 on Imagenet
Figure S11: Experiments from the main section with a ReduceLROnPlateau schedule.

b.4 No advantage for easy tasks

We expand on the discussion of the main text. We train the same Wide Resnet models of table 1 but without data augmentation. We see that in this case, simple decay reaches training error at the end of training and there is no advantage of cosine decay. The test error without data augmentation for CIFAR100 and (cosine) and (simple) and for CIFAR10 it is for both schedules. We attribute this to the fact that when evolved for a small number of epochs, before the weight norm has time to bounce, the simple decay schedule can already reach training error and thus it can not benefit from the bounce.

(a) Wide Resnet on CIFAR-10
(b) Wide Resnet on CIFAR-10
(c) Wide Resnet on CIFAR-10
(d) Wide Resnet on CIFAR-100
(e) Wide Resnet on CIFAR-100
(f) Wide Resnet on CIFAR-100
(g) Wide Resnet on CIFAR-10
(h) Wide Resnet on CIFAR-10
(i) Wide Resnet on CIFAR-10
Figure S21: Wide Resnets on CIFAR without data augmentation can reach zero training error with simple decay.

b.5 Learning rate dependence of bouncing

We can study the effect of the learning rate. Large learning rates seem important for weight norm bouncing as we can see in figure S26, the small learning rate of does not have a bouncing norm even if we evolve for longer (since smaller learning often to take a longer time to train). We see that only when there is a bouncing norm are cosine schedules benefitial, see table S2.

Setup Simple Decay Cosine Decay
lr 7.4 6.7
lr 7.4 7.8
lr 1000 epochs 7.2 7.6
Table S2: Test error for VGG-16 and CIFAR-10 for different small learning rate setups. Cosine decay is only superior at large learning rates when there is a bouncing weight norm.
(a) VGG-16 on CIFAR-10
(b) VGG-16 on CIFAR-10
(c) VGG-16 on CIFAR-10 1000epochs
(d) VGG-16 on CIFAR-10 1000epochs
Figure S26: In order for the weight norm to bounce in our training run, the learning rate has to be big enough. We see how for learning rate weight norm does not bounce. This is true even if we evolve for five times longer.

b.6 Bouncing for different initialization scales.

The bouncing behaviour suggests that maybe one can avoid bouncing at all by starting with the minimum weight norm to begin with. In order to check this, we run our CIFAR-100 WRN28-10 experiments with cosine decay for different initializations for all weights (including the batchnorm weights), which we denote , with the standard experiments that we have run. We see how, even if we start with a weight norm which is below the minimum weight norm for , there is still bouncing. As we keep decreasing , it gets to a point where performance gets significantly degraded (and bouncing dissappears too). These small

are probably pathological, but it is interesting to see again the correlation between degraded performance and no bouncing. Before reaching this small weight initialization the final weight norm (and error rate) is very similar for the different initializations.

(a) WRN 28-10 on CIFAR-100
(b) WRN 28-10 on CIFAR-100
(c) WRN 28-10 on CIFAR-100
Figure S30: Wide Resnet 28-10 on CIFAR100 expeirents with different initialization scale. a) Test error, we see that the two smallest initialization present substantial degradation of performance. b) Despite of the weight norm being so different at initialization, the final weight norm is the same for the different models. c) Zoomed version of b) we see that all the initializations but the smallest two exhibit clear bouncing.

Appendix C More training curves

c.1 Training curves for Table 1 experiments

We show the remaining training curves for the Table 1 experiments.

(a) Wide Resnet on CIFAR-10
(b) Wide Resnet on CIFAR-10
(c) Wide Resnet on CIFAR-10
(d) VGG on CIFAR-10
(e) VGG on CIFAR-10
(f) VGG on CIFAR-10
(g) Wide Resnet 16-8 on SVHN
(h) Wide Resnet 16-8 on SVHN
(i) Wide Resnet 16-8 on SVHN
(j) PyramidNet on CIFAR100
(k) PyramidNet on CIFAR100
(l) PyramidNet on CIFAR100
Figure S43: Remaining training curves for experiments of table 1.

c.2 Training curves for Table 2 experiments

Here we show the training curves for some experiments in Table 2.

(a) WRN 28-10 on CIFAR100
(b) WRN 28-10 on CIFAR100
(c) WRN 28-10 on CIFAR100
(d) Adam and Imagenet
(e) Adam and Imagenet
(f) Adam and Imagenet
Figure S50: A couple of vision tasks of table 2. We see how in the absence of , a non-trivial schedule does not affect the performance and the weight norm is usually monotonically increasing. d,e,f) Resnet-50 with Adam on Imagenet. We also see how weight norm bouncing also occurs despite of Adam having an implicit schedule.
(a) Transformer on En-De WMT
(b) Transformer on En-De WMT
(c) WRN 28-10 on CIFAR100
Figure S54: Training curves for Transformer trained on the English-German translation task WMT’14. We see how there is no significant difference between differnet learning rate schedules.
(a) PPO Pong
(b) PPO Qbert
(c) PPO Seaquest
(d) PPO Pong
(e) PPO Qbert
(f) PPO Seaquest
Figure S61: Game score of different RL games for the two considered schedules. Mean score over three runs and individual scores in lighter color. Simple is just decaying the base learning rate by a factor of at of training and linear is linear decay.

c.3 Training curves for large learning rate experiment

We show the training curves for the ImageNet experiment with a large learning rate of of section 2, figure 13. We see that Cosine and Stepwise schedules are basically not learning until the learning rate decays significantly, however, ABEL adapts quickly to the large learning rate and decays it fast, yielding better performance.

(a) Resnet-50 on ImageNet
(b) Resnet-50 on ImageNet
(c) Resnet-50 on ImageNet
Figure S65: Training curves for the Imagenet experiment of section 2, figure 13 with learning rate .

Appendix D Derivation of equations of section 4

Weight dynamics equations.

Equation 1 follows directly from the SGD update:

(S1)
(S2)

The terms that we are dropping are subleading for our purposes because practically. We expect these terms to be only relevant if the dominant terms cancel each other:

for scale invariant layers.

A scale invariant layer is such that its network function satisfies , the bare loss only depends on the weights through the network function. If we divide the weight into its norm and its direction: and use , we get that:

(S3)

For scale invariant networks without , the change in the weights becomes smaller as the weight norm equilibrates.

Using the SGD equation we get that:

(S4)

As , .

Appendix E Flax implementation of ABEL

For completeness we include our Flax implementation of ABEL.

1# Copyright 2021 Google LLC.
2# SPDX-License-Identifier: Apache-2.0
3
4class ABELScheduler():
5  """Implementation of ABEL scheduler."""
6
7  def __init__(self,
8              num_epochs: int,
9              learning_rate: float,
10              steps_per_epoch: int,
11              decay_factor: float,
12              train_fn: callable,
13              warmup: int = 0):
14
15    self.num_epochs = num_epochs
16    self.learning_rate = learning_rate
17    self.steps_per_epoch = steps_per_epoch
18    self.decay_factor = decay_factor
19    self.train_fn = train_fn
20    self.warmup = warmup
21    self.learning_rate_fn = self.get_learning_rate_fn(self.learning_rate)
22    self.weight_list = []
23    self.reached_minima = False
24    self.epoch = 0
25
26  def get_learning_rate_fn(self, lr):
27    """Outputs a simple decay learning rate function from base learning rate."""
28    lr_fn = flax.training.lr_schedule.create_stepped_learning_rate_schedule(
29        lr, self.steps_per_epoch // jax.host_count(),
30        [[int(self.num_epochs * 0.85), self.decay_factor]])
31    if self.warmup:
32      warmup_fn = lambda step: jax.numpy.minimum(
33          1., step / self.steps_per_epoch / self.warmup)
34    else:
35      warmup_fn = lambda step: 1
36
37    return lambda step: lr_fn(step) * warmup_fn(step)
38
39  def update(self, step_fn, weight_norm):
40    """Optimizer update rule for ABEL Scheduler. This is basically Algorithm 1."""
41    self.weight_list.append(weight_norm)
42    self.epoch += 1
43
44    if len(self.weight_list) < 3:
45      return step_fn
46
47    if (self.weight_list[-1] - self.weight_list[-2]) * (
48        self.weight_list[-2] - self.weight_list[-3]) < 0:
49      if self.reached_minima:
50        self.reached_minima = False
51        self.learning_rate *= self.decay_factor
52        step_fn = self.update_train_step(self.learning_rate)
53      else:
54        self.reached_minima = True
55
56    return step_fn
57
58  def update_train_step(self, learning_rate):
59    learning_rate_fn = self.get_learning_rate_fn(learning_rate)
60    return self.train_fn(learning_rate_fn=learning_rate_fn)

The main reason why this is a Flax implementation is because the update_step takes a train_step function and outputs a train_step function. This can be easily added to the Flax examples/baselines, we only have to make two modifications. Before starting the training loop, we initiate ABEL cccNote that ABELSchedule takes train_step_fn as an argument. This is a function that takes a learning_rate_fn and outputs a train_step (ie train_step_fn = lambda lr_fn: functools.partial( train_step, learning_rate_fn=lr_fn):

# Copyright 2021 Google LLC.
# SPDX-License-Identifier: Apache-2.0
# Before training starts.
scheduler = ABELScheduler(num_epochs, base_learning_rate, steps_per_epoch = steps_per_epoch, decay_factor=decay_factor, train_fn = train_step_fn)
learning_rate_fn = scheduler.learning_rate_fn

Only remaining modification is to add the ABEL update rule at the end of each epoch which takes the current train_step function and mean weight norm and returns a (possibly updated) train_step function. ABEL will update the optimizer if the learning rate has to be decayed.

# Copyright 2021 Google LLC.
# SPDX-License-Identifier: Apache-2.0
# At the end of each epoch.
train_step = scheduler.update(train_step, weight_norm)