Recently, there has been a surge in the interest of acquiring a theoretical understanding over deep neural network’s behavior. Breakthroughs have been made in characterizing the optimization process, showing that learning algorithms such as stochastic gradient descent (SGD) tend to end up in one of the many local minima which have close-to-zero training loss(choromanska2015loss; dauphin2014identifying; kawaguchi2016deep; nguyen2018optimization; du2018gradient). However, these numerically similar local minima typically exhibit very different behaviors in terms of generalizability. It is, therefore, natural to ask two closely related questions: (a) What kind of local minima can generalize better? (b) How to find those better local minima?
To our knowledge, existing work focused only on one of the two questions. For the “what” question, various definitions of “flatness/sharpness” have been introduced and analyzed (keskar2016large; neyshabur2017pac; neyshabur2017exploring; wu2017towards; liang2017fisher). However, they suffer from one or more of the problems: (1) being mostly theoretical with no or poor empirical evaluations on modern neural networks, (2) lack of theoretical analysis and understanding, (3) in practice not applicable to finding better local minima. Regarding the “how” question, existing approaches (hochreiter1997flat; sokolic2017robust; chaudhari2016entropy; hoffer2017train; neyshabur2015path; izmailov2018averaging) share some of the common drawbacks: (1) derived only from intuitions but no specific metrics provided to characterize local minima, (2) no or weak analysis of such metrics, (3) not applicable or no consistent generalization improvement for modern DNNs.
In this paper, we tackle both the “what” and the “how” questions in a unified manner. Our answer provides both the theory and applications for the generalization problems across different local minima. Based on the determinant of Fisher information estimated from the training set, we propose a metric thatsolves all the aforementioned issues. The metric can well capture properties that characterize local minima of different generalization ability. We provide its theoretical analysis, primarily a generalization bound based on PAC-Bayes (mcallester1999some; mcallester1999pac). For modern DNNs in practice, it is necessary to provide a tractable approximation of our metric. We propose an intuitive and efficient approximation to compare it across different local minima. Our empirical evaluations fully illustrate the effectiveness of the metric as a strong indicator of local minima’s generalizability. Moreover, from the metric we further derive and design a practical regularization technique that guides the optimization process in finding better generalizable local minima. The experiments on image classification datasets demonstrate that our approach gives consistent generalization boost for a range of DNN architectures.
2 Related Work
It has been empirically shown that larger batch sizes lead to worse generalization (keskar2016large). hoffer2017train
analyzed how the training dynamics is affected by different batch sizes and presented a perturbed batch normalization technique for better generalization. While it effectively improves generalization for large-batch training, a specific metric that indicates the generalizability is missing. Similarly,elsayed2018large employed a structured margin loss to improve performance of DNNs w.r.t. noise and adversarial attack yet no metric was proposed. Furthermore, this approach essentially provided no generalization gain in the normal training setup.
The local entropy of the loss landscape was proposed to measure “flatness” in chaudhari2016entropy, which also designed an entropy-guided SGD that achieves faster convergence in training DNNs. However, the method does not consistently improve generalization, e.g., a decrease of performance on CIFAR-10 (krizhevsky2009learning). Another method that focused on modifying the optimization process is the Path-SGD proposed by neyshabur2015path. Specifically, the authors derived an approximate steepest descent algorithm that utilizes the path-wise norm regularization to achieve better generalization. The authors only evaluated it on a two-layer neural network, very likely since the path norm is computationally expensive to optimize during training.
A flat minimum search algorithm was proposed by hochreiter1997flat based on the “flatness” of local minima defined as the volume of local boxes. Yet since the boxes have their axes aligned to the axes of the model parameters, their volumes could be significant underestimations of “flatness” for over-parametrized networks, due to the specific spectral density of Hessian of DNNs studied in pennington2018spectrum; sagun2017empirical. The authors of wu2017towards also characterized the “flatness” by volumes. They considered the inverse volume of the basin of attraction and proposed to use the Frobenius norm of Hessian at the local minimum as a metric. In our experiments, we show that their metric does not accurately capture the generalization ability of local minima under different scenarios. Moreover, they have not derived a regularizer from their metric.
Based on a “robustness” metric, sokolic2017robust derived a regularization technique that successfully improves generalization on multiple image classification datasets. Nevertheless, we show that their metric fails to capture the generalizability across different local minima.
By using the Bayes factor,mackay1992practical studied the generalization ability of different local minima obtained by varying the coefficient of L2 regularization. It derived a formula involving the determinant of Hessian, similar to the one in ours. Whereas, this approach has restricted settings and, without proposing an efficient approximation, its metric is not applicable to modern DNNs, let alone serving as a regularizer. A generalization bound is missing in mackay1992practical as well.
In a broader context of the “what” question, properties that capture the generalization of neural networks have been extensively studied. Various complexity measures for DNNs have been proposed based on norm, margin, Lipschitz constant, compression and robustness (bartlett2002rademacher; neyshabur2015norm; sokolic2017robust; xu2012robustness; bartlett2017spectrally; zhou2018nonvacuous; dziugaite2017computing; arora2018stronger; jiang2018predicting). While some of them aimed to provide tight generalization bounds and some of them to provide better empirical results, none of the above approaches explored the “how” question at the same time.
Very recently, karakida2019universal and sun2019lightlike studied the Fisher information of the neural network through the lens of its spectral density. In specific, karakida2019universal applied mean field theory to study the statistics of the spectrum and the appropriate size of the learning rate. Also an information-theoretic approach, sun2019lightlike derived a novel formulation of the minimum description length in the context of deep learning by utilizing tools from singular semi-Riemannian geometry.
3 Outline and Notations
In a typical -way classification setting, each sample belongs to a single class denoted, where is the k-dimensional probability simplex so that and . Denote a feed-forward DNN parametrized by as . Denote the training set as , defined over with . The training objective is given as . Assume is sampled from some true data distribution denoted , we can define expected loss . Throughout this paper, we refer a local minimum of corresponding to a local minimizer as just the local minimum . Given such , our paper’s outline as well as our main achievements are:
3.1 Other Notations
Denote as gradient, as Jacobian matrix, as Hessian, as KL divergence, as spectrum or Euclidean norm, as Frobenius norm, as determinant, as trace norm, as spectral radius, as log-likelihood on , and for selecting the entry.
We define whose entry is so that . We define as and the one-hot vector whose -th entry is 1 and otherwise 0. Then we define as the “simplified” loss vector of whose entries are for , i.e., we approximate the cross entropy loss by .
4 Local Minimum and Fisher Information
First of all, if is strictly one-hot, no local minimum will even exist with 100% training accuracy, since the cross entropy loss will always be positive. To admit good local minima in the first place, we assume the widely used label smoothing (szegedy2016rethinking) is applied to train all models in our analysis. Label smoothing enables us to assume a local minimum (in this case, also a global minimum) of the training loss with .
Each sample has its label sampled by , denoted as . The joint probability modeled by the DNN is with . We can relate the training loss to the negative log-likelihood by:
Also, corresponds to a local maximum of the likelihood function. The observed Fisher information (efron1978assessing) evaluated at is defined as , i.e.
Remark: When we assume global optimality, we have as ; yet it does not indicate in Equation 2.
5 Local Minima Characterization
In this section, we derive and propose our metric, provide a PAC-Bayes generalization bound, and lastly, propose and give intuitions of an effective approximation of our metric for modern DNNs.
5.1 Fisher Determinant as Generalization Metric
We would like a metric to compare different local minima. Similar to the various definitions of “flatness/sharpness”, we take a small neighborhood of the target local minimum into account. Formally for a sufficiently small , we define the model class as the largest connected subset of that contains , where the height is defined as a real number such that the volume (namely the Lebesgue measure) of is . By the Intermediate Value Theorem, for any sufficiently small volume there exists a corresponding height .
We propose our metric , where lower indicates a better local minimum :
As a metric, requires . Therefore, we state the following Assumption 1.
The local minima we care about in the comparison are well isolated and unique in their corresponding neighborhood .
The Assumption 1 is quite reasonable. For state-of-the-art network architectures used in practice, this is often the fact. To be precise, the Assumption 1 is violated when the Hessian matrix at a local minimum is singular. Specifically, orhan2018skip
summarizes three sources of the singularity: (i) due to a dead neuron, (ii) due to identical neurons, and (iii) linear dependence of the neurons. As well demonstrated inorhan2018skip, network with skip connection, e.g. ResNet (he2016deep), WRN (wideresnet), and DenseNet (huang2017densely) used in our experiments, can effectively eliminate all the aforementioned singularity.
, the authors pointed out another source of the singularity specifically for networks with scale-invariant activation functions, e.g. ReLU, referred as the rescaling issue. Namely, one can rescales the model parameters layer-wise so that the underlying function represented by the network remains unchanged in the region. In practice, this issue is not critical. Firstly, most modern deep ReLU networks, e.g. ResNet, WRN, and DenseNet, have normalization layers, e.g. BatchNorm(ioffe2015batch), applied before the activations. BatchNorm shifts all the inputs to the ReLU function, equivalently shifting the ReLU horizontally which makes it no longer scale-invariant. Secondly, due to the ubiquitous use of Gaussian weights initialization scheme and weight decay, most local minima obtained by gradient learning have weights of a relatively small norm. Consequently, in practice, we will not compare two local minima essentially the same but have one as the rescaled version of the other with a much larger norm of the weights.
Note that normally we have a limited size of the dataset, and so an approximation of is a must. We present our approximation scheme and its intuition in Section 5.3.
5.1.1 Connection to Fisher Information Approximation (FIA) Criterion
Our metric is closely related to the FIA criterion. From Information Theory, the MDL principle suggests that among different statistical models the best is the one that best compresses both the sampled data and the model (rissanen1978modeling). Accordingly, rissanen1996fisher derived the FIA criterion to compare statistical models, each of which is a class of model in the neighborhood of a global minimum . The model class’s FIA criterion is written as (lower FIA is better):
On the right hand side, the first two terms are both constants. To see the connection to our metric, we replace the expected Fisher information with the tractable observed one . Assuming the training loss is locally quadratic in , an assumption later formalized and validated as Assumption 2, since from Equation 1, we can modify the last term to be .
Remark: Although in a similar format, the FIA criterion and our metric are essentially different due to the appearance of observed Fisher information in place of the expected one, making our metric both tractable and much more applicable (no longer requires global optimality).
5.1.2 Connection to Existing Flatness/Sharpness Metrics
As mentioned in Section 2, the “flatness” of a local minimum was firstly related to the generalization ability of the neural network in hochreiter1997flat, where the concept and the method are both preliminary. The idea is recently popularized in the context of deep learning by a series of paper such as keskar2016large; chaudhari2016entropy; wu2017towards. Our approach roughly shares the same intuition with these existing works, namely, a “flat” local minimum admits less complexity and so generalizes better than a “sharp” one. To our best knowledge, our paper is the first among these work that provides both the theoretical analysis including a generalization bound and the empirical verification of both an efficient metric and a practical regularizer for modern network architectures.
5.2 Generalization Bound
The Assumption 2 is quite reasonable as well. grunwald2007minimum suggests that, a log-likelihood function, under regularity conditions (1) existence of its , & derivatives and (2) uniqueness of its maximum in the region, behaves locally like a quadratic function around its maximum. In our case, corresponds to the log-likelihood function and so corresponds to a local maximum of . Since is analytic and is the only local minimum of in , the training loss indeed can be considered locally quadratic.
Similar to langford2002not, harvey2017nearly and neyshabur2017exploring
, we apply the PAC-Bayes Theorem(mcallester2003simplified) to derive a generalization bound for our metric. Specifically, we pick a uniform prior over according to the maximum entropy principle and after observing the training data pick the posterior of density . Then Theorem 1 bounds the expected using . See its proof in Appendix B.
In short, Theorem 1 shows that a lower indicates a more generalizable local minimum .
As stated in Section 4, in practice an approximation of as is necessary, as calculating involves computing the determinant of a matrix. Let us first assume we have an imagined training set of size , a local minimum of and so correspondingly a full-rank observed Fisher information matrix so that is well defined. In reality, we only have a training set with a singular . Notice that is also a local minimum of since as assumed in Section 4
. We then approximate eigenvalues ofby those of its sub-matrices and so to approximate .
First of all, we replace by its one-hot version defined in Section 3.1 since they are very close. This drastically reduces the cost of gradient calculation. With and defined in Section 3.1, according to Equation 2, the observed Fisher information is:
Let denote the eigenvalues of ; then . Without calculating all eigenvalues, we can perform a Monte-Carlo estimation of by randomly sampling eigenvalues from . We denote the samples as and we have . Suppose the estimation is run times, we have .
In practice is inaccessible since we don’t have in the first place. Instead, we sample with for times and define
Notice that is a principal sub-matrix of by removing rows & columns for data in \ . According to Theorem 3, one can roughly estimate the size of eigenvalues of a matrix by those of its sub-matrices. Therefore we propose to estimate by with:
We leave Theorem 3 as well as the derivation of Equation 5 to Appendix C. In proposing , we ignore the constants and irrelevant scaling factors because what matters is the relative size of when comparing different local minima. Empirically we find that given relatively large number of sample trials , our metric can effectively capture the generalizability of a local minimum even for a small (details in Section 7.1 and in Appendix D).
6 Local Minima Regularization
Besides pragmatism, devising a practical regularizer based on also “verifies” our theoretical understanding of DNN training, helping for future improvement of the learning algorithms. However, converting to a practical regularizer is non-trivial due to the computation burden of:
optimizing terms related to the gradient, which involves calculating the Hessian
computing the eigenvalues in each training step, which is even more expensive
We first solve the second issue and then the first one. To solve the second issue, we propose to optimize a surrogate term for which avoids eigenvalue computations, namely the trace norm of the observed Fisher information . These two terms have the relation:
Another major benefit of using the trace norm is that, unlike , still remains well defined even with a small training set . From Equation 2 we have:
The cost of computing is linear in the number of its terms (in the double summation). We therefore simplify the calculation by replacing with similar to Equation 5.3 so that
where and are defined in Section 3.1. As in gradient-based training we never exactly reach the local minimum , we choose to optimize during the entire training process. We have for each mini-batch . Then we can further reduce the computation cost by batching. In specific, we randomly split into sub-batches of equal size, namely . We define and compute for instead of computing for each data point in .
We deal with the first computation burden by adopting first order approximation. For any , with a sufficiently small we have . Then
Therefore, we propose to optimize the following regularized training objective for each update step:
We omit any second order term when computing , simply by no back-prop through
. On the other hand, we find that gradient clipping, especially at the beginning of the training, is necessary to make the generalization boost consistent. We have 4 hyper-parameters:, , the number of sub-batches and the gradient clip threshold . Our approach is formalized as:
We perform two sets of experiments to illustrate the effectiveness of our metric . We demonstrate that: (1) the approximation captures the generalizability well across local minima; (2) our regularization technique based on provides consistent generalization gain for DNNs.
Throughout our theoretical analysis, we assume that label smoothing (LS) is applied during model training in order to obtain well-defined local minima (first mentioned in Section 4). In all our empirical evaluations, we perform both the version with LS applied and without. Results are very similar and so we stick to the version without LS since this is the same as the original training setup in papers of the various network architectures that we used.
7.1 Experiments on Local Minima Characterization
We perform comprehensive evaluations to compare our metric with several others on ResNet-20 (he2016deep) for the CIFAR-10 dataset (architecture details in Appendix E). Our metric consistently outperforms others in indicating local minima’s generalizability. Specifically, sokolic2017robust proposed a robustness-based metric used as a regularizer; wu2017towards proposed to use Frobenius norm of the Hessian as a metric; keskar2016large proposed a metric closely related to the spectral radius of Hessian. In summary, we compare 4 metrics, all evaluated at a local minimum given . All four metrics go for “smaller values indicate better generalization”.
Both the Frobenius norm and the spectral radius based metric are related to ours, as from Equation 1 we have and . These two metric, however, are too expensive to compute for the entire training set ; we instead calculate them by averaging the results for sampled , similar to when we compute . We leave details of how we exactly compute these metrics in our experiments to Appendix D.
We perform evaluations in three scenarios, similar to neyshabur2017exploring; keskar2016large. We examine different local minima due to (1) a confusion set of varying size in training, (2) different data augmentation schemes, and (3) different batch size. In specific,
In Scenario I, we randomly select a subset of 10000 images as the training set and train the DNN with a confusion set consisting of CIFAR-10 samples with random labels. We vary the size of the confusion set so that the resulting local minima generalize differently to the test set while all remain close-to-zero training losses. We consider confusion size of , 1k, 2k, 3k, 4k and 5k. We calculate all metrics based on the sampled 10000 training images.
In Scenario II, we vary the level of data augmentation. We apply horizontal flipping, denoted flip-only
, random cropping from images with 1 pixel padded each side plus flipping, denoted1-crop-f, random cropping with 4 pixels padded each side plus flipping, denoted 4-crop-f and no data augmentation at all, denoted no-aug. Under all schemes, the network achieves perfect training accuracy. All the metrics are computed on the un-augmented training set.
In Scenario III, we vary the batch size. hoffer2017train suggests that large batch size leads to poor generalization. We consider the batch size to be , , and .
The default values for the 3 variables are confusion size 0, 4-crop-f
and batch size 128. For each configuration in each scenario, we train 5 models and report results (average & standard deviations) of all metrics as well as the test errors (in percentage). For the confusion set experiments, we sample a new training set and a new confusion set every time. In all scenarios, we train the model for 200 epochs with an initial learning rate 0.1, divided by 10 whenever the training loss plateaus. Within each scenario, we find the final training loss very small and very similar across different models and the training accuracy essentially equal to 1, indicating the convergence to local minima.
7.2 Experiments on Local Minima Regularization
We evaluate our regularizer on CIFAR-10 & CIFAR-100 for four different network architectures including a plain CNN, ResNet-20, Wide ResNet (wideresnet) and DenseNet (huang2017densely). We use WRN-28-2-B(3,3) from the Wide ResNet paper and the DenseNet-BC-k=12 from the DensetNet paper. See Appendix E for further architecture details. We denote the four networks as CNN, ResNet, WRN and DenseNet, respectively.
We manually set in all experiments and select the other three hyper-parameters in Algorithm 1 by validation via a 45k/5k training data split for each of the network architecture on each dataset. In specific, we consider , and . We keep all the other training hyper-parameters, schemes as well as the setup identical to those in their original paper (details in Appendix E). The training details of the plain CNN are also in Appendix E. We train 5 separate models for each network-dataset combination and report the test errors in percentage (mean std.) in Table 1, where “+reg” indicates training with our regularizer applied. The results demonstrate that our method provides consistent generalization improvement to a wide range of DNNs.
7.2.1 The Choice of The Optimizer
As described in Algorithm 1, our proposed regularizer is not tied to a specific optimizer. We perform experiments with SGD+Momentum because it is chosen to be used in ResNet, WRN, and DenseNet, helping all of them achieve current or previous state-of-the-art results. Our regularizer aims to find better “flatter” minima to improve generalization whereas adaptive optimization methods such as Adam (kingma2014adam) and AdaGrad (duchi2011adaptive) try to boost up convergence, yet usually at the cost of generalizability. Recent works (wilson2017marginal; keskar2017improving) show that adaptive methods generalize worse than SGD+Momentum. In specific, very similar to our setup, keskar2017improving demonstrates that SGD+Momentum consistently outperforms the others on ResNet and DenseNet for CIFAR-10 and CIFAR-100. Other approaches that also utilize local curvature to improve SGD, such as the Entropy-SGD (chaudhari2016entropy) mentioned in Section 2, have empirical results rather preliminary compared to ours.
|8.52 0.08||7.46 0.13||5.44 0.04||4.80 0.10||4.61 0.08||4.31 0.08||8.50 0.31||7.91 0.25|
|31.12 0.35||29.19 0.21||25.52 0.15||23.65 0.14||22.54 0.32||22.19 0.28||-||-|
|w/o reg.||-979.3 22.3||-737.6 20.3||-850.3 23.5|
|with reg.||-1138.1 11.0||-804.8 18.7||-886.2 20.5|
7.2.2 Generalization Boost As a Result of Better Local Minima
Our regularizer essentially optimizes an upper bound of the proposed metric during training. We perform a sanity check to illustrate that the regularizer indeed induces better local minima characterized by our metric. For ResNet, Wide-ResNet and DenseNet trained on CIFAR-10, we compute the metric on local minima of similar training loss obtained with or without applying the regularizer. Table 2 shows that the resulting generalization boost aligns with what captured by our metric.
8 Conclusion and Future Work
In this paper, we show a bridge between the field of deep learning theory and regularization methods with respect to the generalizability of local minima. We propose a metric that captures the generalization properties of different local minima and provide its theoretical analysis including a generalization bound. We further derive an efficient approximation of the metric and devise a practical and effective regularizer from it. Empirical results demonstrate our success in both capturing and improving the generalizability of DNNs. Our exploration promises a direction for future work on the regularization and optimization of DNNs.
Appendix A Proof of Equation 1
To prove the second equality in Equation 1, it suffices to prove the following equality:
For convenience, we change the notation of the local minimum from to and further denote as . Since , for each and , we have:
Since is a local minimum (also a global minimum) described in Section 4 as for , when taking the double summation, the first term above becomes:
Then it follows that:
Appendix B Proof of the Generalization Bound in Section 5.2
First let us review the PAC-Bayes Theorem in mcallester2003simplified:
For any data distribution and a loss function
and a loss function, let and be the expected loss and training loss respectively for the model paramterized by , with the training set . For any prior distribution with a model class as its support, any posterior distribution over (not necessarily Bayesian posterior), and for any , we have with probability at least that:
As , we can rewrite the generalization bound we want to prove in Section 5.2 as: E_w ∼Q [L(D, w)] ≤L_0 + W⋅V2/Wπ1/W—I(w0) —1/W4πe + 2 W⋅V2/Wπ1/W—I(w0) —1/W+ 4πe L0+ 2πe ln2Nδ2πe (N - 1)
As defined in Section 5.2, given the model class , whose volume is , for the neural network , the uniform prior
attains the probability density functionfor any and the posterior has density . Based on Assumption 2 and the observed Fisher information derived in Section 4, especially the Equation 2, we have:
Denote . Then
is a truncated multivariate Gaussian distribution whose density functionis:
Denote the denominator of Equation 7 as Z and define:
Then can also be written as:
In order to derive a generalization bound in the form of the PAC-Bayes Theorem, it suffices to prove an upper bound of the KL divergence term:
where is the height of defined in Section 5.1. For convenience, we shift down by and denote the shifted training loss so that . Then
Furthermore, the following two sets are equivalent
both of which are the -dimensional hyperellipsoid given by the equation , which can be converted to the standard form for hyperellipsoids as:
The volume enclosed by this hyperellipsoid is exactly the volume of , i.e., ; so we have
Solve for , with the Stirling’s approximation for factorial , we have
where denotes the Gamma function. Notice that for modern DNNs we have , and so . Then the generalization bound in the form of the PAC-Bayes Theorem is given as: