Robust Neural Network Training using Periodic Sampling over Model Weights

05/14/2019 ∙ by Samarth Tripathi, et al. ∙ LG Electronics Inc 0

Deep neural networks provide best-in-class performance for a number of computer vision problems. However, training these networks is computationally intensive and requires fine-tuning various hyperparameters. In addition, performance swings widely as the network converges making it hard to decide when to stop training. In this paper, we introduce a trio of techniques (PSWA, PWALKS, and PSWM) centered around periodic sampling of model weights that provide consistent and more robust convergence on a variety of vision problems (classification, detection, segmentation) and gradient update methods (vanilla SGD, Momentum, Adam) with marginal additional computation time. Our techniques use existing optimal training policies but converge in a less volatile fashion with performance improvements that are approximately monotonic. Our analysis of the loss surface shows that these techniques also produce minima that are deeper and wider than those found by SGD.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1. Introduction

Optimizing Deep Neural Networks

is especially challenging due to the nonconvex nature of their loss function. Hence, the development of gradient-based methods that use back-propagation to approximate optimal solutions has been crucial for neural network adoption. Optimization techniques over gradient updates like

Stochastic Gradient Descent (SGD) or gradient-based adaptive optimizers have made the training process more effective. However, optimal convergence of the loss function is still time-consuming, volatile, and needs many finely tuned hyperparameters. In this paper we show that by manipulating the model weights directly using their distributions over batchwise updates, we can achieve significant improvements in training convergence, and add more robustness to the optimization process with negligible cost of additional training time. Since our technique modifies the model weights directly using their distribution over gradient updates, it remains independent of gradient optimization methods.

Using the model weight distribution to achieve improvements on either the training process or a trained model has been widely studied by extending the Polyak-Ruppert Averaging (PRA) method. Lacoste-Julien et al. (Lacoste-Julien et al., 2012) explored many techniques to speed up convergence of convex functions using the projected stochastic subgradient method. Their work explored gradient-based averaging, weighted averaging, and other variations, as well as the theoretical justifications for such an approach. Moulines and Bach (Moulines and Bach, 2011) explored how PRA on SGD has better convergence guarantees, especially when the initial condition on the weights are carefully removed from the averages and the learning rate is decayed correctly. However, most of this earlier research was focused on a theoretical understanding of weight-averaging methods and lacks practical analysis especially on their application to highly nonlinear DNN models.

Recently, a similar technique has been applied over model weight distribution, but mostly on pretrained models (Izmailov et al., 2018). It shows better generalization and achieves a wider local minima post sampling. However, when such PRA based methods are directly applied to train a DNN from scratch, they fail to produce performance that matches the state-of-the-art. Meanwhile, these weight averaging based approaches also increase the computation load leading to increased training time. Even when we compute the weight averages in an online fashion (which essentially has linear time and space complexity), it adds a computational load proportional to the number of batches. Moreover, as  (Izmailov et al., 2018) mentioned, recalibration of Batch Normalization (BN) layers (Ioffe and Szegedy, 2015) is needed after performing the reassignment of weights which is additionally time-consuming.

The rest of the paper is organized as follows, in Sec. 2, we present related work and highlight the problem that all previous work share - the inability to converge to state-of-the-art performance when training a DNN model from scratch. In Sec. 3 we propose three new techniques: PSWA, PWALKS, and PSWM which build upon the prior works while addressing their major flaws. And in Sec. 4

we demonstrate the success of these techniques with an extensive empirical study on various computer vision tasks like classification, localization, detection, and segmentation on datasets including Cifar10, ImageNet, ADE20K, Coco, and MPII. In

Sec. 5, we compare the loss surface of DNN models trained by different approaches, and quantify the performance and stability improvements provided by our approach to the baseline model.

2. Related Work

Given a DNN model with a loss function, , on a training sample , the mini-batch SGD method aims to minimize the loss of the training data by updating the model weights iteratively as:

(1)

The partial derivatives of the loss correspond to the direction of the gradient ascent of a batch of training data. The hyperparameter is the learning rate that controls the step size of the update. Most research in this area focus on the effects of different learning rate schedules, gradient update techniques algorithms, optimal batch sizes, etc and how improvements in these areas can provide better convergence and add robustness (Duchi et al., 2011; Tieleman and Hinton, 2012; Kingma and Ba, 2015; Smith et al., 2017, as examples). These works are mainly dominated by the modified versions of the update . In comparison, the weight averaging approach aims to reassign the final value of weights as

(2)

from a sample of weights after batch updates. Variations on the application of this technique have been studied previously by the research community. Rather than simply discuss their short-comings, we show empirically the failings of two salient previous works using the well-adopted ResNet18 (He et al., 2016) on Cifar10 dataset (Krizhevsky, 2009). We used the publicly available implementation 111Refer https://github.com/kuangliu/pytorch-cifar/blob/master/main.py for details (commit 3407511). with SGD updates and stepwise learning rate decay presented in Sec. 4.

Izmailov et al. (Izmailov et al., 2018) proposed the Stochastic Weight Averaging (SWA) method, which uses PRA over model distribution when retraining pretrained models to achieve flatter minimas and better generalization. This technique provides better generalization when finetuning a model. Since we are interested in how this technique can be used to train a model from scratch, we modify the technique to train our aforementioned neural model from scratch. Following the SWA

algorithm, we initialized a running mean for all model parameters and after training it for ‘c’ epochs (‘c’is a pre-defined hyperparameter), we replaced the model weights with their respective running means. Note that the running mean is initialized only once. It is then kept updated and reassigned after ‘c’ epochs, consistent with the original algorithm. Similarly, we also calibrate the

BN layers for both approaches as mentioned in (Izmailov et al., 2018), by performing a forward pass over the training data after each reassignment. We emphasize that the SWA technique we implement for baseline comparisons is a modification of the original implementation, to instead train models from scratch.

Lacoste-Julien et al. (Lacoste-Julien et al., 2012) presented an averaging technique for the Projected Stochastic Subgradient

method where an iteration-based weighed averaging approach to model training and its variations are explored. They presented theoretical analysis of the technique and discuss the finite variance bound of their approach for SVM models. We investigated two variations of their technique on ResNet18 on Cifar10 with weighed averaging where again we initialize a running weight for each model parameter. In the first approach ‘BachEpoch’ we update the mean estimation after each epoch, multiplied by

value (as we perform weighed averaging), and then reassign the mean values to those weights. In the second approach ‘BachBatch’ we update the mean estimation after each batch, multiplied by , and reassign model weights at the end of the epoch ( being the total number of batches in the epoch and being the current batch). Hence ‘BachEpoch’ provides a linear weighed averaging approach and ‘BachBatch’ provides an exponential weighed averaging approach. We also apply the re-calibration of the BN layers as discussed earlier.

Figure 1. Performance of previous algorithms (Lacoste-Julien et al., 2012; Izmailov et al., 2018) with SGD.

As evident in Fig. 1, we see that none of the approaches can replicate the accuracy achieved by SGD. When using the SWA technique in (Izmailov et al., 2018), the approaches deteriorate performance and hinder convergence. The performance improvements with higher values of ‘c’ is because of fewer reassignments of the SWA technique. (Lacoste-Julien et al., 2012)’s approach of using weighed averaging with more significance to later epochs in a linear (BachEpoch) and exponential (BachBatch) fashion also fail to converge optimally. Moreover, all approaches add an increased computational load in processing of PRA for each model parameter while training, reassignment of the computed values, followed by recalibration of BN layers. These loads add up because all three tasks are performed for each epoch while training. Hence the techniques of  (Izmailov et al., 2018) and  (Lacoste-Julien et al., 2012) which work impressively over improving generalization of pretrained neural networks and optimizing convex learning models respectively, when translated to DNN training, increase the computational load and training time without providing state-of-the-art convergence.

3. Methods

The analysis of the performance of the two methods shows that both SWA and weighed averaging do provide better generalization at the early stages of the training process, typically in the underfitting regime. However, as the mean is biased by the model weights at the early stages of the training process, it cannot converge properly at the later stages of the training, even when one allows weighted averaging in favor of models at later training stages. We address this problem by removing the dependency of any prior weight distribution estimations for the general PRA approach.

We call this technique Periodically Sampled Weight Averaging (PSWA), as we sample the model weights over the batchwise updates, and repeat it periodically over epochs. Fig. 2 depicts the application of PSWA for ResNet18 on Cifar10 with SGD (consistent with the prior experiments) on Accuracy and Cross Entropy Loss on the test dataset. The approach allows the model to train effectively for one epoch, while keeping running means for all model parameters over the weight distribution after batchwise gradient updates, followed by reassigning the running mean to the parameter weights, and then reinitializing the mean at the end of the epoch. This additional step allows for SGD to gradually converge the model to the optimum by making gradient updates. Meanwhile, averaging over the batchwise distribution provides for a stabling effect on the model.

Figure 2. Performance of PSWA on test data during training.

Another challenge for applying general weight averaging techniques for DNN models is the added computational load, which leads to longer training times. The time complexity of model training with weight averaging typically contains three parts:

(3)

where , , , mark the time spend on back-propagation, recalibration, and weight sampling using the full training dataset. Using the plain PSWA for the same number of epochs clearly leads to a longer training time.

To remedy this additional computational load, we improve the plain-vanilla PSWA such that we update the running mean for only a few percent () of the batches spread evenly over the randomized training data; Similarly, we recalibrate the global mean and variance of each BN layer with percent of the training data using a fast forward pass. We demonstrate later on that by reducing the number of updates in this fashion, the added computational cost becomes negligible.

1:procedure PSWA
2:
3:Initialize DNN model
4:Initialize Learning rate schedule
5:Initialize training data batches
6:Initialize total epochs
7:Initialize running mean for parameters
8:Determine sampling strategy and
9:
10:     for  in 1 …  do
11:         randomize()
12:         reset (, 0)
13:         for  in 1 …  do
14:              
15:              update (, , , )          
16:          . assign()
17:         BN recalibration ()      
Algorithm 1 Periodically Sampled Weight Averaging

Algorithm 1 presents the general workflow of the PSWA method for training a DNN. After initializing the model parameters and data for training, we repeatedly update the model weights by SGD or other gradient-based optimizations. Then we update the mean estimation of each weight. The update is carried out in an online fashion. For where we use the full dataset, it is:

(4)

To reduce the computational time, we only select percent of batches to be used for mean estimation, and we change the count correspondingly. Before each epoch, we always reset the to 0, and after the epoch, we reassign the mean weights to model weights. After reassigning, the BN layers are not best suitable for the new set of weights, so we recalibrate the BN layers using percent of the training data to perform a forward pass and recompute global mean and variance statistics for each BN layer.

Figure 3. PSWA on ResNet50 convergence problem.

Although the PSWA method achieves optimal final test accuracy using shallow ResNet18 model and other lightweight models, we found that for deeper networks PSWA still does not converge properly to the optimum. Fig. 3 shows the effect on ResNet50. This problem is pervasive across similar deep networks like Inception and DenseNet, and also on datasets like ImageNet. However, it is important to note that the learning rate schedule decreases by a factor of 10 at epochs 80 and 120, and that it is only after the 120th epoch that the SGD method converges to a better result than PSWA.

To investigate why this happens, we analyze PSWA in more detail. In essence, for PSWA, we modify the algorithm of (Lacoste-Julien et al., 2012) which works over the entire training process, to run over only one epoch. This approach does not burden the running mean with the weights distribution of the earlier epochs while still providing regularizing effect from PRA over batchwise weight distribution. It however becomes cumbersome when the learning rate has decreased significantly as the batchwise descent of the SGD loss function is able to reach a deeper minima for more complex models. We believe that by performing PRA over the SGD walk at this stage, the regularization counteracts the optimal convergence to the minima at the latter part of the batch-wise training. To address this problem we return to the conclusions presented by (Moulines and Bach, 2011), who show that there is a need to carefully remove from the running mean, the initial weights which bias the mean towards the local minima.

Figure 4. PWALKS comparisons

We solve the suboptimal convergence problem for deeper networks by proposing two different modifications to PSWA. In both approaches we allocate more importance to the model weights during the final batches while still maintaining the regularization afforded by using the weight distribution.

In the first approach, Periodic Weight Averaging over Last K Samples (PWALKS), instead of sampling weights evenly from all batches (for the mean weight distribution), we sample only the last ‘k’% of the samples, ‘k’ being a hyperparameter of size of the dataset and batches, ranging between 0 (last batch only, standard SGD) and (PSWA with ). Empirically a small k value between 2-5 provides a consistently good performance by providing improvement over plain SGD during early training (though not as much as PSWA), and consistently converges to the optimum as demonstrated in Fig. 4. Parallels can be drawn between the PWALKS technique and constructing an ensemble of models over the last few batches, but they are beyond the scope of this paper. To convert PSWA to PWALKS, the update (line 7) of Algorithm 1 is applied only when

(5)

And the parameter k is equivalent to in the PSWA method in terms of controlling computational cost.

Figure 5. PSWM comparisons

A second approach to solving the PSWA convergence problem is to approach it from the perspective of a cumulative adjustment to weights. We propose a momentum based modification to PSWA called Periodically Sampled Weight Momentum (PSWM) where instead of keeping a running mean, we keep the running weights updated using momentum. For the model’s parameters we keep a running momentum term, which we update at the end of each batch and reassign at the end of the epoch. Empirically momentum values between (0.5,0.9) yield good performance with being standard SGD. To convert PSWA to PSWM, the update (line 7) of Algorithm 1 is changed to:

(6)

And since the PSWM is built on the PSWA, the sampling technique developed for PSWA can be applied to also reduce the time complexity of PSWM.

Figure 6. Comparison of computation time for plain-vanilla SGD, PSWA with and PWALKS

We next compare the computational performances of plain-vanilla SGD, PSWA with and PWALKS with and . The code is based on the fastest Cifar10 training code listed in the DAWN project (Coleman et al., 2017); and the original implementation222Refer https://github.com/davidcpage/cifar10-fast for details (commit d31ad8d). is changed from half-precision to full precision. We repeated the training process 10 times for each technique and report the corresponding mean and standard deviations. Fig. 6 shows that the PSWA leads to a 34% overhead when using the full training dataset for weight update and recalibration of BN layers; and by adopting PWALKS, we achieve the same prediction accuracy on the testing dataset without sacrificing the speed significantly without code-level optimizations. In addition, we also observe that the variations of the training process is much smaller when weight averaging techniques have been applied, which we discuss in Sec. 5

4. Experiments and Results

We demonstrate that our techniques can be effectively applied on various computer vision tasks such as classification, detection, segmentation across different convolutional neural network architectures such as ResNet18, ResNet50, DenseNet121, Inception, MobileNet trained using a variety of optimization techniques such as SGD, momentum, ADAM and over diverse learning rate schedules. Our techniques consistently provide increased stability and consistent improvements on intermediate performance while converging optimally over a broad spectrum of hyperparameter values. We compare our approach against SGD-based approaches since existing weight averaging techniques such as  (Izmailov et al., 2018) and  (Lacoste-Julien et al., 2012) do not provide state-of-the-art performance and are, in addition, more expensive computationally.

4.1. Dataset: Cifar-10

We have already discussed in detail the results of our techniques on ResNet18 and ResNet50 over Cifar10. ResNet18 is trained for 150 epochs and ResNet50 is trained for 180 epochs. Both use SGD with momentum of 0.9, L2 penalty of 0.0005, and have a learning rate schedule which decreases by factor of 10 at epochs 80,120 and 150. We use Standard CrossEntropy loss and batch size of 128.

Figure 7. MobileNet trained on Cifar10.
Figure 8. Inception trained on Cifar10.

In our experiments on shallow networks like MobileNet-v2 (Sandler et al., 2018) and ResNet18, we find PSWA not only provides faster and more robust convergence, but also converges to a more optimal minima, as evident in Fig. 7. Another interesting comparison is between PWALKS and PSWM with PSWA on shallow networks, where PSWA converges to deeper minima, which PWALKS and PSWM are unable to. However, for deeper networks like Inception (Szegedy et al., 2016) , DenseNet-121 (Huang et al., 2017) and ResNet50, as discussed before, PSWA does not converge properly, while both PWALKS and PSWM do. Fig. 8 shows Inception network trained using the same implementation as above. We observe that PSWA and its variations reach 90% and 94% thresholds much faster consistently and while training on larger learning rate, while SGD needs a learning rate change by a factor of 10, to cross the thresholds.

4.2. Optimizer: Adam

Figure 9. ResNet50 with Cifar10 using Adam.

Until now we have only addressed SGD (with momentum)-based optimizing methods. To show our techniques can be effectively used on adaptive optimizers as well, we present experiments on Adam (Kingma and Ba, 2014)

, which performs first-order gradient-based optimization of stochastic objective functions, based on adaptive estimates of lower-order moments.

Fig. 9 shows the implementation of ResNet50 on Cifar10 consistent with prior implementations except we use Adam instead of SGD, with a starting learning rate of 0.001. As we can see PSWA, PWALKS and PSWM

all offer marginal but consistent improvement on Adam, across epochs over multiple runs. The improvement is not as significant and dramatic as SGD, because Adam itself alleviates the common problems of SGD like large fluctuations, and slow convergence. Since Adam modulates the learning rate of each weight based on the magnitudes of its gradients, instead of the complete raw and noisy gradient vector, the distribution of the parameter weights remains small compared to SGD.

Figure 10. ResNet50 on Cifar10 with Adam and high learning rate.

Adam and other adaptive optimizers suffer from some important documented problems. Though Adam converges faster, it does not generalize well (Keskar and Socher, 2017). From our experiments PSWA

over Adam also provided for reduced CrossEntropy loss over training. Another problem for adaptive optimizers like RMSPROP and Adam is they become unstable at high learning rate near convergence. This happens as the squares of rolling mean of gradients are used to divide the current gradient, in which case very small gradients can introduce instability. In

Fig. 10 we present such a scenario where we use a of 0.01 (instead of 0.001) which causes Adam to become unstable. However, Adam with PSWA remains stable and converges better.

4.3. Dataset: ImageNet

ImageNet (Deng et al., 2009) is another standard image classification dataset, it has 1.2 million high resolution images from 1000 classes. Our implementations uses Resnet50 as the underlying network, and SGD with momentum as the optimizer. We use learning rate with 0.1, which changes by a factor of 0.1 every 30 epochs, for a total of 150 epochs.

Figure 11. ResNet50 on ImageNet with PWALKS.
Figure 12. ResNet50 on ImageNet with PSWM.

Our results on ImageNet follow similar trends as on Cifar10 presented in Fig. 11 and Fig. 12 for PWALKS and PSWM respectively. As we see for Cifar10, PSWA converges much faster during early training, but does not converge optimally. PWALKS does better generalization, without compromising on final convergence over PSWM, while both the techniques provide improvement over SGD. An important threshold for Imagenet classification is 90% top-5 classification accuracy. In our experiments SGD needs a learning rate of 0.001 to reach the threshold, while PSWA and its variants cross the threshold with 0.01 learning rate and fewer epochs consistently.

4.4. Task: Human-Pose Detection

Figure 13. COCO-Keypoint detection on ResNet50 and Adam333An aberrant drop in PSWA accuracy in Figure 14, seems a result of biased subset of data points during BN recaliberation.
Figure 14. MPII Human-Pose detection on ResNet50 and Adam

We apply our techniques on the work of (Xiao et al., 2018b), where they perform Human-keypoint detection on MS-COCO (Lin et al., 2014) and Human-pose detection on MPII dataset (Andriluka et al., 2014)

. Both tasks use ResNet50 pretrained on ImageNet, and perform transfer learning on the new dataset. Both experiments use Adam as the optimizer with a learning rate of 0.001. Consistent with our prior experiments,

PWALKS and PSWM provide consistent improvement over Adam in the early stages of training.

4.5. Task: Segmentation

We also apply our techniques on the works of (Zhou et al., 2017), where they perform scene segmentation on MIT ADE20K Dataset (Zhou et al., 2016), the largest open source dataset for semantic segmentation and scene parsing. The implementation uses an encoder-decoder architecture with ResNet50 pretrained on ImageNet as the encoder and Pyramid Pooling Module with Bilinear Upsampling as decoder with deep supervision (Xiao et al., 2018a). The implementation uses per-pixel cross-entropy loss, SGD as the optimizer and a ’poly’ learning rate policy.

Figure 15. Pixel accuracy of segmentation on ADE20K.
Figure 16. Mean IOU of segmentation on ADE20K.

For our implementation we initialize two distributions one each for the encoder and decoder. We update both the distributions together and reassign at the end of the epoch. We do not need to recalibrate the BN layers, since the implementation uses Synchronized Batch Normalization (Peng et al., 2017). Fig. 15 shows Pixel wise accuracy of the Segmentation models on test set and Fig. 16 shows Mean IOU of the predicted segmentation on test data, where PSWA provides significant improvement over SGD based training.

5. Analysis

5.1. Loss Surface

To better understand the results of our approach, we investigate the effect of PSWA on the loss surface of the model during training when compared to SGD. Training neural networks requires minimizing a high-dimensional non-convex loss function, with a deeper minima correlating with better performance. An important characteristic of the minima is its ‘flatness’ or the measure of size of the connected region around the minimum where the training loss remains low. There exist strong claims that “flat” minima generalize better, while increased sharpness of a minima could indicate low generalization (Hochreiter and Schmidhuber, 1997; Kawaguchi, 2016). (Izmailov et al., 2018) shows their technique, SWA (based on averaging multiple points along the trajectory of SGD) leads to solutions corresponding to wider optima than SGD. We show similar conclusions can be made for PSWA.

Model Epoch 20 Epoch 100 Epoch 150
(Early Training) (Near Convergence) (At convergence)
SGD
SGD with PSWA
Figure 17. Comparison of loss surface with SGD and PSWA at different training stages.

We follow (Li et al., 2017)

approach, which presents a technique that calculates and visualizes the loss surface along random direction(s) near the weight space. They use a novel “filter normalization” scheme that enables side-by-side comparisons of different minima, which addresses problems with 1-Dimensional Linear Interpolation. Filter normalization is scale invariant, as otherwise perturbing the weights by one unit will have very little effect on network performance if the weights live on a scale much larger than one.

Fig. 17 presents loss surface comparison at different stages of training- beginning, near convergence, and at convergence on ResNet18 for Cifar10 trained by SGD and SGD with PSWA (consistent with Sec. 4.1). The horizontal axis represents the displacement of the random Gaussian direction vector; the red lines indicate accuracy and the blue lines indicate the loss values; the dashed lines represent the values on the test dataset while the solid lines represent the training set. As is clearly evident, the model trained by PSWA has much flatter and deeper minima, for both training and testing set, at the early training stage. The trend continues for near convergence stage and at convergence, though it becomes less pronounced. We can see that with similar test and train accuracies, PSWA still retains wider minima.

Figure 18. Loss surface of model weights before and after applying PSWA on one epoch, during the early training stage.

Fig. 18 presents a different representation of the loss surface at early training stage (epoch 50), before and after reassigning the model weights. The PSWA-based model is located at index 0 on the horizontal axis, and SGD model at index 1, while variables between them represent the displacement in the “filter normalized” direction between the weights (since we use the same model). We notice steady improvements in performance in the direction of weights after PSWA is applied.

5.2. Performance Statistics

While it is clear from Sec. 4 that PSWA and its variants converge consistently and more robustly than the baseline models, it is non-trivial to quantify this performance improvement as both the accuracy and loss functions form a non-stationary and volatile time-series. We analyze the test accuracy distribution with ResNet50 over ImageNet and discover all three of our techniques provide consistent and significantly lower Standard Deviation (SD) both at convergence and over the saturation phases at constant learning rates as shown in Tab. 1. The lower variance at early saturation phase (epoch 20-30) points to a less volatile training process and the lower variance at the final 20 epochs point to a more stable convergence. Moreover, PWALKS and PSWM both converge optimally compared to SGD across wide range of hyperparameters, which address the major challenges of prior research.

Standard Deviation
last 20 epochs
Standard Deviation
epochs 20-30
Mean
last 20 epochs
SGD 0.080 1.06 76.16
PSWA 0.058 0.39 74.77
PWALKS k=5 0.045 0.31 76.35
PWALKS k=2.5 0.047 0.21 76.50
PSWM m=0.9 0.047 0.39 75.97
PSWM m=0.5 0.056 0.24 76.16
Table 1. Volatility analysis of training ResNet50 on ImageNet.

Another important observation is that PSWA

’s performance on the test set monotonically increases or remains stable over epochs until convergence. This is especially important since it indicates that with a high probability the model is consistently improving and the performance does not sporadically fluctuate like the baseline model. Again analyzing test accuracy distribution with ResNet50 over ImageNet, we find that for almost 70% consecutive epochs with

PSWA the accuracy is improving, or 95% of them are stable within 0.2 percentage range decrement, unlike SGD-based training which only shows 57% and 77% respectively (shown in Tab. 2). We also compare performance improvements between current epoch and best overall performance over previous epochs and present the stability within 0.2 percentage range decrement. As we can see PSWA improves upon best previous performance or remains stable for 99% of the epochs.

consecutive epochs
with improvement
consecutive epochs
stable improvement
epochs with stable
improvement over
all previous epochs
SGD 57% 77% 77%
PSWA 70% 95% 99%
PWALKS k=5 66% 93% 95%
PWALKS k=2 65% 89% 97%
PSWM m=0.9 60% 85% 91%
PSWM m=0.5 65% 85% 90%
Table 2. Monotonic improvement demonstrating stability when training ResNet50 on ImageNet.

We emphasize that in all our presented examples from Sec. 4

, the optimizer, learning rate (and its schedule), and training hyperparameters, have been all finetuned for convergence in the original implementations, and we do not modify them when training using our techniques to ensure fair comparisons. But for most machine learning applications we are unaware of these optimal hyperparameters and learning rate schedules which can introduce volatility in the training process and uncertainty regarding the final convergence. For both these scenarios our techniques can provide improvements in stability and performance. Moreover, we might be constrained by computational resources or training time, requiring us to train for short time and at high learning rates. In all such scenarios our techniques can provide consistently better generalization and higher expected accuracy at intermediate epochs and early epochs as presented in

Tab. 3.

Average Improvement
over SGD accuracy
for all epochs
Average Improvement
over SGD accuracy
for first 30 epochs
PSWA 2.75 12.52
PWALKS k=5 3.71 13.31
PWALKS k=2.5 3.50 12.37
PSWM m=0.9 1.05 4.68
PSWM m=0.5 2.01 7.81
Table 3. Intermediate accuracy improvement when training ResNet50 on ImageNet.

6. Conclusion

Deep learning has become the defacto tool for a number of machine learning problems, but these networks need to be carefully fine-tuned and there exists no optimal training regime that works over the wide variety of available datasets and learning methods. In this paper, we introduced a trio of techniques (PSWA, PWALKS, and PSWM) based on sampling over model weights that solve issues with previous weight averaging approaches, and provide stable and more robust convergence for many problems and across different gradient update techniques, while remaining straightforward to implement.

While PSWA can converge efficiently and excels when applied to smaller networks, PWALKS and PSWM work across deeper and more complex networks and converge optimally while still providing the same improvements as PSWA. In the light of the advantages offered by these techniques, they provide a good starting point for training DNNs especially in those cases where no optimal training regime exists. It is important to note that these measures are empirical and there remains work to be done in showing theoretical convergence guarantees. Future research can also explore schedules for faster convergence, as PSWA performs better at high learning rates (which offers training speed-up) and plateaus over shorter time. This ability can be especially useful when training on large datasets with constrained computational resources. Moreover, since PSWA is fairly independent of the training process, it could essentially be ‘turned off’, after exploiting the fast convergence during the early stages of training.

Also, while both PWALKS and PSWM offer optimal convergence, they depend on hyperparameters ‘k’ and ‘m’ respectively. Empirically we find that both techniques converge optimally to within a small margin to each other across a wide range of their hyperparameter’s values; however PWALKS does offer minor improvements over PSWM during early stages of training. An avenue of future research could involve exploring PWALKS and PSWM with adaptive and dynamic values for ‘k’ and ‘m’, as well as using them in conjunction.

References

Appendix A Reproducibility

We provide details of hyperparameter values and additional implementation details about the experiments in Sec. 4

For the experiment in Sec. 4.1 We adopted the implementation in (https://github.com/kuangliu/pytorch-cifar/blob/master/main.py) with the exception of a custom learning rate schedule Sec. 4.1

as the original was sub-optimal. Data augmentation on the training set was performed using random crop (padding 4) and horizontal flip while both train and test were normalized. The dataloaders perform random shuffle on data batches, with 2 concurrent workers for each test and train data queue. The code runs on Pytorch 1.0, Python 3.6, CUDA 9.0 with cuDNN. We use 1 Tesla V100 with 16 GB GPU memory and 8vCPU Intel Skylake.

For Sec. 4.3 We used the implementation in (https://github.com/pytorch/examples/tree/master/imagenet). We used a batch size of 32, L2 penalty of 0.0001, momentum of 0.9, and perform standard data augmentation like cropping, horizontal flipping and input data normalization. The code runs on Pytorch 1.0, Python 3.6, CUDA 9.0 with cuDNN. We use 4 Tesla V100 with 64 GB GPU memory and 32vCPU Intel Skylake.

For Sec. 4.4 We adopted the implementation in
(https://github.com/Microsoft/human-pose-estimation.pytorch) . The code runs on Pytorch 1.0, Python 3.6, CUDA 9.0 with CudNN. We use 4 Tesla V100 with 64 GB GPU memory and 32vCPU Intel Skylake.

For Sec. 4.5 We used
(https://github.com/CSAILVision/semantic-segmentation-pytorch) for implementation details. The code runs on Pytorch 1.0, Python 3.6, CUDA 9.0 with cuDNN. We use 4 Tesla V100 with 64 GB GPU memory and 32vCPU Intel Skylake.

Appendix B Random Sampling

Another trivial case of using the batchwise model weight distribution could involve randomly sampling parameter weights and performing fast forward passes on the train set to choose the configuration maximizing the objective function iteratively. We analyze this scenario in comparison to standard SGD and PSWA on our CIFAR experiments. For our experiments we sample ‘s’ (s=10) configurations from the model weight distributions along with the standard configuration (post-SGD) after each epoch. We then perform fast forward passes on a random subset of the training dataset (10%) and reassign the model weights with that of the configuration which yields the least cross-entropy loss. We ensure we use the same training data subset for all configurations for an unbiased comparison. We present the results for the experiments in Fig. 19. As we can see the approach does provide improvement over SGD performance and matches the PSWA performance for most of the training process, however it fails to converge optimally and yields a suboptimal final performance compared to both SGD and PSWA.

Figure 19. Random Sampling on ResNet50 convergence problem.