An Easy to Interpret Diagnostic for Approximate Inference: Symmetric Divergence Over Simulations

02/25/2021
by   Justin Domke, et al.
0

It is important to estimate the errors of probabilistic inference algorithms. Existing diagnostics for Markov chain Monte Carlo methods assume inference is asymptotically exact, and are not appropriate for approximate methods like variational inference or Laplace's method. This paper introduces a diagnostic based on repeatedly simulating datasets from the prior and performing inference on each. The central observation is that it is possible to estimate a symmetric KL-divergence defined over these simulations.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

05/19/2017

AIDE: An algorithm for measuring the accuracy of probabilistic inference algorithms

Approximate probabilistic inference algorithms are central to many field...
05/10/2019

A Contrastive Divergence for Combining Variational Inference and MCMC

We develop a method to combine Markov chain Monte Carlo (MCMC) and varia...
12/28/2018

Divergence Triangle for Joint Training of Generator Model, Energy-based Model, and Inference Model

This paper proposes the divergence triangle as a framework for joint tra...
09/10/2021

Diagnostics for Monte Carlo Algorithms for Models with Intractable Normalizing Functions

Models with intractable normalizing functions have numerous applications...
05/02/2018

Alpha-Beta Divergence For Variational Inference

This paper introduces a variational approximation framework using direct...
03/05/2022

Recursive Monte Carlo and Variational Inference with Auxiliary Variables

A key challenge in applying Monte Carlo and variational inference (VI) i...
05/31/2016

Quantifying the probable approximation error of probabilistic inference programs

This paper introduces a new technique for quantifying the approximation ...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

This paper considers the probabilistic inference problem. Given a known distribution and observing some specific value , one wishes to infer

. (E.g. predict its mean or variance.) Unless

is simple, there is no simple form for the posterior . Many approximate methods exist, including variants of MCMC, message-passing, Laplace’s method, and variational inference (VI). All these methods produce large errors on some problems. Thus, diagnostic techniques are of high interest to understand when a given inference method will perform well on a given problem.

For MCMC, there are several widely-used diagnostics. The potential scale reduction factor diagnostic [Gelman_1992_InferenceIterativeSimulation] runs multiple chains, and then compares within-chain and between-chain variances. The expected sample size diagnostic considers correlations in a single chain. Diagnostics of this type are an active research area [Vehtari_2020_Ranknormalizationfoldinglocalization].

While successful, these diagnostics are grounded in the fact that MCMC is asymptotically exact. That is, under mild conditions MCMC will converge to the stationary distribution if run long enough. Informally, diagnostics for MCMC only need to diagnose “has the chain converged?”, rather than “has it converged to the correct distribution?”.

For inference methods that are asymptotically approximate, different diagnostics are needed. This paper is in the line of simulation-based diagnostics. These are a fairly radical departure. Rather than measuring how well inference performs on the given , these estimate how well inference performs on average over data generated from the model. These diagnostics repeatedly sample and then do inference on the simulated . The power of this approach is that the true latent corresponding to the observed is known.

To the best of our knowledge, this simulation-based approach was first pursued by Cook_2006_ValidationSoftwareBayesian, who sample and then perform inference to approximately sample

The quantiles of each component

generated this way are compared to those generated directly from the prior This can be done visually (looking at histograms), or by using a Kolmogorov-Smirnov test. More recently, Yao_2018_YesDidIt suggest testing for symmetry. These error measures may not be appropriate for all situations. First, these measures can be challenging to automate or interpret, since they do not provide a scalar quantity but rather a procedure to perform in each dimension. Second, there could be inference errors not detected by looking at univariate distributions.

In this paper, we observe that some inference methods, such as Laplace’s method and variational inference (VI), do not simply give a set of samples, but an approximate distribution . This turns out to enable diagnostics that would be impossible with MCMC.

Our central idea is simple. Suppose that on input , inference returns a distribution

Define the joint distribution

Then, our diagnostic is an estimate of the symmetric KL-divergence between and . This essentially measures how far is from on average, over simulated datasets.

The key observation is that the symmetric divergence induces cancellations induces cancellations between among the unknown normalization terms. Specifically, if and then we can simulate

and the expected value of is the symmetric divergence. To compute this diagnostic, one must: (1) simulate and and (2) compute and . We do not need to be able to evaluate even though it is part of the definitions of and .

We also show that this idea can be extended to situations with conditional or hidden variables. As an example of the latter, we show that it can be used with importance-weighted inference methods that generate many samples and select then one according to the importance weights [Burda_2015_ImportanceWeightedAutoencoders]. Experiments show that the diagnostic gives practical measures of performance, for regular VI, for Laplace’s method, and for importance-weighted variants of both.

1.1 Notation

The KL-divergence is

. Sans-serif font marks random variables. This disambiguates conflicting conventions in machine learning and information theory.

is a divergence over for a fixed . Meanwhile, is the conditional divergence [Cover_2006_Elementsinformationtheory], with an expectation over both and . In all cases, symmetric divergences are defined as .

2 A New Simulation-Based Diagnostic

This section gives a novel simulation-based diagnostic based on the symmetric KL-divergence. The key idea is that some inference methods (e.g. VI or Laplace’s method) do not just give approximate samples, but an approximate distribution that can be evaluated at any point. This enables certain diagnostics that would be impossible with just a set of samples. Again, let be the target. We consider approximate inference methods that input some and produce a distribution over that approximates We denote that approximation as

One might hope to use the KL divergence as a diagnostic. This is almost never tractable since is unknown. One can instead compute the evidence lower bound (ELBO) which is equal to the KL-divergence plus The ELBO precisely measures the relative error for different algorithms, but gives little information about the absolute error, since is unknown.

Instead, our diagnostic is based on the symmetric KL-divergence. The basic idea of the diagnostic is to define a joint distribution with the same distribution over as Then, cancellations make it possible to estimate the joint symmetric divergence between and . This is formalized in the following result.

Theorem 1.

Given and define . Then,

where is sampled from the model distribution and is sampled from the approximating distribution.

For Simulate . Infer to get . (fix ) Simulate . (fix ) Use Algorithm 1 Computing the proposed diagnostic. Input . For Simulate . Infer to get . (fix ) Simulate . ( fixed) Use Algorithm 2 Diagnostic for conditional models.

Pseudocode for how one would use this result is given in Alg. 1. One attractive aspect is that the output is the mean of a set of

independent quantities. This makes it is easy to produce uncertainty measures such as confidence intervals. These bound how far the estimated diagnostic may be from the true symmetric divergence.

2.1 Inference in Conditional Models

Many inference problems are conditional, meaning one is given a model , with no distribution specified over . After observing and , the goal is to predict . For example, in regression or classification problems, would represent the input features, the output values/labels, and the latent parameters.

In these cases, inference takes as input a pair and produces a distribution approximating . It’s easy to see that the following generalization of Thm. 1 holds. This is given by taking Thm. 1, substituting for and then conditioning all distributions on the fixed value .

Corollary 2.

Given and define . Then,

where is sampled from the model distribution and is sampled from the approximating distribution.

Pseudocode for how the diagnostic would be used with conditional models is given as Alg. 2. It’s critical that is not a random variable– it is the actual observed input data. The simulated latent variables and datasets are conditioned on .

In order to run use this algorithm, one must be able to perform the following operations: (1) Simulate and . (2) Compute for a given , , and . (3) Compute for a given , , and . It is not necessary to be able to evaluate despite the fact that it is part of the definition of .

Example Results. Before moving on to more complex cases, we give some examples of the use of this diagnostic. Fig. 1 shows an example of running the diagnostic on five example models using variational inference (VI) and Laplace’s method that maximizes to get and uses a Gaussian centered at with a covariance matching the Hessian of . We also compare to an “adjusted” Laplace’s method that better matches the curvature if is only an approximate maxima. This instead uses a mean of where is that gradient of at A more full description of the models and inference algorithms is given in Sec. 5.

Figure 1: The diagnostic gives plausible measures of the value of different inference algorithms and of optimizing for different numbers of iterations. The diagnostic computed with repetitions. Lines show the mean while the colored areas show 95% confidence intervals. Confidence intervals are computed before the log-transform and therefore appear large for the lower-bounds seen on concrete. Laplace’s method fails due to numerical problems with few iterations on hospitals. Adjusted Laplace’s method is exact for concrete.

3 Inference with Augmented Variables

Many approximate inference methods used the idea of augmentation. The idea is to create an extra variable and then approximate with . Why would this be useful? The basic reason is that many powerful approximating families are obtained by integrating out other random variables. Such families often do not have tractable densities , but can be represented as the marginal of some density If we choose in a way that is “easy” for to match, then augmented inference might be nearly as accurate as directly approximating with Agakov_2004_AuxiliaryVariationalMethod introduced the idea of auxiliary variational inference, which fits this form.

To apply the diagnostic to inference with hidden variables, we need another version of Thm. 1. This can be proven by taking Thm. 1 and substituting for .

Corollary 3.

Given and define Then

where is sampled from the model distribution, is sampled from the augmenting distribution, and is sampled from the approximating distribution.

The diagnostic is useful because (by the chain rule of KL-divergence),

i.e. the diagnostic yields an upper bound on the error of

The corresponding algorithm is given as Alg. 3. This shows an important computational constraint. We assume that is given as a “black box”. Thus, the augmented must be chosen so that it is tractable to simulate The next section will give a special case of this result for a certain class of approximate augmented inference methods.

For :

  • Simulate .

  • Infer to get . (fix )

  • Simulate

  • Simulate . (fix )

Use

Algorithm 3 Diagnostic with augmentation.

For :

  • Simulate .

  • Run inference to find a base distribution .

  • Simulate

  • Simulate . (fix )

Use

Algorithm 4 With importance-weighting.

4 Importance Sampling

Self-normalized importance sampling is a classic Monte-Carlo method [Owen_2013_MonteCarlotheory]. Given any distribution , one can approximately sample from the posterior by drawing a set of samples

, selecting one with probability proportional to the importance weights

and then returning the final sample Call the resulting density

One might hope to directly apply the diagnostic to self-normalizing importance sampling. However, this cannot be done because it is intractable to to evaluate However, we can identify augmented distributions, and thereby upper-bound the symmetric divergence. Define the distribution

(1)

This can be seen as an augmented distribution with the hidden variables augmenting the original Define also

(2)

It is not immediately obvious that this augments the self-normalized importance sampling density introduced at the beginning of this section (or indeed that this is a valid density at all). However, it was recently shown [Domke_2018_ImportanceWeightingVariational] that the following algorithm samples from

Claim 4.

The following process yields a sample from as defined in Eq. 2.

  1. Draw .

  2. Choose with

  3. Set

Informally, this algorithm draws samples from then swaps one to be in the first position, chosen according to importance weights. Thus, this is a valid augmentation of the self-normalized importance sampling distribution.

This shows that Eq. 1 and Eq. 2 augment the target and self-normalized importance-sampling density, and thus that upper-bounds It can be shown that the (non-symmetric) divergence from to asymptotically decreases at a rate [Maddison_2017_FilteringVariationalObjectives, Domke_2018_ImportanceWeightingVariational].

Figure 2: The diagnostic applied to self-normalized importance-weighting with samples. The proposal distribution is estimated by Laplace’s method with adjustment. The curve is the same as the “laplace (adjusted)” curve in Fig. 1.

4.1 Diagnosing Importance Weighted Inference

The above discussion shows that adding importance-sampling can improve any distribution Even better results can be obtained by explicitly optimizing to work well after augmentation. Importance-weighted variational inference directly performs an optimization to maximize the ELBO from to equivalent to minimizing the KL-divergence between the augmented distributions. This ELBO can be simplified using the relationship that

(3)

which follows from cancellations between and , followed from the observation that the argument of the expectation is constant with respect to permutations of This objective was originally introduced in the context of importance-weighted auto-encoders [Burda_2015_ImportanceWeightedAutoencoders, Cremer_2017_ReinterpretingImportanceWeightedAutoencoders, Domke_2018_ImportanceWeightingVariational, Naesseth_2018_VariationalSequentialMonte, Le_2018_AutoEncodingSequentialMonte] (without explicitly identifying and ) and subsequently studied by various others [Cremer_2017_ReinterpretingImportanceWeightedAutoencoders, Domke_2018_ImportanceWeightingVariational, Bachman_2015_TrainingDeepGenerative, Naesseth_2018_VariationalSequentialMonte, Le_2018_AutoEncodingSequentialMonte].

The following result specializes Cor. 3 to the case of importance-weighted inference.

Corollary 5.

Given and define and as in Eq. 1 and Eq. 2. Further, set Then

where is sampled from the model distribution, is sampled from the approximating distribution and are also sampled from the approximating distribution.

A proof is in the supplement. Unlike Cor. 2 and Cor. 3, the result is not trivial. The main idea is to substitute for and for Then, many expressions can be simplified based on the particular forms of and . Finally, we can observe that the argument of the expectation is independent of permutations of . This allows a final simplification.

5 Experiments

Figure 3: The diagnostic as computed using Alg. 4 where parameters are determined by optimizing Eq. 3 with the same value . The curve is the same as the “vi” curve in Fig. 1.

Models. We use five models, described in detail in Sec. Sec. 7 (Supplement). glm_binomial is a hierarchical model of the number of a bird population over time. heart_transplants models the survival times of patients after surgery. hospitals measures the number of deaths in different hospitals. ionosphere

is a Bayesian logistic regression model on a classic dataset.

concrete

is a Bayesian linear regression model – included as a baseline because the exact posterior is Gaussian.

Optimization. The first inference algorithm we consider is Laplace’s method which produces a multivariate Gaussian approximation . For this method, we run Adam with a step-size of 0.01 for the first half of iterations, and 0.001 for the second half. The resulting point is the mean. Finally, the Hessian of at is used to estimate the covariance of as This method fails to match the local curvature of when is far from an optima. We also consider an “adjusted” Laplace’s method that instead uses a mean of where is that gradient of at . This guarantees that has the same gradient at as .

We also consider variational inference. We initialize to a standard Gaussian and optimize with Adam with a step size of 0.001 for the first half of iterations, and 0.0001 for the second half. We estimate the gradient of the ELBO using the reparameterization trick, using the “sticking the landing” estimator [Roeder_2017_StickingLandingSimple].

Constraints and Transformations. Often, random variables have constraints – they are not supported over the reals. As is common [Kucukelbir_2017_AutomaticDifferentiationVariational] we deal with this through a process of transformations. For our models, it is sufficient to consider two cases:

  • Random variables that are either defined over the non-negative reals . In this case, we replace with a new random variable that is unconstrained.

  • Random variables defined on a closed interval Here, we transform to

Results. Results comparing VI and the two variants of Laplace’s method are shown in Fig. 1, averaging over simulated datasets. Laplace’s method is reasonably accurate in many cases, but usually has a “floor” of accuracy it does not exceed. The adjustment to Laplace’s method is often helpful and never harmful. VI performs better with many iterations. For these models, the diagnostic shows that inference error is reasonably low with many iterations, but not quite “exact”.

Fig. 2 shows the results of importance sampling with a proposal computed using Laplace’s method with adjustment. Using more samples yields a clear improvement.

Finally, results with full importance weighted variational inference (optimizing the Eq. 3 rather than the standard ELBO) is shown in Fig. 3. The same value is used during optimization and at test time.

6 Discussion

This paper proposed a new diagnostic for approximate inference methods. This is a simulation-based diagnostic, meaning it is computed by repeatedly simulating latent variables along with datasets and running inference on each dataset. The central idea is that cancellations in unknown constants make it possible to estimate a symmetric divergence. This is notable in being a simple, scalar quantity with a clear information-theoretic interpretation. It can also be computed in a fully automated way along with error measures like confidence intervals. We showed that the diagnostic can be extended to augmented inference methods, in particular importance-weighted inference. Empirically, the method gives reasonable diagnostic information on several test models.

While Thm. 1 is quite simple, there are numerous points worth clarifying in its use as a diagnostic:

What the diagnostic measures. One possibly counter-intuitive aspect of this diagnostic (like all simulation-based diagnostics) is that it does not use the actual observed data . Rather, it measures the typical error, averaged over simulated from the model . It is this essential that the prior and likelihood be selected so that yields realistic simulated datasets. In particular, very broad priors are might lead to “nonsense” observations that are unrepresentative of the data that would be seen in practice.

Computational considerations. In order to compute this diagnostic, one must be able to perform several operations: (1) Simulate and . (2) Compute for a given and . (3) Compute for a given and . Crucially, it is not necessary to be able to evaluate . This is true despite the fact that is part of the definition of .

Other representations of the diagnostic. The quantity that the diagnostic is representing can be written in a different form that emphasizes that it measures errors over . This uses the notion of a conditional divergence It is not hard to show that in the setting of Thm. 1 that

This is true because of cancellations between and due to the inclusion of the term in

Use with randomized inference methods. In practice, approximate inference algorithms are often non-deterministic. This is not reflected by the notation . With non-determinism, no modification to the diagnostic technique in Alg. 1 is needed. Only the interpretation is slightly different. To formalize what the diagnostic measures in this case, define to be the approximate posterior produced where are the random numbers underlying the algorithm. Then, a diagnostic can be defined as the expected divergence between and , i.e., where This is still a very reasonable measure of the accuracy of inference. For simplicity, our presentation mostly neglects the issue of non-determinism in approximating distributions.

6.1 Related Work

There are several lines of related work not mentioned so far. One recent line of work explores inference diagnostics based on Stein’s method [Gorham_2017_MeasuringSampleQuality, Gorham_2015_MeasuringSampleQuality]. The idea is to create sets of functions whose true expectations must be zero. Any deviation from zero in those functions indicates inference failure. This diagnostic can be used with methods that are asymptotically approximate. However, it is intended for cases when error decreases to zero. There is no claim that the magnitude of the diagnostic is a good measure of the usefulness of an approximate posterior.

The test proposed by Geweke_2004_GettingItRight is an interesting early diagnostic that repeatedly simulates datasets in a non-independent manner. The idea is to iteratively sample , the run inference to produce and then sample Then, one compares the expectation of some function to those on exact samples If is exact, these expectations should match. We prefer an approach where each simulation is independent since this is easier to parallelize, avoids correlations between simulations, and makes it easier to compute error measures like confidence intervals.

Bidirectional MCMC [Grosse_2015_Sandwichingmarginallikelihood] runs MCMC on repeated simulated datasets to get upper and lower bounds on the marginal likelihood This is intended as a technique to evaluate the quality of a model, not as a diagnostic for inference. Still, in principle one could use these to transform an ELBO into bounds on the KL-divergence. One drawback is the expense of repeatedly running MCMC. Typically, variational inference is used in settings where MCMC would be too expensive.

6.2 Limitations and Future Work

This work has several limitations shared with all simulation-based diagnostics: First, computing them requires repeating inference numerous times. This comes with an associated cost. Second, these methods can be overly pessimistic when used with extremely broad or uninformative priors. It is important that the model is chosen so that simulated data are representative of the datasets one cares about. Third, the diagnostic measures average accuracy over data simulated the prior, as opposed to the expected accuracy for a particular dataset. (Put another way, the diagnostic is arguably frequentist rather than Bayesian.)

One might be concerned about the success of this diagnostic when used with variational inference methods. Namely, VI typically minimizes while the diagnostic is based on the symmetric divergence. Informally, VI cares about finding a distribution that is close in a “mode finding” divergence, while the diagnostic measures both “mode finding” and “mode spanning”. It is possible that a distribution could be close in VI’s objective, yet yield a high diagnostic value. This is arguably a flaw not of the diagnostic, but of variational inference. One interesting future direction would be to investigate recent VI variants that try to minimize other divergences [Li__RenyiDivergenceVariational, Dieng_2017_VariationalInferencechi].

In future work, it would be interesting to address MCMC methods. Of course, most MCMC methods are not suitable for this framework. However, some methods like annealed importance sampling [Neal_1998_AnnealedImportanceSampling] formally create augmented target and proposal densities at a variety of “temperatures”. It may be possible to use the diagnostic proposed here to measure the symmetric divergence between these augmented distributions. This could potentially offer a diagnostic for MCMC with the unusual property that the diagnostic going to zero is both necessary and sufficient to guarantee convergence to the stationary distribution.

References

7 Models

GLM Binomial. This is a model of the number of peregrine pairs in the French Jura in year The data is from between and but is scaled to between and . The model and data are from Kery_2012_Bayesianpopulationanalysis.

Heart Transplants. This is a model of a hypothetical population of patients who underwent a surgery, of whom survived [Lunn_2013_BUGSbookpractical, Ex. 3.5.1]. These were tracked to see the number of years () that the

-th patient who survived surgery lived post-surgery. This is assumed to be determined by an exponential distribution with parameter

. Thus, the model is:

Note that, when generating synthetic datasets for this model, we always use the same set of variables , independent of the value of . This is done because of the difficulties posed by having different dimensionality in different realizations of the posterior. While not fully in keeping with the spirit of the original model, this still defines a perfectly valid probabilistic model and test of the diagnostic.

Hospitals. This is a hierarchical model of the mortality rate of English hospitals performing heart surgery [Lunn_2013_BUGSbookpractical, Ex. 10.1.1]. The data is where is the number of operations in hospital and

is the corresponding number of deaths. The logit of the true mortality rate

of hospital is a Gaussian with unknown mean

and standard deviation

. The latent variables are , , and .

The original model has a very wide prior on and , which leads to the problems discussed in LABEL:subsec:new-diagnostic-discussion. We use the above model with more modest priors.

Ionosphere. This is a classic dataset for binary classification. We model it as a Bayesian logistic regression problem with a standard Gaussian prior over the weights

Concrete. This is a well-known dataset for linear regression. We model it as a Bayesian linear regression problem with a standard Gaussian prior over the weights This model is particularly notable because the true posterior is exactly Gaussian. Since both Laplace’s method and variational inference can exactly represent such a posterior, this provides an important test if the diagnostic can correctly recognize inference success when it occurs.

8 Theory

See Thm. 1

Proof.

The divergence is equal to

In the first line we use the fact that while in the second line we pull out a factor of from each term, which cancel. The claimed result is the same as the last line with a sign change. ∎

See Cor. 5

Proof.

Start with the result of Cor. 3.

Now, make the following transformations

Then, we get

Now, note that

This leaves us with the result of

Now, finally, note that the expectation is unchanged under permutations of the order of . Thus, the expectation is unchanged if we replace the distribution with