1 Introduction
This paper considers the probabilistic inference problem. Given a known distribution and observing some specific value , one wishes to infer
. (E.g. predict its mean or variance.) Unless
is simple, there is no simple form for the posterior . Many approximate methods exist, including variants of MCMC, message-passing, Laplace’s method, and variational inference (VI). All these methods produce large errors on some problems. Thus, diagnostic techniques are of high interest to understand when a given inference method will perform well on a given problem.For MCMC, there are several widely-used diagnostics. The potential scale reduction factor diagnostic [Gelman_1992_InferenceIterativeSimulation] runs multiple chains, and then compares within-chain and between-chain variances. The expected sample size diagnostic considers correlations in a single chain. Diagnostics of this type are an active research area [Vehtari_2020_Ranknormalizationfoldinglocalization].
While successful, these diagnostics are grounded in the fact that MCMC is asymptotically exact. That is, under mild conditions MCMC will converge to the stationary distribution if run long enough. Informally, diagnostics for MCMC only need to diagnose “has the chain converged?”, rather than “has it converged to the correct distribution?”.
For inference methods that are asymptotically approximate, different diagnostics are needed. This paper is in the line of simulation-based diagnostics. These are a fairly radical departure. Rather than measuring how well inference performs on the given , these estimate how well inference performs on average over data generated from the model. These diagnostics repeatedly sample and then do inference on the simulated . The power of this approach is that the true latent corresponding to the observed is known.
To the best of our knowledge, this simulation-based approach was first pursued by Cook_2006_ValidationSoftwareBayesian, who sample and then perform inference to approximately sample
The quantiles of each component
generated this way are compared to those generated directly from the prior This can be done visually (looking at histograms), or by using a Kolmogorov-Smirnov test. More recently, Yao_2018_YesDidIt suggest testing for symmetry. These error measures may not be appropriate for all situations. First, these measures can be challenging to automate or interpret, since they do not provide a scalar quantity but rather a procedure to perform in each dimension. Second, there could be inference errors not detected by looking at univariate distributions.In this paper, we observe that some inference methods, such as Laplace’s method and variational inference (VI), do not simply give a set of samples, but an approximate distribution . This turns out to enable diagnostics that would be impossible with MCMC.
Our central idea is simple. Suppose that on input , inference returns a distribution
Define the joint distribution
Then, our diagnostic is an estimate of the symmetric KL-divergence between and . This essentially measures how far is from on average, over simulated datasets.The key observation is that the symmetric divergence induces cancellations induces cancellations between among the unknown normalization terms. Specifically, if and then we can simulate
and the expected value of is the symmetric divergence. To compute this diagnostic, one must: (1) simulate and and (2) compute and . We do not need to be able to evaluate even though it is part of the definitions of and .
We also show that this idea can be extended to situations with conditional or hidden variables. As an example of the latter, we show that it can be used with importance-weighted inference methods that generate many samples and select then one according to the importance weights [Burda_2015_ImportanceWeightedAutoencoders]. Experiments show that the diagnostic gives practical measures of performance, for regular VI, for Laplace’s method, and for importance-weighted variants of both.
1.1 Notation
The KL-divergence is
. Sans-serif font marks random variables. This disambiguates conflicting conventions in machine learning and information theory.
is a divergence over for a fixed . Meanwhile, is the conditional divergence [Cover_2006_Elementsinformationtheory], with an expectation over both and . In all cases, symmetric divergences are defined as .2 A New Simulation-Based Diagnostic
This section gives a novel simulation-based diagnostic based on the symmetric KL-divergence. The key idea is that some inference methods (e.g. VI or Laplace’s method) do not just give approximate samples, but an approximate distribution that can be evaluated at any point. This enables certain diagnostics that would be impossible with just a set of samples. Again, let be the target. We consider approximate inference methods that input some and produce a distribution over that approximates We denote that approximation as
One might hope to use the KL divergence as a diagnostic. This is almost never tractable since is unknown. One can instead compute the evidence lower bound (ELBO) which is equal to the KL-divergence plus The ELBO precisely measures the relative error for different algorithms, but gives little information about the absolute error, since is unknown.
Instead, our diagnostic is based on the symmetric KL-divergence. The basic idea of the diagnostic is to define a joint distribution with the same distribution over as Then, cancellations make it possible to estimate the joint symmetric divergence between and . This is formalized in the following result.
Theorem 1.
Given and define . Then,
where is sampled from the model distribution and is sampled from the approximating distribution.
Pseudocode for how one would use this result is given in Alg. 1. One attractive aspect is that the output is the mean of a set of
independent quantities. This makes it is easy to produce uncertainty measures such as confidence intervals. These bound how far the estimated diagnostic may be from the true symmetric divergence.
2.1 Inference in Conditional Models
Many inference problems are conditional, meaning one is given a model , with no distribution specified over . After observing and , the goal is to predict . For example, in regression or classification problems, would represent the input features, the output values/labels, and the latent parameters.
In these cases, inference takes as input a pair and produces a distribution approximating . It’s easy to see that the following generalization of Thm. 1 holds. This is given by taking Thm. 1, substituting for and then conditioning all distributions on the fixed value .
Corollary 2.
Given and define . Then,
where is sampled from the model distribution and is sampled from the approximating distribution.
Pseudocode for how the diagnostic would be used with conditional models is given as Alg. 2. It’s critical that is not a random variable– it is the actual observed input data. The simulated latent variables and datasets are conditioned on .
In order to run use this algorithm, one must be able to perform the following operations: (1) Simulate and . (2) Compute for a given , , and . (3) Compute for a given , , and . It is not necessary to be able to evaluate despite the fact that it is part of the definition of .
Example Results. Before moving on to more complex cases, we give some examples of the use of this diagnostic. Fig. 1 shows an example of running the diagnostic on five example models using variational inference (VI) and Laplace’s method that maximizes to get and uses a Gaussian centered at with a covariance matching the Hessian of . We also compare to an “adjusted” Laplace’s method that better matches the curvature if is only an approximate maxima. This instead uses a mean of where is that gradient of at A more full description of the models and inference algorithms is given in Sec. 5.
3 Inference with Augmented Variables
Many approximate inference methods used the idea of augmentation. The idea is to create an extra variable and then approximate with . Why would this be useful? The basic reason is that many powerful approximating families are obtained by integrating out other random variables. Such families often do not have tractable densities , but can be represented as the marginal of some density If we choose in a way that is “easy” for to match, then augmented inference might be nearly as accurate as directly approximating with Agakov_2004_AuxiliaryVariationalMethod introduced the idea of auxiliary variational inference, which fits this form.
To apply the diagnostic to inference with hidden variables, we need another version of Thm. 1. This can be proven by taking Thm. 1 and substituting for .
Corollary 3.
Given and define Then
where is sampled from the model distribution, is sampled from the augmenting distribution, and is sampled from the approximating distribution.
The diagnostic is useful because (by the chain rule of KL-divergence),
i.e. the diagnostic yields an upper bound on the error of
The corresponding algorithm is given as Alg. 3. This shows an important computational constraint. We assume that is given as a “black box”. Thus, the augmented must be chosen so that it is tractable to simulate The next section will give a special case of this result for a certain class of approximate augmented inference methods.
For :
-
Simulate .
-
Infer to get . (fix )
-
Simulate
-
Simulate . (fix )
-
Use
For :
-
Simulate .
-
Run inference to find a base distribution .
-
Simulate
-
Simulate . (fix )
-
Use
4 Importance Sampling
Self-normalized importance sampling is a classic Monte-Carlo method [Owen_2013_MonteCarlotheory]. Given any distribution , one can approximately sample from the posterior by drawing a set of samples
, selecting one with probability proportional to the importance weights
and then returning the final sample Call the resulting densityOne might hope to directly apply the diagnostic to self-normalizing importance sampling. However, this cannot be done because it is intractable to to evaluate However, we can identify augmented distributions, and thereby upper-bound the symmetric divergence. Define the distribution
(1) |
This can be seen as an augmented distribution with the hidden variables augmenting the original Define also
(2) |
It is not immediately obvious that this augments the self-normalized importance sampling density introduced at the beginning of this section (or indeed that this is a valid density at all). However, it was recently shown [Domke_2018_ImportanceWeightingVariational] that the following algorithm samples from
Claim 4.
The following process yields a sample from as defined in Eq. 2.
-
Draw .
-
Choose with
-
Set
Informally, this algorithm draws samples from then swaps one to be in the first position, chosen according to importance weights. Thus, this is a valid augmentation of the self-normalized importance sampling distribution.
This shows that Eq. 1 and Eq. 2 augment the target and self-normalized importance-sampling density, and thus that upper-bounds It can be shown that the (non-symmetric) divergence from to asymptotically decreases at a rate [Maddison_2017_FilteringVariationalObjectives, Domke_2018_ImportanceWeightingVariational].
4.1 Diagnosing Importance Weighted Inference
The above discussion shows that adding importance-sampling can improve any distribution Even better results can be obtained by explicitly optimizing to work well after augmentation. Importance-weighted variational inference directly performs an optimization to maximize the ELBO from to equivalent to minimizing the KL-divergence between the augmented distributions. This ELBO can be simplified using the relationship that
(3) |
which follows from cancellations between and , followed from the observation that the argument of the expectation is constant with respect to permutations of This objective was originally introduced in the context of importance-weighted auto-encoders [Burda_2015_ImportanceWeightedAutoencoders, Cremer_2017_ReinterpretingImportanceWeightedAutoencoders, Domke_2018_ImportanceWeightingVariational, Naesseth_2018_VariationalSequentialMonte, Le_2018_AutoEncodingSequentialMonte] (without explicitly identifying and ) and subsequently studied by various others [Cremer_2017_ReinterpretingImportanceWeightedAutoencoders, Domke_2018_ImportanceWeightingVariational, Bachman_2015_TrainingDeepGenerative, Naesseth_2018_VariationalSequentialMonte, Le_2018_AutoEncodingSequentialMonte].
The following result specializes Cor. 3 to the case of importance-weighted inference.
Corollary 5.
where is sampled from the model distribution, is sampled from the approximating distribution and are also sampled from the approximating distribution.
A proof is in the supplement. Unlike Cor. 2 and Cor. 3, the result is not trivial. The main idea is to substitute for and for Then, many expressions can be simplified based on the particular forms of and . Finally, we can observe that the argument of the expectation is independent of permutations of . This allows a final simplification.
5 Experiments
Models. We use five models, described in detail in Sec. Sec. 7 (Supplement). glm_binomial is a hierarchical model of the number of a bird population over time. heart_transplants models the survival times of patients after surgery. hospitals measures the number of deaths in different hospitals. ionosphere
is a Bayesian logistic regression model on a classic dataset.
concreteis a Bayesian linear regression model – included as a baseline because the exact posterior is Gaussian.
Optimization. The first inference algorithm we consider is Laplace’s method which produces a multivariate Gaussian approximation . For this method, we run Adam with a step-size of 0.01 for the first half of iterations, and 0.001 for the second half. The resulting point is the mean. Finally, the Hessian of at is used to estimate the covariance of as This method fails to match the local curvature of when is far from an optima. We also consider an “adjusted” Laplace’s method that instead uses a mean of where is that gradient of at . This guarantees that has the same gradient at as .
We also consider variational inference. We initialize to a standard Gaussian and optimize with Adam with a step size of 0.001 for the first half of iterations, and 0.0001 for the second half. We estimate the gradient of the ELBO using the reparameterization trick, using the “sticking the landing” estimator [Roeder_2017_StickingLandingSimple].
Constraints and Transformations. Often, random variables have constraints – they are not supported over the reals. As is common [Kucukelbir_2017_AutomaticDifferentiationVariational] we deal with this through a process of transformations. For our models, it is sufficient to consider two cases:
-
Random variables that are either defined over the non-negative reals . In this case, we replace with a new random variable that is unconstrained.
-
Random variables defined on a closed interval Here, we transform to
Results. Results comparing VI and the two variants of Laplace’s method are shown in Fig. 1, averaging over simulated datasets. Laplace’s method is reasonably accurate in many cases, but usually has a “floor” of accuracy it does not exceed. The adjustment to Laplace’s method is often helpful and never harmful. VI performs better with many iterations. For these models, the diagnostic shows that inference error is reasonably low with many iterations, but not quite “exact”.
Fig. 2 shows the results of importance sampling with a proposal computed using Laplace’s method with adjustment. Using more samples yields a clear improvement.
6 Discussion
This paper proposed a new diagnostic for approximate inference methods. This is a simulation-based diagnostic, meaning it is computed by repeatedly simulating latent variables along with datasets and running inference on each dataset. The central idea is that cancellations in unknown constants make it possible to estimate a symmetric divergence. This is notable in being a simple, scalar quantity with a clear information-theoretic interpretation. It can also be computed in a fully automated way along with error measures like confidence intervals. We showed that the diagnostic can be extended to augmented inference methods, in particular importance-weighted inference. Empirically, the method gives reasonable diagnostic information on several test models.
While Thm. 1 is quite simple, there are numerous points worth clarifying in its use as a diagnostic:
What the diagnostic measures. One possibly counter-intuitive aspect of this diagnostic (like all simulation-based diagnostics) is that it does not use the actual observed data . Rather, it measures the typical error, averaged over simulated from the model . It is this essential that the prior and likelihood be selected so that yields realistic simulated datasets. In particular, very broad priors are might lead to “nonsense” observations that are unrepresentative of the data that would be seen in practice.
Computational considerations. In order to compute this diagnostic, one must be able to perform several operations: (1) Simulate and . (2) Compute for a given and . (3) Compute for a given and . Crucially, it is not necessary to be able to evaluate . This is true despite the fact that is part of the definition of .
Other representations of the diagnostic. The quantity that the diagnostic is representing can be written in a different form that emphasizes that it measures errors over . This uses the notion of a conditional divergence It is not hard to show that in the setting of Thm. 1 that
This is true because of cancellations between and due to the inclusion of the term in
Use with randomized inference methods. In practice, approximate inference algorithms are often non-deterministic. This is not reflected by the notation . With non-determinism, no modification to the diagnostic technique in Alg. 1 is needed. Only the interpretation is slightly different. To formalize what the diagnostic measures in this case, define to be the approximate posterior produced where are the random numbers underlying the algorithm. Then, a diagnostic can be defined as the expected divergence between and , i.e., where This is still a very reasonable measure of the accuracy of inference. For simplicity, our presentation mostly neglects the issue of non-determinism in approximating distributions.
6.1 Related Work
There are several lines of related work not mentioned so far. One recent line of work explores inference diagnostics based on Stein’s method [Gorham_2017_MeasuringSampleQuality, Gorham_2015_MeasuringSampleQuality]. The idea is to create sets of functions whose true expectations must be zero. Any deviation from zero in those functions indicates inference failure. This diagnostic can be used with methods that are asymptotically approximate. However, it is intended for cases when error decreases to zero. There is no claim that the magnitude of the diagnostic is a good measure of the usefulness of an approximate posterior.
The test proposed by Geweke_2004_GettingItRight is an interesting early diagnostic that repeatedly simulates datasets in a non-independent manner. The idea is to iteratively sample , the run inference to produce and then sample Then, one compares the expectation of some function to those on exact samples If is exact, these expectations should match. We prefer an approach where each simulation is independent since this is easier to parallelize, avoids correlations between simulations, and makes it easier to compute error measures like confidence intervals.
Bidirectional MCMC [Grosse_2015_Sandwichingmarginallikelihood] runs MCMC on repeated simulated datasets to get upper and lower bounds on the marginal likelihood This is intended as a technique to evaluate the quality of a model, not as a diagnostic for inference. Still, in principle one could use these to transform an ELBO into bounds on the KL-divergence. One drawback is the expense of repeatedly running MCMC. Typically, variational inference is used in settings where MCMC would be too expensive.
6.2 Limitations and Future Work
This work has several limitations shared with all simulation-based diagnostics: First, computing them requires repeating inference numerous times. This comes with an associated cost. Second, these methods can be overly pessimistic when used with extremely broad or uninformative priors. It is important that the model is chosen so that simulated data are representative of the datasets one cares about. Third, the diagnostic measures average accuracy over data simulated the prior, as opposed to the expected accuracy for a particular dataset. (Put another way, the diagnostic is arguably frequentist rather than Bayesian.)
One might be concerned about the success of this diagnostic when used with variational inference methods. Namely, VI typically minimizes while the diagnostic is based on the symmetric divergence. Informally, VI cares about finding a distribution that is close in a “mode finding” divergence, while the diagnostic measures both “mode finding” and “mode spanning”. It is possible that a distribution could be close in VI’s objective, yet yield a high diagnostic value. This is arguably a flaw not of the diagnostic, but of variational inference. One interesting future direction would be to investigate recent VI variants that try to minimize other divergences [Li__RenyiDivergenceVariational, Dieng_2017_VariationalInferencechi].
In future work, it would be interesting to address MCMC methods. Of course, most MCMC methods are not suitable for this framework. However, some methods like annealed importance sampling [Neal_1998_AnnealedImportanceSampling] formally create augmented target and proposal densities at a variety of “temperatures”. It may be possible to use the diagnostic proposed here to measure the symmetric divergence between these augmented distributions. This could potentially offer a diagnostic for MCMC with the unusual property that the diagnostic going to zero is both necessary and sufficient to guarantee convergence to the stationary distribution.
References
7 Models
GLM Binomial. This is a model of the number of peregrine pairs in the French Jura in year The data is from between and but is scaled to between and . The model and data are from Kery_2012_Bayesianpopulationanalysis.
Heart Transplants. This is a model of a hypothetical population of patients who underwent a surgery, of whom survived [Lunn_2013_BUGSbookpractical, Ex. 3.5.1]. These were tracked to see the number of years () that the
-th patient who survived surgery lived post-surgery. This is assumed to be determined by an exponential distribution with parameter
. Thus, the model is:Note that, when generating synthetic datasets for this model, we always use the same set of variables , independent of the value of . This is done because of the difficulties posed by having different dimensionality in different realizations of the posterior. While not fully in keeping with the spirit of the original model, this still defines a perfectly valid probabilistic model and test of the diagnostic.
Hospitals. This is a hierarchical model of the mortality rate of English hospitals performing heart surgery [Lunn_2013_BUGSbookpractical, Ex. 10.1.1]. The data is where is the number of operations in hospital and
is the corresponding number of deaths. The logit of the true mortality rate
of hospital is a Gaussian with unknown mean . The latent variables are , , and .The original model has a very wide prior on and , which leads to the problems discussed in LABEL:subsec:new-diagnostic-discussion. We use the above model with more modest priors.
Ionosphere. This is a classic dataset for binary classification. We model it as a Bayesian logistic regression problem with a standard Gaussian prior over the weights
Concrete. This is a well-known dataset for linear regression. We model it as a Bayesian linear regression problem with a standard Gaussian prior over the weights This model is particularly notable because the true posterior is exactly Gaussian. Since both Laplace’s method and variational inference can exactly represent such a posterior, this provides an important test if the diagnostic can correctly recognize inference success when it occurs.
8 Theory
See Thm. 1
Proof.
The divergence is equal to
In the first line we use the fact that while in the second line we pull out a factor of from each term, which cancel. The claimed result is the same as the last line with a sign change. ∎
See Cor. 5
Proof.
Now, note that
This leaves us with the result of
Now, finally, note that the expectation is unchanged under permutations of the order of . Thus, the expectation is unchanged if we replace the distribution with
∎
Comments
There are no comments yet.