On Statistical Optimality of Variational Bayes

12/25/2017 ∙ by Debdeep Pati, et al. ∙ Florida State University Texas A&M University 0

The article addresses a long-standing open problem on the justification of using variational Bayes methods for parameter estimation. We provide general conditions for obtaining optimal risk bounds for point estimates acquired from mean-field variational Bayesian inference. The conditions pertain to the existence of certain test functions for the distance metric on the parameter space and minimal assumptions on the prior. A general recipe for verification of the conditions is outlined which is broadly applicable to existing Bayesian models with or without latent variables. As illustrations, specific applications to Latent Dirichlet Allocation and Gaussian mixture models are discussed.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Variational inference [25, 7, 40]

is now a well-established tool to approximate intractable posterior distributions in hierarchical multi-layered Bayesian models. The traditional Markov chain Monte Carlo (MCMC;

[17]) approach of approximating distributions with intractable normalizing constants draws (correlated) samples according to a discrete-time Markov chain whose stationary distribution is the target distribution. Despite their success and popularity, MCMC methods can be slow to converge and lack scalability in big data problems and/or problems involving very many latent variables, which has fueled search for alternatives.

In contrast to the sampling approach of MCMC, variational inference approaches the problem from an optimization viewpoint. First, a class of analytically tractable distributions, referred to as the variational family, is identified for the problem at hand. For example, in mean-field approximation, the set of parameters and latent variables is divided into blocks and the variational distribution is assumed to be independent across blocks. The distribution in the variational family closest to the target distribution relative to the Kullback–Leibler (KL) divergence is then used as a proxy to the target. Implementation-wise, the above optimization is commonly solved using coordinate descent or alternating minimization. A comprehensive review of various aspects of variational inference can be found in the recent article by [9].

Variational inference has arguably found its most potent applications in latent variable models such as mixture models, hidden Markov models, graphical models, topic models, and neural networks; see

[29, 5, 40, 24, 15, 38, 8, 10, 21] for a flavor of this enormous literature. Due to the fast convergence properties of the variational objective, variational inference algorithms are typically orders of magnitude faster in big data problems compared to MCMC approaches [27, 26, 1]. However, in spite of their tremendous empirical success, a general statistical theory qualifying the statistical properties of a variational solution is lacking. Existing results operate in a case-by-case manner, often directly analyzing the specific iterative algorithm to characterize properties of its limit; see Section 5.2 of [9] for a comprehensive review. These analyses typically require sufficient tractability of the successive steps of the iterative algorithm, and can be difficult to adapt to minor changes in the prior. More recently [3, 43, 2] modified the objective function to introduce an inverse-temperature parameter, and obtained general guarantees for the variational solution under this modified objective function.

In this article, we aim to address the general question as to whether point estimates obtained from usual variational approximations share the same statistical accuracy as those from the actual posterior. We clarify at the onset that we operate in a frequentist setting assuming the existence of a true data generating parameter. Our novel contribution is to relate the Bayes risk relative to the variational solution for a general distance metric (defined on the parameter space) to (i) the existence of certain test functions for testing the true parameter against complements of its neighborhood under the error metric, and (ii) the size of the variational objective function. As an important consequence of such Bayes risk bounds, risk bounds for variational point estimates can be readily derived when the distance metric is convex. If the risk of the variational point estimate coincides with the contraction rate of the exact posterior at the true parameter, it can be argued that there is no loss of statistical efficiency, at least asymptotically, in using the variational approximation.

We identify a number of popularly used models where the conditions can be satisfied and the variational point estimate attains the minimax rate. Since variational Bayes is primarily used for point estimation, our theory suggests that variational Bayes successfully achieves its desiderata. As vignettes, we present two non-trivial examples in the form of density estimation in Latent Dirichlet Allocation (LDA) for topic modeling, and estimating component specific parameters in Gaussian mixture models.

2 Background

In this section, we introduce notations and offer some background to set up our theoretical results. Let and

denote the Hellinger distance and Kullback–Leibler divergence, respectively, between two probability density functions

and relative to a a common dominating measure . Also denote . For a set , we use

to denote its indicator function. For any vector

and positive semidefinite matrix , we use

to denote the normal distribution with mean

and covariance matrix , and use to denote its pdf at . We use w.h.p. to abbreviate “with high probability”, when the probability is evident from the context. Throughout, denotes a constant independent of everything else whose value may change from one line to the other. We write to denote for some constant . Similarly, .

Suppose we have observations and a probabilistic model

for the joint distribution of the data

, with a density relative to the Lebesgue measure. Here, is the unknown parameter to be estimated from the data which lives in some parameter space . Our formulation does not require the to be identically distributed or even independent. In a Bayesian setup, uncertainty regarding the parameter is quantified through a prior distribution on

, which upon observing the data is updated to the posterior distribution using Bayes’ theorem:


for any measurable subset and .

In a wide variety of practical problems, the likelihood function may be intractable or difficult to analyze directly. For example, in a 2-component Gaussian mixture model, is a combinatorial sum over terms. A widely used trick in such situations is to introduce latent variables to simplify the conditional likelihood. Specifically, assume one can decompose


where denotes a collection of discrete latent variables, with the latent variable for the th observation. We have assumed the parameter can be decomposed as , with and and denotes the probability of the latent vector taking on the value . In the 2-component mixture model example, denotes the latent membership indicator for the th observation. We assume discrete latent variables for notational convenience and note that our results generalize to continuous latent variables in a straightforward fashion.

Let . The augmented posterior assumes the form


whose constituent terms are typically more tractable. Variational inference in this setup proceeds by first identifying a variational family comprising of distributions on and finding the closest member in this family to relative to the KL divergence




is the evidence lower bound (ELBO) which gives a lower bound to the log marginal likelihood . If is completely unrestricted, then coincides with the posterior distribution . In practice, the choice of the variational family is dictated by a trade-off between flexibility and computational tractability. For example, in mean-field variational approximation, it is common to assume independence among the parameters and the latent variables in the variational family, whence decomposes as


We shall assume the mean-field decomposition (6) throughout, so that the minimizer in (4) necessarily is of the form In many situations, the constituent terms may be further decomposed as and . Such a decomposition may be either due to computational reasons or implied by the conditional independence structure of the model.

Under the mean-field decomposition (6) and using (3), we have, after some simplification,


The quantity is an approximation to the log likelihood in terms of the latent variables. To see this, multiply and divide the right hand side of (2) by and apply Jensen’s inequality to the concave function to conclude that . Replacing by in (2) and adjusting the error term, we obtain


The quantity is an average error due to the likelihood approximation and is clearly nonnegative. In the specific situation where no latent variables are present in the model, . However, in general, is a strictly positive quantity.

3 Variational risk bounds

We are now prepared to state a general Bayes risk bound for the variational distribution. We shall operate in a frequentist framework and assume the existence of a true data generating parameter . In other words, we assume that the data is distributed according to . Let be a distance metric on the parameter space which quantifies distance between two putative parameter values. For example, if designates a space of densities so that each can be identified with a density function, can be chosen as the Hellinger or total variation distance. We are interested in obtaining bounds on

for some , that hold with high probability with respect to . In particular, if is convex in its first argument, then by Jensen’s inequality,

where is the mean of the variational distribution and a surrogate for the posterior mean. We are specifically interested in obtaining sufficient conditions for the variational point estimate to contract at the same rate as posterior mean. Since variational approaches are overwhelmingly used for rapidly obtaining point estimates, such a result will indicate that at least in terms of rates of convergence, there is no loss of statistical accuracy in using a variational approximation. Moreover, the negative result from [41] shows that the spread of the variational distribution

is typically “too small” compared with that of the sampling distribution of the maximum likelihood estimator. This fact combined with the Bernstein von-Mises theorem (Chapter 10 of

[39]) implies the inadequacy of using for approximating the true posterior distribution, and a rate optimal variational point estimator is the best one can hope for in general.

Define to be the log-likelihood ratio between and . We can replace with inside the integrand in (2) without affecting the minimization problem - this is done for purely theoretical reasons to harness the structure of the log-likelihood ratio. Let us call the equivalent objective function , so that


We are now ready to state the main assumption.
Assumption T: (existence of tests) Let be a sequence satisfying and for some . Let be a sequence of test functions for testing

with type-I and II error rates satisfying

for any with , where denotes an expectation with respect to . While appears naturally in most problems, we provide a non-standard example (estimating component specific means in a mixture model) in Section 4 with .

In models satisfying the monotone likelihood ratio property [12] such as exponential families, one can construct such tests (with ) from the generalized likelihood ratio test (GLRT) statistic when corresponds to the Euclidean metric on the natural parameter. A general recipe [19] to construct such tests when is compact relative to is to (i) construct an -net such that for any with , there exists with , (ii) construct a test for versus with type-I and II error rates as in Assumption T, and (iii) set

. The type-II error of

retains the same upper bound, while the type-I error can be bounded by

. Since can be further bounded by , the covering number of by -balls of radius , it suffices to show that . For example, when is a compact subset of , then as long as . More generally, if is a space of densities and the Hellinger/ metric, then construction of the point-by-point tests in (i) (with ) from the LRT statistics follows from the classical Birgé-Lecam testing theory [6, 28]; see also [19].

We are now ready to state our first theorem on the variational risk bound, which relates the Bayes risk under the variational solution to the size of the objective function .

Theorem 3.1.

Suppose there exists a sequence of test functions for the metric satisfying Assumption T. Then, it holds with probability at least that

for any

and any probability distribution

on which is nowhere zero.

Theorem 3.1 implies that minimizing the Bayes risk is equivalent to minimizing the objective function in (3). [43] obtained a similar result for a modified variational objective function and being limited to Rényi divergence measures. Theorem 3.1 instead allows any metric as long as the testing condition in Assumption T can be satisfied.

A detailed proof of Theorem 3.1 (as well as proofs of other results) is provided in the Section 6. We sketch some of the key steps to highlight the main features of our argument. Our first key step is to show using the testing assumption T that w.h.p. (w.r.t. ),


To show this, write the integral in (10) as , where and respectively split the integral over and . Using Markov’s inequality along with the fact that , it can be shown that w.h.p. To tackle , write by decomposing , where is the test from Assumption T. Using Markov’s inequality and the fact that has small type-II error, it can be shown that w.h.p. The bound on the type-I error of along with Markov’s inequality yields w.h.p., which yields w.h.p. Combining, one gets w.h.p.

Once (10) is established, the next step is to link the integrand in (10) with the latent variables. To that end, observe that


Combining the above with (10), we have, w.h.p.,


Next, use a well-known variational/dual representation of the KL divergence (see, e.g., Corollary 4.15 of [11]) which states that for any probability measure and any measurable function with ,


where the supremum is over all probability measures . In the present context, setting , , and , it follows from the variational lemma (28) and some rearrangement of terms that w.h.p.

It then follows from (2)–(3) that the right hand side of the above display equals . The proof of the theorem then follows, since by definition, for any in the variational family .

Now we discuss choices of good variational distributions and for minimizing , the stochastic component of the variational upper bound in Theorem 3.1. We make some additional assumptions first on the augmented likelihood and prior in (3) for the subsequent development. First, assume independent priors on and so that . Next, assume and splits into independent components. This implies , where . For the variational distribution , we additionally assume a mean-field decomposition .

Recall that the objective function decomposes as . The first model-fit term is an averaged (with respect to the variational distribution) log-likelihood ratio which tends to get small as the variational distribution places more mass near the true parameter . The second term arising from the approximation of the likelihood function by will become small under some proper choice of , as we will illustrate in the proof of Theorem 3.2 below. The last regularization or penalty term prevents over-fitting to the data by constricting the KL divergence between the variational solution and the prior. Consequently, a good variational distribution should put all its mass into an appropriately small neighborhood around the truth so that the first two terms in become small; on the other hand, the neighborhood has to be large enough so that the last regularization term is not too large.

Motivated by the above discussion, we follow the development of [43] by defining two KL neighborhoods around with radius as

where we used the shorthand to denote the KL divergence between categorical distributions with parameters and in the -dim simplex . Consistent with notation introduced at the beginning of Section 2, .

Theorem 3.2.

For any fixed , with probability at least , it holds that

The proof follows Theorem 4.5 of [43], and is omitted; we provide a sketch here. As discussed above, we first make a good choice of as follows. Let be a probability distribution over defined as


Intuitively, takes the full conditional distribution of , and replaces by the true parameter . With this choice, simplifies to

It now remains to choose . The first term in the above display naturally suggests choosing as the restriction of into . For this choice, the second term can also be controlled w.h.p., leading to the conclusion in Theorem 3.2.

4 Applications

As described in Section 3, variational risk bounds for the parameter of interest depend on the existence of appropriate test functions which characterizes the ability of the likelihood to identify the parameter. Developing test functions for studying convergence rates of estimators in classical and Bayesian statistics dates back to

[36, 28], with renewed attention in the Bayesian context due to [19]. Specific tests have been constructed for nonparametric density estimation [19, 33], semi/non-parametric regression [4, 31], convergence of latent mixing measures [30, 22], high-dimensional problems [32, 16, 14]

, and empirical Bayes methods

[35], among others. As long as the prior distributions are supported on compact subsets of the parameter space, these existing tests can be used to prove minimax optimality of the variational estimate in each of these problems. We skip the details for space constraints.

In this section, we focus on two novel examples involving latent variables where variational methods are commonly used and no theoretical guarantee is available for the variational solutions. The first one is the Latent Dirichlet Allocation (LDA; [10]), a generative probabilistic model for topic modeling. The second example is concerned with estimating the component specific parameters in Gaussian mixture models.

4.1 Latent Dirichlet allocation

We first consider LDA [10], a conditionally conjugate probabilistic topic model [8] for learning the latent “topics” contained in a collection of documents. Starting from the original paper [10], the mean-field variational Bayes approximation has become a routine approach for implementing LDA. However, theoretical guarantees for the variational solution is still an open problem despite its empirical success. In this subsection, we show the rate optimality of the estimate from the mean-field approximation to LDA.

In LDA, each document is assumed to contain multiple topics, where a topic is defined as a distribution over words in a vocabulary. Our presentation of the model follows the notation of [23]. Let be the total number of topics, the vocabulary size, the total number of documents, and the number of words in each document (for simplicity, we assume the same number of words across documents). Recall that we use the notation to denote the -dim simplex. LDA contains two parameters: word distribution matrix and topic distribution matrix , where is the word proportion vector of the th topic, , and is the topic proportion vector of the th document, . Given parameters and , the data generative model of LDA is:

  1. for each document in , draw a topic assignment , then

  2. for each word in , draw a word .

where stands for categorical distribution with probability . Here, is the latent class variable over topics so that indicates the th word in document is assigned to the th topic. Similarly, is the latent class variable over the words in the vocabulary so that indicates that the th word in document is the th word in the vocabulary. A common prior distribution over the parameters and is:

  1. For each topic in , word proportion vector has prior ,

  2. For each document in , topic proportion vector has prior .

Here, is a hyper-parameter of the symmetric Dirichlet prior on the topics , and are hyper-parameters of the Dirichlet prior on the topic proportions for each document. To facilitate adaptation to sparsity using Dirichlet distributions when , we choose and for some fixed number [42].

To apply our theory, we view as the sample size, and as the “dimension” of the parameters in the model. Under our vanilla notation, we are interested in learning parameters , with and , from the posterior distribution , where with are latent variables, with are the data, and the priors for are independent Dirichlet distributions and whose densities are denoted by and . The conditional distribution of the observation given the latent variable is

We consider the following mean-field approximation [10] by decomposing the variational distribution into

for approximating the joint posterior distribution of . Since for LDA, each observation is composed of independent observations, it is natural to present the variational oracle inequality with respect to the average squared Hellinger distance , where denotes the likelihood function of the th observation in . We make the following assumption.

Assumption S: (sparsity and regularity condition)

Suppose for each , is sparse, and for each , is sparse. Moreover, there exists some constant , such that each nonzero component of or is at least .

Theorem 4.1.

Suppose Assumption S holds. If as , then it holds with probability tending to one that as

Theorem 4.1 implies the estimation consistency as long as the “effective” dimensionality of the model is as the “effective sample size” . In addition, the upper bound depends only logarithmically on the vocabulary size due to the sparsity assumption.

4.2 Gaussian mixture models

Variational inference methods are routinely used in conjugate exponential-family mixtures [18] to speed up computation and perform inference on component-specific parameters, with Gaussian mixtures constituting an important special case. Traditional MCMC methods face difficulties in inferring component-specific parameters due to label-switching [37]. It has been empirically verified [7] that variational inference for Gaussian mixtures provides accurate estimates of the true density as well as the labels (up to permutation of the indices). However, theoretical guarantees of such a phenomenon is an open problem till date. In this section, we close this gap and provide an affirmative answer under reasonable assumptions on the true mixture density. In particular, we show that variational techniques using mean field approximation provide optimal estimates for the component specific parameters up to some permutation of the labels.

Suppose the true data generating model is the -dimensional Gaussian mixture model with components,


where is the mean vector associated with the th component and is a vector of probabilities lying in the -dimensional simplex . Assume that where is a compact subset of . Set and as before. Although we assume the covariance matrix of each Gaussian component to be , it is straightforward to extend our results to diagonal covariance matrices.

Introducing independent latent variables for such that for , (14) can be re-written as


where and denotes a discrete distribution with support and probabilities . For simplicity, we assume independent priors for .

We apply the mean field approximation by using the family of density functions of the form

to approximate the joint posterior distribution of .

Since we are interested in studying accuracy in estimating the component specific parameters , we turn our attention to relating the discrepancy in estimating with that of . We work with the Wasserstein metric [30] between the mixing measures associated with the density. Note that (15)-(16) can also be written in terms of the mixing measure as


Henceforth will be referred to as the atoms of . Let denote the class of all such mixing measures

such that the atoms lie in a compact subset of . Define the -Wasserstein distance, denoted , between two mixing measures and in as


where for , and is the set of all possible couplings, i.e. joint distributions of and with marginals and respectively. One can write (18) in terms of as


where varies over , the set of joint probability mass functions over satisfying and .

In the following, we will consider to work with the metric. It is known [30] that is compact with respect to