While modern neural network-based machine learning architectures have yielded a step change in predicition accuracy over traditional methods, especially in application domains such as computer vision, speech recognition and natural language processing, accurately estimating uncertainty remains an issue. Providing accurate estimations of the uncertainty associated to each prediction is of critical importance, particularly in the healthcare domain. Classical Bayesian inference methods(wang2020survey; mackay1992bayesian; neal2012bayesian; murphy2012machine) address the uncertainty question but they typically do not scale with architecture and data set size and so resort to leveraging computationally-tractable approximations via parametrisation.
This paper introduces a framework for uncertainty estimation that provides a novel and simple way to describe and extend many uncertainty estimation methods. The main idea relies on considering hyperparameters involved in classical training procedures as random variables. We show that through marginalising out different combinations of hyperparameters, different sources of uncertainty can be estimated in the (model) parameter space. Within this framework, methods such as SWAG (maddox_simple_2019), deep ensembles (lakshminarayanan2017simple), MultiSWAG (wilson_bayesian_2020), MC-dropout (gal2016dropout) and hyperparameter ensembles (levesque2016bayesian; wenzel2020hyperparameter) are shown to be approximations to particular marginalisations.
We apply methods resulting from this framework on benchmarking problems to investigate which forms and combinations of marginalisation are most useful from a practical point of view. As one may expect, results suggest that increasing the number of random variables that are marginalised out tends to increase the quality of the uncertainty estimates, with certain random variables playing a more evident role. In particular, some combinations of marginalisations can produce reliable estimates of uncertainty without the need for extensive hyperparameter tuning. However, this is not the case with all combinations.
2 Finding uncertainties
Given a set of data points , one usually wishes to determine the predictive distribution for some new value
. As this problem is generally intractable, it is common practice to assume the predictive distribution takes on a certain functional form described by a parametric modelfor some parameter space , and to determine
where the value of , given and (and consequently ) is deterministic and is the distribution over that likely gave rise to .
For modern architectures and data sets, computing
is computationally intractable. This has given rise to a number of proposed approximate methods. In the next sections, we show that many of these methods can be realised within our framework as approximating the conditional probability distribution ofgiven the data set and additional random variables that capture different aspects of the distribution, a graphical representation of which is depicted in Figure 1.
2.1 A starting point: classical training
In the classical (deterministic) modern neural network training paradigm, the training goal is to find the optimal model parameter that minimises some notion of difference between and , usually referred to as loss , by solving
Given , classical training methods can be seen as approximating . Determining an exact solution to (2) is computationally intractable for most problems arising in machine learning due to factors such as data dimensionality and size, non-convexity of the problem, etc. The aim then is to find a suboptimal solution via an iterative algorithm from the countably infinite set of optimisation algorithms , with chosen hyperparameter point in an algorithm-appropriate hyperparameter space 111Hereafter, when talking about hyperparameters we will refer only to those characterising the algorithm ..
Let be the solution estimate at iteration of . The sequence produced by iterations of is determined by the initial condition and the hyperparameters , e.g., step-size, batch-size and batch process order. With a slight abuse of notation, we consider as a function and write as to represent compositions (iterations) of .
It should now be clear that since problem (2) is generally intractable, rather than simply making the approximation , classical training computes the value of further conditioned on , and , i.e.,
Then the output of a trained model is characterised as
2.2 Intrinsic uncertainties in classical training
By marginalising out the different variables and possibly the model family , we should obtain estimates for that are better than the point mass distribution (4). In the next sections, we will show that this is indeed the case and that many of the current state-of-the-art algorithms for uncertainty estimation can be considered as estimating (or ) by marginalising out different subsets of the conditioning variables in (3) (or (4)).
2.2.1 Number of iterations
Under suitable assumptions, algorithms in usually generate a sequence of solution estimates that, in the limit, converges to a neighbourhood of a stationary point . In practice, convergence to this neighbourhood usually occurs after a certain number of iterations . While in general there are no guarantees regarding the position of in relation to within the neighbourhood, the trajectory of the solution estimates generated by the algorithm provides information on the geometry of parameters space near
which in turn induces a probability distribution over(see Figure 1, plot in box (a)). Formally, we have
which for stochastic gradient descent algorithms can be shown to be Gaussian in the limit. Note that this formalism nicely describes the SWAG algorithm(maddox_simple_2019), which assumes
is a Gaussian distribution and computes the first and second order momenta by taking finite samples from, and the TRADI algorithm (franchi2019tradi) which again makes the Gaussian assumption and tracks the momenta of
along the algorithm evolution through a Kalman filter.
2.2.2 Initial condition
Due to the likely non-convexity of the loss function, the initial condition may determine the stationary point around which
stabilises. Thus, the initial condition induces a probability distribution overthat can be computed by marginalising (3) with respect to , as
is chosen according to an appropriate parameter initialisation strategy for the model architecture such as the Glorot uniform distribution(pmlr-v9-glorot10a). See Figure 1, plot in box (a) for a graphical representation.
Approximating (5) is clearly the goal of classical ensemble methods for uncertainty estimation (see, e.g., lakshminarayanan2017simple).
As mentioned, in general, different algorithm procedures and/or different points in appropriate hyperparameter space produce different approximate solutions to problem (2). For example, in the case of the simple Stochastic Gradient Descent (SGD), the step size, batch size and the order in which batches are processed by the algorithm all influence the trajectory . This is also true for more complex algorithms where other hyperparameters may play a role. To capture the uncertainty associated to the choice of algorithm, we can marginalise out from to give
In Appendix B we characterise for the SGD and for the general class of stochastic algorithms. A graphical representation is reported in Figure 1, plot in box (a), where different choices of lead to different values of .
Ensembling over hyperparameter space to improve model performance and uncertainty estimation has been proposed in (levesque2016bayesian) and (wenzel2020hyperparameter).
2.2.4 Model family
Recall that assumptions are made about the functional form of the predictive distribution to enable its computation. With this in mind, so far, we have presented our framework in the context of a fixed parametric model family over some parameter space . The particular choice of model family typically involves architecture and hyperparameter space selection and is usually informed by the problem domain and the available data set . As different parametric model families yield different estimates of the predictive distribution, the choice of model family is also a source of uncertainty which one can account for by marginalising out to give
where is some unknown distribution over the space of all possible parametric model families over all possible parameter spaces. In practice, when deciding on a parametric model family, a typical approach is to focus on a finite subset of possible modelling approaches informed by a priori knowledge of the problem domain and then choose the best family or families via a model selection procedure (see, e.g., (murphy2012machine)). Another approach is to use Monte Carlo dropout (gal2016dropout) which ensembles models obtained via edge dropout (see Appendix C). Figure 1, plot in box (a), pictorially shows how selecting different models (which in turns produces different loss landscapes) may lead to different solutions.
2.3 Multiple marginalisation
Until now, we have focussed on marginalising out single conditioning random variables to obtain better approximations to . A natural step at this point is to combine two or more of the proposed marginalisations.
In Figure 1, boxes (b)-(d), all the possible combinations of marginalisations in our framework are depicted, some of which have recently been studied. For example, multiSWAG (wilson_bayesian_2020), an extension to SWAG that can be seen as an approximation to marginalising out and to give
This is depicted in Figure 1, box (b), plot . Another example is hyper-ensembles (wenzel2020hyperparameter) which can be seen as marginalising out both and the hyperparameter space (Figure 1, box (b), plot ).
In general, this framework allows a number of marginalisation combinations that have not been addressed in the literature thus far. Depending on the particular problem at hand, different combinations may lead to improved uncertainty estimates.
3 Estimating uncertainty
Until now, we have focussed on estimating uncertainty in the parameter solution space. By using (1), that uncertainty can be propagated to the predictive distribution.
However, computing it explicitly is often intractable and approximations are preferred. Monte Carlo sampling is arguably the most widely used approach to approximating as
where is the number of samples drawn from the trained model parameter space and . As increases, the accuracy of
also increases. Similarly, the mean and variance of the predictive distribution can be approximated asand , with .
Assumed Density Filtering (ADF) (ghosh2016assumed) can also be used to perform to compute the expected value and variance of in a single forward pass (see Appendix D for further details).
4 Experimental results
In this section, we apply methods described by the proposed framework on benchmarking problems to investigate which forms and combinations of marginalisation are most useful from a practical point of view. We consider problems on benchmark UCI datasets and CIFAR-10. Additional details and experiments are provided in AppendixA.
4.1 Regression on UCI datasets
We consider regression problems on UCI data sets, originally proposed in (hernandez2015probabilistic) and used for benchmarking MC-dropout (gal2016dropout), SWAG (maddox_simple_2019) and deep ensembles (blundell2015weight). Each data set is split into 20 train-test folds except for protein where 5 folds were used. The network architecture had one hidden layer consisting of units ( for the protein dataset), followed by a dropout layer with dropout rate fixed at . Each model was trained for epochs to minimise the Mean Squared Error with no regularisation, using Adam with a learning rate and batch-size .
We used SWAG, MC-dropout and -model ensembles to approximate the marginalisation of , (as shown in the appendix) and respectively. Results on all marginalisation combinations are reported in Table 1 in terms of negative log-likelihood (NLL). The tendency is for NLL to decrease (and the quality of the uncertainty estimate to improve) as more random variables are marginalised out, with combined , and producing the lowest NLL in 6 of 9 data sets.
4.2 Classification on CIFAR-10
We train VGG16 neural networks for epochs over CIFAR-10 using SGD with batch-size . We approximated the marginalisation over and using SWAG and MC-dropout respectively, while ensembles were used to marginalise over and learning rates (see Appendix A.3 for furhter detais).
The NLL and accuracy on the test data set with increasing ensemble size are reported in Figure 2. As we increase the ensemble size, the NLL decreases and the accuracy increases, an outcome we know from (blundell2015weight). However, one must note that even with a single model, marginalising out additional random variables , and produces a significantly lower NLL (and higher accuracy). In particular, while have a significant impact, other marginalisations seems to be less effective. This trend continues when we add more models to the ensemble.
From the preliminary experiments we performed some trends seem to appear. First, we see that marginalising and always produce good results. This is probably due to the fact that these two marginalisations better capture the geometry of the parameter space for a fixed model and set of hyperparameters. Marginalising out other hyperparameters can be useful (wenzel2020hyperparameter), but it probably requires some preliminary fine tuning to select a good distribution from which to draw them. A similar reasoning can be applied to the marginalisation of the model family, where a preliminary search for good candidates seems to be required. MC-dropout with a very small dropout rate can be a good option to minimally perturb a good architecture. If fine tuning is not an option, then marginalising and seems to be the best way to go among the options available in this framework.
Appendix A Details and additional experiments
a.1 Regression on toy data
We first consider a simple 1-dimensional toy regression dataset consisting of training samples generated from with and , as in (hernandez2015probabilistic; blundell2015weight; franchi2019tradi), to evaluate the performance of different marginalisations. A neural network with one hidden layer consisting of
units with ReLU activation was trained forepochs to fit the data using the stochastic gradient descent algorithm. We compared the effects of approximating the marginalisation over and with hyperparameters . The marginalisation over was approximated using SWAG, while over , we used an ensemble of models. RMSE and NLL were calculated on a test set of 1000 equally spaced points , that lie on the black line.
The results for the various marginalisation strategies over the different hyperparameter combinations are depicted in Figure 3. The first and second rows correspond to models trained with hyperparameter points and respectively, while results in the third row are obtained by ensembles combining the two hyperparameter point. For the first, third and fifth columns, we ensembled over two models (one for each hyperparameter point), whereas the remaining columns, corresponding to marginalising out , are ensembles of 5 initial conditions, trained once with and once with .
Our results show that the quality of predictive uncertainty clearly benefits from multiple marginalisations with a decreasing NLL trend along the rows and columns. Interesting points to note are that despite the hyperparameter choice, often produces a lower NLL than and that the combined hyperparameters (third row) outperforms both individual hyperparameters (first and second rows). For RMSE, outperforms , with the two combined generally lying somewhere in between, but with the additional advantage of a lower NLL. Also worthy of note is the computation time, measured in terms of number of trained models. While marginalising out will outperform other random variable combinations given a sufficiently large ensemble, the point is that good estimates can be achieved with limited computational budget, e.g., combined hyperparameters with SWAG and MD-dropout (third row, sixth column).
a.2 UCI datasets
a.2.1 Details on datasets
Recall that a given data set , with , , has number of samples , (input) random variable dimension and (output) random variable dimension . Table 2 provides additional details on the UCI data sets used for our experiments and the tuned learning rates we will use in the next section. As in (gal2016dropout) and other works using these UCI data sets, for energy and naval, we predict the first output variable so, .
a.2.2 Training for 40 epochs
Here we report the results obtained by training the networks for epochs with two different set-ups: firstly using Adam with a fixed learning rate for all data sets and secondly using the tuned learning rates reported in Table 2.
Fixed learning rate
The results for the first case are reported in Table 3. Clearly, the number of epochs is too low for MC-dropout to converge (see (gal2016dropout)), thus the NLL for is generally poor. However, when also marginalising out and/or (again approximated via SWAG and an ensemble of models respectively), the addition of MC-dropout generally results in lower NLL suggesting that even with few training epochs, MC-dropout may help when marginalised out in combination with other random variables.
Tuned learning rates
Results with tuned learning rates are reported in Table 4. We approximate jointly marginalising out and by drawing 5 samples from , one for each in the ensemble, and combine this marginalisation approximation with other SWAG and multiSWAG as required. While we propose that aggressive hyperparameter tuning can be avoided with multitple marginalisations, we do see an improvement in NLL when compared to Table 1 and for a number of data sets the NLL decreases further by perturbing the value of .
The trend of NLL decreasing as the number of marginalisation combinations increases is maintained.
Training for epochs with a (fixed) higher or tuned learning rate generally produces comparable or even better results than those reported in Table 1. However, this comes with a cost of higher RMSE for most data sets which is likely due to non-convergence after training for epochs and hence higher prediction error. In general, the training set up very much depends on the desired use of a prediction. In domains such as disease diagnostics, understanding prediction certainty is more valuable than accuracy.
Fine-tuning the learning rates leads to an improvement in NLL across all data sets. However, the addition of MC-dropout can produce better results without the need for hyperparameter fine tuning.
a.3 Implementation details for CIFAR-10
As in (maddox_simple_2019), the learning rate was defined as a decreasing schedule along epochs as
for some . When approximating the marginalisation over the learning rate , we randomly selected and when marginalising out , otherwise we fixed and .
Appendix B Algorithm marginalisation
We first provide a detailed description of marginalising out hyperparameters in the case of stochastic gradient descent. Then, we extend this idea to the (countably infinite) set of optimisation algorithms typically used in machine learning.
Stochastic gradient descent
Over a single iteration, the SGD is governed by the step size , the batch size and the batch number , where , which without loss of generality, corresponds to
In the context of our framework, the hyperparameter space is defined as
where points are drawn randomly with probability , and and are assumed to be well chosen distributions, for example, through hyperparameter optimisation. Note that is a continuous distribution whereas and are discrete. Since random variables and are independent, taking products of their respective marginal distributions is well defined.
To allow for scheduling of step size and batch size, over the course of a complete training run, the hyperparameter space over which we marginalise is defined as
where variable vectors are indexed over. Hyperparameter points are drawn randomly with probability
Note that at each iteration, we are compelled to select the next batch for processing uniformly at random to ensure that, in the limit, the gradient estimator is unbiased. If happens to be a multiple of , we can consider
as a number of epochs, each of which processes every batch exactly once. Thus the stochastic gradient is guaranteed to be an unbiased estimator which allows us to select a more appropriate distribution for batch ordering.
We can now marginalise out each of the hyperparameter variables to express as
The class of optimisation algorithms
Let be a (countably infinite) set of optimisation algorithms that, given an initial condition and a point in the appropriate hyperparameter space , return a solution to the given problem after iterations. Let be a weight associated to the output of algorithm and assume that
Then, if we define
the output of the -th algorithm is recovered by imposing (since, thanks to (7), for all ).
Let be the hyperparameter space associated to . We define and let . Then, for some , we can rewrite (8) as
Then, we marginalise out and to express as
Appendix C MC-dropout as model marginalisation
Given a particular and parametric model family , let be the parametric model family obtained by masking elements of . This is a finite set of models given by
where the dropout masks take values from and is the Hadamard product. Masks can be drawn randomly with probability , and we can marginalise out to give
and then to give
Similar, we can condition on all the variables we have considered thus far to express as
The idea of marginalising out the entire class of parametric model families can be formulated in a similar way to the algorithm marginalisation strategy described in Section B.
Appendix D One-shot estimation of the predictive probability distribution
In the main paper we showed how to compute statistics of the predictive probability distribution via Monte Carlo sampling. The main drawback of Monte Carlo sampling is the requirement to sample multiple times from the parameter distribution in order to compute , which in turn, requires multiple forward passes through the network.
To overcome this issue, Assumed Density Filtering (ADF) can be used to perform a one-shot estimation, where we compute the expected value and variance of in a single forward pass. ADF has previously been used to propagate input uncertainty through the network to provide output uncertainty (gast2018lightweight), and to learn probability distributions over parameters (ghosh2016assumed) (a similar approach has been used in (wu2018deterministic)). Also in (brach2020single), a one-shot approach has been proposed to approximate uncertainty for MC-dropout.
Let represent the output of the
-th layer of a feed-forward neural network (considered as a function between two random variables). We can convert the-th layer into an uncertainty propagation layer by simply matching first and second-order central momenta (minka2001family), i.e.,
where and . By doing so, the values of and are obtained as the output of the last layer of the modified neural network. Notice that this procedure can account for input (aleatoric) uncertainty as long as the input .
The main drawback with ADF is the reliance on a modification to the structure of the neural network, where each layer is replaced with its probabilistic equivalent. Also, while many probabilistic layers admit a closed form solution to compute (9), (10) (see below), in some cases approximation is needed.
d.1 Probabilistic layers
Let be independent random variables and consider a linear layer, . Then
In this case, if
are assumed to be normally distributed,can be approximated by gaussian with the above statistics.
Given a random variable
the ReLU non-linearity activation functionleads to a closed form solution of its momenta (frey1999variational).
where is the standard normal distribution function
and is the cumulative normal distribution function
Almost all types of layers have closed form solutions or can be approximated. See (Gast_2018_CVPR; loquercio2020uncertainty) for more details and (brach2020single) for dropout layers.