Code of SALeRA paper https://arxiv.org/abs/1709.01427
When applied to training deep neural networks, stochastic gradient descent (SGD) often incurs steady progression phases, interrupted by catastrophic episodes in which loss and gradient norm explode. A possible mitigation of such events is to slow down the learning process. This paper presents a novel approach to control the SGD learning rate, that uses two statistical tests. The first one, aimed at fast learning, compares the momentum of the normalized gradient vectors to that of random unit vectors and accordingly gracefully increases or decreases the learning rate. The second one is a change point detection test, aimed at the detection of catastrophic learning episodes; upon its triggering the learning rate is instantly halved. Both abilities of speeding up and slowing down the learning rate allows the proposed approach, called SALeRA, to learn as fast as possible but not faster. Experiments on standard benchmarks show that SALeRA performs well in practice, and compares favorably to the state of the art.READ FULL TEXT VIEW PDF
Hyperparameter tuning is a bothersome step in the training of deep learn...
We propose a novel diminishing learning rate scheme, coined
As a simple and efficient optimization method in deep learning, stochast...
We propose to optimize neural networks with a uniformly-distributed rand...
Stochastic Gradient Descent (SGD) methods are prominent for training mac...
Convergence detection of iterative stochastic optimization methods is of...
We propose a stochastic optimization method for minimizing loss function...
Code of SALeRA paper https://arxiv.org/abs/1709.01427
SGD was revived in the last decade as an effective method for training deep neural networks with linear computational complexity in the size of the dataset (Bottou and Bousquet, 2008; Hardt et al., 2015). SGD faces two limitations, depending on the learning rate: too large, and the learning trajectory leads to catastrophic episodes; too small, and its convergence takes ages. The dynamic adjustment of the learning rate was therefore acknowledged a key issue since the early SGD days (Robbins and Monro, 1951).
The exploding gradient problem, described in(Goodfellow et al., 2016, Chapter 8)
as the encounter of steep cliff structures in the derivative landscape during learning, is frequently met while training neural networks (and even more so when training recurrent neural networks(Bengio et al., 1994)).
. Gradient clipping, constraining the gradient norm to remain smaller than a constant(Pascanu et al., 2013)
, is another possibility. The introduction of batch normalization(Ioffe and Szegedy, 2015) also helps diminishing the frequency of such events. Finally, proper initialization (Glorot and Bengio, 2010; Sutskever et al., 2013) or unsupervised pre-training (Erhan et al., 2010), i.e. initializing the optimization trajectory in a good region of the parameter space, also diminish the frequency of such events.
Addressing the slow speed of SGD, learning rate adaptation has been acknowledged a key issue since the late 80s (see George and Powell (2006) for a review). Using the information contained in the correlation of successive gradient directions was already at the heart of the delta-delta and delta-bar-delta update rules proposed by Jacobs (1988). Briefly, the delta-bar-delta rule states that for each parameter, if the current gradient and the relaxed sum of past gradients have the same sign, the learning rate is incremented additively; and if they are of opposite sign, then the learning rate is decremented multiplicatively. Decrementing the learning rates faster than increasing them was already advocated by the author to adapt faster in case of catastrophic events.
The natural gradient descent (NGD) approach (Amari, 1998)
considers the Riemaniann geometry of the parameter space, using the Fisher information matrix (estimated by the gradient covariance matrix), to precondition the gradient. Due to its quadratic complexity in the dimension of the parameter space, NGD approximations have been designed for deep networks(Pascanu and Bengio, 2014). Notably, approaches such as the Hessian-Free from Martens (2010) can be interpreted as NGD (Pascanu and Bengio, 2014).
, RMSProp(Tieleman and Hinton, 2012), and Adam (Kingma and Ba, 2014). Adam
is based on estimating the first and second moments of the gradient w.r.t. each parameter, and using their ratio to update the parameters. Moment estimates are maintained by exponential moving averages of different weight factors, such that by default the inertia of the first moment is higher by two orders of magnitude than the second. As will be seen,SALeRA also builds upon the use of the gradient second moment, with the difference that it is compared with a fixed agnostic counterpart.
In (Schaul et al., 2013)
, the learning rate is computed at each time-step to approximately maximally decrease the expected loss, where the loss function is locally approximated by a parabola. Finally,Andrychowicz et al. (2016)
address learning rate adaptation as a reinforcement learning problem, exploiting the evidence gathered in the current time steps to infer what would have been the good decisions earlier on, and accordingly optimizing a hyper-parameter adjustment policy.
SALeRA involves two components: a learning rate adaptation scheme, which ensures that the learning system goes as fast as it can; and a catastrophic event manager, which is in charge of detecting undesirable behaviors and getting the system back on track.
The basic idea of the proposed learning rate update is to compare the current gradient descent to a random walk with uniformly chosen gradient directions. Indeed, the sum of successive normalized gradient vectors, referred to as cumulative path in the following, has a larger norm than the sum of uniformly drawn unit vectors if and only if gradient directions are positively correlated. In such cases, the learning process has a global direction and the process can afford to speed up. On the opposite, if the norm of the cumulative path is smaller than its random equivalent, gradient directions are anti-correlated: the process is alternating between opposite directions (e.g., bouncing on the sides of a narrow valley, or hovering around some local optimum) and the learning rate should be decreased.
The ALeRA scheme takes inspiration from the famed CMA-ES (Hansen and Ostermeier, 2001) and NES (Wierstra et al., 2014) algorithms, today considered among the best-performing derivative-free continuous optimization algorithms. These approaches de facto implement natural gradient optimization (Amari, 1998) and instantiate the Information-Geometric Optimization paradigm (Ollivier et al., 2017)
in the space of normal distributions on
. Formally, CMA-ES maintains a normal distribution. The variance of the normal distribution, akastep-size, is updated on the basis of a comparison of the cumulative path of the algorithm (moving exponential average of successive steps) with that of a random walk with Gaussian moves of fixed step-sizes. This mechanism is said to be agnostic as it makes no assumption whatsoever on the properties of the optimization objective.
The partial adaptation of the CMA-ES scheme to the minimization of a loss function on a d-dimensional parameter space is defined as follows. Let be the solution at time , the gradient of the current loss, and the current learning rate. SGD computes the solution at time by . Let denote the norm and the associated dot product.
Definition. For , the exponential moving average of the normalized gradients with weight , , and its random equivalent r are defined as:
where are independent random unit vectors in .
Proposition. For , the expectation and variance of as defined above are:
Proof: Appendix A.
Let and denote the limits of respectively and as . At time step , the ALeRA scheme updates the cumulative path by comparing its norm with the distribution of the agnostic momentum defined from and . The learning rate is increased or decreased depending on the normalized gap between the squared norm of and r:
with a hyper-parameter of the algorithm.
The approach is implemented in the ALeRA algorithm (non-greyed lines in Algorithm 1). Given hyper-parameters and , as well as , the initial learning rate, and the mini-batch size, each iteration over a mini-batch computes the new exponential moving average of the normalized gradient (line 1), performs the agnostic update of the learning rate (line 1) before updating parameter the usual way (line 1). The learning rate can be controlled in a layer-wise fashion, independently maintaining an exponentiated moving average and updating the learning rate for each layer of the neural network (lines 1-1). This algorithm is used in all experiments of Section 3.
As noted by Kingma and Ba (2014), the parameter-wise control of the learning rate is desirable in some contexts. The above scheme is extended to achieve the parameter-wise update of the learning rate as follows. For let denote the squared i coordinate of r
. It is straightforward to show that the expectation (respectively the standard deviation) ofis the expectation of r divided by (respectively the standard deviation of r divided by ).
Given and , the squared i coordinate of noted r, can thus likewise be adjusted by comparison with its random counterpart . The update therefore becomes:
See Appendix B for a full derivation of the parameter-wise algorithm.
As said, the ability to learn fast requires an emergency procedure, able both to detect an emergency and to recover from it.
In a healthy learning regime, the training error should decrease along time up to the noise due to the inter-batch variance unless the learning system abruptly meets a cliff structure (Goodfellow et al., 2016), usually blamed on too large a learning rate in an uneven gradient landscape.
In a convex noiseless optimization setting, if computationally tractable, the best strategy is to compute (an approximation of) the optimal learning rate through line search333See Defazio et al. (2014) on how to handle the noise in the case of a composite loss function.. In such context, as a thought experiment, let be such that, if used to update , the resulting would yield the same performance as . For , yields a worse performance than , and if continued the optimization process is likely to diverge. For , yields a performance improvement. For , yields a performance improvement too; the further trajectory is likely to bounce back and forth on the walls of the optimum valley.
Overall, the safety zone for the learning rate is (with the caveat that the safety zone is narrower for ill-conditioned optimization problems). The proposed safeguard strategy primarily aims to detect when steps outside of its safety zone (see the change point detection test below), and to apply a correction as to get back in it. Upon change detection in the mini-batch loss, SALeRA implements a straightforward correction: halving the learning rate and recovering the last solution before
test triggering. The halving process is iterated if needed, sending back in exponentially fast (except for the perturbations in the gradient due to the mini-batch variance).
The rationale for the halving trick is based on a trade-off between the number of successive dividing iterations, that could indeed be made even smaller by using a larger dividing factor, and the required standard ALeRA iterations that will be needed to reach the optimal learning rate after having reached the safety zone again. The choice of the dividing factor is further discussed in Appendix C.
SALeRA applies a change detection test on the signal given as the minibatch loss .
The PH detection test (Page, 1954; Hinkley, 1970) is chosen as it provides optimal guarantees about the trade-off between the detection delay upon a change (affecting the average or standard deviation of the signal) and the mean time between false alarms. For , it maintains the empirical mean of the signal, and the cumulative deviation444The PH test takes into account the extreme value phenomenon by considering an upper cumulative deviation and a lower cumulative deviation, defined from by adding (resp. subtracting) a margin to . In SALeRA is set to 0. from the empirical mean (). Finally, it records the empirical bounds of (; ). In case of a stationary signal, the expectation of is 0 by construction; the PH change test is thus triggered when the gap between and its empirical bounds is higher than a problem-dependent threshold , which controls the alarm rate.
The PH test is implemented in the SALeRA algorithm as follows (greyish lines in Algorithm 1). is set to , with , that is, one tenth of the empirical loss on the first minibatch in all experiments (this issue is further discussed in Section 3.2). Variables , and are maintained (lines 1- 1). In the learning context, a decrease of the loss signal is welcomed and expected. Only the case of an increasing signal is thus monitored. Upon test triggering (), the learning rate is halved and the weight vector is reset to the last solution before then (line 1), and the PH test is reinitialized (line 1).
The goal of the following experiments is to validate the algorithmic ideas introduced in section 2 by comparing their application with that of widely used optimization techniques (see Section 1) on some straightforward NN architectures.
datasets, which respectively contain 60k and 50k training examples. Both contain 10k test examples, which are to be classified in 10 classes. The data is normalized according to the mean and standard deviation along each coordinate on the training set.
Adagrad, NAG, and Adam are used as baselines. The agnostic adaptation rule and the change detection can be applied independently. In order to separate their effect, 3 original algorithms are studied here: ALeRA (the white lines in Algorithm 1) implements the agnostic learning rate adaptation without the change detection; Ag-Adam uses the same agnostic adaptation for the learning rate on top of the Adam algorithm. Finally, the change detection mechanism is implemented with the agnostic adaptation, yielding SALeRA as described in Algorithm 1, as well as its parameter-wise version SPALeRA (Algorithm 2 in Appendix B).
The exploration of the hyper-parameter space for all algorithms has been done on a grid of possible values (with the exception of Adagrad which has no hyper-parameter):
NAG: the momentum .
Adam: and ;
ALeRA: and (see Algorithm 1). The parameters for SALeRA and SPALeRA are the same, as there is no additional parameter for the Page-Hinkley part.
Ag-Adam: the recommended values for Adam ( and ) are used for the Adam part, the same values than for ALeRA are used for the ALeRA part.
The initial learning rate ranges from to
depending on the algorithm. Finally, the mini-batch size was set to either 1% or 1‰ of the training set size. All reported results are based on 5 independent runs performed for each hyperparameter set unless otherwise specified.
All experiments consider the following 4 network architectures:
M0: a softmax regression model with cross-entropy loss (i.e. no hidden layers),
: 2 fully connected hidden layers with ReLU activation, on top of M0. The hidden layers are of respective sizes (500, 300) for MNIST, (1 500, 900) for CIFAR-10.
M2b is identical to Model M2 above, except that Batch Normalization layers (Ioffe and Szegedy, 2015) are added in each hidden layer.
M4: LeNet5-inspired convolutional models (Le Cun et al., 1998)
. These models contain 2 convolutional layers with max-pooling followed by 2 fully connected layers, all with ReLU activation. They are of respective sizes (32, 64, 128, 128) for MNIST, and (32, 64, 384, 384) for CIFAR-10. Batch normalization is used in each layer.
These architectures are not specifically optimized for the task at hand, but rather chosen to compare the performances of past and novel algorithms in a wide variety of situations.
All computations are performed on 46 GPUs (5 TITAN X(Pascal), 9 GTX 1080 and 32 Tesla K80) using the Torch library(Collobert et al., 2011) in double precision. A typical run on a TITAN X (Pascal) GPU for 20k mini-batches of size 50 for CIFAR-10 on M4 takes between 8 and 10 minutes for all algorithms.
MNIST and CIFAR-10 are classification problems. We therefore report classification accuracy on the test set at 5 epochs (i.e. 5 full passes on the training set) and 20 epochs (end of all runs), as well as their standard deviations on the 5 independent runs.
|MNIST||M0||5ep.||7.75 (.14)||7.73 (.06)||7.72 (.12)||7.31 (.09)||7.51 (.09)||7.51 (.09)||8.03 (.09)|
|20ep.||7.59 (.07)||7.51 (.09)||7.43 (.10)||7.29 (.09)||7.43 (.03)||7.44 (.04)||7.60 (.08)|
|M2||5ep.||1.95 (.17)||2.00 (.06)||2.07 (.11)||1.93 (.17)||1.86 (.11)||1.87 (.05)||1.93 (.14)|
|20ep.||1.58 (.08)||1.71 (.08)||1.56 (.06)||1.57 (.04)||1.55 (.10)||1.59 (.09)||1.55 (.09)|
|M2b||5ep.||1.82 (.13)||1.72 (.07)||1.81 (.07)||1.66 (.08)||1.59 (.08)||1.59 (.08)||1.78 (.08)|
|20ep.||1.47 (.10)||1.48 (.06)||1.57 (.94)||1.53 (.05)||1.43 (.04)||1.48 (.09)||1.50 (.09)|
|M4b||5ep.||.85 (.08)||1.02 (.09)||.89 (.31)||.91 (.06)||.82 (.30)||.82 (.30)||.82 (.14)|
|20ep.||.72 (.09)||.82 (.08)||.80 (.08)||.79 (.05)||.63 (.05)||.64 (.07)||.63 (.11)|
|CIFAR||M0||5ep.||60.37 (.55)||60.49 (.71)||60.60 (.45)||59.62 (.27)||59.89 (.19)||59.69 (.33)||61.32 (.65)|
|M0||20ep.||59.73 (.19)||59.76 (.36)||59.81 (.24)||59.34 (.24)||59.31 (.11)||59.31 (.25)||59.71 (.48)|
|M2||5ep.||45.82 (.93)||44.81 (.62)||45.68 (.39)||44.91 (.42)||44.69 (17.58)||44.42 (.24)||45.74 (.85)|
|M2||20ep.||45.08 (.32)||43.59 (.51)||44.43 (.50)||43.25 (.40)||43.19 (.21)||42.72 (.41)||44.48 (.65)|
|M2b||5ep.||45.01 (.84)||44.18 (.62)||44.30 (.96)||43.33 (.33)||43.08 (.17)||43.08 (.17)||44.92 (.89)|
|M2b||20ep.||42.50 (.48)||43.79 (.25)||43.60 (.64)||42.72 (.33)||42.12 (.15)||42.50 (.29)||43.23 (.27)|
|M4b||5ep.||27.74 (.48)||34.93 (.96)||28.50 (.68)||25.60 (.29)||28.61 (.50)||28.61 (.50)||29.61 (.99)|
|M4b||20ep.||27.45 (.39)||29.15 (.67)||27.84 (.59)||25.30 (.18)||26.35 (.64)||25.93 (.64)||27.94 (.22)|
|MNIST||M0||5ep.||7.91 (.14)||7.87 (.25)||7.86 (.20)||7.96 (.13)||7.96 (.13)||8.06 (.10)|
|20ep.||7.59 (.07)||7.45 (.08)||7.34 (.08)||7.53 (.10)||7.53 (.06)||7.60 (.08)|
|M2||5ep.||1.95 (.17)||2.30 (.17)||2.04 (.15)||1.89 (.07)||1.87 (.05)||1.94 (.13)|
|20ep.||1.64 (.06)||1.59 (.06)||1.80 (.45)||1.68 (.04)||1.68 (.05)||1.63 (.05)|
|M2b||5ep.||1.85 (.05)||1.93 (.09)||1.99 (.11)||1.71 (.11)||1.72 (.04)||1.80 (.08)|
|20ep.||1.47 (.10)||1.62 (.09)||1.82 (.26)||1.59 (.05)||1.59 (.05)||1.52 (.02)|
|M4b||5ep.||.85 (.08)||1.03 (.08)||1.08 (.13)||.99 (.11)||.99 (.14)||.94 (.09)|
|20ep.||.76 (.12)||.80 (.08)||.91 (.08)||.83 (.01)||.83 (.01)||.63 (.11)|
|CIFAR||M0||5ep.||60.73 (.44)||60.74 (.39)||60.55 (.49)||60.46 (.81)||60.43 (.42)||61.71 (.76)|
|20ep.||59.73 (.19)||60.08 (.22)||59.34 (.24)||59.31 (.11)||59.37 (.10)||59.79 (.57)|
|M2||5ep.||46.62 (.12)||45.71 (.54)||45.15 (.15)||44.84 (17.52)||45.01 (.14)||46.08 (.90)|
|20ep.||45.08 (.32)||44.44 (.61)||44.30 (.42)||43.92 (1.43)||43.68 (.25)||44.62 (.56)|
|M2b||5ep.||45.48 (.67)||44.90 (.50)||43.33 (.33)||44.15 (.29)||44.49 (.18)||45.45 (.57)|
|20ep.||43.86 (.52)||43.60 (.64)||43.11 (.35)||42.70 (.31)||42.70 (.31)||43.51 (.80)|
|M4b||5ep.||27.74 (.48)||29.34 (.83)||28.15 (.21)||28.61 (.50)||28.61 (.50)||30.09 (.73)|
|20ep.||27.70 (.76)||28.00 (.72)||27.46 (.54)||27.60 (.26)||27.60 (.26)||28.05 (.43)|
The experimental evidence (Table 1) shows that Ag-Adam quite often slightly but statistically significantly improves on Adam. A possible explanation is that Ag-Adam has a more flexible adjustment of the learning rate than Adam (possibly increasing by a few orders of magnitude). In many cases, ALeRA and SALeRA yield similar results; indeed, whenever ALeRA does not meet catastrophic episodes, ALeRA and SALeRA have the same behaviors. A representative run, where ALeRA and SALeRA undergo catastrophic episodes is depicted on Fig. 1 (both runs with same random seed). ALeRA faces a series of catastrophic episodes, where the training error reaches up to 80%. It eventually stabilizes itself with a medium training loss, but with test error above 80%. In the meanwhile, SALeRA reacts upon the first catastrophic episode around epoch 8, by halving the learning rate on each layer. It faces a further catastrophic episode around epoch 13, and halves the learning rates again. Overall, it faces less frequent and less severe (in terms of train loss and test error deteriorations) accidents. Eventually, SALeRA recovers acceptable train and test errors.
It is interesting to note that the learning rates on Fig. 1 are constantly increasing, in contradiction with common knowledge. In practice, the learning rate behavior depends on the dataset, the neural architecture and the seed, and can be very diverse (constant decrease, constant increase, or most of the time, an increase followed by a decrease). The diverse learning rate behavior is viewed as an original feature of the proposed approach, made possible by the ability to detect, and recover from, catastrophic explosions of the training loss.
The actual behavior of all algorithms is depicted for CIFAR-10, model M4 on Fig. 2, with Ag-Adam, ALeRA and SALeRA respectively getting first, second and third rank in terms of test error at 20 epochs. In terms of optimization per se, SALeRA (respectively ALeRA) reaches a training error close to 0 at epoch 10 (resp. epoch 20) whereas Ag-Adam reaches a plateau after epoch 15. In the meanwhile, the test error decreases and the test loss increases for all three algorithms. A tentative interpretation for this fact is that the neural net yields more crisp output, close to 0 or 1; this does not change the error while increasing the loss. This result suggests several perspectives for further work (Section 4).
In order to determine to which extent the best results in Table 1 depend on the hyper-parameter settings, and define best configurations for the considered benchmarks, a sensitivity analysis was performed, comparing the results on all models and all epochs for each setting, and choosing the one with lowest sum of ranks. The results for these robust settings are displayed table 2 (as Adagrad has no hyper-parameter, it is not mentioned there). The proposed hyper-parameters for ALeRA and SALeRA are found to be and . Interestingly enough, we find that the optimal hyper-parameters for Adam are and instead of and as is suggested in the original paper (Kingma and Ba, 2014). The proposed approaches still show some advantage over Adam and NAG, though they seem more sensitive to their parameter tuning. It is left for further work to derive precise recommendations depending on model characteristics.
Let us define a failed run as a run which attains more than 80% test error after 20 epochs. ALeRA is observed to have 18.3% failed runs over the parameter range as defined in 3.1. Our catastrophe management scheme makes it possible for SALeRA to avoid approximately 40% of these failures, reaching a rate of failure of 11.7% on the same parameter range.
It would of course be possible to further diminish the failure rate of SALeRAby setting the PH alarm threshold to with (or ), rather than as used in all experiments reported above. However, this would potentially interfere with learning rate adaptation by triggering learning rate halvings even when there is no serious alert to be made, thus preventing it to be as bold as possible. Indeed, setting causes a decline of more than half of SALeRA performances (data not shown), even though it manages to approximately halve the number of failed runs. Furthermore, setting does not further diminish the failure rate, therefore only harming SALeRA learning performances. On the other hand, setting does not significantly improve the best performances, but has a higher failure rate.
One tenth of the initial loss is therefore a very good balance between the aggressive learning rate adaptation scheme and its braking counterpart, at least on these datasets and architectures.
The first proposed contribution relies on the comparison of the gradient momentum with a fixed reference. It is meant to estimate the overall correlation among the sequence of gradients, which can be thought of as a signal-to-noise ratio in the process generated from the current solution, the objective and the successive mini-batches. Depending on this ratio, the process can be accelerated or slowed down. The ALeRA procedure, which implements this idea, proves to be significantly able to increase and decrease the learning rate. Furthermore, this process can be plugged on Adam, with a performance improvement on average.
The price to pay for this flexibility is that it increases the risk of catastrophic episodes, with instant rocketing of the training loss and gradient norm. The proposed approach relies on the conjecture that catastrophic episodes can be rigorously observed and detected. A second conjecture is that a neural net optimizer is almost doomed to face such episodes along the optimization process. These events are mostly detrimental to optimization: before the run, one often chooses small learning rates (and thus slow convergence) to prevent them, and during the run, they are mostly not recovered from. Based on these two conjectures, the second contribution of the paper is an agnostic and principled way to detect and address such episodes. The detection relies on the Page-Hinkley change point detection test. As soon as an event is detected, learning rates are halved and the previous solution recovered.
A short-term perspective for further research is to apply the proposed approach to recurrent neural networks, and to consider more complex datasets. Another perspective is to replace the halving trick by approximating the line search, e.g. by exploiting the gaps between the actual momentum and the reference one, for several values of the momentum weight factor. A third perspective regards the adaptation of the PH detection threshold during learning. In some runs, the PH test is triggered over and over, resulting in a very small learning rate preventing any further improvement. The goal is to adapt when the PH mechanism is re-initialized (line 1 of Algorithm 1) from the current loss values. Another perspective is to apply SALeRA to
We heartfully thank Steve Altschuler and Lani Wu for making this work possible, supporting it, as well as for very insightful discussions. We also thank Yann Ollivier and Sigurd Angenent for other insightful discussions, and the anonymous reviewers of a preliminary version of this paper for their accurate and constructive comments.
Proceedings of the Thirteenth International Conference on Artificial Intelligence and Statistics, AISTATS 2010, Chia Laguna Resort, Sardinia, Italy, May 13-15, 2010, pages 249–256, 2010. URL http://www.jmlr.org/proceedings/papers/v9/glorot10a.html.
Let us define r = 0, and for
with a unit vector uniformly drawn in . We have by recurrence for
We are dealing with unit vectors which are uniformly drawn. Thus with denoting the Kronecker delta
Taking the expectation in 8 we have
The derivation of formula 5 is similar.
We use here the context and notations of Section 2.2, and make it even simpler by assuming one dimension (the gradient direction does not change), and minimizing the 1D parabola . It is straightforward to show that, independently of the current solution , the optimal value for the learning rate is , and that the value above which the loss will deteriorate is .
Let us assume that the current learning rate is , and that the recovery phase of SALeRA will be used to prevent further catastrophic event, with a dividing factor , and let us show that is the best trade-off between bringing back into and then reaching the optimal value , as discussed in Section 2.2.
After the PH test has been triggered, a first phase brings below by successive divisions by . The number of such divisions is .
The second phase uses the standard ALeRA procedure to reach from this value . In the 1D context, let us assume a simplified procedure, that updates by multiplying it by some if and by if , for some small .
Let us consider two cases, depending on whether is smaller or greater than . We can compute the expectation of the number of iteration of SALeRA that are needed to reach from .
If , , hence the further standard update phase decreases by multiplying it by until becoming less than (and very close to) . The length of this update phase increases with , thus by construction it is minimal for . In the meanwhile, the length of the division phase decreases as increases; thus the optimal value for is .
If , then . Let us consider the two intervals and , and U and T the respective expectations of the number of ALeRA iterations to reach
. The expectation of the total number of iteration is the sum of U and T, weighted by the probability of arriving in the respective intervals, i.e., the lengths of these intervals. These weights are hencefor U and for T.
The expected value for is independent of . For , and only counting the number of multiplications needed to get close to , the number of iterations is . Integrating over gives
Similarly, the expected value for can be computed over the interval using the same approach. Only counting again the number of multiplications needed to get close to , it comes
Let us summarize now the different computational costs involved after a catastrophe has been detected and before the grail is reached. The cost of the dividing iterations only involves a forward pass on the current minibatch. Let us denote this cost by . On the other hand, the standard ALeRA iterations have a larger cost, involving a forward pass plus a backward pass and the weights update. Let us denote this cost by .
We are looking for the value of that will minimize the total cost of reaching after a catastrophic event has been detected, i.e., that minimizes
or, equivalently, that minimizes
It is easy to empirically check (Figure 3) that has its global minimum between 3 and 5, that depends on the value of the constant , assumed small here (). Then increases to some asymptotic value. However, the value was initially chosen for historical reasons, by reference to the famed doubling trick frequently used in different areas of Machine Learning. In the light of these results in the simple 1D case, further work will investigate slightly larger values.