Exact posterior distributions of wide Bayesian neural networks

06/18/2020 ∙ by Jiri Hron, et al. ∙ 0

Recent work has shown that the prior over functions induced by a deep Bayesian neural network (BNN) behaves as a Gaussian process (GP) as the width of all layers becomes large. However, many BNN applications are concerned with the BNN function space posterior. While some empirical evidence of the posterior convergence was provided in the original works of Neal (1996) and Matthews et al. (2018), it is limited to small datasets or architectures due to the notorious difficulty of obtaining and verifying exactness of BNN posterior approximations. We provide the missing theoretical proof that the exact BNN posterior converges (weakly) to the one induced by the GP limit of the prior. For empirical validation, we show how to generate exact samples from a finite BNN on a small dataset via rejection sampling.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

A BNN is composed of a finite number of layers where the output of the th layer is function of the previous layer outputs , a nonlinearity , and parameters . For a fully connected network (FCN)

with , , , and for convenience.

Where a BNN differs from its NN equivalent is in the handling of the parameters

. In particular, a BNN treats the parameters as random variables following some prior distribution

, and—instead of gradient based optimisation—uses Bayes’ rule to calculate the posterior given a fixed dataset , ,

where and are the density functions of and , and is a likelihood function appropriate for the dataset.

2 Large width behaviour of the prior

Since direct interpretation of the parameter space distribution is difficult, neal1996 proposed to instead study the distribution over input-to-output mappings , , induced by computing the forward pass with the randomly sampled , i.e., for any measurable set (the usual Borel product -algebra is assumed throughout)

where we use to emphasise that is a fully determined by . Assuming a fully connected single layer architecture and an i.i.d. zero mean Gaussian prior over

with variance of the readout weights inversely proportional to the hidden layer width

, neal1996 was able to show that converges weakly111A sequence of distributions converges weakly to if for all real-valued continuous bounded . to a centred GP distribution as .

neal1996’s results were recently generalised to various deep NN architectures including convolutional, pooling, residual, and attention layers (matthews2018gaussian; lee2018deep; novak2019bayesian; garriga2019deep; yang2019v1; yang2019v2; hron20). These papers study the function space priors for a sequence of increasingly wide NNs, and establish their weak convergence to a centred GP distribution with covariance determined by the architecture and the underlying sequence of parameter space priors . To ensure the asymptotic variance neither vanishes nor explodes, is assumed zero mean with marginal variances inversely proportional to layer input size. For example, a common choice satisfying this assumption for a FCN is

(1)

for each . Throughout, we assume that the architecture and the sequence of parameter space priors was chosen such that for some fixed distribution , where denotes weak convergence.

3 Large width behaviour of the posterior

While the results establishing convergence of the function space prior have been very influential and provided useful insights, many practical applications of BNNs require computation of expectations with respect to the function space posterior. Some previous works (e.g., neal1996; matthews2018gaussian) have shown good empirical agreement of the wide BNN posterior with the one induced by for certain architectures, datasets, and likelihoods, but theoretical proof of the asymptotic convergence was up until now missing.

Here we prove that the sequence of exact function space posteriors —induced by the sequence of exact parameter space posteriors —converges weakly to , the Bayesian posterior induced by the weak limit of the priors , under the following assumption on the likelihood.222We abuse the notation in the rest of the document by treating and as both random variables and the values these variables take. Correct interpretation should be clear from the context.

Assumption 1.

The targets depend on the network parameters and the inputs only through

and there exists a measure such that the distribution of given the network outputs is absolutely continuous w.r.t.  for every value of . Further, the resulting likelihood written as a function of

satisfies for all , with a continuous bounded likelihood function.

Put another way, creftypecap 1 says that the data is modelled as conditionally independent of given (i.e., is a sufficient statistic), and the corresponding conditional distribution does not change with . Fortunately, this is satisfied by many popular likelihood choices like Gaussian

(2)

with the Lebesgue measure on , or categorical over any number of categories

where is the softmax function, each

is assumed to be a one-hot encoding of the label, and

is the counting measure on . Any continuous transformations of network outputs (like softmax) can be assumed part of the likelihood in the statement of our main result (see Appendix B for the proofs).

Proposition 1.

Assume on the usual Borel product -algebra on , creftypecap 1 holds for the chosen likelihood , and that . Then

(3)

with and the Bayesian posteriors induced by the likelihood and respectively the priors and .

Proposition 1 essentially says that whenever we can establish weak convergence of the prior, weak convergence of the posterior comes almost for free. Even though we usually cannot compute the exact parameter space posterior analytically, we will often be able to compute the exact function space posterior it (weakly) converges to in the wide limit.

A few technical comments are due. Firstly, we make the Borel product -algebra assumption only to exclude the cases where the coordinate projection is not continuous; all the prior work cited in Section 2 satisfies this assumption. Secondly, note that neither nor need to be Gaussian; whilst often will be (though exceptions exist; see yang2019v1; hron20), will not unless the prior and the likelihood are Gaussian. Finally, we emphasise by definition, i.e., even though does not appear in Proposition 1 explicitly, it is implicit in which means we could have equally well replaced Equation 3 by . In contrast to the finite case, in is not to be subscripted with since for ‘

’, the mapping between an infinite parameter vector

and is ill-defined, and some at first reasonably sounding definitions entail undesirable conclusions (see Section 4).

While Proposition 1 is useful, it does not imply convergence of certain expected values such as the predictive mean and variance. This is rectified by combining Corollary 1 with the results on convergence of expectations w.r.t. the prior (see (yang2019v1; yang2019v2) for an overview).

Corollary 1.

If is a real-valued continuous function such that with , then

(4)

4 Parameter space

In light of Section 3, you may wonder about the posterior behaviour of other quantities such as the parameters . Such questions are complicated by the fact that the dimension of the parameter space grows with , implying that any and , , are not distributions on the same (measurable) space, a necessity for establishing any form of convergence. We choose the resolution provided by the ‘infinite-width, finite fan-out’ interpretation (matthews2018gaussian) where for each , a countably infinite number of hidden units (convolutional filters, attention heads, etc.) and their corresponding parameters are instantiated, but only a finite number affects the layer output. In the FCN example, with

the neuron index

(5)

where , for all .

With the ‘infinite-width, finite fan-out’ construction, each is embedded into the same infinite dimensional space , and we interpret as the corresponding sequence of prior distributions on (with the usual Borel product -algebra). From now on, all results should be viewed as regarding prior and posterior distributions constructed in this way unless explicitly stated otherwise.

Assumption 2.

Let the assumptions of Proposition 1 hold, and let the underlying sequence of BNNs be composed of only fully connected, convolutional, and attention layers with the number of units (neurons, filters, heads) goes to infinity with

, or layers without trainable parameters (e.g., average pooling, residual connections). Further, let

be centred Gaussian with diagonal covariance with non-zero entries equal to for a fixed and appropriate (resp. diagonal of all ones under the NTK parametrisation),333In the Neural Tangent Kernel (NTK) (Jacot2018ntk) parametrisation, weights are a priori i.i.d. , and scaled by as part of the mapping, ensuring the induced is the same. See (jsd2020infinite) for more details. except for biases for which variance may be just .

Proposition 2.

Let creftypecap 2 hold, and denote (i.e., all the parameters except for the top-layer bias) and the corresponding marginal of by .

Then where is defined by , i.e., the parameters with prior variance inversely proportional to converge to (point mass at zero), and the others remain independent zero mean Gaussian with their prior variance (biases, and all parameters under the NTK parametrisation). The top-layer bias converges in distribution to the posterior induced by summing , , with where the two are treated as independent under the prior, and enter the likelihood as (see the end of the proof for the details). converges weakly to the product of the marginal limits over and .

Figure 1: Posterior sample mean and variance of a deep finite BNN converge to those of an NNGP (infinite BNN) as width increases (left to right). For a given training set, posterior mean and variances are computed using rejection sampling for the finite BNN, and in closed form as a GP posterior (see (rasmussen2006gaussian, page 16) for the NNGP, and Equation 16 in (lee2019wide) for the NTK). As width

increases, posterior predictions of a

finite BNN and NNGP become indistinguishable, but clearly different from the NTK. Presented is a fully connected network with hidden layers and Erf nonlinearities. All models use a Gaussian likelihood (Equation 2) with observation noise . NNGP and NTK predictions are computed with the Neural Tangents library (neuraltangents2020). For each width , a total of proposals are sampled, resulting in , , , and accepts respectively. See Figure 2 for a numerical measure of convergence.

To understand Proposition 2, we can note that the ‘infinite-width, finite fan-out’ construction ensures the posterior marginal over ‘active’ parameters (those used in computation of the outputs) is exactly the posterior distribution we would have obtained had no extra parameters been introduced. Since convergence on the infinite dimensional space typically implies convergence of all finite-dimensional marginals,444By the continuous mapping theorem for weak convergence, and, e.g., by definition for the total variation distance (modulo continuity, resp. measurability, of finite coordinate projections). we can draw conclusions about the behaviour of the ‘active’ marginals.

However, the types of conclusions we can draw also show the crucial limitation of this approach. For example, in the ‘infinite-width, finite fan-out’ sense implies that for any continuous bounded real-valued function , , including which only depend on the ‘active’ parameters of the th network for any chosen ; unfortunately, this does not guarantee the expectations are close for the th network itself! In other words, Proposition 2 tells us little about behaviour of finite BNNs.

One might still wonder about the limit under the standard parametrisation, since most NN architectures output a constant if all the parameters but biases are zero. This is due to the requirement that is selected such that converges, which will generally force each weight’s variance to vanish as (see Equation 1 for an example). This typically translates into the same scaling under the posterior essentially because neurons in each layer are exchangeable (matthews2018gaussian; garriga2019deep; hron20), implying no single parameter will be ‘pushed too far away’ from the prior by the likelihood (a similar effect can be seen in Example 1 in Appendix A). Since concentration in an increasingly ‘small’ region around zero is sufficient for weak convergence,555‘Small’ w.r.t., e.g., which metrises the assumed product topology. the result follows.

Furthermore, we emphasise only describes a point in the distribution space to which the ‘infinite-width, finite fan-out’ converge, but should not be interpreted as a distribution over parameters of an infinitely wide BNN since the map for ‘’ is not well-defined (as mentioned in Section 3). To see why, let us consider the single-layer FCN example with prior as in Equation 1. Our first impulse may be to define the map as the pointwise limit of the functions where each takes in a point and computes the function implemented by the NN with corresponding index (as in Equation 5).

If the function is to be real-valued, we need

to be well-defined and finite which is only true if is summable for each at the same time. This is not an issue under the standard parametrisation (since a.s.), but it is not satisfied under the NTK parametrisation where the support of is all of .

Since the pointwise limit approach yields a.s. when a.s., and a.s. or is undefined (if does not exist) when has full support (NTK parametrisation), we may instead try to look for satisfying , , for all possible finite , with the limit posterior from Proposition 1. As demonstrated, this is not satisfied by the pointwise limit. It also cannot be satisfied by any other deterministic if , and there will be more than one solution if has full support on (at least if we only require agreement on a fixed countable marginal of ).666Since and the set of finite subsets of a countable set is countable, there is a bijection between the countable space of all evaluated at the countably many points for each of the countably many possible , and the countable number of dimensions of .

All in all, we see no obvious way of defining without placing restrictive assumptions on . This is related to the dimensionality issue discussed at the beginning which forced us to adopt the additional ‘infinite-width, finite fan-out’ assumption. Since the ‘infinite-width, finite fan-out’ interpretation proved much less innocuous than in the case of input-to-output mappings where it is little more than a technicality (matthews2018gaussian; garriga2019deep; hron20), we study two alternative choices in Appendix A. Unfortunately, neither yields a parameter space limit free of the pathologies we observed here.

5 Experimental validation

We verify our results empirically by sampling from the exact posterior of a finite BNN posterior using rejection sampling. For a given width , we use as the proposal density that envelops our target unnormalised posterior density:

(6)

where is the unnormalised Gaussian likelihood from Equation 2. In Figures 1 and 2 we confirm that as the finite BNN gets wider, its posterior sample mean and covariance converge to those of the NNGP.

6 Conclusion

We proved the sequence of exact posteriors of increasingly wide BNNs converges to the posterior induced by the infinite width limit of the prior and the same likelihood (when treated as a function of the NN outputs only). If the computation of the infinite width limit posterior is tractable, our result opens a path to tractable function space inference even if evaluation of parameter space posterior is intractable. We further studied conditions under which infinite width analysis in parameter space is possible and outlined several potential pitfalls. In experiments, we have shown how to draw samples from the exact BNN posterior on small datasets, and validated our function space convergence predictions. We hope our work provides theoretical basis for further study of exact BNN posteriors, and inspires development of more accurate BNN approximation techniques.

Figure 2: Posterior sample mean and covariance of a deep finite BNN converge to those of an NNGP (infinite BNN) as width increases (left to right, as measured by the relative Frobenius distance). Values evaluated at equidistant test points on . See Figure 1 for a visual demonstration and further details about the setup.

Acknowledgements

We thank Lechao Xiao for discussion, help, and feedback on the initial draft, and Wei Hu and Matej Balog for useful discussions.

References

Appendix A Alternatives to the ‘infinite width, finite fan-out’ interpretation

The following is an (admittedly unconventional) attempt to gain intuition for the behaviour of parameter space posterior in wide BNNs by studying the simpler Bayesian linear regression model, and in particular, by measuring the discrepancy between the prior and the posterior of this model in 2-Wasserstein distance and Kullback-Leibler (KL) divergence.

Example 1.

Let be a matrix of inputs and the vector of corresponding regression targets. Assume the usual Bayesian linear regression model , ; to avoid notational clutter, we take . The induced posterior has closed form with

Note that if was replaced, e.g., by the outputs of FCN’s layer, would be converging almost surely to a constant matrix as (for an overview see, e.g., yang2019v1; yang2019v2). To simplify our analysis, we assume the entries of

are uniformly bounded with the implicit understanding that the results would have to be converted to high probability statements in order to hold for an actual BNN (e.g., using the results of

yang2019v1).

(I) We look at the squared 2-Wasserstein distance between the posterior and the prior

By the uniform entry bound assumption, since is constant and , where and

are the minimum and maximum eigenvalues. With a bit of algebra, one can also show that the value of the trace can be upper bounded by

. Hence the Wasserstein distance between the prior and the posterior shrinks to zero at rate.777As an aside, if both and were allowed to vary, the distance would be proportional to .

If we used the ‘infinite-width, finite fan-out’ construction, would be converging weakly to , and the same can be shown for the induced posterior analogously to Proposition 2. On one hand, the convergence of the prior-to-posterior 2-Wasserstein distance to zero could be interpreted as a confirmation of this result. On the other hand, for all , meaning that the prior (and thus the posterior) never approaches in . This is because is defined w.r.t. the metric here which is inappropriate for since such is not a.s. square summable.888Convergence to could be recovered by using the metric induced instead. Since metrises the product topology w.r.t. which weak convergence in Proposition 2 is defined, and converegence in

is equivalent to weak convergence and convergence of the first two moments (

, ), a proof analogous to that of Proposition 2 yields the result.

(II) The KL-divergence between the posterior and the prior is

where we know that the sum of all the terms from the second to the last must be non-negative (it is equal to ). Hence we can lower bound by which is order one (can be obtained analogously to the upper bound derived in our discussion of ). This is perhaps not surprising as KL-divergence is lower bounded by (two times the square of) the total variation distance (Pinsker’s inequality) in which even —for some fixed —does not converge to .

While Example 1 assumes the standard parametrisation, comparing to the conclusions that would have been drawn under the NTK parametrisation is instructive. Since the posterior remains Gaussian (with and scaled respectively by and ), it is easy to check that KL-divergence remains unchanged (as it is invariant under any injective transformation), but the 2-Wasserstein distance grows by a multiplicative factor of since

where the infinum ranges over all joint distributions

on which have and as their respective -dimensional marginals. In other words, the prior-to-posterior Wasserstein distance does not converge to zero when Euclidean distance is used (it will converge to zero when used with the metric from Footnote 8 though, which is why the above is not a contradiction of Proposition 2; cf. the last statement in (I), Example 1).

The above implies we need to be careful in interpreting rates of convergence, and in particular, that some discrepancy measures like the Wasserstein metrics necessitate choice of measurement unit for the parameters. KL-divergence does not suffer from such issues but its relation with the total variation distance could make it excessively strict (total variation distance implies weak convergence but the reverse is not true; see the example in the last statement in (II) in Example 1).

While we cannot offer a definite answer to the above issues, it is worth pointing out that what we care about in practice is the accuracy of the function space approximation where issues of changing dimensionality disappear, and measurement units are dictated by the dataset we are trying to model. Hence a potentially more fruitful approach would be to refocus our attention from the parameter space to measurement and optimisation of function space approximation accuracy.

Appendix B Proofs

Proof of Proposition 1.

By the definition of weak convergence, all we need to show is that for any continuous bounded function , the expectation converges . The key observation is that creftypecap 1 ensures each posterior has a density w.r.t. the prior (e.g. schervish2012theory, theorem 1.31)

where . Substituting

Since is continuous bounded and by assumption, by . Similarly, is continuous bounded, and thus also

The proof is concluded by observing theorem 1.31 (schervish2012theory) applies also to the density of w.r.t. . ∎

Proof of Corollary 1.

Following the proof strategy of Proposition 1

we see that all we need to prove is the convergence of the integral on the right hand side ( established in Proposition 1). Let and . Since is continuous, by the continuous mapping theorem. Because the expectation of converges under the prior by assumption, is uniformly integrable by theorem 3.6 in (billingsey86). Because is bounded by assumption, is uniformly integrable as well by definition. Since by the continuous mapping theorem again

by theorem 3.5 in (billingsey86). ∎

Proof of Proposition 2.

By (billingsey86, theorem 2.4), it will be sufficient to prove convergence of finite dimensional marginals of . Denoting indices of this marginal by and the corresponding sequence of marginal distributions by , all we need to establish is that for any continuous bounded real-valued function , By creftypecap 1, we can rewrite both the integrals in terms of the prior; for the this yields

where denotes the appropriate subset of entries of , and . By the same argument as in the proof of Proposition 1, where by assumption. Hence we can focus on

where are all the entries of not in , and the equality is by boundedness of both and , the Tonelli-Fubini theorem, and diagonal Gaussian prior (implying for all ). Also by the diagonal Gaussian assumption, we can use the change of variable formula to replace any weight by for an appropriate and i.i.d. (this step is of course not necessary under the NTK parametrisation). The r.h.s. above can then be rewritten as

where

Let us assume there are no top-layer biases for now, and add them back at a later point. Our current goal is to show that pointwise for some function such that for under the standard, and under the NTK parametrisation. Since both and are bounded by assumption, are uniformly bounded and thus the pointwise convergence could be combined with the dominated convergence theorem to conclude the proof. Since is continuous by assumption, under the standard, and for all values under the NTK parametrisation. One can easily verify that in both cases as required. All that remains is thus to show pointwise.

To do so, we will show that fixing a finite set of parameters while letting the others vary does not affect the convergence of the induced input-to-output mappings where denotes the function space distribution given the fixed . We achieve this by a modification of the proof techniques in (matthews2018gaussian; garriga2019deep; hron20)

. The arguments therein are invariably build around an inductive application of the central limit theorem for infinitely exchangeable triangular arrays (eCLT) due to

Blum1958

to linear projections of a finite subset of units. Since convergence in distribution of all such projections implies pointwise convergence of the characteristic function (

is continuous bounded), convergence in distribution follows. It will thus be sufficient to show how to modify the recursive application of the eCLT.

What follows is a high-level description of this modification; a detailed description showcasing how to fill in the details on the FCN example can be found in Section B.1. Let us consider layer and the corresponding vector of activations (by the definition of , we only need weak convergence of evaluated on the training set). By theorem 2.4 in (billingsey86), weak convergence is implied by weak convergence of all finite marginals. Denote the indices of these final marginals by , and define .

As mentioned, weak convergence of is implied by weak convergence of linear projections, so fix a vector , and consider the scalar random variable . Conveniently, can always be rewritten as

for some (random) coefficients , and a subset of indices s.t. , for all . Here are essentially a combination of the projection coefficients and the inputs to the layer, whereas are the Gaussian random variables constituting (either directly under NTK, or by reparametrisation under standard parametrisation); see Equations 10, 9 and 8 for an example.

Defining , for all large enough

(7)

Using , the first term on the r.h.s. can be shown to converge to zero in probability. Since , and the remaining sum is properly scaled by , an argument analogous to that of matthews2018gaussian can be used to establish it converges in distribution to the desired limit as it does not depend on any of the fixed parameters in the th layer, and dependence on the fixed values in the previous layers vanishes as by the recursive application of the above argument (of course there are no terms that depend on the fixed values for when ). A simple application of Slutsky’s lemmas (if and in probability, , then , and ) then yields that for any fixed values of converges in distribution to the desired limit. Hence, for any fixed value of as desired.

All that remains is to add back the top-layer biases. As we have seen, the distribution without top-layer biases converges to (the prior limit after subtraction of top-layer biases), and thus the biases may be simply added on top. In the case of a Gaussian , this will result in an additive term in the covariance matrix as usual (this can be proved by standard argument via characteristic function using the assumed Gaussian diagonal prior over all parameters). The posterior over the top-layer biases will then be same as if we did joint posterior update over and a prior distributed according to the weak limit of the corresponding marginals of . ∎

b.1 Proving pointwise in a fully-connected network

Note: All of the references to (matthews2018gaussian) here are to the version accessible at https://arxiv.org/abs/1804.11271v2

The goal of this section is to adapt the original proof by matthews2018gaussian. We thus omit introduction of the notation as well as substantial discussion of the steps that do not require modifications. We also modify our notation to match that of matthews2018gaussian to make comparison easier. It is thus advisable to consult section 2 in (matthews2018gaussian) which introduces the general notation before reading on, and then referring to section 6 whenever necessary.

We can follow the same steps as matthews2018gaussian right until the application of the Cramér-Wold device and definition of projections and summands (matthews2018gaussian, p. 19-20). The application of eCLT (resp. its modified version (matthews2018gaussian, p. 22, lemma 10)) essentially reduces the problem of establishment of weak convergence of to that of proving of convergence of its first few finite-dimensional moments. Following matthews2018gaussian, we define the projections and summands as in their equations (25) to (27) which we restate here for convenience:

(8)
(9)
(10)

Here is the th post-nonlinearity in th layer of the th network evaluated at point , is the width of the same layer, identifies the finite marginal of the countably infinite vector under consideration, and is the Cramér-Wold projection vector.

Note that matthews2018gaussian define as the sum of the inner product of the relevant weight vector with and the bias term , which is why is subtracted in Equation 8. In contrast, we omitted subtraction of in the previous section to reduce the notational clutter. From now on, we stick with the notation of Equations 10, 9 and 8. Last point where our notation differs from matthews2018gaussian is in omitting the dependence of and on and (the original notation was and ).

Our goal is to prove convergence of the outputs given that a finite subset of is fixed to an arbitrary value. As in (matthews2018gaussian), we approach this problem by an inductive application of their lemma 10 to the projections combined with theorem 3.5 from (billingsey86). We will thus need to prove the sums defined in Equation 10 satisfy all the desired properties for any choice of , , and ; recall that corresponds to the pre-nonlinearities in the first layer and thus even for a single layer neural network, is the output. This is important because the input dimension is fixed and thus the distribution of need not be Gaussian for a given as we can trivially select bigger than the input dimension and thus control value of any finite subset of the pre-nonlinearities in the first layer. As you may suspect, the fact that we can only ever affect a finite subset of these activations will be crucial in the next paragraphs.

We turn to applying lemma 10 from (matthews2018gaussian) for . As the lemma applies only to exchangeable sequences, our first step will be to isolate the non-exchangeable terms. matthews2018gaussian prove exchangeability of the summands over the index in their lemma 8. The key observation here is that the same proof still works if we exclude all indices s.t.  with (where is the set of width indices in ), i.e., if we exclude all summands for which at least one weight is fixed through . Defining , we can rewrite Equation 10 as