Log In Sign Up

A Bayesian Perspective on Training Speed and Model Selection

by   Clare Lyle, et al.
University of Oxford

We take a Bayesian perspective to illustrate a connection between training speed and the marginal likelihood in linear models. This provides two major insights: first, that a measure of a model's training speed can be used to estimate its marginal likelihood. Second, that this measure, under certain conditions, predicts the relative weighting of models in linear model combinations trained to minimize a regression loss. We verify our results in model selection tasks for linear models and for the infinite-width limit of deep neural networks. We further provide encouraging empirical evidence that the intuition developed in these settings also holds for deep neural networks trained with stochastic gradient descent. Our results suggest a promising new direction towards explaining why neural networks trained with stochastic gradient descent are biased towards functions that generalize well.


page 1

page 2

page 3

page 4


A subsampling approach for Bayesian model selection

It is common practice to use Laplace approximations to compute marginal ...

Neural ODEs as the Deep Limit of ResNets with constant weights

In this paper we prove that, in the deep limit, the stochastic gradient ...

Neural Tangents: Fast and Easy Infinite Neural Networks in Python

Neural Tangents is a library designed to enable research into infinite-w...

The Impact of Neural Network Overparameterization on Gradient Confusion and Stochastic Gradient Descent

The goal of this paper is to study why stochastic gradient descent (SGD)...

Stochastic Gradient/Mirror Descent: Minimax Optimality and Implicit Regularization

Stochastic descent methods (of the gradient and mirror varieties) have b...

Parsimonious Bayesian deep networks

Combining Bayesian nonparametrics and a forward model selection strategy...

Tutorial on Variational Autoencoders

In just three years, Variational Autoencoders (VAEs) have emerged as one...

1 Introduction

Choosing the right inductive bias for a machine learning model, such as convolutional structure for an image dataset, is critical for good generalization. The problem of

model selection

concerns itself with identifying good inductive biases for a given dataset. In Bayesian inference, the marginal likelihood (ML) provides a principled tool for model selection. In contrast to cross-validation, for which computing gradients is cumbersome, the ML can be conveniently maximised using gradients when its computation is feasible. Unfortunately, computing the marginal likelihood for complex models such as neural networks is typically intractable. Workarounds such as variational inference suffer from expensive optimization of many parameters in the variational distribution and differ significantly from standard training methods for Deep Neural Networks (DNNs), which optimize a single parameter sample from initialization. A method for estimating the ML that closely follows standard optimization schemes would pave the way for new practical model selection procedures, yet remains an open problem.

A separate line of work aims to perform model selection by predicting a model’s test set performance. This has led to theoretical and empirical results connecting training speed and generalization error (hardt2015train; jiang2019fantastic). This connection has yet to be fully explained, as most generalization bounds in the literature depend only on the final weights obtained by optimization, rather than on the trajectory taken during training, and therefore are unable to capture this relationship. Understanding the link between training speed, optimization and generalization thus presents a promising step towards developing a theory of generalization which can explain the empirical performance of neural networks.

In this work, we show that the above two lines of inquiry are in fact deeply connected. We investigate the connection between the log ML and the sum of predictive log likelihoods of datapoints, conditioned on preceding data in the dataset. This perspective reveals a family of estimators of the log ML which depend only on predictions sampled from the posterior of an iterative Bayesian updating procedure. We study the proposed estimator family in the context of linear models, where we can conclusively analyze its theoretical properties. Leveraging the fact that gradient descent can produce exact posterior samples for linear models (matthews2017) and the infinite-width limit of deep neural networks (matthews2018gaussian; lee2018deep), we show that this estimator can be viewed as the sum of a subset of the model’s training losses in an iterative optimization procedure. This immediately yields an interpretation of marginal likelihood estimation as measuring a notion of training speed in linear models. We further show that this notion of training speed is predictive of the weight assigned to a model in a linear model combination trained with gradient descent, hinting at a potential explanation for the bias of gradient descent towards models that generalize well in more complex settings.

We demonstrate the utility of the estimator through empirical evaluations on a range of model selection problems, confirming that it can effectively approximate the marginal likelihood of a model. Finally, we empirically evaluate whether our theoretical results for linear models may have explanatory power for more complex models. We find that an analogue of our estimator for DNNs trained with stochastic gradient descent (SGD) is predictive of both final test accuracy and the final weight assigned to the model after training a linear model combination. Our findings in the deep learning setting hint at a promising avenue of future work in explaining the empirical generalization performance of DNNs.

2 Background and Related Work

2.1 Bayesian Parameter Inference

A Bayesian model is defined by a prior distribution over parameters , , and a prediction map from parameters to a likelihood over the data , . Parameter fitting in the Bayesian framework entails finding the posterior distribution , which yields robust and principled uncertainty estimates. Though exact inference is possible for certain models like Gaussian processes (GPs) (rasmussen2003gaussian), it is intractable for DNNs. Here approximations such as variational inference (blei2017variational) are used (gal2016dropout; blundell2015weight; mackay1992bayesian; graves2011practical; duvenaud2016early), to improve robustness and obtain useful uncertainty estimates.

Variational approximations require optimisation over the parameters of the approximate posterior distribution. This optimization over distributions changes the loss landscape, and is significantly slower than the pointwise optimization used in standard DNNs. Pointwise optimization methods inspired by Bayesian posterior sampling can produce similar variation and uncertainty estimates as variational inference, while improving computational efficiency (welling2011bayesian; mandt2017stochastic; maddox2019simple). An appealing example of this is ensembling (lakshminarayanan2017simple), which works by training a collection models in the usual pointwise manner, starting from independently initialized points.

In the case of linear models, this is exactly equivalent to Bayesian inference, as this sample-then-optimize approach yields exact posterior samples (matthews2017; osband2018randomized). he2020bayesian extend this approach to obtain posterior samples from DNNs in the infinite-width limit.

2.2 Bayesian Model Selection

In addition to finding model parameters, Bayesian inference can also perform model selection over different inductive biases, which are specified through both model structure (e.g. convolutional vs fully connected) and the prior distribution on parameters. The Bayesian approach relies on finding the posterior over models , which uses the marginal likelihood (ML) as its likelihood function:


Instead of computing the full posterior, it is common to select the model with the highest marginal likelihood. This is known as type-II maximum likelihood (mackay1992bayesian; mackay2003information) and is less prone to overfitting than performing maximum likelihood over the parameters and model combined. This is because the marginal likelihood is able to trade off between model fit and model complexity (rasmussen2001occam). Maximising the ML is standard procedure when it is easy to compute. For example, in Gaussian processes it used to set simple model parameters like smoothness (rasmussen2003gaussian), while recent work has demonstrated that complex inductive biases in the form of invariances can also be learned (van_der_wilk_learning_2018).

For many deep models, computing Equation 1 is intractable, and obtaining approximations that are accurate enough for model selection and that scale to complex models is an active area of research khan2019approximate. In general, variational lower bounds that scale are too loose when applied to DNNs (blundell2015weight). Deep Gaussian processes provide a case where the bounds do work (damianou13a; dutordoir20a), but heavy computational load holds performance several years behind deep learning. While ensembling methods provide useful uncertainty estimates and improve the computational efficiency of the variational approach, they have not yet provided a solution for Bayesian model selection.

2.3 Generalization and Risk Minimization

Bayesian model selection addresses a subtly different problem from the risk minimization framework used in many learning problems. Nonetheless, the two are closely related; germain_pac-bayesian_2016 show that in some cases optimizing a PAC-Bayesian risk bound is equivalent to maximizing the marginal likelihood of a Bayesian model. In practice, maximizing an approximation of the marginal likelihood in DNNs trained with SGD can improve generalization performance (smith2017bayesian). More recently, arora2019fine computed a data-dependent complexity measure which resembles the data-fit term in the marginal likelihood of a Bayesian model and which relates to optimization speed, hinting at a potential connection between the two.

At the same time, generalization in deep neural networks (DNNs) remains mysterious, with classical learning-theoretic bounds failing to predict the impressive generalization performance of DNNs (zhang2016understanding; nagarajan2019uniform). Recent work has shown that DNNs are biased towards functions that are ‘simple’, for various definitions of simplicity (kalimeris2019sgd; frankle2018the; valle2018deep; smith2018). PAC-Bayesian generalization bounds, which can quantify a broad range of definitions of complexity, can attain non-vacuous values (mcallester1999; dziugaite_computing_2017; dziugaite_data-dependent_2018), but nonetheless exhibit only modest correlation with generalization error (jiang2019fantastic). These bounds depend only on the final distribution over parameters after training; promising alternatives consider properties of the trajectory taken by a model during optimization (hardt2015train; negrea2019information). This trajectory-based perspective is a promising step towards explaining the correlation between the number of training steps required for a model to minimize its objective function and its final generalization performance observed in a broad range of empirical analyses (jiang2019fantastic; belkin2018reconciling; nakkiran2019deep; ru2020revisiting).

3 Marginal Likelihood Estimation with Training Statistics

In this section, we investigate the equivalence between the marginal likelihood (ML) and a notion of training speed in models trained with an exact Bayesian updating procedure. For linear models and infinitely wide neural networks, exact Bayesian updating can be done using gradient descent optimisation. For these cases, we derive an estimator of the marginal likelihood which 1) is related to how quickly a model learns from data, 1) only depends on statistics that can be measured during pointwise gradient-based parameter estimation, and 1) becomes tighter for ensembles consisting of multiple parameter samples. We also investigate how gradient-based optimization of a linear model combination can implicitly perform approximate Bayesian model selection in Section 3.3.

3.1 Training Speed and the Marginal Likelihood

Let denote a dataset of the form , and let with . We will abbreviate when considering a single model . Observe that to get the following form of the log marginal likelihood:


If we define training speed as the number of data points required by a model to form an accurate posterior, then models which train faster – i.e. whose posteriors assign high likelihood to the data after conditioning on only a few data points – will obtain a higher marginal likelihood. Interpreting the negative log posterior predictive probability

of each data point as a loss function, the log ML then takes the form of the sum over the losses incurred by each data point during training, i.e. the area under a training curve defined by a Bayesian updating procedure.

3.2 Unbiased Estimation of a Lower Bound

In practice, computing may be intractable, necessitating approximate methods to estimate the model evidence. In our analysis, we are interested in estimators of computed by drawing samples of for each . We can directly estimate a lower bound using the log likelihoods of these samples


This will produce a biased estimate of the log marginal likelihood due to Jensen’s inequality. We can get a tighter lower bound by first estimating using our posterior samples before applying the logarithm, obtaining

Proposition 3.1.

Both and as defined in Equation 4 are estimators of lower bounds on the log marginal likelihood; that is


Further, the bias term in can be quantified as follows.


We include the proof of this and future results in Appendix A

. We observe that both lower bound estimators exhibit decreased variance when using multiple posterior samples; however,

also exhibits decreasing bias (with respect to the log ML) as increases; each defines a distinct lower bound on . The gap induced by the lower bound is characterized by the information gain each data point provides to the model about the posterior, as given by the Kullback-Leibler (KL) divergence (kullback1951) between the posterior at time and the posterior at time . Thus, while has a Bayesian interpretation it is arguably more closely aligned with the minimum description length notion of model complexity (hinton1993keeping).

When the posterior predictive distribution of our model is Gaussian, we consider a third approach which, unlike the previous two methods, also applies to noiseless models. Let , and be parameter samples from . We assume a mapping such that sampling parameters and computing is equivalent to sampling from the posterior . We can then obtain the following estimator of a lower bound on .

Proposition 3.2.

Let for some . Define the standard mean and variance estimators and . Then the estimator


is a lower bound on the log ML: i.e. .

We provide an empirical evaluation of the rankings provided by the different estimators in Section 4. We find that exhibits the least bias in the presence of limited samples from the posterior, though we emphasize its limitation to Gaussian posteriors; for more general posterior distributions, minimizes bias while still estimating a lower bound.

3.2.1 Lower bounds via gradient descent trajectories

The bounds on the marginal likelihood we introduced in the previous section required samples from the sequence of posteriors as data points were incrementally added . Ensembles of linear models trained with gradient descent yield samples from the model posterior. We now show that we can use these samples to estimate the log ML using the estimators introduced in the previous section.

We will consider the Bayesian linear regression problem of modelling data

assumed to be generated by the process for some unknown , known , and feature map . Typically, a Gaussian prior is placed on ; this prior is then updated as data points are seen to obtain a posterior over parameters. In the overparmeterised, noiseless linear regression setting, matthews2017 show that the distribution over parameters obtained by sampling from the prior on and running gradient descent to convergence on the data is equivalent to sampling from the posterior conditioned on . osband2018randomized extend this result to posteriors which include observation noise under the assumption that the targets are themselves noiseless observations.

Input: A dataset , parameters
Result: An estimate of
;  ;  sumLoss 0 ;
for  do
       sumLoss sumLoss ;
       GradientDescent() ;
end for
return sumLoss
Algorithm 1 Marginal Likelihood Estimation for Linear Models

We can use this procedure to obtain posterior samples for our estimators by iteratively running sample-then-optimize on the sets . Algorithm 1 outlines our approach, which uses sample-then-optimize on iterative subsets of the data to obtain the necessary posterior samples for our estimator. Theorem 3.3

shows that this procedure yields an unbiased estimate of

when a single prior sample is used, and an unbiased estimate of when an ensemble of models are trained in parallel.

Theorem 3.3.

Let and let be generated by the procedure outlined above. Then the estimators and , applied to the collection , are lower bounds on . Further, expressing as the regression loss plus a constant, we then obtain


We highlight that Theorem 3.3 precisely characterizes the lower bound on the marginal likelihood as a sum of ‘training losses’ based on the regression loss .

3.2.2 From Linear Models to Infinite Neural Networks

Beyond linear models, our estimators can further perform model selection in the infinite-width limit of neural networks. Using the optimization procedure described by he2020bayesian, we can obtain an exact posterior sample from a GP given by the neural tangent kernel (jacot2018neural). The iterative training procedure described in Algorithm 1 will thus yield a lower bound on the marginal likelihood of this GP using sampled losses from the optimization trajectory of the neural network. We evaluate this bound in Section 4, and formalize this argument in the following corollary.

Corollary 3.4.

Let be a dataset indexed by our standard notation. Let be sampled from an infinitely wide neural network architecture under some initialization distribution, and let be the limiting solution under the training dynamics defined by he2020bayesian applied to the initialization and using data . Let denote the neural tangent kernel for , and the induced Gaussian Process. Then , and in the limit of infinite training time, the iterative sample-then-optimize procedure yields an unbiased estimate of . Letting denote the scaled squared regression loss and be a constant, we obtain as a direct corollary of Theorem 3.3


This result provides an additional view on the link between training speed and generalisation in wide neural networks noted by arora2019fine, who analysed the convergence of gradient descent. They compute a PAC generalization bound which a features the data complexity term equal to that in the marginal likelihood of a Gaussian process rasmussen2003gaussian. This term provides a bound on the rate of convergence of gradient descent, whereas our notion of training speed is more closely related to sample complexity and makes the connection to the marginal likelihood more explicit.

It is natural to ask if such a Bayesian interpretation of the sum over training losses can be extended to non-linear models trained with stochastic gradient descent. Although SGD lacks the exact posterior sampling interpretation of our algorithm, we conjecture a similar underlying mechanism connecting the sum over training losses and generalization. Just as the marginal likelihood measures how well model updates based on previous data points generalize to a new unseen data point, the sum of training losses measures how well parameter updates based on one mini-batch generalize to the rest of the training data. If the update generalizes well, we expect to see a sharper decrease in the training loss, i.e. for the model to train more quickly and exhibit a lower sum over training losses. This intuition can be related to the notion of ‘stiffness’ proposed by fort2019stiffness. We provide empirical evidence supporting our hypothesis in Section 4.2.

3.3 Bayesian Model Selection and Optimization

The estimator reveals an intriguing connection between pruning in linear model combinations and Bayesian model selection. We assume a data set and a collection of models . A linear regressor is trained to fit the posterior predictive distributions of the models to the target ; i.e. to regress on the dataset


The following result shows that the optimal linear regressor on this data generating distribution assigns the highest weight to the model with the highest whenever the model errors are independent. This shows that magnitude pruning in a linear model combination is equivalent to approximate Bayesian model selection, under certain assumptions on the models.

Proposition 3.5.

Let be Bayesian linear regression models with fixed noise variance and Gaussian likelihoods. Let

be a (random) matrix of posterior prediction samples, of the form

. Suppose the following two conditions on the columns of are satisfied: for all , and . Let denote the least-squares solution to the regression problem . Then the following holds


The assumption on the independence of model errors is crucial in the proof of this result: families of models with large and complementary systematic biases may not exhibit this behaviour. We observe in Section 4 that the conditions of Proposition 1 are approximately satisfied in a variety of model comparison problems, and running SGD on a linear combination of Bayesian models still leads to solutions that approximate Bayesian model selection. We conjecture that analogous phenomena occur during training within a neural network. The proof of Proposition 3.5 depends on the observation that, given a collection of features, the best least-squares predictor will assign the greatest weight to the feature that best predicts the training data. While neural networks are not linear ensembles of fixed models, we conjecture that, especially for later layers of the network, a similar phenomenon will occur wherein weights from nodes that are more predictive of the target values over the course of training will be assigned higher magnitudes. We empirically investigate this hypothesis in Section 4.2.

4 Empirical Evaluation

Section 3 focused on two key ideas: that training statistics can be used as an estimator for a Bayesian model’s marginal likelihood (or a lower bound thereof), and that gradient descent on a linear ensemble implicitly arrives at the same ranking as this estimator in the infinite-sample, infinite-training-time limit. We further conjectured that similar phenomena may also hold for deep neural networks. We now illustrate these ideas in a range of settings. Section 4.1 provides confirmation and quantification of our results for linear models, the model class for which we have theoretical guarantees, while Section 4.2 provides preliminary empirical confirmation that the mechanisms at work in linear models also appear in DNNs.

4.1 Bayesian Model Selection

While we have shown that our estimators correspond to lower bounds on the marginal likelihood, we would also like the relative rankings of models given by our estimator to correlate with those assigned by the marginal likelihood. We evaluate this correlation in a variety of linear model selection problems. We consider three model selection problems; for space we focus on one, feature dimension selection, and provide full details and evaluations on the other two tasks in Appendix B.1.

For the feature dimension selection task, we construct a synthetic dataset inspired by wilson2020bayesian of the form , where , and consider a set of models with feature embeddings . The optimal model in this setting is the one which uses exactly the set of ‘informative’ features .

We first evaluate the relative rankings given by the true marginal likelihood with those given by our estimators. We compare , and ; we first observe that all methods agree on the optimal model: this is a consistent finding across all of the model selection tasks we considered. While all methods lower bound the log marginal likelihood, and exhibit a reduced gap compared to the naive lower bound. In the rightmost plot of Figure 1, we further quantify the reduction in the bias of the estimator described in Section 3. We use exact posterior samples (which we denote in the figure simply as posterior samples) and approximate posterior samples generated by the gradient descent procedure outlined in Algorithm 1 using a fixed step size and thus inducing some approximation error. We find that both sampling procedures exhibit decreasing bias as the number of samples is increased, with the exact sampling procedure exhibiting a slightly smaller gap than the approximate sampling procedure.

Figure 1: Left: ranking according to , with exact posterior samples, and computed on samples generated by gradient descent. Right: gap between true marginal likelihood and estimator shrinks as a function of for both exact and gradient descent-generated samples.

We next empirically evaluate the claims of Proposition 3.5 in settings with relaxed assumptions. We compare the ranking given by the true log marginal likelihood, the estimated , and the weight assigned to each model by the trained linear regressor. We consider three variations on how sampled predictions from each model are drawn to generate the features : sampling the prediction for point from (‘concurrent sampling’ – this is the setting of Proposition 3.5), as well as two baselines: the posterior (‘posterior sampling’), and the prior (‘prior sampling’). We find that the rankings of the marginal likelihood, its lower bound, and of the ranking given by concurrent optimization all agree on the best model in all three of the model selection problems outlined previously, while the prior and posterior sampling procedure baselines do not exhibit a consistent ranking with the log ML. We visualize these results for the feature dimension selection problem in Figure 2; full results are shown in Figure 5.

Figure 2: Left: Relative rankings given by optimize-then-prune, ML, and estimated

on the feature selection problem. Right: visualizing the interpretation of

as the ‘area under the curve’ of training losses: we plot the relative change in the estimator for convolutional and fully-connected NTK-GP models, and shade their area.

We further illustrate how the estimator can select inductive biases in the infinite-width neural network regime in Figure 2

. Here we evaluate the relative change in the log ML of a Gaussian Process induced by a fully-connected MLP (MLP-NTK-GP) and a convolutional neural network (Conv-NTK-GP) which performs regression on the MNIST dataset. The fully-connected model sees a consistent decrease in its log ML with each additional data point added to the dataset, whereas the convolutional model sees the incremental change in its log ML become less negative as more data points are added as a result of its implicit bias, as well as a much higher incremental change in its log ML from the start of training. This leads to the Conv-NTK-GP having a higher value for

than the MLP-NTK-GP. We provide an analogous plot evaluating in the appendix.

4.2 Training Speed, Ensemble Weight, and Generalization in DNNs

We now address our conjectures from Section 3, which aim to generalize our results for linear models to deep neural networks trained with SGD. Recall that our hypothesis involves translating iterative posterior samples to minibatch training losses over an SGD trajectory, and bayesian model evidence to generalization error; we conjectured that just as the sum of the log posterior likelihoods is useful for Bayesian model selection, the sum of minibatch training losses will be useful to predict generalization error. In this section, we evaluate whether this conjecture holds for a simple convolutional neural network trained on the FashionMNIST dataset. Our results provide preliminary evidence in support of this claim, and suggest that further work investigating this relationship may reveal valuable insights into how and why neural networks generalize.

4.2.1 Linear Combination of DNN Architectures

We first evaluate whether the sum over training losses (SOTL) obtained over an SGD trajectory correlates with a model’s generalization error, and whether SOTL predicts the weight assigned to a model by a linear ensemble. To do so, we train a linear combination of DNNs with SGD to determine whether SGD upweights NNs that generalize better. Further details of the experiment can be found in Appendix B.2. Our results are summarized in Figure 3.

Figure 3: Linear combinations of DNNs on FashionMNIST trained. Left: ensemble weights versus the test loss for concurrent training. Middle: sum over training losses (SOTL), standardized by the number of training samples, versus test loss for parallel training. Right: training curves for the different models trained in parallel. All results are averaged over

runs, and standard deviations are shown by the shaded regions around each observation. The model parameters, given in the parentheses, are the number of layers (

), nodes per layer () and kernel size (), respectively.

We observe a strong correlation between SOTL and average test cross-entropy (see Figure 3 middle column), validating that the SOTL is correlated with generalization. Further, we find that architectures with lower test error (when trained individually) are given higher weight by the linear ensembling layer – as can be seen from the left plot in Figure 3. This supports our hypothesis that SGD favours models that generalize well.

4.2.2 Subnetwork Selection in Neural Networks

Finally, we evaluate whether our previous insights apply to submodels within a neural network, suggesting a potential mechanism which may bias SGD towards parameters with better generalization performance. Based on the previous experiments, we expect that nodes that have a lower sum over training errors (if evaluated as a classifier on their own) are favoured by gradient descent and therefore have a larger final weight than those which are less predictive of the data. If so, we can then view SGD followed by pruning (in the final linear layer of the network) as performing an approximation of a Bayesian model selection procedure. We replicate the model selection problem of the previous setting, but replace the individual models with the activations of the penultimate layer of a neural network, and replace the linear ensemble with the final linear layer of the network. Full details on the experimental set-up can be found in Appendix

B.3. We find that our hypotheses hold here: SGD assigns larger weights to subnetworks that perform well, as can be seen in Figure 4

. This suggests that SGD is biased towards functions that generalize well, even within a network. We find the same trend holds for CIFAR-10, which is shown in Appendix


Figure 4: Weight assigned to subnetwork by SGD in a deep neural network (x-axis) versus the subnetwork performance (estimated by the sum of cross-entropy, on the y-axis) for different FashionMNIST classes. The light blue ovals denote depict confidence intervals, estimated over 10 seeds (i.e. 2 for both the weight and SOTL). The orange line depicts the general trend.

5 Conclusion

In this paper, we have proposed a family of estimators of the marginal likelihood which illustrate the connection between training speed and Bayesian model selection. Because gradient descent can produce exact posterior samples in linear models, our result shows that Bayesian model selection can be done by training a linear model with gradient descent and tracking how quickly it learns. This approach also applies to the infinite-width limit of deep neural networks, whose dynamics resemble those of linear models. We further highlight a connection between magnitude-based pruning and model selection, showing that models for which our lower bound is high will be assigned more weight by an optimal linear model combination. This raises the question of whether similar mechanisms exist in finitely wide neural networks, which do not behave as linear models. We provide preliminary empirical evidence that the connections shown in linear models have predictive power towards explaining generalization and training dynamics in DNNs, suggesting a promising avenue for future work.

6 Broader Impact

Due to the theoretical nature of this paper, we do not foresee any immediate applications (positive or negative) that may arise from our work. However, improvement in our understanding of generalization in deep learning may lead to a host of downstream impacts which we outline briefly here for completeness, noting that the marginal effect of this paper on such broad societal and environmental impacts is likely to be very small.

  1. Safety and robustness. Developing a stronger theoretical understanding of generalization will plausibly lead to training procedures which improve the test-set performance of deep neural networks. Improving generalization performance is crucial to ensuring that deep learning systems applied in practice behave as expected based on their training performance.

  2. Training efficiency and environmental impacts. In principle, obtaining better estimates of model and sub-model performance could lead to more efficient training schemes, thus potentially reducing the carbon footprint of machine learning research.

  3. Bias and Fairness. The setting of our paper, like much of the related work on generalization, does not consider out-of-distribution inputs or training under constraints. If the training dataset is biased, then a method which improves the generalization performance of the model under the i.i.d. assumption will be prone to perpetuating this bias.


Lisa Schut was supported by the Accenture Labs and Alan Turing Institute.


Appendix A Proofs of Theoretical Results

See 3.1


The result for follows from a straightforward derivation:


The result for follows immediately from Jensen’s inequality, yielding



applies Jensen’s inequality to a random variable with decreasing variance as a function of

, we expect the bias of to decrease as grows, an observation characterized in Section 4. ∎

See 3.2


To show that the sum of the estimated log likelihoods is a lower bound on the log marginal likelihood, it suffices to show that each term in the sum of the estimates is a lower bound on the corresponding term in log marginal likelihood expression. Thus, without loss of generality we consider a single data point and posterior distribution .

Let , the standard estimators for sample mean and variance given sample sampled from . We want to show


We first note that for a collection of i.i.d. Gaussian random variables [basu1955]. We also take advantage of the fact that the log likelihood of a Gaussian is concave with respect to its parameter and its parameter. Notably, the log likelihood is not concave w.r.t. the joint pair , but because the our estimators are independent, this will not be a problem for us. We proceed as follows by first decomposing the expectation over the samples into an expectation over and

We apply Jensen’s inequality first to the inner expectation, then to the outer.

So we obtain our lower bound. ∎

See 3.3


The heavy lifting for this result has largely been achieved by Propositions 3.1 and 3.2, which state that provided the samples are distributed according to the posterior, the inequalities will hold. It therefore remains only to show that the sample-then-optimize procedure yields samples from the posterior. The proof of this result can be found in Lemma 3.8 of osband2018randomized, who show that the optimum for the gradient descent procedure described in Algorithm 1 does indeed correspond to the posterior distribution for each subset .

Finally, it is straightforward to express the lower bound estimator as the sum of regression losses. We obtain this result by showing that the inequality holds for each term in the summation.


We note that in practice, the solutions found by gradient descent for finite step size and finite number of steps will not necessarily correspond to the exact local optimum. However, it is straightforward to bound the error obtained from this approximate sampling in terms of the distance of from the optimum . Denoting the difference by , we get


and so the error in the estimate of will be proportional to the distance induced by the approximate optimization procedure. ∎

See 3.4


Follows immediately from the results of he2020bayesian stating that the the limiting distribution of is precisely . We therefore obtain the same result as for Theorem 3.3, plugging in the kernel gradient descent procedure on for the parameter-space gradient descent procedure on . ∎

The following Lemma will be useful in order to prove Proposition 3.5. Intuitively, this result states that in a linear regression problem in which each feature is ‘normalized’ (the dot product for some and all ) and ‘independent’ (i.e. ), then the optimal linear regression solution assigns highest weight to the feature which obtains the least error in predicting on its own.

Lemma A.1.

Let , and be a design matrix such that for some fixed , with , and for all . Let be the solution to the least squares regression problem on and . Then


We express the minimization problem as follows. We let = , where , with . We denote by

the vector containing all ones (of length

). We observe that we can decompose the design matrix into one component whose columns are parallel to , denoted , and one component whose columns are orthogonal to , denoted . Let . By assumption, , and . We then observe the following decomposition of the squared error loss of a weight vector , denoted .

In particular, the loss decomposes into a term which depends on the sum of the , and another term which will depend on the norm of the component of each model’s predictions orthogonal to the targets .

As this is a quadratic optimization problem, it is clear that an optimal exists, and so will take some finite value, say . We will show that for any fixed , the solution to the minimization problem


is such that the argmax over of is equal to that of the minimum variance. This follows by applying the method of Lagrange multipliers to obtain that the optimal satisfies


In particular, is inversely proportional to the variance of , and so is maximized for .

See 3.5


We first clarify the independence assumptions as they pertain to the assumptions of the previous lemma: writing as with corresponding to the noise from the posterior distribution and its mean, the first independence assumption is equivalent to the requirement that with for all . The second independence assumption is an intuitive expression of the constraint that in the linear-algebraic sense of independence, and that is sampled independently (in the probabilistic sense) for all and .

We note that our lower bound for each model in the linear regression setting is equal to where is a fixed normalizing constant. By the previous Lemma, we know that the linear regression solution based on the posterior means satisfies, . It is then straightforward to extend this result to the noisy setting.


We again note via the same reasoning as in the previous Lemma that the model with the greatest lower bound will be the one which minimizes , and that the weight given to index will be inversely proportional to this term.

It only remains to show that for each model , the model which maximizes will also minimize . This follows precisely from the Gaussian likelihood assumption. As we showed previously


and so finding the model which maximizes is equivalent to picking the maximal index of which optimizes the expected loss of the least squares regression problem. ∎

Appendix B Experiments

b.1 Experimental details: Model Selection using Trajectory Statistics

Figure 5: Relative rankings given by optimize-then-prune, ML, and estimated . Left: feature selection. Middle: prior variance selection. Right: RFF frequency selection. Rankings are consistent with what our theoretical results predict. Results are averaged over runs.

We consider 3 model selection settings in which to evaluate the practical performance of our estimators. In prior variance selection we evaluate a set of BLR models on a synthetic linear regression data set. Each model has a prior distribution over the parameters of the form for some , and the goal is to select the optimal prior variance (in other words, the optimal regularization coefficient). We additionally evaluate an analogous initialization variance selection method on an NTK network trained on a toy regression dataset. In frequency (lengthscale) selection

we use as input a subset of the handwritten digits dataset MNIST given by all inputs labeled with a 0 or a 1. We compute random Fourier features (RFF) of the input to obtain the features for a Bayesian linear regression model, and perform model selection over the frequency of the features (full details on this in the appendix). This is equivalent to obtaining the lengthscale of an approximate radial basis function kernel. In

feature dimension selection, we use a synthetic dataset [wilson2020bayesian] of the form , where . We then consider a set of models with feature embeddings . The optimal model in this setting is the one which uses exactly the set of ‘informative’ features .

The synthetic data simulation used in this experiment is identical to that used in [wilson2020bayesian]. Below, we provide the details.

Let be the number of informative features and the total number of features. We generate a datapoint as follows:

  1. Sample :

  2. Sample informative features:

  3. Sample noise features:

  4. Concatenate the features:

We set , , , and let vary from to . We then run our estimators on the Bayesian linear regression problem for each feature dimension, and find that all estimators agree on the optimal number of features, .

To compute the random fourier features used for MNIST classification, we vectorize the MNIST input images and follow the procedure outlined by [rahimi2008random]

(Algorithm 1) to produce RFF features, which are then used for standard Bayesian linear regression against the binarized labels. The frequency parameter (which can also be interpreted as a transformation of the lengthscale of the RBF kernel approximated by the RFF model) is the parameter of interest for model selection. Results can be found in Figure 


We additionally provide an analogue to our evaluation of model selection in NTK-GPs, with the change in the log marginal likelihood plotted instead of . We obtain analogous results, as can be seen in Figure 6.

Figure 6: Evaluation of change in log ML after data point is added for NTK-GPs on a random subset of MNIST.

b.2 Experimental details: Bayesian model comparison

Here we provide further detail of the experiment in Section 4.2.1. The goal of the experiment is to determine whether the connection between sum-over-training losses (SOTL) and model evidence observed in the linear regression setting extends to DNNs. In particular, the two sub-questions are:

  1. Do models with a lower SOTL generalize better?

  2. Are these models favoured by SGD?

To answer these questions, we train a linear combination of NNs. We can answer subquestion [1] by plotting the correlation between SOTL and test performance of an individual model. Further, we address subquestion [2] by considering the correlation between test loss and linear weights assigned to each model.

Below we explain the set-up of the linear combination in more detail. We train a variety of deep neural networks along with a linear ‘ensemble’ layer that performs a linear transformation of the concatenated logit outputs

111These are pre-softmax outputs. To obtain the predicted probability of a class, they are fed through a softmax function. of the classification models. Let be logit output of model for input , be the loss for point (where is a logit) and be the weight corresponding to model at time step .

We consider two training strategies: we first train models individually using the cross-entropy loss between each model’s prediction and the true label, only cross-entropy loss of the final ensemble prediction to train the linear weights. Mathematically, we update the models using the gradients


and the ‘ensemble’ weights using


We refer to this training scheme as Parallel Training

as the models are trained in parallel. We also consider the setting in which the models are trained using the cross entropy loss from the ensemble prediction backpropagated through the linear ensemble layer, i.e. the model parameters are now updated using:


We refer to this scheme as the Concurrent Training.

We train a variety of different MLPs (with varying layers,and nodes) and convolutional neural networks (with varying layers, nodes and kernels) on FashionMNIST using SGD until convergence.

b.3 Experimental Details: SGD upweights submodels that perform well

Below we provide further details of the experiment in Section 4.2.2. The goal of the experiment is to determine whether SGD upweights sub-models that fit the data better.

We train a MLP network (with units ) on FashionMMIST using SGD until convergence. After training is completed, for every class of , we rank all nodes in the penultimate layer by the norm of their absolute weight (in the final dense layer). We group the points into submodels according to their ranking – the nodes with the highest weights are grouped together, next the ranked nodes are grouped, etc. We set .

We determine the performance of a submodels by training a simple logistic classifier to predict the class of an input, based on the output of the submodel. To measure the performance of the classifier, we use the cross-entropy loss. To capture the equivalent notion of the AUC, we estimate the performance of the sub-models throughout training, and sum over the estimated cross-entropy losses.

Below, we show additional plots for the parallel and concurrent training schemes. The results are the same to those presented in the main text, and we observe [1] a negative correlation between test performance and ensemble weights and [2] a strong correlation between SOTL and average test cross-entropy.

Figure 7: Linear combinations of DNNs on FashionMNIST. Left: ensemble weights versus the test loss for parallel training; we observe a negative correlation. Middle: SOTL (standardized by the number of training samples) versus test loss for concurrent and concurrent training. We observe a strong correlation indicating that the SOTL generalizes well. Right: training curves for the different models in concurrent training schemes. All results are averaged over runs, and standard deviations are shown by the shaded regions around each observation. The model parameters, given in the parentheses, are the number of layers (), nodes per layer () and kernel size (), respectively.

However, similarly to the linear setting, the difference in assigned weights is magnified in the concurrent training scheme. Here we find that in the concurrent training scheme, the ensemble focuses on training the CNNs (as can be seen from the training curve in Figure 3 in the main text). This is likely because CNNs are able to learn more easily, leading to larger weights earlier on.

Above, we show additional plots to those shown in Figure 4, Section 4.2.2. Figure 8 shows the results for the all FashionMNIST classes, and Figure 9 shows the results for experiment on CIFAR-10. From both, we see that SGD assigns higher weights to subnetworks that perform better.

Figure 8: Weight assigned to subnetwork by SGD in a deep neural network (x-axis) versus the subnetwork performance (estimated by the sum of cross-entropy, on the y-axis) for different FashionMNIST classes. The light blue ovals denote depict confidence intervals, estimated over 10 seeds (i.e. 2 for both the weight and SOTL). The orange line depicts the general trend.
Figure 9: Weight assigned to subnetwork by SGD in a deep neural network (x-axis) versus the subnetwork performance (estimated by the sum of cross-entropy, on the y-axis) for different CIFAR-10 classes. The light blue ovals denote depict confidence intervals, estimated over 10 seeds (i.e. 2 for both the weight and SOTL). The orange line depicts the general trend.