Averaging Weights Leads to Wider Optima and Better Generalization

03/14/2018 ∙ by Pavel Izmailov, et al. ∙ 0

Deep neural networks are typically trained by optimizing a loss function with an SGD variant, in conjunction with a decaying learning rate, until convergence. We show that simple averaging of multiple points along the trajectory of SGD, with a cyclical or constant learning rate, leads to better generalization than conventional training. We also show that this Stochastic Weight Averaging (SWA) procedure finds much broader optima than SGD, and approximates the recent Fast Geometric Ensembling (FGE) approach with a single model. Using SWA we achieve notable improvement in test accuracy over conventional SGD training on a range of state-of-the-art residual networks, PyramidNets, DenseNets, and Shake-Shake networks on CIFAR-10, CIFAR-100, and ImageNet. In short, SWA is extremely easy to implement, improves generalization, and has almost no computational overhead.

READ FULL TEXT VIEW PDF

Authors

page 2

page 5

Code Repositories

swa

Stochastic Weight Averaging in PyTorch


view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

With a better understanding of the loss surfaces for multilayer networks, we can accelerate the convergence, stability, and accuracy of training procedures in deep learning. Recent work

(Garipov et al., 2018; Draxler et al., 2018) shows that local optima found by SGD can be connected by simple curves of near constant loss. Building upon this insight, Garipov et al. (2018) also developed Fast Geometric Ensembling (FGE) to sample multiple nearby points in weight space to create high performing ensembles in the time required to train a single DNN.

Figure 1: Illustrations of SWA and SGD with a Preactivation ResNet- on CIFAR-222

Suppose we have three weight vectors

. We set , . Then the normalized vectors , form an orthonormal basis in the plane containing . To visualize the loss in this plane, we define a Cartesian grid in the basis and evaluate the networks corresponding to each of the points in the grid. A point with coordinates in the plane would then be given by .
. Left: test error surface for three FGE samples and the corresponding SWA solution (averaging in weight space). Middle and Right

: test error and train loss surfaces showing the weights proposed by SGD (at convergence) and SWA, starting from the same initialization of SGD after 125 training epochs.

FGE uses a high frequency cyclical learning rate with SGD to select networks to ensemble. In Figure 2 (left) we see that the weights of the networks ensembled by FGE are on the periphery of the most desirable solutions. This observation suggests it is promising to average these points in weight space, and use a network with these averaged weights, instead of forming an ensemble by averaging the outputs of networks in model space. Although the general idea of maintaining a running average of weights traversed by SGD dates back to Ruppert (1988), this procedure is not typically used to train neural networks. It is sometimes applied as an exponentially decaying running average in combination with a decaying learning rate (where it is called an exponential moving average), which smooths the trajectory of conventional SGD but does not perform very differently. However, we show that an equally weighted average of the points traversed by SGD with a cyclical or high constant learning rate, which we refer to as Stochastic Weight Averaging (SWA), has many surprising and promising features for training deep neural networks, leading to a better understanding of the geometry of their loss surfaces. Indeed, SWA with cyclical or constant learning rates can be used as a drop-in replacement for standard SGD training of multilayer networks — but with improved generalization and essentially no overhead. In particular:

  • We show that SGD with cyclical (e.g., Loshchilov and Hutter, 2017) and constant learning rates traverses regions of weight space corresponding to high-performing networks. We find that while these models are moving around this optimal set they never reach its central points. We show that we can move into this more desirable space of points by averaging the weights proposed over SGD iterations.

  • While FGE ensembles (Garipov et al., 2018) can be trained in the same time as a single model, test predictions for an ensemble of models requires times more computation. We show that SWA can be interpreted as an approximation to FGE ensembles but with the test-time, convenience, and interpretability of a single model.

  • We demonstrate that SWA leads to solutions that are wider than the optima found by SGD. Keskar et al. (2017) and Hochreiter and Schmidhuber (1997) conjecture that the width of the optima is critically related to generalization. We illustrate that the loss on the train is shifted with respect to the test error (Figure 2, middle and right panels, and sections 3, 4). We show that SGD generally converges to a point near the boundary of the wide flat region of optimal points. SWA on the other hand is able to find a point centered in this region, often with slightly worse train loss but with substantially better test error.

  • We show that the loss function is asymmetric in the direction connecting SWA with SGD. In this direction, SGD is near the periphery of sharp ascent. Part of the reason SWA improves generalization is that it finds solutions in flat regions of the training loss in such directions.

  • SWA achieves notable improvement for training a broad range of architectures over several consequential benchmarks. In particular, running SWA for just epochs on ImageNet we are able to achieve improvement for ResNet- and DenseNet-, and improvement for ResNet-. We achieve improvement of over on CIFAR- and of over on CIFAR- with Preactivation ResNet-, VGG- and Wide ResNet--. We also achieve substantial improvement for the recent Shake-Shake Networks and PyramidNets.

  • SWA is extremely easy to implement and has virtually no computational overhead compared to the conventional training schemes.

  • We provide an implementation of SWA at
    https://github.com/timgaripov/swa.

We emphasize that SWA is finding a solution in the same basin of attraction as SGD, as can be seen in Figure 2, but in a flatter region of the training loss. SGD typically finds points on the periphery of a set of good weights. By running SGD with a cyclical or high constant learning rate, we traverse the surface of this set of points, and by averaging we find a more centred solution in a flatter region of the training loss. Further, the training loss for SWA is often slightly worse than for SGD suggesting that SWA solution is not a local optimum of the loss. In the title of this paper, optima is used in a general sense to mean solutions (converged points of a given procedure), rather than different local minima of the same objective.

2 Related Work

This paper is fundamentally about better understanding the geometry of loss surfaces and generalization in deep learning. We follow the trajectory of weights traversed by SGD, leading to new geometric insights and the intuition that SWA will lead to better results than standard training. Empirically, we make the discovery that SWA notably improves training of many state-of-the-art deep neural networks over a range of consequential benchmarks, with essentially no overhead.

The procedures for training neural networks are constantly being improved. New methods are being proposed for architecture design, regularization and optimization. The SWA approach is related to work in both optimization and regularization.

In optimization, there is great interest in how different types of local solutions affect generalization in deep learning. Keskar et al. (2017) claim that SGD is more likely to converge to broad local optima than batch gradient methods, which tend to converge to sharp optima. Moreover, they argue that the broad optima found by SGD are more likely to have good test performance, even if the training loss is worse than for the sharp optima. On the other hand Dinh et al. (2017) argue that all the known definitions of sharpness are unsatisfactory and cannot on their own explain generalization. Chaudhari et al. (2017) propose the Entropy-SGD method that explicitly forces optimization towards wide valleys. They report that although the optima found by Entropy-SGD are wider than those found by conventional SGD, the generalization performance is still comparable.

The SWA method is based on averaging multiple points along the trajectory of SGD with cyclical or constant learning rates. The general idea of maintaining a running average of weights proposed by SGD was first considered in convex optimization by Ruppert (1988) and later by Polyak and Juditsky (1992). However, this procedure is not typically used to train neural networks. Practitioners instead sometimes use an exponentially decaying running average of the weights found by SGD with a decaying learning rate, which smooths the trajectory of SGD but performs comparably.

SWA is making use of multiple samples gathered through exploration of the set of points corresponding to high performing networks. To enforce exploration we run SGD with constant or cyclical learning rates. Mandt et al. (2017)

show that under several simplifying assumptions running SGD with a constant learning rate is equivalent to sampling from a Gaussian distribution centered at the minimum of the loss, and the covariance of this Gaussian is controlled by the learning rate. Following this explanation from

(Mandt et al., 2017), we can interpret points proposed by SGD as being constrained to the surface of a sphere, since they come from a high dimensional Gaussian distribution. SWA effectively allows us to go inside the sphere to find higher density solutions.

In a procedure called Fast Geometric Ensembling (FGE), Garipov et al. (2018) showed that using a cyclical learning rate it is possible to gather models that are spatially close to each other but produce diverse predictions. They used the gathered models to train ensembles with no computational overhead compared to training a single DNN model. In recent work Neklyudov et al. (2018) also discuss an efficient approach for model averaging of Bayesian neural networks. SWA was inspired by following the trajectories of FGE proposals, in order to find a single model that would approximate an FGE ensemble, but provide greater interpretability, convenience, and test-time scalability.

Dropout (Srivastava et al., 2014)

is an extremely popular approach to regularizing DNNs. Across each mini-batch used for SGD, a different architecture is created by randomly dropping out neurons. The authors make analogies between dropout, ensembling, and Bayesian model averaging. At test time, an ensemble approach is proposed, but then approximated with similar results by multiplying each connection by the dropout rate. At a high level, SWA and Dropout are both at once regularizers and training procedures, motivated to approximate an ensemble. Each approach implements these high level ideas quite differently, and as we show in our experiments, can be combined for improved performance.

3 Stochastic Weight Averaging

We present Stochastic Weight Averaging (SWA) and analyze its properties. In section 3.1, we consider trajectories of SGD with a constant and cyclical learning rate, which helps understand the geometry of SGD training for neural networks, and motivates the SWA procedure. Then in section 3.2 we present the SWA algorithm in detail, in section 3.3 we derive its complexity, and in section 3.4 we analyze the width of solutions found by SWA versus conventional SGD training. In section 3.5 we then examine the relationship between SWA and the recently proposed Fast Geometric Ensembling (Garipov et al., 2018). Finally, in section 3.6 we consider SWA from the perspective of stochastic convex optimization.

We note the name SWA has two meanings: on the one hand, it is an average of SGD weights. On the other, with a cyclical or constant learning rate, SGD proposals are approximately sampling from the loss surface of the DNN, leading to stochastic weights.

3.1 Analysis of Sgd Trajectories

SWA is based on averaging the samples proposed by SGD using a learning rate schedule that allows exploration of the region of weight space corresponding to high-performing networks. In particular we consider cyclical and constant learning rate schedules.

The cyclical learning rate schedule that we adopt is inspired by Garipov et al. (2018) and Smith and Topin (2017). In each cycle we linearly decrease the learning rate from to . The formula for the learning rate at iteration is given by

The base learning rates and the cycle length are the hyper-parameters of the method. Here by iteration we assume the processing of one batch of data. Figure 2 illustrates the cyclical learning rate schedule and the test error of the corresponding points. Note that unlike the cyclical learning rate schedule of Garipov et al. (2018) and Smith and Topin (2017), here we propose to use a discontinuous schedule that jumps directly from the minimum to maximum learning rates, and does not steadily increase the learning rate as part of the cycle. We use this more abrupt cycle because for our purposes exploration is more important than the accuracy of individual proposals. For even greater exploration, we also consider constant learning rates .

Figure 2: Top: cyclical learning rate as a function of iteration. Bottom: test error as a function of iteration for cyclical learning rate schedule with Preactivation-ResNet- on CIFAR-. Circles indicate iterations corresponding to the minimum learning rates.
Figure 3: The -regularized cross-entropy train loss and test error surfaces of a Preactivation ResNet- on CIFAR- in the plane containing the first, middle and last points (indicated by black crosses) in the trajectories with (left two) cyclical and (right two) constant learning rate schedules.

We run SGD with cyclical and constant learning rate schedules starting from a pretrained point for a Preactivation ResNet- on CIFAR-. We then use the first, middle and last point of each of the trajectories to define a -dimensional plane in the weight space containing all affine combinations of these points. In Figure 3 we plot the loss on train and error on test for points in these planes. We then project the other points of the trajectory to the plane of the plot. Note that the trajectories do not generally lie in the plane of the plot, except for the first, last and middle points, showed by black crosses in the figure. Therefore for other points of the trajectories it is not possible to tell the value of train loss and test error from the plots.

The key insight from Figure 3 is that both methods explore points close to the periphery of the set of high-performing networks. The visualizations suggest that both methods are doing exploration in the region of space corresponding to DNNs with high accuracy. The main difference between the two approaches is that the individual proposals of SGD with a cyclical learning rate schedule are in general much more accurate than the proposals of a fixed-learning rate SGD. After making a large step, SGD with a cyclical learning rate spends several epochs fine-tuning the resulting point with a decreasing learning rate. SGD with a fixed learning rate on the other hand is always making steps of relatively large sizes, exploring more efficiently than with a cyclical learning rate, but the individual proposals are worse.

Another important insight we can get from Figure 3 is that while the train loss and test error surfaces are qualitatively similar, they are not perfectly aligned. The shift between train and test suggests that more robust central points in the set of high-performing networks can lead to better generalization. Indeed, if we average several proposals from the optimization trajectories, we get a more robust point that has a substantially higher test performance than the individual proposals of SGD, and is essentially centered on the shifted mode for test error. We further discuss the reasons for this behaviour in sections 3.4, 3.5, 3.6.

3.2 Swa Algorithm

We now present the details of the Stochastic Weight Averaging algorithm, a simple but effective modification for training neural networks, motivated by our observations in section 3.1.

Following Garipov et al. (2018), we start with a pretrained model . We will refer to the number of epochs required to train a given DNN with the conventional training procedure as its training budget and will denote it by . The pretrained model can be trained with the conventional training procedure for full training budget or reduced number of epochs (e.g. ). In the latter case we just stop the training early without modifying the learning rate schedule. Starting from we continue training, using a cyclical or constant learning rate schedule. When using a cyclical learning rate we capture the models that correspond to the minimum values of the learning rate (see Figure 2), following Garipov et al. (2018). For constant learning rates we capture models at each epoch. Next, we average the weights of all the captured networks to get our final model .

Note that for cyclical learning rate schedule, the SWA algorithm is related to FGE (Garipov et al., 2018), except that instead of averaging the predictions of the models, we average their weights, and we use a different type of learning rate cycle. In section 3.5 we show how SWA can approximate FGE, but with a single model.

Batch normalization.

If the DNN uses batch normalization

(Ioffe and Szegedy, 2015), we run one additional pass over the data, as in Garipov et al. (2018)

, to compute the running mean and standard deviation of the activations for each layer of the network with

weights after the training is finished, since these statistics are not collected during training. For most deep learning libraries, such as PyTorch or Tensorflow, one can typically collect these statistics by making a forward pass over the data in training mode.

The SWA procedure is summarized in Algorithm 1.

0:   weights , LR bounds ,cycle length  (for constant learning rate ), number of iterations
0:  
   {Initialize weights with }
  
  for  do
      {Calculate LR for the iteration}
      {Stochastic gradient update}
     if  then
         {Number of models}
         {Update average}
     end if
  end for{Compute BatchNorm statistics for weights}
Algorithm 1 Stochastic Weight Averaging

3.3 Computational Complexity

The time and memory overhead of SWA compared to conventional training is negligible. During training, we need to maintain a copy of the running average of DNN weights. Note however that the memory consumption in storing a DNN is dominated by its activations rather than its weights, and thus is only slightly increased by the SWA procedure, even for large DNNs (e.g., on the order of 10%). After the training is complete we only need to store the model that aggregates the average, leading to the same memory requirements as standard training.

During training extra time is only spent to update the aggregated weight average. This operation is of the form

and it only requires computing a weighted sum of the weights of two DNNs. As we apply this operation at most once per epoch, SWA and SGD require practically the same amount of computation. Indeed, a similar operation is performed as a part of each gradient step, and each epoch consists of hundreds of gradient steps.

3.4 Solution Width

Figure 4: (Left) Test error and (Right) -regularized cross-entropy train loss as a function of a point on a random ray starting at SWA (blue) and SGD (green) solutions for Preactivation ResNet- on CIFAR-. Each line corresponds to a different random ray.
Figure 5: -regularized cross-entropy train loss and test error as a function of a point on the line connecting SWA and SGD solutions on CIFAR-. Left: Preactivation ResNet-. Right: VGG-.

Keskar et al. (2017) and Chaudhari et al. (2017) conjecture that the width of a local optimum is related to generalization. The general explanation for the importance of width is that the surfaces of train loss and test error are shifted with respect to each other and it is thus desirable to converge to the modes of broad optima, which stay approximately optimal under small perturbations. In this section we compare the solutions found by SWA and SGD and show that SWA generally leads to much wider solutions.

Let and denote the weights of DNNs trained using SWA and conventional SGD, respectively. Consider the rays

which follow a direction vector on the unit sphere, starting at and , respectively. In Figure 4 we plot train loss and test error of and as a function of for random directions ,

drawn from a uniform distribution on the unit sphere. For this visualization we use a Preactivation ResNet-

on CIFAR-.

First, while the loss values on train for and are quite similar (and in fact has a slightly lower train loss), the test error for is lower by (at the converged value corresponding to ). Further, the shapes of both train loss and test error curves are considerably wider for than for , suggesting that SWA indeed converges to a wider solution: we have to step much further away from to increase error by a given amount. We even see the error curve for SGD has an inflection point that is not present for these distances with SWA.

Notice that in Figure 4 any of the random directions from increase test error. However, we know that the direction from to would decrease test error, since has considerably lower test error than . In other words, the path from to is qualitatively different from all directions shown in Figure 4, because along this direction is far from optimal. We therefore consider the line segment connecting and :

In Figure 5 we plot the train loss and test error of as a function of signed distance from for Preactivation ResNet- and VGG- on CIFAR-.

We can extract several key insights about and from Figure 5. First, the train loss and test error plots are indeed substantially shifted, and the point obtained by minimizing the train loss is far from optimal on test. Second, lies near the boundary of a wide flat region of the train loss. Further, the loss is very steep near .

Keskar et al. (2017) argue that the loss near sharp optima found by SGD with very large batches are actually flat in most directions, but there exist directions in which the optima are extremely steep. They conjecture that because of this sharpness the generalization performance of large batch optimization is substantially worse than that of solutions found by small batch SGD. Remarkably, in our experiments in this section we observe that there exist directions of steep ascent even for small batch optima, and that SWA provides even wider solutions (at least along random directions) with better generalization. Indeed, we can see clearly in Figure 5 that SWA is not finding a different minima than SGD, but rather a flatter region in the same basin of attraction. We can also see clearly that the significant asymmetry of the loss function in certain directions, such as the direction SWA to SGD, has a role in understanding why SWA provides better generalization than SGD. In these directions SWA finds a much flatter solution than SGD, which can be near the periphery of sharp ascent.

3.5 Connection to Ensembling

Garipov et al. (2018) proposed the Fast Geometric Ensembling (FGE) procedure for training ensembles in the time required to train a single model. Using a cyclical learning rate, FGE generates a sequence of points that are close to each other in the weight space, but produce diverse predictions. In SWA instead of averaging the predictions of the models we average their weights. However, the predictions proposed by FGE ensembles and SWA models have similar properties.

Let denote the predictions of a neural network parametrized by weights . We will assume that

is a scalar (e.g. the probability for a particular class) twice continuously differentiable function with respect to

.

Consider points proposed by FGE. These points are close in the weight space by design, and concentrated around their average . We denote . Note . Ensembling the networks corresponds to averaging the function values

Consider the linearization of at .

where denotes the dot product. Thus, the difference between averaging the weights and averaging the predictions

where . Note that the difference between the predictions of different perturbed networks is

and is thus of the first order of smallness, while the difference between averaging predictions and averaging weights is of the second order of smallness. Note that for the points proposed by FGE the distances between proposals are relatively small by design, which justifies the local analysis.

To analyze the difference between ensembling and averaging the weights of FGE proposals in practice, we run FGE for epochs and compare the predictions of different models on the test dataset with a Preactivation ResNet-164 (He et al., 2016) on CIFAR-. The norm of the difference between the class probabilities of consecutive FGE proposals averaged over the test dataset is . We then average the weights of the proposals and compute the class probabilities on the test dataset. The norm of difference of the probabilities for the SWA model and the FGE ensemble is , which is substantially smaller than the difference between the probabilities of consecutive FGE proposals. Further, the fraction of objects for which consecutive FGE proposals output the same labels is not greater than . For FGE and SWA the fraction of identically labeled objects is .

The theoretical considerations and empirical results presented in this section suggest that SWA can approximate the FGE ensemble with a single model.

3.6 Connection to Convex Minimization

SWA
DNN (Budget) SGD FGE ( Budget) Budget Budgets Budgets
CIFAR-100
VGG-16 ()
ResNet-164 ()
WRN-28-10 ()
PyramidNet- ()
CIFAR-10
VGG-16 ()
ResNet-164 ()
WRN-28-10 ()
ShakeShake-2x64d ()
Table 1: Accuracies () of SWA, SGD and FGE methods on CIFAR-100 and CIFAR-10 datasets for different training budgets. Accuracies for the FGE ensemble are from Garipov et al. (2018).

Mandt et al. (2017) showed that under strong simplifying assumptions SGD with a fixed learning rate approximately samples from a Gaussian distribution centered at the minimum of the loss. Suppose this is the case when we run SGD with a fixed learning rate for training a DNN.

Let us denote the dimensionality of the weight space of the neural network by . Denote the samples produced by SGD by . Assume the points are concentrated around the local optimum . The SWA solution is given by . The points are samples from a multidimensional Gaussian for some covariance matrix defined by the curvature of the loss, batch size and the learning rate. Note that the samples from a multidimensional Gaussian are concentrated on the ellipsoid

and the probability mass for a sample to end up inside the ellipsoid near is negligible. On the other hand, is guaranteed to converge to as .

Moreover, Polyak and Juditsky (1992) showed that averaging SGD proposals achieves the best possible convergence rate among all stochastic gradient algorithms. The proof relies on the convexity of the underlying problem and in general there are no convergence guarantees if the loss function is non-convex (see e.g. Ghadimi and Lan, 2013). While DNN loss functions are known to be non-convex (e.g. Choromanska et al., 2015), over the trajectory of SGD these loss surfaces are approximately convex (e.g. Goodfellow et al., 2015). However, even when the loss is locally non-convex, SWA can improve generalization. For example, in Figure 5 we see that SWA converges to a central point of the training loss.

In other words, there are a set of points that all achieve low training loss. By running SGD with a high constant or cyclical schedule, we traverse over the surface of this set. Then by averaging the corresponding iterates, we get to move inside the set. This observation explains both convergence rates and generalization. In deep learning we mostly observe benefits in generalization from averaging. Averaging can move to a more central point, which means one has to move further from this point to increase the loss by a given amount, in virtually any direction. By contrast, conventional SGD with a decaying schedule will converge to a point on the periphery of this set. With different initializations conventional SGD will find different points on the boundary, of solutions with low training loss, but it will not move inside.

4 Experiments

We compare SWA against conventional SGD training on CIFAR-, CIFAR- and ImageNet ILSVRC-2012 (Russakovsky et al., 2012). We also compare to Fast Geometric Ensembling (FGE) (Garipov et al., 2018), but we note that FGE is an ensemble whereas SWA corresponds to a single model. Conventional SGD training uses a standard decaying learning rate schedule (details in the Appendix) until convergence. We found an exponentially decaying average of SGD to perform comparably to conventional SGD at convergence. We release the code for reproducing the results in this paper at https://github.com/timgaripov/swa.

4.1 Cifar Datasets

For the experiments on CIFAR datasets we use VGG-16 (Simonyan and Zisserman, 2014), a 164-layer Preactivation-ResNet (He et al., 2016) and Wide ResNet-28-10 (Zagoruyko and Komodakis, 2016) models. Additionally, we experiment with the recent Shake-Shake-2x64d (Gastaldi, 2017) on CIFAR- and PyramidNet- (bottleneck, ) (Han et al., 2016) on CIFAR-. All models are trained using -regularization, and VGG- also uses dropout.

For each model we define budget as the number of epochs required to train the model until convergence with conventional SGD training, such that we do not see improvement with SGD beyond this budget. We use the same budgets for VGG, Preactivation ResNet and Wide ResNet models as Garipov et al. (2018). For Shake-Shake and PyramidNets we use the budgets indicated by the papers that proposed these models (Gastaldi, 2017; Han et al., 2016). We report the results of SWA training within , and budgets of epochs.

For VGG, Wide ResNet and Preactivation-ResNet models we first run standard SGD training for of the training budget, and then use the weights at the last epoch as an initialization for SWA with a fixed learning rate schedule. We ran SWA for , and budget to complete the training within , and budgets respectively.

For Shake-Shake and PyramidNet architectures we do not report the results in one budget. For these models we use a full budget to get an initialization for the procedure, and then train with a cyclical learning rate schedule for and budgets. We used long cycles of small learning rates for Shake-Shake, because this architecture already involves many stochastic components.

We present the details of the learning rate schedules for each of these models in the Appendix.

For each model we also report the results of conventional SGD training, which we denote by SGD. For VGG, Preactivation ResNet and Wide ResNet we also provide the results of the FGE method with one budget reported in Garipov et al. (2018). Note that for FGE we report the accuracy of an ensemble of to networks, while for SWA we report the accuracy of a single model.

We summarize the experimental results in Table 1. For all models we report the mean and standard deviation of test accuracy over runs. In all conducted experiments SWA substantially outperforms SGD in one budget, and improves further, as we allow more training epochs. Across different architectures we see consistent improvement by on CIFAR- (excluding Shake-Shake, for which SGD performance is already extremely high) and by - on CIFAR-. Amazingly, SWA is able to achieve comparable or better performance than FGE ensembles with just one model. On CIFAR- SWA usually needs more than one budget to get results comparable with FGE ensembles, but on CIFAR- even with budget SWA outperforms FGE.

4.2 Imagenet

On ImageNet we experimented with ResNet-, ResNet- (He et al., 2016) and DenseNet- (Huang et al., 2017). For these architectures we used pretrained models from PyTorch.torchvision. For each of the models we ran SWA for epochs with a cyclical learning rate schedule with the same parameters for all models (the details can be found in the Appendix), and report the mean and standard deviation of test error averaged over runs. The results are shown in Table 2.

SWA
DNN SGD epochs epochs
ResNet-50
ResNet-152
DenseNet-161
Table 2: Top-1 accuracies () on ImageNet for SWA and SGD with different architectures.

For all architectures SWA provides consistent improvement by - over the pretrained models.

4.3 Effect of the Learning Rate Schedule

Figure 6: Test error as a function of training epoch for SWA with different learning rate schedules with a Preactivation ResNet- on CIFAR-.

In this section we explore how the learning rate schedule affects the performance of SWA. We run experiments on Preactivation ResNet- on CIFAR-. For all schedules we use the same initialization from a model trained for epochs using the conventional SGD training. As a baseline we use a fully-trained model trained with conventional SGD for epochs.

We consider a range of constant and cyclical learning rate schedules. For cyclical learning rates we fix the cycle length to , and consider the pairs of base learning rate parameters . Among the constant learning rates we consider .

We plot the test error of the SWA procedure for different learning rate schedules as a function of the number of training epochs in Figure 6.

We find that in general the more aggressive constant learning rate schedule leads to faster convergence of SWA. In our experiments we found that setting the learning rate to some intermediate value between the largest and the smallest learning rate used in the annealing scheme in conventional training usually gave us the best results. The approach is however universal and can work well with different learning rate schedules tailored for particular tasks.

4.4 Dnn Training With a Fixed Learning Rate

Figure 7: Test error as a function of training epoch for constant (green) and decaying (blue) learning rate schedules for a Wide ResNet-- on CIFAR-100. In red we average the points along the trajectory of SGD with constant learning rate starting at epoch .

In this section we show that it is possible to train DNNs from scratch with a fixed learning rate using SWA. We run SGD with a fixed learning rate of on a Wide ResNet-- (Zagoruyko and Komodakis, 2016) for epochs from a random initialization on CIFAR-100. We then averaged the weights at the end of each epoch from epoch and until the end of training. The final test accuracy of this SWA model was .

Figure 7 illustrates the test error as a function of the number of training epochs for SWA and conventional training. The accuracy of the individual models with weights averaged by SWA stays at the level of which is less than the accuracy of the SWA model. These results correspond to our intuition presented in section 3.6 that SGD with a constant learning rate oscillates around the optimum, but SWA converges.

While being able to train a DNN with a fixed learning rate is a surprising property of SWA, for practical purposes we recommend initializing SWA from a model pretrained with conventional training (possibly for a reduced number of epochs), as it leads to faster and more stable convergence than running SWA from scratch.

5 Discussion

We have presented Stochastic Weight Averaging (SWA) for training neural networks. SWA is extremely easy to implement, architecture-agnostic, and improves generalization performance at virtually no additional cost over conventional training.

There are so many exciting directions for future research. SWA does not require each weight in its average to correspond to a good solution, due to the geometry of weights traversed by the algorithm. It therefore may be possible to develop SWA for much faster convergence than standard SGD. One may also be able to combine SWA with large batch sizes while preserving generalization performance, since SWA discovers much broader optima than conventional SGD training. Furthermore, a cyclic learning rate enables SWA to explore regions of high posterior density over neural network weights. Such learning rate schedules could be developed in conjunction with stochastic MCMC approaches, to encourage exploration while still providing high quality samples. One could also develop SWA to average whole regions of good solutions, using the high-accuracy curves discovered in Garipov et al. (2018).

A better understanding of the loss surfaces for multilayer networks will help continue to unlock the potential of these rich models. We hope that SWA will inspire further progress in this area.

Acknowledgements.

This work was supported by NSF IIS-1563887, Samsung Research, Samsung Electronics and Russian Science Foundation grant 17-11-01027. We also thank Vadim Bereznyuk for helpful comments.

References

Appendix A Appendix

a.1 Experimental Details

For the experiments on CIFAR datasets (section 4.1) we used the following implementations (embedded links):

Models for ImageNet are from here. Pretrained networks can be found here.

SWA learning rates.

For PyramidNet SWA uses a cyclic learning rate with and and cycle length . For VGG and Wide ResNet we used constant learning . For ResNet we used constant learning rates on CIFAR- and on CIFAR-.

For Shake-Shake Net we used a custom cyclic learning rate based on the cosine annealing used when training Shake-Shake with SGD. Each of the cycles replicate the learning rates corresponding to epochs of the standard training and the cycle length epochs. The learning rate schedule is depicted in Figure 8 and follows the formula

where epoch(i) is the number of data passes completed before iteration .

For all experiments with ImageNet we used cyclic learning rate schedule with the same hyperparameters

, and .

Figure 8: Cyclical learning rate used for Shake-Shake as a function of iteration.

SGD learning rates.

For conventional SGD training we used SGD with momentum and with an annealed learning rate schedule. For VGG, Wide ResNet and Preactivation ResNet we fixed the learning rate to for the first half of epochs (), then linearly decreased the learning rate to for the next of epochs (), and then kept it constant for the last of epochs (). For VGG we set , and for Preactivation ResNet and Wide ResNet we set . For Shake-Shake Net and PyramidNets we used the cosine and piecewise-constant learning rate schedules described in Gastaldi [2017] and Han et al. [2016] respectively.

a.2 Training Resnet With a Constant Learning Rate

In this section we present the experiment on training Preactivation ResNet- using a constant learning rate. The experimental setup is the same as in section 4.4. We set the learning rate to and start averaging after epoch . The results are presented in Figure 9.

Figure 9: Test error as a function of training epoch for constant (green) and decaying (blue) learning rate schedules for a Preactivation ResNet- on CIFAR-100. In red we average the points along the trajectory of SGD with constant learning rate starting at epoch .