Fast and stable optimization algorithms are what generations of researchers have been pursuing (Cauchy, 1847)
. Remarkably, stochastic gradient-based optimization, such as stochastic gradient descent (SGD), has witnessed tremendous success in many fields of science and engineering despite its simplicity. Recently, many new methods have been proposed to accelerate optimization by applyingadaptive learning rate. In particular, Adagrad (Duchi et al., 2011) and its variants, e.g., RMSprop (Tieleman and Hinton, 2012), Adam (Kingma and Ba, 2014), Adadelta (Zeiler, 2012) and Nadam (Dozat, 2016), have been widely used due to their fast convergence.
However, it has been observed that in many cases, these optimization methods converge to bad/suspicious local optima, and have to resort to a warmup heuristic – using a small learning rate in the first few epochs of training – to mitigate the convergence problem(Vaswani et al., 2017; Popel and Bojar, 2018). For example, on the De-En IWSLT’14 dataset, removing warmup increases the training perplexity of a Transformer language model from 10 to over 500, as shown in Figure 1. Similar phenomena are observed in other scenarios like BERT pre-training (Devlin et al., 2018).
Since the theoretical underpinnings of the warmup heuristic are lacking, there is neither guarantee that it always works in various machine learning settings nor guidance on how we should conduct warmup. Thus, researchers typically use different settings in different applications and have to take a trial-and-error approach, which can be tedious and time-consuming.
In this paper, we conduct both theoretical and empirical analysis of the convergence issue to identify its origin. Specifically, we show that its root cause is that the adaptive learning rate has undesirably large variance in the early stage of model training due to the limited amount of training samples being used. Thus, to reduce such variance, it is better to use smaller learning rates in the first few epochs of training, which justifies the warmup heuristic.
Moreover, we propose a new variant of Adam, called Rectified Adam (RAdam), which explicitly rectifies the variance of the adaptive learning rate based on derivations. We conduct extensive experiments on language modeling, image classification, and neural machine translation. RAdam brings consistent improvement over the vanilla Adam, which verifies the variance issue generally exists on various tasks across different network architectures.
Our main contributions are two-fold:
We identify the variance issue of the adaptive learning rate and present a theoretical justification for the warmup heuristic. We show that the convergence issue is due to the undesirably big variance of the adaptive learning rate in the early stage of model training.
We propose a new variant of Adam (i.e., RAdam), which not only explicitly rectifies the variance and is theoretical sound, but also compares favorably with the heuristic warmup.
2 Preliminaries and Motivations
Generic adaptive methods. Algorithm 1 is a generic framework (all operations are element-wise) that describes many popular adaptive gradient descent methods algorithms (Reddi et al., 2019). Specifically, different optimization algorithms can be specified by different choices of and . For example, in the Adam algorithm, these two functions are set to:
For numerical stability, is usually calculated as , where is set to a relatively small value (e.g., ).
Learning rate warmup. Instead of setting the learning rate as a constant or in decreasing order, a learning rate warmup strategy sets as some small values in the first few steps. For example, linear warmup sets when
. Warmup has been demonstrated to be beneficial in many deep learning applications. For example, in the NMT experiments in Figure1, the perplexity convergences around 500 when warmup is not applied (Adam-vanilla), and it surprisingly decreases to below 10 after applying warmup (Adam-warmup).
To further analyze this phenomenon, we visualize the histogram of the absolute value of gradients on a log scale in Figure 2. We observe that, without applying warmup, the gradient distribution is distorted to have a mass center in relatively small values within 10 updates. Such gradient distortion means that the vanilla Adam is trapped in bad/suspicious local optima after the first few updates. Warmup essentially reduces the impact of these problematic updates to avoid the convergence problem. In the following sections, we focus our analysis on learning rate warmup for the Adam algorithm, while it can be applied to other algorithms that use similar adaptive learning rate () designs, e.g., RMSprop (Tieleman and Hinton, 2012) and Nadam (Dozat, 2016).
3 Variance of Adaptive Rate
In this section, we first introduce empirical evidence, then analyze the variance of the adaptive learning rate to support our hypothesis – Due to the lack of samples in the early stage, the adaptive learning rate has an undesirably large variance, which leads to suspicious/bad local optima.
To begin with, we first analyze a special case. When , we have . We view222The mean zero normal assumption is valid at the beginning of the training, since weights are sampled from normal distributions with mean zero (Balduzzi et al., 2017), further analysis is conducted in Section 5.3.. Therefore,
is subject to the scaled inverse chi-squared distribution,. Noted and it is divergent. It means that the adaptive ratio can be undesirably large in the first stage of learning. Meanwhile, setting a small learning rate at the early stage can reduce the variance (), thus alleviate this problem. Therefore, we suggest it is the unbounded variance of the adaptive learning rate in the early stage that causes the problematic updates.
3.1 Warmup as Variance Reduction
In this section, we design a set of controlled experiments to verify our hypothesis.
Particularly, we design two variants of Adam: Adam-2k and Adam-eps, and compare them to Adam with warmup and the vanilla Adam (without warmup) on the IWSLT’14 German to English dataset (Cettolo et al., ).
In the first two thousand iterations of Adam-2k, only the adaptive learning rate () is updated, while the momentum () and parameters () are fixed333Different from Gotmare et al. (2019) , all parameters and first moments are frozen in the first 2000 iterations.
, all parameters and first moments are frozen in the first 2000 iterations.; other than this, it follows the original Adam algorithm. To comparison with other methods, its iterations are indexed from -1999 instead of 1. As in Figure 1
, we observe that, after getting these additional two thousand samples for estimating the adaptive learning rate, Adam-2k avoids the convergence problem of the vanilla-Adam. Also, comparing Figure2 and Figure 3, getting large enough samples prevents the gradient distribution from being distorted. These observations verify our hypothesis that the lack of sufficient data samples in the early stage is the root cause of the convergence issue.
We further demonstrate that the convergence problem can be avoided by reducing the variance of the adaptive learning rate. A straightforward way to reduce the variance is to increase the value of in . Actually, if we assume
is subject to the uniform distribution, its variance equals to. Therefore, we design Adam-eps, which sets from a negligible value () to a non-negligible value (). Its performance is summarized in Figure 1. We observe that it does not suffer from the serious convergence problem of vanilla-Adam. This demonstrates that the convergence problem can be alleviated by reducing the variance of the adaptive learning rate, and also explains why tuning is important in practice (Liu et al., 2019). Besides, similar to Adam-2k, it prevents the gradient distribution from being distorted (as shown in Figure 3). However, as in Figure 1, it produces a much worse performance comparing to Adam-2k and Adam-warmup. We conjecture that this is because large induces a large bias into the adaptive learning rate and slows down the optimization process. Thus, we need a more principled and rigorous way to control the variance of the adaptive learning rate. In the next subsection, we will present a theoretical analysis of the variance of the adaptive learning rate.
3.2 Analysis of Adaptive Learning Rate Variance
As mentioned before, Adam uses the exponential moving average to calculate the adaptive learning rate. For gradients , their exponential moving average has a larger variance than their simple average. Also, in the early stage ( is small), the difference of the exponential weights of is relatively small (up to ). Therefore, for ease of analysis, we approximate the distribution of the exponential moving average as the distribution of the simple average (Nau, 2014), i.e., . Since , we have . Therefore, we assume also subjects to a scaled inverse chi-square distribution with degrees of freedom (further analysis on this approximation is conducted in Section 5.3). Based on this assumption, we have and the PDF of . Now, we proceed to the analysis of its square root variance, i.e., .
If , monotonically decreases as increases.
For ease of notation, we refer as and as . Thus, and:
where is the gamma function. Therefore, we have:
where is the Beta function. By analyzing the derivative of , we know it monotonically decreases as increases. The detailed derivation is elaborated in the Appendix A. ∎
Theorem 3.2 gives a qualitative analysis of the variance of the adaptive learning rate. It shows that, due to the lack of training samples in the early stage, is larger than the late stage (Figure 8). To rigorously constraint the variance, we perform a quantified analysis on by estimating the degree of freedoms .
4 Rectified Adaptive Learning Rate
In the previous section, Equation 4 gives the analytic form of , where is the degree of freedoms. Here, we first give an estimation of based on to conduct a quantified analysis for , then we describe the design of the learning rate rectification, and compare it to the heuristic warmup strategies.
4.1 Estimation of
As the exponential moving average (EMA) is widely used in economics, it is usually interpreted as an approximation to the simple moving average (SMA) (Nau, 2014), i.e.,
where is the length of the SMA which allows the SMA has the same “center of mass” with the EMA. In other words, satisfies:
By solving this equation, we have: . In the previous section, we assume: . Here, since , we have . Thus, Equation 5 views as an approximation to . Therefore, we treat as an estimation of . For ease of notation, we mark as . Also, we record as (maximum length of the approximated SMA), due to the inequality .
4.2 Variance Estimation and Rectification
Based on the previous estimation, we have . The value of this function in the early stage is significantly larger than the late stage (as analyzed later, it decays roughly at the speed of ). For example, the variance at is over times larger than the variance at . Additionally, based on Theorem 3.2, we know and mark this minimal value as . In order to ensure that the adaptive learning rate () has consistent variance, we rectify the variance at the -th timestamp as below,
Although we have the analytic form of (i.e., Equation 4), it is not numerically stable. Therefore, we use the first-order approximation to calculate the rectification term. Specifically, by approximating to the first order (Wolter, 2007),
Since , we have:
In Section 5.3, we conduct simulation experiments to examine Equation 6 and find that it is a reliable approximation. Also, we know that decreases approximately at the speed of . With this approximation, we can calculate the rectification term as:
Applying our rectification term to Adam, we come up with a new variant of Adam, RAdam, as summarized in Algorithm 2. Specifically, when the length of the approximated SMA is less or equal than 4, the variance of the adaptive learning rate is intractable and the adaptive learning rate is inactivated. Otherwise, we calculate the variance rectification term and update parameters with the adaptive learning rate. It is worth mentioning that, if , we have and RAdam is degenerated to SGD with momentum.
4.3 In Comparison with Warmup
We notice that has a similar form to the heuristic linear warmup, which can be viewed as setting the rectification term as
. It verifies our intuition that warmup works as a variance reduction technique. Comparing these two strategies, RAdam deactivates the adaptive learning rate when its variance is divergent, thus avoiding undesired instability in the first few updates. Besides, our method does not require an additional hyperparameter (i.e., ) to control the variance reduction and can automatically adapt to different moving average rules.
In this paper, we identify and fix an underlying issue of adaptive optimization methods instead of neural architectures. Thus, the proposed rectification term is orthogonal to other training stabilization techniques such as gradient clipping(Bengio et al., 2013), initialization (Balduzzi et al., 2017; Zhang et al., 2019) and normalization (Ba et al., 2016; Ioffe and Szegedy, 2015). Indeed, these techniques can be integrated with our proposed variance rectification. Specifically, since warmup is originally proposed to handle gradient variance for SGD (Goyal et al., 2017; Gotmare et al., 2019; Xiao et al., 2019), RAdam can also be integrated with the warmup heuristic to handle some extreme cases such as training with very large batches.
We evaluate RAdam on several benchmarks444The detailed hyperparameter settings are elaborated in the Appendix B
: One Billion Word for Language Modeling; Cifar10 and ImageNet for Image Classification. FollowingLoshchilov and Hutter (2017), we decouple weight decays in the vanilla Adam, Adam with warmup and RAdam in our experiments.
5.1 Comparing to Vanilla Adam
As analyzed before, the adaptive learning rate has undesirably large variance in the early stage of training and leads to suspicious/bad local optima on NMT. One question we are interested in answering is: whether such an issue widely exits in other similar tasks and applications. Thus, we conduct a set of experiments with two classical tasks of NLP and CV, i.e., language modeling and image classification. RAdam not only results in consistent improvements over the vanilla Adam, but also demonstrates its robustness to the change of learning rates. It verifies that the variance issue exists in various machine learning applications, and has a big impact on the model behavior. Detailed comparison and analysis are described as follows.
|Method||One Billion Word|
Performance Comparison. The performances on language modeling (i.e., One Billion Word 555Rare words that occur less than 3 times are replaced with a special token, the resulting dictionary is shrank from 7.9M to 6.4M. (Chelba et al., 2013)) and image classification (i.e., CIFAR10 (Krizhevsky et al., 2009) and ImageNet (Deng et al., 2009)) are summarized in Table 1 and Table 2, and their learning curves are presented in Figure 4 and Figure 5, respectively. The results show that RAdam outperforms Adam in all three datasets. As shown in Figure 4, although the rectification term makes RAdam slower than the vanilla Adam in the first few epochs, it allows RAdam to converge faster after that. In other words, by reducing the variance of the adaptive learning rate in the early stage, it gets both faster convergence and better performance, which verifies the impact of the variance issue. We also observe that RAdam obtains consistent improvements over Adam on image classification. It is worth noting that, on both ImageNet and CIFAR10, although RAdam fails to outperform SGD in terms of test accuracy, it results in a better training performance (e.g., the training accuracy of SGD, Adam, and RAdam on ImageNet are , and respectively).
Robustness to Learning Rate Change. Besides performance improvements, RAdam also improves the robustness of model training. We use different initial learning rates, conduct experiments with ResNet-20 on the CIFAR10 datasets, and summarize their performance in Figure 6. For learning rates within a broad range (i.e., ), RAdam achieves consistent model performances (their test accuracy curves highly overlap with each other), while Adam and SGD are shown to be sensitive to the learning rate. The observation can be interpreted that by rectifying the variance of the adaptive learning rate, RAdam improves the robustness of model training and can adapt to different learning rates of a broader range.
5.2 Comparing to Heuristic Warmup
To examine the effectiveness of RAdam, we first conduct comparisons on neural machine translation, on which the state-of-the-art employs Adam with the linear warmup. Specifically, we conduct experiments on three datasets, i.e., IWSLT’14 De-En, IWSLT’14 En-De, and WMT’16 En-De. Due to the limited size of the IWSLT’14 dataset, we conduct experiments using 5 different random seeds and report their mean and standard derivation. As discussed before, the vanilla Adam algorithm leads to suspicious/bad local optima (i.e., converges to a training perplexity around 500), and needs a learning rate warmup stage to stabilize the training.
We summarize the performance obtained with the heuristic warmup and our proposed rectification term in Table 3 and visualize the training curve of IWSLT De-En in Figure 1. With a consistent adaptive learning rate variance, our proposed method achieves similar performance to that of previous state-of-the-art warmup heuristics. It verifies our intuition that the problematic updates of Adam are indeed caused by the undesirably large variance in the early stage.
|Method||IWSLT’14 DE-EN||IWSLT’14 EN-DE||WMT’16 EN-DE|
|Adam with warmup|
Moreover, we applied Adam with warmup on the CIFAR10 dataset. Its best accuracy on the test set is , which is similar to RAdam (). However, we found that RAdam requires less hyperparameter tuning. Specifically, we visualize their learning curves in Figure 7. For some warmup steps, Adam with warmup is relatively more sensitive to the choice of the learning rate. RAdam, at the same time, is not only more robust, but also can automatically control the warmup behavior (i.e., without requiring the length of warmup). For example, when setting the learning rate as , Adam with 100 steps of warmup fails to get satisfying performance and only results in an accuracy of ; RAdam successfully gets a accuracy of , with the original setting of the moving average calculation (i.e., ). We conjecture the reason is due to the fact that RAdam, which is based on a rigorous variance analysis, explicitly avoids the extreme situation where the variance is divergent, and rectifies the variance to be consistent in other situations.
5.3 Simulated Verification
In Sections 3 and 4, we approximate to the first order, and assume subjects to a scaled inverse chi-square distribution (covers the approximation from EMA to SMA). In this section, we examine these two approximations using simulations.
First Order Approximation of . To compare Equations 6 and 4, we assume and plot their values and their difference for in Figure 8. The curve of the analytic form and the first-order approximation highly overlap, and their difference is much smaller than their value by more than an order of magnitude. This result verifies the reliability of our first-order approximation.
Scaled Inverse Chi-Square Distribution Assumption. In this paper, we assume accords to a Normal distribution with a zero mean. We also assume accords to the scaled inverse chi-square distribution to derive the variance of , based on the similarity between the exponential moving average and simple moving average. Here, we empirically verify this assumption.
Specifically, since in the optimization problem may not be zero-mean, we assume its expectation is and sample from . Then, based on these samples, we calculate the variance of the original adaptive learning rate and the proposed rectified adaptive learning rate, i.e., and respectively. We set to , the number of sampled trajectories to , the number of iterations to , and summarize the simulation results in Figure 9. Across all six settings with different , the adaptive learning rate has a larger variance in the first stage and the rectified adaptive learning rate has relative consistent variance. This verifies the reliability of our assumption.
In this paper, we explore the underlying principle of the effectiveness of the warmup heuristic used for adaptive optimization algorithms. Specifically, we identify that, due to the limited amount of samples in the early stage of model training, the adaptive learning rate has an undesirably large variance and can cause the model to converge to suspicious/bad local optima. We provide both empirical and theoretical evidence to support our hypothesis, and further propose a new variant of Adam, whose adaptive learning rate is rectified so as to have a consistent variance. Empirical results demonstrate the effectiveness of our proposed method. In future work, we plan to apply the proposed method to other applications such as Named Entity Recognition(Reimers and Gurevych, 2017; Lin et al., 2019). Another interesting direction to pursue is to adapt the choice of based on the variance estimation of different parameters, i.e., use a larger for parameters with a larger variance.
- Layer normalization. arXiv preprint arXiv:1607.06450. Cited by: §4.3.
- The shattered gradients problem: if resnets are the answer, then what is the question?. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 342–350. Cited by: §4.3, footnote 2.
- Advances in optimizing recurrent networks. In 2013 IEEE International Conference on Acoustics, Speech and Signal Processing, pp. 8624–8628. Cited by: §4.3.
- Méthode générale pour la résolution des systemes d’équations simultanées. Comp. Rend. Sci. Paris 25 (1847), pp. 536–538. Cited by: §1.
-  Report on the 11th iwslt evaluation campaign, iwslt 2014. Cited by: §3.1.
- One billion word benchmark for measuring progress in statistical language modeling. arXiv preprint arXiv:1312.3005. Cited by: §5.1.
- Imagenet: a large-scale hierarchical image database. In , pp. 248–255. Cited by: §5.1.
- Bert: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805. Cited by: §1.
Incorporating nesterov momentum into adam. Cited by: §1, §2.
- Adaptive subgradient methods for online learning and stochastic optimization. Journal of Machine Learning Research 12 (Jul), pp. 2121–2159. Cited by: §1.
- A closer look at deep learning heuristics: learning rate restarts, warmup and distillation. In International Conference on Learning Representations, External Links: Cited by: §4.3, footnote 3.
- Accurate, large minibatch sgd: training imagenet in 1 hour. arXiv preprint arXiv:1706.02677. Cited by: §4.3.
- Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778. Cited by: §B.2.
- Batch normalization: accelerating deep network training by reducing internal covariate shift. arXiv preprint arXiv:1502.03167. Cited by: §4.3.
- Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980. Cited by: §1.
- Learning multiple layers of features from tiny images. Technical report Citeseer. Cited by: §5.1.
- Reliability-aware dynamic feature composition for name tagging. In Proceedings of the 57th Conference of the Association for Computational Linguistics, pp. 165–174. Cited by: §6.
- Efficient contextualized representation: language model pruning for sequence labeling. arXiv preprint arXiv:1804.07827. Cited by: §B.1.
- RoBERTa: a robustly optimized bert pretraining approach. arXiv preprint arXiv:1907.11692. Cited by: §3.1.
- Fixing weight decay regularization in adam. arXiv preprint arXiv:1711.05101. Cited by: §5.
- Forecasting with moving averages. Cited by: §3.2, §4.1.
- Fairseq: a fast, extensible toolkit for sequence modeling. arXiv preprint arXiv:1904.01038. Cited by: §B.3.
- Training tips for the transformer model. The Prague Bulletin of Mathematical Linguistics 110 (1), pp. 43–70. Cited by: §1.
- On the convergence of adam and beyond. arXiv preprint arXiv:1904.09237. Cited by: §2.
- Optimal hyperparameters for deep lstm-networks for sequence labeling tasks. arXiv preprint arXiv:1707.06799. Cited by: §6.
- Rethinking the inception architecture for computer vision. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2818–2826. Cited by: §B.3.
Lecture 6.5-rmsprop, coursera: neural networks for machine learning. University of Toronto, Technical Report. Cited by: §1, §2.
- Attention is all you need. In Advances in neural information processing systems, pp. 5998–6008. Cited by: §B.3, §1.
- Taylor series methods. In Introduction to variance estimation, pp. 226–271. Cited by: §4.2.
- DSCOVR: randomized primal-dual block coordinate algorithms for asynchronous distributed optimization.. Journal of Machine Learning Research 20 (43), pp. 1–58. Cited by: §4.3.
- ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701. Cited by: §1.
- Fixup initialization: residual learning without normalization. arXiv preprint arXiv:1901.09321. Cited by: §4.3.
Appendix A Proof of Theorem 3.2
To prove the monotonic, we need to show
The target inequality can be re-wrote as
This inequality is equivalent to:
where is derived from Legendre duplication formula. Simplify the above inequality, we get:
We only need to show
where the first inequality is from .
Therefore, we only need to show
which is equivalent to
where is from Legendre duplication formula.
So we only need to show
Using Gautschi’s inequality (), we have
Appendix B Implementation Details
b.1 Language Modeling
Our implementation is based on the previous work (Liu et al., 2018). Specifically, we use two-layer LSTMs with 2048 hidden states with adaptive softmax to conduct experiments on the one billion words dataset. Word embedding (random initialized) of 300 dimensions is used as the input and the adaptive softmax is incorporated with a default setting (cut-offs are set to ). Additionally, as pre-processing, we replace all tokens occurring equal or less than 3 times with as UNK, which shrinks the dictionary from 7.9M to 6.4M. Dropout is applied to each layer with a ratio of , gradients are clipped at 5.0. We use the default hyper-parameters to update moving averages, i.e. and . The learning rate is set to start from 0.001, and decayed at the start of 10th epochs. LSTMs are unrolled for 20 steps without resetting the LSTM states and the batch size is set to 128. All models are trained on one NVIDIA Tesla V100 GPU.
b.2 Imageine Classification
We use the default ResNet architectures (He et al., 2016)
in a public pytorch re-implementation666https://github.com/bearpaw/pytorch-classification. Specifically, we use -layer ResNet ( Basic Blocks) for CIFAR-10 and 18-layer ResNet ( Basic Blocks) for ImageNet. Batch size is for CIFAR-10 and for ImageNet. The model is trained for epoches and the learning rate decays at the -th and the -th epoches by on CIFAR-10, while the model is trained for epoches and the learning rate decays at the -th and the -th epoch by on ImageNet. For Adam and RAdam, we set . For SGD, we set the momentum factor as . The weight decay rate is . Random cropping and random horizontal flipping are applied to training data.
b.3 Neural Machine Translation
Our experiments are based on the default Transformers (Vaswani et al., 2017) implementation from the fairseq package (Ott et al., 2019). Specifically, we use word embedding with 512 dimensions and 6-layer encoder / decoder with 4 head and 1024 hidden dimensions on the IWSLT14’ dataset; use word embedding with 512 dimension and 6-layer encoder / decoder with 8 heads and 2048 hidden dimensions. Label smoothed cross entropy is used as the objective function with an uncertainty (Szegedy et al., 2016). We use linear learning rate decay starting from , and the checkpoints of the last epoches are averaged before evaluation. As to the wamrup strategy, we use a linear warmup for Adam in the first updates, and set to satisfy (). In the IWSLT’14 dataset, we conduct training on one NVIDIA Tesla V100 GPU, set maximum batch size as , apply dropout with a ratio , using weight decay of and clip the gradient norm at . In the WMT’16 dataset, we conduct training on four NVIDIA Quadro R8000 GPUs and set maximum batch size as .