Bayesian Prior Networks with PAC Training

06/03/2019
by   Manuel Haussmann, et al.
Bosch
University of Heidelberg
0

We propose to train Bayesian Neural Networks (BNNs) by empirical Bayes as an alternative to posterior weight inference. By approximately marginalizing out an i.i.d. realization of a finite number of sibling weights per data-point using the Central Limit Theorem (CLT), we attain a scalable and effective Bayesian deep predictor. This approach directly models the posterior predictive distribution, by-passing the intractable posterior weight inference step. However, it introduces a prohibitively large number of hyperparameters for stable training. As the prior weights are marginalized and hyperparameters are optimized, the model also no longer provides a means to incorporate prior knowledge. We overcome both of these drawbacks by deriving a trivial PAC bound that comprises the marginal likelihood of the predictor and a complexity penalty. The outcome integrates organically into the prior networks framework, bringing about an effective and holistic Bayesian treatment of prediction uncertainty. We observe on various regression, classification, and out-of-domain detection benchmarks that our scalable method provides an improved model fit accompanied with significantly better uncertainty estimates than the state-of-the-art.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

06/17/2020

Learning Partially Known Stochastic Dynamics with Empirical PAC Bayes

We propose a novel scheme for fitting heavily parameterized non-linear s...
02/04/2022

Demystify Optimization and Generalization of Over-parameterized PAC-Bayesian Learning

PAC-Bayesian is an analysis framework where the training error can be ex...
03/18/2019

Combining Model and Parameter Uncertainty in Bayesian Neural Networks

Bayesian neural networks (BNNs) have recently regained a significant amo...
10/19/2020

PAC^m-Bayes: Narrowing the Empirical Risk Gap in the Misspecified Bayesian Regime

While the decision-theoretic optimality of the Bayesian formalism under ...
05/16/2022

Appropriate reduction of the posterior distribution in fully Bayesian inversions

Bayesian inversion generates a posterior distribution of model parameter...
04/22/2022

Bayesian operator inference for data-driven reduced-order modeling

This work proposes a Bayesian inference method for the reduced-order mod...
04/22/2021

Bayesian predictive inference without a prior

Let (X_n:n≥ 1) be a sequence of random observations. Let σ_n(·)=P(X_n+1∈...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

As the interest of the machine learning community in data-efficient and uncertainty-aware predictors increases, research on Bayesian Neural Networks (BNNs) 

mackay1995probable ; neal1995bayesian gains gradual prominence. Differently from deterministic neural nets, BNNs have stochasticity on their synaptic weights. Thanks to their stacked structure, they propagate prediction uncertainty through the hidden layers. Hence, they can characterize complex uncertainty structures. Exact inference of such a highly nonlinear system is not only analytically intractable but also extremely hard to approximate with high precision. Consequently, research on BNNs thus far focused greatly on improving approximate inference techniques in terms of precision and computational cost lobato2015probabilistic ; kingma2015variational ; louizos2017multiplicative . All these prior attempts take the posterior inference of global parameters as given and develop their approximation based on it.

This paper proposes a modified BNN formulation and an accompanying training method derived from learning-theoretic first principles. Instead of global weight parameters, we assume independent local weight random variables controlling the BNN for each input/target pair of data points, which share common hyperparameters. We marginalize out these data-point specific weights of our network and perform training via empirical Bayes on the prior hyperparameters. This analytically intractable marginalization is tightly approximated using the Central Limit Theorem (CLT). Differently from earlier weight marginalization approaches that assign global weight distributions on infinitely many neurons that recover a Gaussian Process (GP) 

alonso2019deep ; leep2018deep ; neal1995bayesian , our formulation maintains finitely many hidden units per layer and assigns them individual weight distributions. Per data-point treatment of weight marginalization scales the BNN training linearly with the data size. Adopting empirical Bayes training on this simplified setup, our method avoids the explicit approximation of a highly nonlinear and intractable weight posterior, yet can improve the quality of uncertainty estimations.

The advantages of such a BNN with data-point specific marginalization comes at the expense of two major drawbacks. First, the number of hyperparameters in a weight-marginalized BNN grows proportionally to the number of synaptic connections and hence maximizing the marginal likelihood w.r.t. such a large number of hyperparameters is prone to overfitting bauer2016understanding

. Second, since the weight variables are marginalized out and their hyperparameters of the weight prior are set via optimization, the model can no longer incorporate prior knowledge other than the parametric form of the prior distribution (e.g., normal with mean and variance as free parameters). We address both of these drawbacks by deriving a trivial Probably Approximately Correct (PAC) 

macallester1999pac ; macallester2003pac bound that contains the marginal likelihood as its empirical risk term. Minimization of this PAC bound automatically balances the fit to the data and deviation from a prior hypothesis as a result of hyperparameter tuning.

Our contribution.

We propose a new Bayesian deep learning method that

  1. [noitemsep,topsep=0pt,parsep=0pt,partopsep=0pt,label=()]

  2. introduces the local weight marginalization concept to empirical Bayes training of BNNs,

  3. uses CLT-based moment matching for the first time for marginalizing the local weights up to a tight and affordable analytical approximation, leveraging scalability at prediction time,

  4. applies the resultant method to a Bayesian prior of a likelihood, resulting in an inference scheme within the prior networks framework malinin2018prior ; sensoy2018evidential that can account for model uncertainty,

  5. uses a novel and principled PAC regularizer that both stabilizes training and substantially improves the uncertainty estimation quality.

We compare our method on various standard regression, classification, and out-of-domain detection benchmarks against state-of-the-art approximate posterior inference based BNN training approaches. We observe that our method provides competitive prediction accuracy and significantly better uncertainty estimation scores than those baselines.

2 Bayesian prior networks

Bayesian Local Neural Nets.

Given a data set consisting of pairs of input and the corresponding target , parameterizing the likelihood by a BNN with random variables as the weights results in the following probabilistic model

(1)

where is a likelihood function, e.g. for regression one would have , with precision , and is some prior over local weights with shared hyperparameters . Differently from a canonical BNN where all data points share the same global weight latent variable blundell2015weight ; gal2015bayesian ; kandemir2018sampling ; kingma2015variational ; louizos2017multiplicative , here the mapping between each input-output pair is determined by a separate random variable , giving a unique mapping constrained by sharing a common set of prior hyperparameters . As this modification implies a collection of only local latent variables, we refer to the resultant probabilistic model as a Bayesian Local Neural Net (BLNN). This model consists of two sources of uncertainty. First, the model (epistemic) uncertainty captured by the prior over the parameters, i.e. , which accounts for the mismatch between the model and the true functional mapping from to . Second, the irreducible data (aleatoric) uncertainty given by stemming from irreducible measurement noise. The mainstream BNN methods are characterized specifically to model these two sources of uncertainty lobato2015probabilistic ; kendall2017what ; wu2018fixing .

Prior Networks.

Classification with cross-entropy loss in deep learning can be interpreted as a categorical likelihood parameterized by a neural net. Prior networks (PNs) malinin2018prior ; sensoy2018evidential generalize this setup by parameterizing a prior by a deterministic neural net, instead of the likelihood, formally

(2)

This way, the model explicitly accounts for the distributional uncertainty which may arise due to a mismatch between the train and test data distributions quionero2009dataset . For classification, a natural choice for prior on the categorical likelihood would be a Dirichlet distribution. These probabilities, in turn, determine the categorical likelihood from which actual observations

would be generated. For regression, one could use a normal distribution over the means and log-variances to capture distributional uncertainty. However, in both cases calculating a full posterior distribution over parameters has so far been avoided in favor for deterministic nets due to its high computational cost.

Bayesian Prior Networks (BLNNs as Prior Networks).

The original formulation for PNs builds on deterministic neural nets, hence assumes a point estimate on the model weights, consciously ignoring model uncertainty. We improve this framework by assigning a local prior on the PNs weights 111Throughout this work we will assume the weights to be normal distributed and such that we have that is . which amounts to employing a BLNN as a prior on the likelihood. We name the eventual model that combines BLNNs with PNs as a Bayesian Prior Network (BPN). By virtue of the localized weights, the marginal likelihood of the BPN factorizes across data points, bringing additive data point specific marginal log-likelihoods maintaining the central source of PN scalability, formally

(3)

The marginalization of on the last step can be performed analytically under conjugacy or can be efficiently approximated by Taylor expansion or one-step Monte Carlo (MC) sampling, while one can employ the CLT argument to marginalize the weights .

BPN training with empirical Bayes and prediction.

Introducing independent for each data-point leads to the marginal likelihood formulation corresponding to independent observations of the individual marginal likelihood. The independence assumption across data points has various benefits. First, provided that the integrals in Eq. (3) are analytically tractable, the computational complexity scales linearly with the training set size. More importantly, directly optimizing the marginal likelihood avoids any complexity for estimating a full posterior over weights, including complex covariance structure and multi-modal distributions. A well-established alternative to Bayesian model training with posterior inference is learning by model selection. This approach suggests marginalizing out all latent variables, comparing the marginal likelihoods of all possible hypotheses, and choosing the one providing the highest response kass1995bayes . Choosing the hyperparameter value set that maximizes the marginal likelihood serves as training, hence avoids the posterior inference step on latent variables. Referred to as empirical Bayes Efron:2012

, this technique is fundamental for fitting non-parametric models such as Gaussian Processes (GPs). The optimization objective for empirical Bayes on a BPN is

(4)

which is amenable to using mini-batches for further scalability as illustrated in the last step of the derivation. The density function of the marginal likelihood of a training data point is identical to the posterior predictive density for new test data . Hence, an analytic approximation developed for training is directly applicable to test time.

Analytic marginalization of local weights with moment matching.

Marginalizing out the local weights in Eq. (3) is an intractable problem due to the highly nonlinear neural net appearing in the prior density. However, we can marginalize over the weights to approximate it by recursive moment matching resorting to the Central Limit Theorem (CLT). This technique has previously been used in BNNs for other purposes, such as expectation propagation ghosh2016assumed , fast dropout wang2013fast , and variational inference wu2018fixing . We employ the same technique for marginalizing out the weights of the BLNN. Focusing on a single data point and the -th hidden fully-connected layer222Convolutional layers follow analogously. consisting of

units with an arbitrary activation function

, the post-activation layer output is given as , where . The -th pre-activation output is a sum of terms which allows assuming it to be normal distributed via the CLT independent of the individual distributions of the and terms. The mean and the variance of this random variable can be computed as

(5)
(6)

The mean and the variance of the weights are readily available via the distributions , leaving only the first two moments of

undetermined. For common activations such as the ReLU,

, which we will rely on in this work, closed-form solutions to these moments are tractable frey1999variational given the moments of the pre-activations of the previous layer . This gives a recursive scheme terminating at the input layer, since we have . As is a constant, its first moment is itself and the second is zero. Consequently,

(7)

completing the full recipe of how all weights of a BNN can be recursively integrated out from bottom to top, subject to a tight approximation. Scenarios with stochastic input typically entail controllable assumptions on . The equations above remain intact after adding an expectation operator around and , readily available for any explicitly defined . Contrarily to the case in GPs, stochastic inputs can be trivially adapted into this framework, greatly simplifying the math for uncertainty-sensitive setups, such as PILCO deisenroth2011tpami . For a net with layers, the outcome of recursive moment matching is a distribution over the final latent hidden layer , simplifying the highly nonlinear integral in Eq. (3) to

(8)

The CLT is observed to provide a tight approximator after summing approximately ten random variables, which can easily be satisfied even for convolutional layers.333A single filter with a single input channel already sums nine values and the standard practice is to apply them over multiple input channels This leaves us in a much simpler situation as we are able to choose suitable distribution families for and . We show in Section 4 pairs of distributions suitable for regression and classification tasks.

Extensions.

Adaptation of CLT-based recursive moment matching to many other activation types and skip connections is feasible without further approximations. Max pooling can also be incorporated using such approximations as 

jang2017categorical

, but have also been shown to be replaceable altogether by strided convolutions without a performance loss 

springenberg2015striving . The variance computations in Eq. (6) and the ReLU specific post activation do not model any potential covariance structure between the pre-/post-activations units of a layer. While this is in principle feasible, e.g. along the lines of wu2018fixing , it leads to an explosion in the required computational cost and memory, hindering the applicability of the approach to deeper nets. Hence, we stick to a diagonal covariance structure throughout, as wu2018fixing have also shown only little test set performance benefit of modeling it.

3 Regularizing marginal likelihood with a trivial PAC bound

Training the objective in Eq. (4) is effective for fitting a predictor on data. It also naturally provides a learned loss attenuation mechanism. However, it lacks two key advantages of the Bayesian modeling paradigm. First, as the hyperparameters of the weight priors are left free for model fitting, they no longer serve as a means for incorporating prior knowledge. Second, it is well-known from the GP literature that marginal likelihood based training is prone to overfitting for models with a large number of hyperparameters bauer2016understanding 444

An obvious direct objection would be to remark that one could just go one level higher in the hierarchy, introducing hyperpriors over the parameters

. We derive in Appendix I an approach of how one would go ahead to do this, but preliminary experiments have shown it to perform a lot worse than the PAC-based approach.. We address these shortcomings by complementing the marginal likelihood objective of Eq. (4) with a penalty term derived from learning-theoretic first principles. We tailor the eventual loss only for robust model training and keep it maximally generic across learning setups. This comes at the expense of arriving at a generalization bound that makes a theoretically trivial statement, yet brings significant improvements to training quality as illustrated in our experiments.

PAC bounds have been commonly used for likelihood-free and loss-driven learning settings. A rare exception by germain2016pac proves the theoretical equivalence of a particular sort of PAC bound to variational inference. Similarly, we keep the notion of a likelihood in our risk definition, but differently, we correspond our bound to the marginal likelihood. Given a predictor chosen from a hypothesis class as a mapping from to , we define the true and the empirical risks as

(9)

for the data set drawn from an arbitrary and unknown data distribution . The risks and are bounded below by and above by zero. Although this setting relaxes the common assumption that bounds risk to the interval, it is substantially simpler than the one suggested in germain2016pac , which defines . This unboundedness brings severe technical complications, which are no longer relevant for our approach. Denoting by the distributions learnable over and by some prior distribution on , according to Theorem 2.1555The theorem assumes risks to be defined within the interval in the original paper, however, it is valid for any bounded risk. Furthermore, our risk definitions can trivially be squashed into up to a constant. in germain2009pac we have for any and any convex function the bound below

(10)

where . Using a quadratic distance measure and suitably bounding by exploiting the boundedness of the likelihood and in turn the risk, we get as an upper bound on the expected true risk

(11)

which is the objective we use to train BPN that contains the marginal likelihood in the first term and a theoretically-grounded regularizer in the second (see Appendix III for a detailed derivation of the bound). The regularizer term builds a mechanism to incorporate domain-specific prior knowledge via and substantially improves the quality of uncertainty estimations.

4 Application to classification and regression

Below we illustrate how the BPN-PAC framework can be used in two concrete widespread learning setups. Specifically, we design and for regression and classification.

Regression.

We place a normal likelihood over the targets, treating the as the mean and another normal as the distribution over parameterizing both mean and variance with a BNN, giving for the -th sample in Eq. (8) with , and that

The approximation is computed via another moment matching step approximating the result of the inner integral with a normal distribution (See Appendix II.2 for the derivation) while the final equality follows directly from standard results on normal distributions. We bound in Eq. (11) by exploiting that is fixed prior to training. Consequently, we obtain as a bound (see also Appendix III), which gives only a trivial performance guarantee (exceeding the maximum possible risk) but provides a justified training scheme for BPN.

Classification.

For

class classification with one-hot encoded targets

, one can use a categorical distribution and for a Dirichlet parameterized by , giving us for the -th term in Eq. (8),

where . Since the computational bottleneck on marginalizing the weights through the neural net layers is circumvented by analytically computed CLT-based moment matching, the final integral can be efficiently approximated by sampling. Similarly as within the regression setting, a trivial bound on can be obtained (see Appendix III).

5 Related work

CLT-based moment matching of local weight realizations.

The objective for variational inference on BNNs gal2016dropout ; kingma2015variational ; wu2018fixing , optimizing a global variational posterior , consists of a computationally intractable that decomposes across data points. Fast dropout wang2013fast approximates these terms via local reparameterization with moment matching. The same local reparameterization has been later combined with a KL term to perform mean-field variational inference via MC sampling kingma2015variational or moment matching wu2018fixing . Our BLNN formulation is akin to an amortized VI approach learning a single global set of posterior parameters for the variational posterior approximation . We use the same trick to marginalize the local weights, which keeps the machinery intact until the top-most step where the order of the and operations is swapped. This small change, however, has a significant impact on the quality of uncertainty estimations.

Wide neural nets as GPs.

The equivalence of a GP to a weight-marginalized BNN with a single infinitely wide hidden layer has been discovered long ago neal1995bayesian using the multivariate version of CLT. This result has later been generalized to multiple dense layers leep2018deep ; matthews2018gaussian , as well as to convolutional layers alonso2019deep . The asymptotic treatment of the neuron count makes this approach exact at the expense of lack of neuron-specific parameterization. The eventual GP has few hyperparameters (at most one per layer) to train, however, a prohibitively expensive covariance matrix to calculate. We employ the same training method on a middle ground where the hyperparameter count is double as many as a deterministic net and the cross-covariances across data points are not explicitly modeled.

Prior networks.

Earlier work malinin2018prior ; sensoy2018evidential parameterizes a prior to a classification-specific likelihood with deterministic neural nets, hence, discards model uncertainty. BPN and BPN-PAC reformulate prior networks independently from the output structure, extend them to support also model uncertainty, and introduce a principled scheme for their training.

PAC learning and neural networks.

The PAC framework within the GP literature has already been employed to bound the generalization performance of a trained GP Seeger:2002 as well as using it as a training objective reeb2018learning . For BNNs, closest prior work to ours uses PAC-Bayesian bounds for training Dziugaite:2017 ; Zhou:2018 . However, many design choices within the PAC formulation are different and, unlike ours, the bound does not directly generalize to regression. Although Zhou et al. Zhou:2018 provide tight theoretical generalization performance bounds, the work does not investigate the quality of uncertainty estimations. As Dziugaite et al. Dziugaite:2017 report test error on binary MNIST classification only ( with the best performing architecture compared to our ) and the inference scheme is sampling based, we exclude it from the comparison in the next section.

6 Experiments

We evaluate BPN and its PAC-regularized version BPN-PAC on a diverse selection of regression and classification tasks. Complete details on the training procedure can be found in Appendix IV.666We provide an implementation of the model under https://github.com/manuelhaussmann/bpnpac/.

Regression.

max width= boston concrete energy kin8nm naval power protein wine Sparse GP bui2016deep MC Dropout gal2016dropout VarOut PBP lobato2015probabilistic DVI wu2018fixing BPN-Hyper (Ours) BPN (Ours) BPN-PAC (Ours)

Table 1: Average test log-likelihood standard error over 20 random train/test splits. N/d give the number of data points in the complete data set and the number of input feature.

We evaluate the regression performance of BPN-PAC and the baselines on eight standard UCI benchmark data sets. Adopting the experiment protocol introduced in lobato2015probabilistic , we use 20 random train-test set splits comprising and of the samples, respectively. The nets consist of a single hidden layer with 50 units and ReLU nonlinearities.777Except for the larger protein, for which we use 100 hidden units. The hypothesis class in this task is over the prior parameters , with precision , while the posterior is given as . We compare BPN-PAC against the state of the art in BNN inference methods that do not require sampling across neural net weights, such as Probabilistic Back-Propagation (PBP) lobato2015probabilistic and Deterministic Variational Inference (DVI) wu2018fixing , which use the CLT-based moment matching for expectation propagation and VI, respectively. For completeness, we also compare against the two most common sampling-based alternatives such as Variational Dropout (VarOut) kingma2015variational ; molchanov2017variational and MC Dropout gal2016dropout . In the results summarized in Table 1, BPN-PAC outperforms all baselines in the majority of the data sets and is competitive in the others.

max width=0.4 MNIST CIFAR10 NNGP leep2018deep Convolutional GP wilk2018convolutional ConvNet GP alonso2019deep - Residual CNN GP alonso2019deep - ResNet GP alonso2019deep - BPN (Ours) BPN-PAC (Ours)

Table 2: Test error in  on two image classification tasks. BPN reaches lower error rate than previously proposed neural net based GP construction alternatives by two convolutional layers with filters of size and stride

. BPN converges in 50 epochs, amounting to circa 30 minutes of training time on a single GPU. The GP alternatives have been reported to have significantly larger time and memory requirements.

The PAC regularization improves over BPN in all data sets except one. We also report results for a sparse GP with 50 inducing points, which approximates a BNN of one infinitely wide hidden layer neal1995bayesian

. As expected, the GP sets a theoretical upper bound on BPN-PAC as well as the baselines for one hidden layer architectures. Lastly, we compare our tediously derived PAC regularizer to straightforward Maximum-A-Posteriori estimation on the BPN hyperpriors (BPN-Hyper) (see Appendix I for details), which deteriorates performance on all UCI data sets.

Classification and out-of-domain detection.

max width=0.8 MNIST Fashion-MNIST CIFAR 1-5 CIFAR 6-10 (In Domain) (Out-of-Domain) (In Domain) (Out-of-Domain) Test Error (%) ECDF-AUC Test Error(%) ECDF-AUC MC Dropout VarOut DVI PN/EDL BPN (Ours) BPN-PAC (Ours)

Table 3: Test error and the area under curve of the empirical CDF (ECDF-AUC) of the predictive entropies on two pairs of datasets. Smaller values are better for both metrics.

We evaluate classification and out-of-domain (OOD) sample detection performance of BDN-PAC on image classification with deep architectures, adhering to the protocol repeatedly used in prior work louizos2017multiplicative ; sensoy2018evidential . We train LeNet-5 networks on the MNIST train split, evaluate their classification accuracy on the MNIST test split as the in-domain task, and measure their uncertainty on the Fashion-MNIST888Due to the license status of the not-MNIST data being in conflict with the affiliation of the authors, we have to change the setup of earlier work, e.g. lakshminarayanan2017simple ; louizos2017multiplicative ; sensoy2018evidential , using instead Fashion-MNIST as the closest substitute. data set as the out-of-domain task. We expect from a perfect model to predict true classes with high accuracy on the in-domain task and always predict a uniform probability mass on the out-of-domain task, i.e. the area under curve of the empirical CDF (ECDF-AUC) of its predictive distribution entropy is zero. We perform the same experiment on CIFAR10 using the first five classes for the in-domain task and treating the rest as out-of-domain. We use as the prior on the class assignment parameters, which has the uniform probability mass on its mean, encouraging an OOD alarm in the absence of contrary evidence. In Table 3, we compare BPN-PAC against EDL sensoy2018evidential

, the state of the art in neural net based uncertainty quantification, also the non-Bayesian and heuristically trained counterpart of BPN-PAC. We consider EDL also as the special case of Prior Networks

malinin2018prior that does not use OOD data during training time, commensurate for our training assumptions. We maintain MC Dropout, VarOut, and DVI as baselines also in this setup. BPN-PAC improves the state of the art in all four metrics except the CIFAR10 in-domain task, where it ranks second after the prediction time weight sampling based (hence less scalable) MC Dropout. Remarkably, BPN-PAC detects the OOD samples with significantly better calibrated ECDF-AUC scores than EDL, the former state of the art.

Comparison to GP variants.

We lastly evaluate the impact of local weight realization on prediction performance by comparing BPN-PAC to GPs with kernels derived from BNNs with global weight realizations alonso2019deep ; leep2018deep ; neal1995bayesian

on MNIST and CIFAR10 data sets. It is technically not possible to perform this evaluation in a fully commensurate setup, as these baselines assume infinitely many neurons per layer and do not have weight specific degrees of freedom. Furthermore, 

alonso2019deep performs neural architecture search and leep2018deep uses only part of the CIFAR10 training set reporting that the rest does not fit into the memory of a powerful workstation. We nevertheless view the performance scores reported in these papers as practical upper bounds and provide a qualitative comparison. For the choice of neural net depth, we take NNGP leep2018deep as reference and devise a contrarily thin two-layer convolutional BPN-PAC network. The results and the architectural details are summarized in Table 2. BPN and BPN-PAC can reach lower error rates using significantly less computational resources.

Computational cost.

max width=0.6 Training per iteration Prediction MC Dropout VarOut DVI PN / EDL BPN BPN-PAC

Table 4: Per data point computational cost analysis in FLOPs. F: Forward pass cost of a deterministic neural net. W: Number of weights in the neural net. L: Analytical calculation cost for the exact or approximate likelihood or the loss term. S: Number of samples taken for approximation. R: The cost of the regularization term per unit (weight or data point).

Table 4 summarizes the computational cost analysis of the considered approaches. MC Dropout and VarOut can quantify uncertainty only by taking samples across weights, which increases the prediction cost linearly to the sample count. DVI and BPN-PAC perform the forward pass during both training and prediction time via analytical moment matching at double and triple costs, respectively. Both methods have sampling cost for intractable likelihoods.999Even this sampling step could be avoided by a suitable Taylor approximation, see e.g. Appendix B.4/B.5 in wu2018fixing . As the added approximation error to stay completely sampling-free was more detrimental to model performance than a cheap MC approach in preliminary experiments, we stay with the latter for both. BPN-PAC may also have another additive per-data-point sampling cost for calculating intractable functional mapping regularizers. Favorably, both of these overheads are only additive to the forward pass cost, i.e. sampling time is independent from the neural net depth, hence they do not set a computational bottleneck. The training and prediction cost of BPN-PAC is three times PN and EDL which build on deterministic neural nets. However, it provides substantial improvements on both prediction accuracy and uncertainty quantification.

7 Conclusion

In this paper, we present a method for performing Bayesian inference within the framework of prior networks. Employing empirical Bayesian methods for inference and combining it with PAC-bounds for regularization, we achieve higher accuracy and better predictive uncertainty estimates while maintaining scalable inference. Exact inference in a fully Bayesian model such as a GP (c.f. Table 

1) or Hamiltonian Monte Carlo inference for BNNs bui2016deep

are known to provide better error rates and TLL scores, yet their computational demand does not scale well to large networks and data-sets. Our method, on the other hand, shows strong indicators for improvement in uncertainty quantification and predictive performance when compared to other BNN approximate inference schemes with reasonable computational requirements. These benefits of the BPN-PAC framework might especially be fruitful in setups such as model-based deep reinforcement learning, active learning, and data synthesis, where uncertainty quantification is a vital ingredient of the predictor.

References

  • (1) M. Bauer, M. v.d. Wilk, and C.E. Rasmussen. Understanding Probabilistic Sparse Gaussian Process Approximations. In NIPS, 2016.
  • (2) C. Blundell, J. Cornebise, K. Kavukcuoglu, and D. Wiestra. Weight Uncertainty in Neural Networks. In ICML, 2015.
  • (3) T. Bui, D. Hernández-Lobato, J. M. Hernandez-Lobato, Y. Li, and R. Turner. Deep Gaussian Processes for Regression using Approximate Expectation Propagation. In ICML, 2016.
  • (4) O. Catoni. PAC-Bayesian Supervised Classification: The Thermodynamics of Statistical Learning. IMS Lecture Notes Monograph Series, 56, 2007.
  • (5) M.P. Deisenroth, D. Fox, and C.E. Rasmussen. Gaussian Processes for Data-Efficient Learning in Robotics and Control. IEEE Trans. Pattern Analysis and Machine Intelligence, 17(2):402–423, 2015.
  • (6) G. Dziugaite and D.M. Roy. Computing Nonvacuous Generalization Bounds for Deep (Stochastic) Neural Networks with Many More Parameters than Training Data. In UAI, 2017.
  • (7) G.K. Dziugaite and D.M. Roy. Computing nonvacuous generalization bounds for deep (stochastic) neural networks with many more parameters than training data. In UAI, 2017.
  • (8) B. Efron.

    Large-scale Inference: Empirical Bayes Methods for Estimation, Testing, and Prediction

    , volume 1.
    Cambridge University Press, 2012.
  • (9) B.J. Frey and G.E. Hinton. Variational Learning in Nonlinear Gaussian Belief Networks. Neural Computation, 11(1):193–213, 1999.
  • (10) Y. Gal and Y. Ghahramani. Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning. In ICML, 2016.
  • (11) Y. Gal and Z. Ghahramani. Bayesian Convolutional Neural Networks with Bernoulli Approximate Variational Inference. arXiv preprint arXiv:1506.02158, 2015.
  • (12) A. Garragia-Alonso, C.E. Rasmussen, and L. Aitchinson. Deep Convolutional Networks as Shallow Gaussian Processe. In ICLR, 2019.
  • (13) P. Germain, F. Bach, A. Lacoste, and S. Lacoste-Julien. PAC-Bayesian Theory Meets Bayesian Inference. In NIPS, 2016.
  • (14) P. Germain, A. Lacasse, F. Laviolette, and M. Marchand.

    PAC-Bayesian Learning of Linear Classifiers.

    In ICML, 2009.
  • (15) S. Ghosh, J. Yedidia, and F.M. Delle Fave. Assumed Density Filtering Methods for Scalable Learning of Bayesian Neural Networks. In AAAI, 2016.
  • (16) M. Haussmann, F.A. Hamprecht, and M. Kandemir.

    Sampling-Free Variational Inference of Bayesian Neural Nets with Variance Backpropagation.

    In UAI, 2019.
  • (17) J. M. Hernández-Lobato and R. Adams. Probabilistic Backpropagation for Scalable Learning of Bayesian Neural Networks. In ICML, 2015.
  • (18) E. Jang, S. Gu, and B. Poole. Categorical Reparameterization with Gumbel Softmax. In ICLR, 2017.
  • (19) R.E. Kass and A.E. Raftery. Bayes Factors. Journal of the American Statistical Association, 90(430):773–795, 1995.
  • (20) A. Kendall and Y. Gal.

    What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?

    In NIPS, 2017.
  • (21) D.P. Kingma, T. Salimans, and M. Welling. Variational Dropout and The Local Reparameterization Trick. In NIPS, 2015.
  • (22) B. Lakshminarayanan, A. Pritzel, and C/ Blundell. Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles. In NIPS, 2017.
  • (23) J. Lee, Y. Bahri, R. Novak, S. Schoenholz, J. Pennington, and J. Sohl-Dickstein. Deep Neural Networks as Gaussian Processe. In ICLR, 2018.
  • (24) C. Louizos and M. Welling. Multiplicative Normalizing Flows for Variational Bayesian Neural Networks. In ICML, 2017.
  • (25) D.J. MacKay. Probable Networks and Plausible Predictions – A Review of Practical Bayesian Methods for Supervised Neural Networks. Network: Computation in Neural Systems, 6(3):469–505, 1995.
  • (26) A. Malinin and M. Gales. Predictive uncertainty estimation via prior networks. In NeurIPS, 2018.
  • (27) A. G. de G. Matthews, M. Rowland, J. Hron, R. E. Turner, and Z. Ghahramani. Gaussian process behaviour in wide deep neural networks. In ICLR, 2018.
  • (28) D. McAllester. PAC-Bayesian Model Averaging. In COLT, 1999.
  • (29) D. McAllester. PAC-Bayesian Stochastic Model Selection. Machine Learning, 51:5–21, 2003.
  • (30) D. Molchanov, A. Ashukha, and D. Vetrov. Variational Dropout Sparsifies Deep Neural Networks. In ICML, 2017.
  • (31) R. Neal. Bayesian Learning for Neural Networks. PhD Thesis, 1995.
  • (32) J. Quionero-Candela, M. Sugiyama, A. Schwaighofer, and N. D. Lawrence. Dataset Shift in Machine Learning. The MIT Press, 2009.
  • (33) D. Reeb, A. Doerr, S. Gerwinn, and B. Rakitsch. Learning Gaussian Processes by Minimizing PAC-Bayesian Generalization Bounds. In NeurIPS, 2018.
  • (34) M. Seeger. PAC-Bayesian generalisation error bounds for Gaussian process classification. Journal of machine learning research, 3:233–269, 2002.
  • (35) M. Seeger. PAC-Bayesian Generalisation Error Bounds for Gaussian Process Classification. Journal of Machine Learning Research, 3:233–269, 2002.
  • (36) M. Sensoy, L. Kaplan, and M. Kandemir. Evidential Deep Learning to Quantify Classification Uncertainty. In NeurIPS, 2018.
  • (37) J.T. Springenberg, A. Dosovitskiy, T. Brox, and M. Riedmiller. Striving for Simplicity: The All Convolutional Net. In ICLR, 2015.
  • (38) M. v.d. Wilk, C.E. Rasmussen, and J. Hensman. Convolutional Gaussian Processes. In NIPS, 2017.
  • (39) S.I. Wang and C.D. Manning. Fast Dropout Training. In ICML, 2013.
  • (40) A. Wu, S. Nowozin, E. Meeds, R.E. Turner, J.M. Hernández-Lobato, and A.L. Gaunt. Deterministic Variational Inference for Bayesian Neural Networks. In ICLR, 2019.
  • (41) W. Zhou, V. Veitch, M. Austern, R.P. Adams, and P. Orbanz.

    Non-vacuous generalization bounds at the imagenet scale: a PAC-Bayesian compression approach.

    In ICLR, 2019.

I: BPN with Hyperpriors

Instead of relying on the PAC-bound based one could try to incorporate further hyper-priors on . This would not share the benefit of allowing for the incorporation of prior knowledge on the functional mapping itself, but could reintroduce the missing regularization BPN faces. The hierarchical model then has the following structure:

(12)
(13)
(14)
(15)

The marginal to be optimized over is then given as

(16)

The first term is our regular marginal likelihood, while the second serves as as a regularizer as an optimization scheme aims to choose such that the marginal likelihood is high, but also that the prior density is large. The form of this hyperprior will vary depending on the problem at hand, but if we consider e.g. the -th weight of the BNN to follow a normal distribution, we have

(17)

An obvious choice for the prior is then given as

(18)

The regression results summarized in Table 1 however show that this approach tends to perform worse than both the PAC regularized BPN as well as the unregularized BPN.

II: Further details on the BPN derivation

II.1 First two moments of the ReLU activation

Mean and variance of a normally distributed variable transformed by the ReLU activation function are analytical tractable [9]. Following the notation from the main paper, we have that for where for some mean and variance they can be computed as

(19)
(20)

where and are the cdf and pdf of the standard normal distribution respectively.

II.2 Derivation of the marginaliztion for regression

We approximate the marginal distribution

(21)

with a normal distribution by a further moment matching step. Dropping the indices and for notational simplicity, the mean of the right hand side is given as

(22)
(23)
(24)

For the variance term we rely on the law of total variance and have

(25)
(26)
(27)
(28)

where the last integral is given as the mean of a log-normal random variable. Altogether we end up with the desired

(29)

III: Derivation of the PAC-bound

This section gives a more detailed derivation of the individual results stated in the main paper. As stated there, given a predictor chosen from a hypothesis class as a mapping from to , we define the true and the empirical risks as

(30)
(31)

for the data set drawn from an arbitrary and unknown data distribution . and are bounded below by and above by zero.

Theorem 2.1 in [14] gives us that for any and any convex function

(32)

where . The PAC framework necessitates a convex and non-negative distance measure for risk evaluations. Common practice is to rescale the risk into the unit interval, define the KL divergence as the distance measure, and upper bound its intractable inverse [13] using Pinsker’s inequality [4, 6]. We follow an alternative path. As our risk is bounded but not restricted to the unit interval, we choose our distance measure as and avoid the Pinsker’s inequality step.

Adapting the standard KL inversion trick [35] to the Euclidean distance, we can simply define for some . We apply this function to both sides of the inequality and get

where and by definition and because and we have . Since

(33)

directly follows from , we bound the true risk as

This outcome has a similar structure to application of Pinsker’s inequality to a setup with risk defined on the unit interval, but without such a restriction. Hence, the implied upper bound is no longer trivial. In order to arrive at the final bound we have to further approximate each of the two terms of the bound.

For the first term, we have that

(34)
(35)
(36)

where the inequality uses that and the equality follows by the marginalisation techniques discussed in the main paper.

To get a tractable second term we process further. Exploiting the fact that

(37)

for any and , we can drop the the expectation term and get

(38)

For a multiclass classification, the likelihood is bounded into the interval such that with we have that

(39)

For regression, with the likelihood , we have that and are bounded from above by and from below by , i.e. by the density at the mode of a normal distribution with precision . Hence,

(40)

Combining these relaxations we get the objectives described in the main paper.

IV: Experimental details and further experiments

This section contains experiments on synthetic toy data as well as details on the hyperparameters of the experiments performed in the main paper. See https://github.com/manuelhaussmann/bpnpac/

for a pytorch implementation to reproduce the reported results.

IV.1 Synthetic experiment

In order to visualize the predictive uncertainty we include an experiments on 1d Regression focusing on the case of a data dependent noise structure to also show how to include prior knowledge in the model via the PAC regularization.

Regression on heteroscedastic noise.

We visualize the predictive uncertainty behavior of BPN and illustrate how it can benefit from the incorporation of prior knowledge via the PAC bound regularization in a 1d regression setting. We sample training inputs uniformly over the interval and the related targets are generated as , where . That is we have a sinusoidal function with location dependent observation noise. We build neural nets with two fully-connected hidden layers of 250 neurons each. To let the models explain the variance of the observed data in a maximally data-driven manner, we assume a homoscedastic observation noise with a fixed and large precision parameter . Visual inspection of the data shows its periodic structure with an increasing amplitude, and that the noise seems to grow with an increase in . We assume an oracle that provides us this prior knowledge in the form of a prior predictive . The posterior predictive is the hypothesis class determined by a BNN. Each weight in the BNN has its prior mean and log variance, as the tunable hyperparameters. As the model capacity is too high compared to the sparsity of the training data, we can see in Figure 1 that the performance of BDN without PAC regularization deteriorates especially in the noisier regions with . BDN-PAC, however, can fit to the data and predict the increase in observation noise reasonably well using the prior knowledge provided by .

Figure 1: Heteroscedastic Synthetic.

Comparison of BDN and BDN-PAC in the presence of prior knowledge. The dashed lines give three standard deviations of the true underlying noise, the light areas show three standard deviations of the predictive uncertainty for each model.

IV.2 Experimental details and hyperparameters

Regression.

The neural net used consists of a single hidden layer of 50 units for all data sets except protein, which gets 100. The results for all of the baselines except for Variational Dropout (VarOut) are quoted from the results reported by the respective papers who introduced them, while the results on the sparse GP are reported via [3]. For VarOut we rely on our own implementation as there are no official results. BPN, BPN-PAC, and VarOut all share the same initialization scheme for the mean and variance parameters for each weight following the initalization of [24], i.e. He-Normal for the means and for the log variances. VarOut gets a Normal prior with a precision of , and all three get an observation precision of , in order to encourage them to learn as much of the predictive uncertainty instead of relying on a fixed hyper-parameter. Note that we keep these values fix and data set independent, different to many of the baselines who set them to data set specific values given cross-validations on separate validation subsets.

Each model is trained with the Adam optimizer with default parameters for 100 epochs with a learning rate of , with varying minibatch sizes depending on the data set size.

Classification and out-of-domain detection.

The network for this task follows the common LeNet5 architecture with the following modifications. Instead of max-pooling layers after the two convolutional layers, the convolutional layers themselves use a larger stride to mimic the behavior. And for the more complex CIFAR data set the number of channels in the two convolutional layers are increased from the default 20,50 to 192 each, while the number of hidden units for the fully connected layer is increased from 500 to 1000 for that data set following [11].

Since there are no OOD results on the BNN baselines we compare against, we rely on our own reimplementations of them, ensuring that they each share the same initialization schemes as in the regression setup. For DVI we implement the diagonal version and use a sampling based approximation on the intractable softmax. Each model gets access to five samples whenever it needs to conduct a MC sampling approximation. All models get trained via the Adam optimizer with the default hyperparameters and a learning rate of . For EDL we rely on the public implementation the authors [36] provide and use their hyperparameters in order to learn the model.

GP Variants Comparison.

The results for the baselines are taken from the respective original papers. The nets for BPN and BPN-PAC consist of two convolutional layers with 96 filters of size and a stride of 5. They are trained until convergence (50 epochs) using Adam with the default hyperparameters and a learning rate of .