Normalized Direction-preserving Adam

09/13/2017 ∙ by Zijun Zhang, et al. ∙ University of Calgary The University of Hong Kong 0

Optimization algorithms for training deep models not only affects the convergence rate and stability of the training process, but are also highly related to the generalization performance of the models. While adaptive algorithms, such as Adam and RMSprop, have shown better optimization performance than stochastic gradient descent (SGD) in many scenarios, they often lead to worse generalization performance than SGD, when used for training deep neural networks (DNNs). In this work, we identify two problems of Adam that may degrade the generalization performance. As a solution, we propose the normalized direction-preserving Adam (ND-Adam) algorithm, which combines the best of both worlds, i.e., the good optimization performance of Adam, and the good generalization performance of SGD. In addition, we further improve the generalization performance in classification tasks, by using batch-normalized softmax. This study suggests the need for more precise control over the training process of DNNs.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In contrast with the growing complexity of neural network architectures (Szegedy et al., 2015; He et al., 2016; Hu et al., 2017)

, the training methods remain relatively simple. Most practical optimization methods for deep neural networks (DNNs) are based on the stochastic gradient descent (SGD) algorithm. However, the learning rate of SGD, as a hyperparameter, is often difficult to tune, since the magnitudes of different parameters can vary widely, and adjustment is required throughout the training process.

To tackle this problem, several adaptive variants of SGD have been developed, including Adagrad (Duchi et al., 2011), Adadelta (Zeiler, 2012), RMSprop (Tieleman & Hinton, 2012), Adam (Kingma & Ba, 2014), etc. These algorithms aim to adapt the learning rate to different parameters automatically, by normalizing the global learning rate based on historical statistics of the gradient w.r.t. each parameter. Although these algorithms can usually simplify learning rate settings, and lead to faster convergence, it is observed that their generalization performance tend to be significantly worse than that of SGD in some scenarios (Wilson et al., 2017). This intriguing phenomenon may explain why SGD (possibly with momentum) is still prevalent in training state-of-the-art deep models, especially feedforward DNNs (Szegedy et al., 2015; He et al., 2016; Hu et al., 2017). Furthermore, recent work has shown that DNNs are capable of fitting noise data (Zhang et al., 2017), suggesting that their generalization capabilities are not the mere result of DNNs themselves, but are entwined with optimization (Arpit et al., 2017).

This work aims to fill the gap between SGD and its adaptive variants. To this end, we identify two problems of Adam that may degrade the generalization performance, and show how these problems are (partially) avoided by using SGD with L2 weight decay. The first problem lies in the fact that the directions of Adam parameter updates are different from that of SGD, i.e., Adam does not preserve the directions of gradients as SGD does. This difference has been discussed in rather recent literature (Wilson et al., 2017), where the authors show that adaptive methods can find drastically different solutions than SGD in some cases. Secondly, while the magnitudes of Adam parameter updates are invariant to rescaling of the gradient, the effect of the updates on the same

overall network function still varies with the magnitudes of parameters. As we show, however, this problem can be partially avoided by using SGD with L2 weight decay, which implicitly normalizes the weight vectors, such that the magnitude of each vector’s direction change does not depend on its L2-norm.

Next, we propose the normalized direction-preserving Adam (ND-Adam) algorithm, which preserves the direction of the gradient w.r.t. each weight vector, and incorporates a special form of weight normalization (Salimans & Kingma, 2016). By using ND-Adam, we are able to achieve significantly better generalization performance than vanilla Adam, and at the same time, obtain much lower training loss at convergence, compared to SGD with L2 weight decay.

Furthermore, we find that the learning signal backpropagated from the softmax layer varies with the overall magnitude of the logits, without proper control. Based on the observation, we apply batch normalization to the logits with a single tunable scaling factor, which further improves the generalization performance in classification tasks.

In essence, our proposed methods, ND-Adam and batch-normalized softmax, enable more precise control over the directions of parameter updates, the learning rates, and the learning signals.

2 Background and Motivation

2.1 Adaptive Moment Estimation (Adam)

Adaptive moment estimation (Adam)

(Kingma & Ba, 2014) is a stochastic optimization method that applies individual adaptive learning rates to different parameters, based on the estimates of the first and second moments of the gradients. Specifically, for trainable parameters, , Adam maintains a running average of the first and second moments of the gradient w.r.t. each parameter as

(1a)
(1b)

Here, denotes the time step, and denote respectively the first and second moments, and and are the corresponding decay factors. Kingma & Ba (2014) further notice that, since and are initialized to ’s, they are biased towards zero during the initial time steps, especially when the decay factors are large (i.e., close to ). Thus, for computing the next update, they need to be corrected as

(2)

where , are the -th powers of , respectively. Then, we can update each parameter as

(3)

where is the global learning rate, and is a small constant to avoid division by zero. Note the above computations between vectors are element-wise.

A distinguishing merit of Adam is that the magnitudes of parameter updates are invariant to rescaling of the gradient, as shown by the adaptive learning rate term, . However, there are two potential problems when applying Adam to DNNs.

First, in some scenarios, DNNs trained with Adam generalize worse than that trained with stochastic gradient descent (SGD) (Wilson et al., 2017). Zhang et al. (2017) demonstrate that over-parameterized DNNs are capable of memorizing the entire dataset, no matter if it is natural data or meaningless noise data, and thus suggest much of the generalization power of DNNs comes from the training algorithm, e.g., SGD and its variants. It coincides with another recent work (Wilson et al., 2017), which shows that simple SGD often yields better generalization performance than adaptive gradient methods, such as Adam. As pointed out by the latter, the difference in the generalization performance may result from the different directions of updates. Specifically, for each hidden unit, the SGD update of its input weight vector can only lie in the span of all possible input vectors, which, however, is not the case for Adam due to the individually adapted learning rates. We refer to this problem as the direction missing problem.

Second, while batch normalization (Ioffe & Szegedy, 2015)

can significantly accelerate the convergence of DNNs, the input weights and the scaling factor of each hidden unit can be scaled in infinitely many (but consistent) ways, without changing the function implemented by the hidden unit. Thus, for different magnitudes of an input weight vector, the updates given by Adam can have different effects on the overall network function, which is undesirable. Furthermore, even when batch normalization is not used, a network using linear rectifiers (e.g., ReLU, leaky ReLU) as activation functions, is still subject to ill-conditioning of the parameterization

(Glorot et al., 2011), and hence the same problem. We refer to this problem as the ill-conditioning problem.

2.2 L2 Weight Decay

L2 weight decay is a regularization technique frequently used with SGD. It often has a significant effect on the generalization performance of DNNs. Despite the simplicity and crucial role of L2 weight decay in the training process, it remains to be explained how it works in DNNs. A common justification for L2 weight decay is that it can be introduced by placing a Gaussian prior upon the weights, when the objective is to find the maximum a posteriori (MAP) weights (Blundell et al., 2015). However, as discussed in Sec. 2.1, the magnitudes of input weight vectors are irrelevant in terms of the overall network function, in some common scenarios,

rendering the variance of the Gaussian prior meaningless

.

We propose to view L2 weight decay in neural networks as a form of weight normalization, which may better explain its effect on the generalization performance. Consider a neural network trained with the following loss function:

(4)

where is the original loss function specified by the task, is a batch of training data, is the set of all hidden units, and denotes the input weights of hidden unit , which is included in the trainable parameters, . For simplicity, we consider SGD updates without momentum. Therefore, the update of at each time step is

(5)

where is the step size. As we can see from Eq. (5), the gradient magnitude of the L2 penalty is proportional to , thus forms a negative feedback loop that stabilizes to an equilibrium value. Empirically, we find that tends to increase or decrease dramatically at the beginning of the training, and then varies mildly within a small range, which indicates . In practice, we usually have , thus is approximately orthogonal to , i.e. .

Let and be the vector projection and rejection of on , which are defined as

(6)

From Eq. (5) and (6), it is easy to show

(7)

As discussed in Sec. 2.1, when batch normalization is used, or when linear rectifiers are used as activation functions, the magnitude of is irrelevant. Thus, it is the direction of that actually makes a difference in the overall network function. If L2 weight decay is not applied, the magnitude of ’s direction change will decrease as increases during the training process, which can potentially lead to overfitting (discussed in detail in Sec. 3.2). On the other hand, Eq. (7) shows that L2 weight decay implicitly normalizes the weights, such that the magnitude of ’s direction change does not depend on , and can be tuned by the product of and .

While L2 weight decay produces the normalization effect in an implicit and approximate way, we will show that explicitly doing so can result in further improved optimization and generalization performance.

3 Normalized Direction-preserving Adam

We first present the normalized direction-preserving Adam (ND-Adam) algorithm, which essentially improves the optimization of the input weights of hidden units, while employing the vanilla Adam algorithm to update other parameters. Specifically, we divide the trainable parameters, , into two sets, and , such that , and . Then we update and by different rules, as described by Alg. 1. The learning rates for the two sets of parameters are denoted respectively by and .

/* Initialization */
;
for  do
       ;
       ;
       ;
      
/* Perform iterations of training */
while  do
       ;
       /* Update */
       for  do
             ;
             ;
             ;
             ;
             ;
             ;
             ;
             ;
            
      /* Update using Adam */
       ;
      
return ;
Algorithm 1 Normalized direction-preserving Adam

In Alg. 1, the iteration over can be performed in parallel, and thus introduces no extra computational complexity. Compared to Adam, computing and may take slightly more time, which, however, is negligible in practice. On the other hand, to estimate the second order moment of each , Adam maintains scalars, whereas ND-Adam requires only one scalar, . Thus, ND-Adam has smaller memory overhead than Adam.

In the following, we address the direction missing problem and the ill-conditioning problem discussed in Sec. 2.1, and explain Alg. 1 in detail. We show how the proposed algorithm jointly solves the two problems, as well as its relation to other normalization schemes.

3.1 Preserving Gradient Directions

Assuming the stationarity of a hidden unit’s input distribution, the SGD update (possibly with momentum) of the input weight vector is a linear combination of historical gradients, and thus can only lie in the span of the input vectors. As a result, the input weight vector itself will eventually converge to the same subspace.

On the contrary, the Adam algorithm adapts the global learning rate to each scalar parameter independently, such that the gradient of each parameter is normalized by a running average of its magnitudes, which changes the direction of the gradient. To preserve the direction of the gradient w.r.t. each input weight vector, we generalize the learning rate adaptation scheme from scalars to vectors.

Let , , be the counterparts of , , for vector . Since Eq. (1a) is a linear combination of historical gradients, it can be extended to vectors without any change; or equivalently, we can rewrite it for each vector as

(8)

We then extend Eq. (1b) as

(9)

i.e., instead of estimating the average gradient magnitude for each individual parameter, we estimate the average of for each vector . In addition, we modify Eq. (2) and (3) accordingly as

(10)

and

(11)

Here, is a vector with the same dimension as , whereas is a scalar. Therefore, when applying Eq. (11), the direction of the update is the negative direction of , and thus is in the span of the historical gradients of .

It is worth noting that only the input to the first layer (i.e., the training data) is stationary throught training. Thus, for the weights of an upper layer to converge to the span of its input vectors, it is necessary for the lower layers to converge first. Interestingly, this predicted phenomenon may have been observed in practice (Brock et al., 2017).

Despite the empirical success of SGD, a question remains as to why it is desirable to constrain the input weights in the span of the input vectors. A possible explanation is related to the manifold hypothesis, which suggests that real-world data presented in high dimensional spaces (images, audios, text, etc) concentrates on manifolds of much lower dimensionality

(Cayton, 2005; Narayanan & Mitter, 2010). In fact, commonly used activation functions, such as (leaky) ReLU, sigmoid, tanh, can only be activated (not saturating or having small gradients) by a portion of the input vectors, in whose span the input weights lie upon convergence. Assuming the local linearity of the manifolds of data or hidden-layer representations, constraining the input weights in the subspace that contains some of the input vectors, encourages the hidden units to form local coordinate systems on the corresponding manifold, which can lead to good representations (Rifai et al., 2011).

3.2 Spherical Weight Optimization

The ill-conditioning problem occurs when the magnitude change of an input weight vector can be compensated by other parameters, such as the scaling factor of batch normalization, or the output weight vector, without affecting the overall network function. Consequently, suppose we have two DNNs that parameterize the same function, but with some of the input weight vectors having different magnitudes, applying the same SGD or Adam update rule will, in general, change the network functions in different ways. Thus, the ill-conditioning problem makes the training process more opaque and difficult to control.

More importantly, when the weights are not properly regularized (e.g., not using L2 weight decay), the magnitude of ’s direction change will decrease as increases during the training process. As a result, the “effective” learning rate for tends to decrease faster than expected, making the network converge to sharp minima (Hoffer et al., 2017). It is well known that sharp minima generalize worse than flat minima (Hochreiter & Schmidhuber, 1997; Keskar et al., 2017).

As shown in Sec. 2.2, L2 weight decay can alleviate the ill-conditioning problem by implicitly and approximately normalizing the weights. However, we still do not have a precise control over , since is unknown and not necessarily stable. Moreover, the approximation fails when is far from the equilibrium.

To address the ill-conditioning problem in a more principled way, we restrict the L2-norm of each to , and only optimize its direction. In other words, instead of optimizing in a -dimensional space, we optimize on a -dimensional unit sphere. Specifically, we first obtain the raw gradient w.r.t. , , and project the gradient onto the unit sphere as

(12)

Here, . Then we follow Eq. (8)-(10), and replace (11) with

(13a)
and
(13b)

In Eq. (12), we keep only the component that is orthogonal to . However, is not necessarily orthogonal as well. In addition, even when is orthogonal to , Eq. (13a) can still increase , according to the Pythagorean theorem. Therefore, we explicitly normalize in Eq. (13b), to ensure after each update.

We now have

(14)

thus we can control the learning rate for through a single hyperparameter, . Note that it is possible to control more precisely, by normalizing by , instead of by . However, by doing so, we lose the information provided by at different time steps. In addition, since is less noisy than , becomes small near convergence, which is considered a desirable property of Adam (Kingma & Ba, 2014). Thus, we keep the gradient normalization scheme intact.

Compared to L2 weight decay, spherical weight optimization explicitly normalizes the weight vectors, such that each update to the weight vectors only changes their directions, and strictly keeps the magnitudes constant. For nonlinear activation functions, such as sigmoid and tanh, an extra scaling factor is needed for each hidden unit to express functions that require unnormalized weight vectors. For instance, given an input vector , and a nonlinearity , the activation of hidden unit is then given by

(15)

where is the scaling factor, and is the bias.

3.3 Relation to Weight Normalization and Batch Normalization

A related normalization and reparameterization scheme, weight normalization (Salimans & Kingma, 2016), has been developed as an alternative to batch normalization, aiming to accelerate the convergence of SGD optimization. We note the difference between spherical weight optimization and weight normalization. First, the weight vector of each hidden unit is not directly normalized in weight normalization, i.e, in general. At training time, the activation of hidden unit is

(16)

which is equivalent to Eq. (15) for the forward pass. For the backward pass, still depends on in weight normalization, hence it does not solve the ill-conditioning problem. At inference time, both of these two schemes can combine and into a single equivalent weight vector, , or .

While spherical weight optimization naturally encompasses weight normalization, it can further benefit from batch normalization. When combined with batch normalization, Eq. (15) evolves into

(17)

where represents the transformation done by batch normalization without scaling and shifting. Here, serves as the scaling factor for both the normalized weight vector and batch normalization. At training time, the distribution of the input vector, , changes over time, slowing down the training of the sub-network composed by the upper layers. Salimans & Kingma (2016) observe that, such problem cannot be eliminated by normalizing the weight vectors alone, but can be substantially mitigated by combining weight normalization and mean-only batch normalization.

Additionally, in linear rectifier networks, the scaling factors, , can be removed (or set to ), without changing the overall network function. Since is standardized by batch normalization, we have

(18)

and hence

(19)

Therefore, ’s that belong to the same layer, or different dimensions of that fed to the upper layer, will also have comparable variances, which potentially makes the weight updates of the upper layer more stable. For these reasons, we combine the use of spherical weight optimization and batch normalization, as shown in Eq. (17).

4 Batch-normalized Softmax

For multi-class classification tasks, the softmax function is the de facto activation function for the output layer. Despite its simplicity and intuitive probabilistic interpretation, the learning signal it backpropagates may not always be desirable.

When using cross entropy as the surrogate loss with one-hot target vectors, the prediction is considered correct as long as is the target class, where is the logit before the softmax activation, corresponding to category . Thus, the logits can be positively scaled together without changing the predictions, even though the cross entropy and its derivatives will vary with the scaling factor. Specifically, denoting the scaling factor by , the gradient w.r.t. each logit is

(20a)
and
(20b)

where is the target class, and .

For Adam and ND-Adam, since the gradient w.r.t. each scalar or vector are normalized, the absolute magnitudes of Eq. (20a) and (20b) are irrelevant. Instead, the relative magnitudes make a difference here. When is small, we have

(21)

which indicates that, when the magnitude of the logits is small, softmax encourages the logit of the target class to increase, while equally penalizing that of other classes. On the other end of the spectrum, assuming no two digits are the same, we have

(22)

where , and . Eq. (22) indicates that, when the magnitude of the logits is large, softmax penalizes only the largest logit of the non-target classes. The latter is also referred to as the saturation problem of softmax in the literature (Oland et al., 2017).

It is worth noting that both of these two cases can happen without the scaling factor. For instance, varying the norm of the weights of the softmax layer is equivalent to varying the value of , in terms of the relative magnitude of the gradient. In the case of small , the logits of all non-target classes are penalized equally, regardless of the difference in for different . However, it is more reasonable to penalize more the logits that are closer to , which are more likely to cause misclassification. In the case of large , although the logit that is most likely to cause misclassification is strongly penalized, the logits of other non-target classes are ignored. As a result, the logits of the non-target classes tend to be similar at convergence, ignoring the fact that some classes are closer to each other than the others.

To exploit the prior knowledge that the magnitude of the logits should not be too small or too large, we apply batch normalization to the logits. Nevertheless, instead of setting ’s as trainable variables, we consider them as a single hyperparameter, , such that . Tuning the value of can lead to a better trade-off between the two cases described by Eq. (21) and (22). We refer to this method as batch-normalized softmax (BN-Softmax).

5 Experiments

In this section, we provide empirical evidence for the analysis in Sec. 2.2, and evaluate the performance of ND-Adam and BN-Softmax on CIFAR-10 and CIFAR-100.

5.1 The Effect of L2 Weight Decay

To empirically examine the effect of L2 weight decay, we train a wide residual network (WRN) (Zagoruyko & Komodakis, 2016) of layers, with a width of times that of a vanilla ResNet. Using the notation in Zagoruyko & Komodakis (2016), we refer to this network as WRN--. We train the network on the CIFAR-10 dataset (Krizhevsky & Hinton, 2009), with a small modification to the original WRN architecture, and with a different learning rate annealing schedule. Specifically, for simplicity and slightly better performance, we replace the last fully connected layer with a convolutional layer with output feature maps. I.e., we change the layers after the last residual block from BN-ReLU-GlobalAvgPool-FC-Softmax to BN-ReLU-Conv-GlobalAvgPool-Softmax. In addition, for clearer comparisons, the learning rate is annealed according to a cosine function without restart (Loshchilov & Hutter, 2016; Gastaldi, 2017).

As a common practice, we use SGD with a momentum of , the analysis for which is similar to that in Sec. 2.2. Due to the linearity of derivatives and momentum, can be decomposed as , where and are the components corresponding to the original loss function, , and the L2 penalty term (see Eq. (4)), respectively. Fig. (a)a shows the ratio between the scalar projection of on and , which indicates how the tendency of to increase is compensated by . Note that points to the negative direction of , even when momentum is used, since the direction change of is slow. As shown in Fig. (a)a, at the beginning of the training, dominants and quickly adjusts to its equilibrium value. During the middle stage of the training, the projection of on , and almost cancel each other out. Then, near the end of the training, the gradient of diminishes rapidly to near zero, making dominant again. Therefore, Eq. (7) holds more accurately during the middle stage of the training.

In Fig. (b)b, we show how the value of varies in different hyperparameter settings. By Eq. (7), is expected to remain the same as long as stays constant, which is confirmed by the fact that the curve for overlaps with that for . However, comparing the curve for , with that for , we can see that the value of does not change proportional to . On the other hand, by using ND-Adam, we can control the value of more precisely by adjusting the learning rate for weight vectors, . For the same training step, changes in lead to approximately proportional changes in , as shown by the two curves corresponding to ND-Adam in Fig. (b)b.

(a) The scalar projection of on normalized by .
(b) The relative magnitude of the weight updates.
Figure 1: An illustration of how L2 weight decay and ND-Adam control the relative magnitude of the weight updates. The results are obtained from the th layer of the network, and other layers show similar results.

5.2 Performance Evaluation

To compare the optimization and generalization performance of SGD, Adam, and ND-Adam, we train the same WRN-- network on the CIFAR-10 and CIFAR-100 datasets. For SGD and ND-Adam, we first tune the hyperparameters for SGD (, momentum ), then tune the initial learning rate of ND-Adam for weight vectors to match the relative magnitude of the weight updates to that of SGD (), as shown in Fig. (b)b. While L2 weight decay can greatly affect the performance of SGD, it does not noticeably benefit Adam in our experiments. For Adam and ND-Adam, and are set to the default values of Adam, i.e., , . Although the learning rate of Adam is usually set to a constant value, we observe better performance with the cosine decay scheme. The initial learning rate of Adam (), and that of ND-Adam for scalar parameters () are both tuned to . We use the same data augmentation scheme as used in Zagoruyko & Komodakis (2016), including horizontal flips and random crops, but no dropout is used.

We first experiment with the use of trainable scaling parameters () of batch normalization. As shown in Fig. (a)a, ND-Adam converges to training losses comparable to that of Adam, which are much lower than that of SGD. More importantly, as shown in Fig. (b)b, the test accuracies of ND-Adam are significantly improved upon vanilla Adam, and matches that of SGD. The average results of 3 runs are summarized in the first part of Table 1. Interestingly, compared to SGD, ND-Adam shows slightly better performance on CIFAR-10, but worse performance on CIFAR-100. This inconsistency may be related to the problem of softmax discussed in Sec. 4, that there is a lack of proper control over the magnitude of the logits.

(a) The training losses on CIFAR-10/100.
(b) The test accuracies on CIFAR-10/100.
Figure 2: The training losses and test accuracies of the same network trained with SGD, Adam, and ND-Adam. Batch normalization with scaling factors is used.

Next, we repeat the experiments with the use of BN-Softmax. As discussed in Sec. 3.2, ’s can be removed from a linear rectifier network, without changing the overall network function. Although this property does not strictly hold for residual networks due to the skip connections, we find that simply removing the scaling factors results in slightly improved generalization performance when using ND-Adam. However, the improvement is not consistent as it degrades performance of SGD. Interestingly, when BN-Softmax is used, we observe consistent improvement over all three algorithms. Thus, we only report results for this setting.

The scaling factor of the logits, , is set to for CIFAR-10, and for CIFAR-100. As shown in Fig. (a)a, the training losses of Adam and ND-Adam again are much lower than that of SGD, although they are increased due to the regularization effect of BN-Softmax. As shown in the second part of Table 1, BN-Softmax significantly improves the performance of Adam and ND-Adam. Moreover, in this setting, we obtain the best generalization performance with ND-Adam, outperforming SGD and Adam on both CIFAR-10 and CIFAR-100.

(a) The training losses on CIFAR-10/100.
(b) The test accuracies on CIFAR-10/100.
Figure 3: The training losses and test accuracies of the same network trained with SGD, Adam, and ND-Adam. Batch normalization without scaling factors, and BN-Softmax are used.
Method CIFAR-10 Error (%) CIFAR-100 Error (%)
BN w/ scaling factors
SGD 4.61 20.60
Adam 6.14 25.51
ND-Adam 4.53 21.45
BN w/o scaling factors, BN-Softmax
SGD 4.49 20.18
Adam 5.43 22.48
ND-Adam 4.14 19.90
Table 1: Test error rates on CIFAR-10 and CIFAR-100.

6 Conclusion

In this paper, we introduced the normalized direction-preserving Adam algorithm, which is a tailored version of Adam for training DNNs. We showed that ND-Adam implements the normalization effect of L2 weight decay in a more principled way, and that it combines the good optimization performance of Adam, with the good generalization performance of SGD. In addition, we introduced batch-normalized softmax, which regularizes the logits before the softmax activation, in order to provide better learning signals. We showed significantly improved generalization performance by combining ND-Adam and BN-Softmax. From a high-level view, our proposed methods and empirical results suggest the need for more precise control over the training process of DNNs.

References