Stochastic Variational Inference

06/29/2012 ∙ by Matt Hoffman, et al. ∙ adobe Princeton University Carnegie Mellon University berkeley college 0

We develop stochastic variational inference, a scalable algorithm for approximating posterior distributions. We develop this technique for a large class of probabilistic models and we demonstrate it with two probabilistic topic models, latent Dirichlet allocation and the hierarchical Dirichlet process topic model. Using stochastic variational inference, we analyze several large collections of documents: 300K articles from Nature, 1.8M articles from The New York Times, and 3.8M articles from Wikipedia. Stochastic inference can easily handle data sets of this size and outperforms traditional variational inference, which can only handle a smaller subset. (We also show that the Bayesian nonparametric topic model outperforms its parametric counterpart.) Stochastic variational inference lets us apply complex Bayesian models to massive data sets.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

Code Repositories

svinet

This package implements algorithms for identifying overlapping communities in large undirected networks. The sampling based algorithms derive from stochastic variational inference under the (assortative) mixed-membership stochastic blockmodel. For details see the following reference: http://www.pnas.org/content/early/2013/08/14/1221839110.full.pdf


view repo

pysvihmm

Implementation of stochastic variational inference for Bayesian hidden Markov models.


view repo

stochasticLDA

Python implementation of Stochastic Variational Inference for LDA


view repo

edward

A library for probabilistic modeling, inference, and criticism. Deep generative models, variational inference. Runs on TensorFlow.


view repo

edward

A library for probabilistic modeling, inference, and criticism. Deep generative models, variational inference. Runs on TensorFlow.


view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Modern data analysis requires computation with massive data. As examples, consider the following. (1) We have an archive of the raw text of two million books, scanned and stored online. We want to discover the themes in the texts, organize the books by subject, and build a navigator for users to explore our collection. (2) We have data from an online shopping website containing millions of users’ purchase histories as well as descriptions of each item in the catalog. We want to recommend items to users based on this information. (3) We are continuously collecting data from an online feed of photographs. We want to build a classifier from these data. (4) We have measured the gene sequences of millions of people. We want to make hypotheses about connections between observed genes and other traits.

These problems illustrate some of the challenges to modern data analysis. Our data are complex and high-dimensional; we have assumptions to make—from science, intuition, or other data analyses—that involve structures we believe exist in the data but that we cannot directly observe; and finally our data sets are large, possibly even arriving in a never-ending stream.

Statistical machine learning research has addressed some of these challenges by developing the field of probabilistic modeling, a field that provides an elegant approach to developing new methods for analyzing data (Pearl, 1988; Jordan, 1999; Bishop, 2006; Koller and Friedman, 2009; Murphy, 2012). In particular, probabilistic graphical models give us a visual language for expressing assumptions about data and its hidden structure. The corresponding posterior inference algorithms let us analyze data under those assumptions, inferring the hidden structure that best explains our observations.

In descriptive tasks, like problems #1 and #4 above, graphical models help us explore the data—the organization of books or the connections between genes and traits—with the hidden structure probabilistically “filled in.” In predictive tasks, like problems #2 and #3, we use models to form predictions about new observations. For example, we can make recommendations to users or predict the class labels of new images. With graphical models, we enjoy a powerful suite of probability models to connect and combine; and we have general-purpose computational strategies for connecting models to data and estimating the quantities needed to use them.

The problem we face is scale. Inference algorithms of the 1990s and 2000s used to be considered scalable, but they cannot easily handle the amount of data that we described in the four examples above. This is the problem we address here. We present an approach to computing with graphical models that is appropriate for massive data sets, data that might not fit in memory or even be stored locally. Our method does not require clusters of computers or specialized hardware, though it can be further sped up with these amenities.

As an example of this approach to data analysis, consider topic models. Topic models are probabilistic models of text used to uncover the hidden thematic structure in a collection of documents (Blei, 2012). The main idea in a topic model is that there are a set of topics that describe the collection and each document exhibits those topics with different degrees. As a probabilistic model, the topics and how they relate to the documents are hidden structure and the main computational problem is to infer this hidden structure from an observed collection. Figure 1 illustrates the results of our algorithm on a probabilistic topic model. These are two sets of topics, weighted distributions over the vocabulary, found in 1.8M articles from the New York Times and 300,000 articles from Nature. Topic models are motivated by applications that require analyzing massive collections of documents like this, but traditional algorithms for topic model inference do not easily scale collections of this size.

Figure 1: Posterior topics from the hierarchical Dirichlet process topic model on two large data sets. These posteriors were approximated using stochastic variational inference with 1.8M articles from the New York Times (top) and 350K articles from Nature (bottom). (See Section 3.3 for the modeling details behind the hierarchical Dirichlet process and Section 4 for details about the empirical study.) Each topic is a weighted distribution over the vocabulary and each topic’s plot illustrates its most frequent words.

Our algorithm builds on variational inference, a method that transforms complex inference problems into high-dimensional optimization problems (Jordan et al., 1999; Wainwright and Jordan, 2008). Traditionally, the optimization is solved with a coordinate ascent algorithm, iterating between re-analyzing every data point in the data set and re-estimating its hidden structure. This is inefficient for large data sets, however, because it requires a full pass through the data at each iteration.

In this paper we derive a more efficient algorithm by using stochastic optimization (Robbins and Monro, 1951), a technique that follows noisy estimates of the gradient of the objective. When used in variational inference, we show that this gives an algorithm which iterates between subsampling the data and adjusting the hidden structure based only on the subsample. This is much more efficient than traditional variational inference. We call our method stochastic variational inference.

We will derive stochastic variational inference for a large class of graphical models. We will study its performance on two kinds of probabilistic topic models. In particular, we demonstrate stochastic variational inference on latent Dirichlet allocation (Blei et al., 2003), a simple topic model, and the hierarchical Dirichlet process topic model (Teh et al., 2006a), a more flexible model where the number of discovered topics grows with the data. (This latter application demonstrates how to use stochastic variational inference in a variety of Bayesian nonparametric settings.) Stochastic variational inference can efficiently analyze massive data sets with complex probabilistic models.

Technical summary.

  We now turn to the technical context of our method. In probabilistic modeling, we use hidden variables to encode hidden structure in observed data; we articulate the relationship between the hidden and observed variables with a factorized probability distribution (i.e., a graphical model); and we use inference algorithms to estimate the posterior distribution, the conditional distribution of the hidden structure given the observations.

Consider a graphical model of hidden and observed random variables for which we want to compute the posterior. For many models of interest, this posterior is not tractable to compute and we must appeal to approximate methods. The two most prominent strategies in statistics and machine learning are Markov chain Monte Carlo (MCMC) sampling and variational inference. In MCMC sampling, we construct a Markov chain over the hidden variables whose stationary distribution is the posterior of interest

(Metropolis et al., 1953; Hastings, 1970; Geman and Geman, 1984; Gelfand and Smith, 1990; Robert and Casella, 2004). We run the chain until it has (hopefully) reached equilibrium and collect samples to approximate the posterior. In variational inference, we define a flexible family of distributions over the hidden variables, indexed by free parameters (Jordan et al., 1999; Wainwright and Jordan, 2008). We then find the setting of the parameters (i.e., the member of the family) that is closest to the posterior. Thus we solve the inference problem by solving an optimization problem.

Neither MCMC nor variational inference scales easily to the kinds of settings described in the first paragraph. Researchers have proposed speed-ups of both approaches, but these usually are tailored to specific models or compromise the correctness of the algorithm (or both). Here, we develop a general variational method that scales.

As we mentioned above, the main idea in this work is to use stochastic optimization (Robbins and Monro, 1951; Spall, 2003)

. In stochastic optimization, we find the maximum of an objective function by following noisy (but unbiased) estimates of its gradient. Under the right conditions, stochastic optimization algorithms provably converge to an optimum of the objective. Stochastic optimization is particularly attractive when the objective (and therefore its gradient) is a sum of many terms that can be computed independently. In that setting, we can cheaply compute noisy gradients by subsampling only a few of these terms.

Variational inference is amenable to stochastic optimization because the variational objective decomposes into a sum of terms, one for each data point in the analysis. We can cheaply obtain noisy estimates of the gradient by subsampling the data and computing a scaled gradient on the subsample. If we sample independently then the expectation of this noisy gradient is equal to the true gradient. With one more detail—the idea of a natural gradient (Amari, 1998)—stochastic variational inference has an attractive form:

Subsample one or more data points from the data.

Analyze the subsample using the current variational parameters.

Implement a closed-form update of the variational parameters.

Repeat. While traditional algorithms require repeatedly analyzing the whole data set before updating the variational parameters, this algorithm only requires that we analyze randomly sampled subsets. We will show how to use this algorithm for a large class of graphical models.

Related work.   Variational inference for probabilistic models was pioneered in the mid-1990s. In Michael Jordan’s lab, the seminal papers of Saul et al. (1996); Saul and Jordan (1996) and Jaakkola (1997) grew out of reading the statistical physics literature (Peterson and Anderson, 1987; Parisi, 1988). In parallel, the mean-field methods explained in Neal and Hinton (1999) (originally published in 1993) and Hinton and Van Camp (1993) led to variational algorithms for mixtures of experts (Waterhouse et al., 1996).

In subsequent years, researchers began to understand the potential for variational inference in more general settings and developed generic algorithms for conjugate exponential-family models (Attias, 1999, 2000; Wiegerinck, 2000; Ghahramani and Beal, 2001; Xing et al., 2003). These innovations led to automated variational inference, allowing a practitioner to write down a model and immediately use variational inference to estimate its posterior (Bishop et al., 2003). For good reviews of variational inference see Jordan et al. (1999) and Wainwright and Jordan (2008).

In this paper, we develop scalable methods for generic Bayesian inference by solving the variational inference problem with stochastic optimization 

(Robbins and Monro, 1951). Our algorithm builds on the earlier approach of Sato (2001), whose algorithm only applies to the limited set of models that can be fit with the EM algorithm (Dempster et al., 1977). Specifically, we generalize his approach to the much wider set of probabilistic models that are amenable to closed-form coordinate ascent inference. Further, in the sense that EM itself is a mean-field method (Neal and Hinton, 1999), our algorithm builds on the stochastic optimization approach to EM (Cappé and Moulines, 2009). Finally, we note that stochastic optimization was also used with variational inference in Platt et al. (2008) for fast approximate inference in a specific model of web service activity.

For approximate inference, the main alternative to variational methods is Markov chain Monte Carlo (MCMC) (Robert and Casella, 2004). Despite its popularity in Bayesian inference, relatively little work has focused on developing MCMC algorithms that can scale to very large data sets. One exception is sequential Monte Carlo, although these typically lack strong convergence guarantees (Doucet et al., 2001). Another is the stochastic gradient Langevin method of Welling and Teh (2011), which enjoys asymptotic convergence guarantees and also takes advantage of stochastic optimization. Finally, in topic modeling, researchers have developed several approaches to parallel MCMC (Newman et al., 2009; Smola and Narayanamurthy, 2010; Ahmed et al., 2012).

The organization of this paper.   In Section 2, we review variational inference for graphical models and then derive stochastic variational inference. In Section 3, we review probabilistic topic models and Bayesian nonparametric models and then derive the stochastic variational inference algorithms in these settings. In Section 4, we study stochastic variational inference on several large text data sets.

2 Stochastic Variational Inference

We derive stochastic variational inference, a stochastic optimization algorithm for mean-field variational inference. Our algorithm approximates the posterior distribution of a probabilistic model with hidden variables, and can handle massive data sets of observations.

We divide this section into four parts.

  1. We define the class of models to which our algorithm applies. We define local and global hidden variables, and requirements on the conditional distributions within the model.

  2. We review mean-field variational inference, an approximate inference strategy that seeks a tractable distribution over the hidden variables which is close to the posterior distribution. We derive the traditional variational inference algorithm for our class of models, which is a coordinate ascent algorithm.

  3. We review the natural gradient and derive the natural gradient of the variational objective function. The natural gradient closely relates to coordinate ascent variational inference.

  4. We review stochastic optimization, a technique that uses noisy estimates of a gradient to optimize an objective function, and apply it to variational inference. Specifically, we use stochastic optimization with noisy estimates of the natural gradient of the variational objective. These estimates arise from repeatedly subsampling the data set. We show how the resulting algorithm, stochastic variational inference, easily builds on traditional variational inference algorithms but can handle much larger data sets.

2.1 Models with local and global hidden variables

Figure 2: A graphical model with observations , local hidden variables and global hidden variables . The distribution of each observation only depends on its corresponding local variable and the global variables . (Though not pictured, each hidden variable , observation , and global variable may be a collection of multiple random variables.)

Our class of models involves observations, global hidden variables, local hidden variables, and fixed parameters. The observations are

; the vector of global hidden variables is

; the local hidden variables are , each of which is a collection of variables ; the vector of fixed parameters is . (Note we can easily allow to partly govern any of the random variables, such as fixed parts of the conditional distribution of observations. To keep notation simple, we assume that they only govern the global hidden variables.)

The joint distribution factorizes into a global term and a product of local terms,

(1)

Figure 2 illustrates the graphical model. Our goal is to approximate the posterior distribution of the hidden variables given the observations, .

The distinction between local and global hidden variables is determined by the conditional dependencies. In particular, the th observation and the th local variable are conditionally independent, given global variables , of all other observations and local hidden variables,

The notation and refers to the set of variables except the th.

This kind of model frequently arises in Bayesian statistics. The global variables

are parameters endowed with a prior and each local variable contains the hidden structure that governs the

th observation. For example, consider a Bayesian mixture of Gaussians. The global variables are the mixture proportions and the means and variances of the mixture components; the local variable

is the hidden cluster label for the th observation .

We have described the independence assumptions of the hidden variables. We make further assumptions about the complete conditionals in the model. A complete conditional is the conditional distribution of a hidden variable given the other hidden variables and the observations. We assume that these distributions are in the exponential family,

(2)
(3)

The scalar functions and are respectively the base measure and log-normalizer; the vector functions and are respectively the natural parameter and sufficient statistics.111We use overloaded notation for the functions and so that they depend on the names of their arguments; for example, can be thought of as a shorthand for the more formal (but more cluttered) notation . This is analogous to the standard convention of overloading the probability function . These are conditional distributions, so the natural parameter is a function of the variables that are being conditioned on. (The subscripts on the natural parameter indicate complete conditionals for local or global variables.) For the local variables , the complete conditional distribution is determined by the global variables and the other local variables in the th context, i.e., the th data point and the local variables . This follows from the factorization in Equation 1.

These assumptions on the complete conditionals imply a conjugacy relationship between the global variables and the local contexts , and this relationship implies a specific form of the complete conditional for . Specifically, the distribution of the local context given the global variables must be in an exponential family,

(4)

The prior distribution must also be in an exponential family,

(5)

The sufficient statistics are

and thus the hyperparameter

has two components . The first component is a vector of the same dimension as ; the second component is a scalar.

Equations 4 and 5 imply that the complete conditional for the global variable in Equation 2 is in the same exponential family as the prior with natural parameter

(6)

This form will be important when we derive stochastic variational inference in Section 2.4. See Bernardo and Smith (1994) for a general discussion of conjugacy and the exponential family.

This family of distributions—those with local and global variables, and where the complete conditionals are in the exponential family—contains many useful statistical models from the machine learning and statistics literature. Examples include Bayesian mixture models (Ghahramani and Beal, 2000; Attias, 2000), latent Dirichlet allocation (Blei et al., 2003)

, hidden Markov models (and many variants)

(Rabiner, 1989; Fine et al., 1998; Fox et al., 2011b; Paisley and Carin, 2009)

, Kalman filters (and many variants)

(Kalman, 1960; Fox et al., 2011a), factorial models (Ghahramani and Jordan, 1997)

, hierarchical linear regression models 

(Gelman and Hill, 2007), hierarchical probit classification models (McCullagh and Nelder, 1989; Girolami and Rogers, 2006), probabilistic factor analysis/matrix factorization models (Spearman, 1904; Tipping and Bishop, 1999; Collins et al., 2002; Wang, 2006; Salakhutdinov and Mnih, 2008; Paisley and Carin, 2009; Hoffman et al., 2010b), certain Bayesian nonparametric mixture models (Antoniak, 1974; Escobar and West, 1995; Teh et al., 2006a), and others.222We note that our assumptions can be relaxed to the case where the full conditional is not tractable, but each partial conditional associated with the global variable is in a tractable exponential family. The topic models of the next section do not require this complexity, so we chose to keep the derivation a little simpler.

Analyzing data with one of these models amounts to computing the posterior distribution of the hidden variables given the observations,

(7)

We then use this posterior to explore the hidden structure of our data or to make predictions about future data. For many models however, such as the examples listed above, the denominator in Equation 7 is intractable to compute. Thus we resort to approximate posterior inference, a problem that has been a focus of modern Bayesian statistics. We now turn to mean-field variational inference, the approximation inference technique which roots our strategy for scalable inference.

2.2 Mean-field variational inference

Variational inference casts the inference problem as an optimization. We introduce a family of distributions over the hidden variables that is indexed by a set of free parameters, and then optimize those parameters to find the member of the family that is closest to the posterior of interest. (Closeness is measured with Kullback-Leibler divergence.) We use the resulting distribution, called the

variational distribution, to approximate the posterior.

In this section we review mean-field variational inference, the form of variational inference that uses a family where each hidden variable is independent. We describe the variational objective function, discuss the mean-field variational family, and derive the traditional coordinate ascent algorithm for fitting the variational parameters. This algorithm is a stepping stone to stochastic variational inference.

The evidence lower bound.   Variational inference minimizes the Kullback-Leibler (KL) divergence from the variational distribution to the posterior distribution. It maximizes the evidence lower bound (ELBO), a lower bound on the logarithm of the marginal probability of the observations . The ELBO is equal to the negative KL divergence up to an additive constant.

We derive the ELBO by introducing a distribution over the hidden variables and using Jensen’s inequality. (Jensen’s inequality and the concavity of the logarithm function imply that for any random variable .) This gives the following bound on the log marginal,

(8)

The ELBO contains two terms. The first term is the expected log joint, . The second term is the entropy of the variational distribution, . Both of these terms depend on , the variational distribution of the hidden variables.

We restrict to be in a family that is tractable, one for which the expectations in the ELBO can be efficiently computed. We then try to find the member of the family that maximizes the ELBO. Finally, we use the optimized distribution as a proxy for the posterior.

Solving this maximization problem is equivalent to finding the member of the family that is closest in KL divergence to the posterior (Jordan et al., 1999; Wainwright and Jordan, 2008),

is replaced by a constant because it does not depend on .

The mean-field variational family.   The simplest variational family of distributions is the mean-field family. In this family, each hidden variable is independent and governed by its own parameter,

(9)

The global parameters govern the global variables; the local parameters govern the local variables in the th context. The ELBO is a function of these parameters.

Equation 9 gives the factorization of the variational family, but does not specify its form. We set and to be in the same exponential family as the complete conditional distributions and , from Equations 2 and 3. The variational parameters and are the natural parameters to those families,

(10)
(11)

These forms of the variational distributions lead to an easy coordinate ascent algorithm. Further, the optimal mean-field distribution, without regard to its particular functional form, has factors in these families (Bishop, 2006).

Note that assuming that these exponential families are the same as their corresponding conditionals means that and in Equation 10 are the same functions as and in Equation 2. Likewise, and in Equation 11 are the same as in Equation 3. We will sometimes suppress the explicit dependence on and , substituting for and for .

The mean-field family has several computational advantages. For one, the entropy term decomposes,

where denotes an expectation with respect to and denotes an expectation with respect to . Its other computational advantages will emerge as we derive the gradients of the variational objective and the coordinate ascent algorithm.

The gradient of the ELBO and coordinate ascent inference.   We have defined the objective function in Equation 8 and the variational family in Equations 9, 10 and 11. Our goal is to optimize the objective with respect to the variational parameters.

In traditional mean-field variational inference, we optimize Equation 8 with coordinate ascent. We iteratively optimize each variational parameter, holding the other parameters fixed. With the assumptions that we have made about the model and variational distribution—that each conditional is in an exponential family and that the corresponding variational distribution is in the same exponential family—we can optimize each coordinate in closed form.

We first derive the coordinate update for the parameter to the variational distribution of the global variables . As a function of , we can rewrite the objective as

(12)

The first two terms are expectations that involve ; the third term is constant with respect to . The constant absorbs quantities that depend only on the other hidden variables. Those quantities do not depend on because all variables are independent in the mean-field family.

Equation 12 reproduces the full ELBO in Equation 8. The second term of Equation 12

is the entropy of the global variational distribution. The first term derives from the expected log joint likelihood, where we use the chain rule to separate terms that depend on the variable

from terms that do not,

The constant absorbs , leaving the expected log conditional .

Finally, we substitute the form of in Equation 10 to obtain the final expression for the ELBO as a function of ,

(13)

In the first and second terms on the right side, we used the exponential family identity that the expectation of the sufficient statistics is the gradient of the log normalizer, . The constant has further absorbed the expected log normalizer of the conditional distribution , which does not depend on .

Equation 13 simplifies the ELBO as a function of the global variational parameter. To derive the coordinate ascent update, we take the gradient,

(14)

We can set this gradient to zero by setting

(15)

This sets the global variational parameter equal to the expected natural parameter of its complete conditional distribution. Implementing this update, holding all other variational parameters fixed, optimizes the ELBO over . Notice that the mean-field assumption plays an important role. The update is the expected conditional parameter , which is an expectation of a function of the other random variables and observations. Thanks to the mean-field assumption, this expectation is only a function of the local variational parameters and does not depend on .

We now turn to the local parameters . The gradient is nearly identical to the global case,

It equals zero when

(16)

Mirroring the global update, this expectation does not depend on . However, while the global update in Equation 15 depends on all the local variational parameters—and note there is a set of local parameters for each of the observations—the local update in Equation 16 only depends on the global parameters and the other parameters associated with the th context. The computational difference between local and global updates will be important in the scalable algorithm of Section 2.4.

1:  Initialize randomly.
2:  repeat
3:     for each local variational parameter  do
4:        Update , .
5:     end for
6:     Update the global variational parameters, .
7:  until the ELBO converges
Figure 3: Coordinate ascent mean-field variational inference.

The updates in Equations 15 and 16 form the algorithm for coordinate ascent variational inference, iterating between updating each local parameter and the global parameters. The full algorithm is in Figure 3, which is guaranteed to find a local optimum of the ELBO. Computing the expectations at each step is easy for directed graphical models with tractable complete conditionals, and in Section 3 we show that these updates are tractable for many topic models. Figure 3 is the “classical” variational inference algorithm, used in many settings.

As an aside, these updates reveal a connection between mean-field variational inference and Gibbs sampling (Gelfand and Smith, 1990)

. In Gibbs sampling, we iteratively sample from each complete conditional. In variational inference, we take variational expectations of the natural parameters of the same distributions. The updates also show a connection to the expectation-maximization (EM) algorithm

(Dempster et al., 1977)—Equation 16 corresponds to the E step, and Equation 15 corresponds to the M step (Neal and Hinton, 1999).

We mentioned that the local steps (Steps 3 and 4 in Figure 3) only require computation with the global parameters and the th local context. Thus, the data can be distributed across many machines and the local variational updates can be implemented in parallel. These results can then be aggregated in Step 6 to find the new global variational parameters.

However, the local steps also reveal an inefficiency in the algorithm. The algorithm begins by initializing the global parameters randomly—the initial value of does not reflect any regularity in the data. But before completing even one iteration, the algorithm must analyze every data point using these initial (random) values. This is wasteful, especially if we expect that we can learn something about the global variational parameters from only a subset of the data.

We solve this problem with stochastic optimization. This leads to stochastic variational inference, an efficient algorithm that continually improves its estimate of the global parameters as it analyzes more observations. Though the derivation requires some details, we have now described all of the computational components of the algorithm. (See Figure 4.) At each iteration, we sample a data point from the data set and compute its optimal local variational parameters; we form intermediate global parameters using classical coordinate ascent updates where the sampled data point is repeated times; finally, we set the new global parameters to a weighted average of the old estimate and the intermediate parameters.

The algorithm is efficient because it need not analyze the whole data set before improving the global variational parameters, and the per-iteration steps only require computation about a single local context. Furthermore, it only uses calculations from classical coordinate inference. Any existing implementation of variational inference can be easily configured to this scalable alternative.

We now show how stochastic inference arises by applying stochastic optimization to the natural gradients of the variational objective. We first discuss natural gradients and their relationship to the coordinate updates in mean-field variational inference.

2.3 The natural gradient of the ELBO

The natural gradient of a function accounts for the information geometry of its parameter space, using a Riemannian metric to adjust the direction of the traditional gradient. Amari (1998) discusses natural gradients for maximum-likelihood estimation, which give faster convergence than standard gradients. In this section we describe Riemannian metrics for probability distributions and the natural gradient of the ELBO.

Gradients and probability distributions.   The classical gradient method for maximization tries to find a maximum of a function by taking steps of size in the direction of the gradient,

The gradient (when it exists) points in the direction of steepest ascent. That is, the gradient points in the same direction as the solution to

(17)

for sufficiently small . Equation 17 implies that if we could only move a tiny distance away from then we should move in the direction of the gradient. Initially this seems reasonable, but there is a complication. The gradient direction implicitly depends on the Euclidean distance metric associated with the space in which lives. However, the Euclidean metric might not capture a meaningful notion of distance between settings of .

The problem with Euclidean distance is especially clear in our setting, where we are trying to optimize an objective with respect to a parameterized probability distribution . When optimizing over a probability distribution, the Euclidean distance between two parameter vectors and is often a poor measure of the dissimilarity of the distributions and . For example, suppose is a univariate normal and is the mean and scale . The distributions and are almost indistinguishable, and the Euclidean distance between their parameter vectors is 10. In contrast, the distributions and barely overlap, but this is not reflected in the Euclidean distance between their parameter vectors, which is only 0.1. The natural gradient corrects for this issue by redefining the basic definition of the gradient (Amari, 1998).

Natural gradients and probability distributions.   A natural measure of dissimilarity between probability distributions is the symmetrized KL divergence

(18)

Symmetrized KL depends on the distributions themselves, rather than on how they are parameterized; it is invariant to parameter transformations.

With distances defined using symmetrized KL, we find the direction of steepest ascent in the same way as for gradient methods,

(19)

As , the solution to this problem points in the same direction as the natural gradient. While the Euclidean gradient points in the direction of steepest ascent in Euclidean space, the natural gradient points in the direction of steepest ascent in the Riemannian space, i.e., the space where local distance is defined by KL divergence rather than the norm.

We manage the more complicated constraint in Equation 19 with a Riemannian metric (Do Carmo, 1992)

. This metric defines linear transformations of

under which the squared Euclidean distance between and a nearby vector is the KL between and ,

(20)

and note that the transformation can be a function of . Amari (1998) showed that we can compute the natural gradient by premultiplying the gradient by the inverse of the Riemannian metric ,

where is the Fisher information matrix of (Amari, 1982; Kullback and Leibler, 1951),

(21)

We can show that Equation 21 satisfies Equation 20 by approximating using the first-order Taylor approximations about

and plugging the result into Equation 18:

For small enough we can ignore the term.

When is in the exponential family (Equation 10) the metric is the second derivative of the log normalizer,

This follows from the exponential family identity that the Hessian of the log normalizer function with respect to the natural parameter is the covariance matrix of the sufficient statistic vector .

Natural gradients and mean field variational inference.   We now return to variational inference and compute the natural gradient of the ELBO with respect to the variational parameters. Researchers have used the natural gradient in variational inference for nonlinear state space models (Honkela et al., 2008) and Bayesian mixtures (Sato, 2001).333Our work here—using the natural gradient in a stochastic optimization algorithm—is closest to that of Sato (2001), though we develop the algorithm via a different path and Sato does not address models for which the joint conditional is not tractable.

Consider the global variational parameter . The gradient of the ELBO with respect to is in Equation 14. Since is a natural parameter to an exponential family distribution, the Fisher metric defined by is . Note that the Fisher metric is the first term in Equation 14. We premultiply the gradient by the inverse Fisher information to find the natural gradient. This reveals that the natural gradient has the following simple form,

(22)

An analogous computation goes through for the local variational parameters,

The natural gradients are closely related to the coordinate ascent updates of Equation 15 or Equation 16. Consider a full set of variational parameters and . We can compute the natural gradient by computing the coordinate updates in parallel and subtracting the current setting of the parameters. The classical coordinate ascent algorithm can thus be interpreted as a projected natural gradient algorithm (Sato, 2001). Updating a parameter by taking a natural gradient step of length one is equivalent to performing a coordinate update.

We motivated natural gradients by mathematical reasoning around the geometry of the parameter space. More importantly, however, natural gradients are easier to compute than classical gradients. They are easier to compute because premultiplying by the Fisher information matrix—which we must do to compute the classical gradient in Equation 14 but which disappears from the natural gradient in Equation 22—is prohibitively expensive for variational parameters with many components. In the next section we will see that efficiently computing the natural gradient lets us develop scalable variational inference algorithms.

2.4 Stochastic variational inference

The coordinate ascent algorithm in Figure 3 is inefficient for large data sets because we must optimize the local variational parameters for each data point before re-estimating the global variational parameters. Stochastic variational inference uses stochastic optimization to fit the global variational parameters. We repeatedly subsample the data to form noisy estimates of the natural gradient of the ELBO, and we follow these estimates with a decreasing step-size.

We have reviewed mean-field variational inference in models with exponential family conditionals and showed that the natural gradient of the variational objective function is easy to compute. We now discuss stochastic optimization, which uses a series of noisy estimates of the gradient, and use it with noisy natural gradients to derive stochastic variational inference.

Stochastic optimization.   Stochastic optimization algorithms follow noisy estimates of the gradient with a decreasing step size. Noisy estimates of a gradient are often cheaper to compute than the true gradient, and following such estimates can allow algorithms to escape shallow local optima of complex objective functions. In statistical estimation problems, including variational inference of the global parameters, the gradient can be written as a sum of terms (one for each data point) and we can compute a fast noisy approximation by subsampling the data. With certain conditions on the step-size schedule, these algorithms provably converge to an optimum (Robbins and Monro, 1951). Spall (2003) gives an overview of stochastic optimization;  Bottou (2003) gives an overview of its role in machine learning.

Consider an objective function and a random function that has expectation equal to the gradient so that . The stochastic gradient algorithm, which is a type of stochastic optimization, optimizes by iteratively following realizations of . At iteration , the update for is

where is an independent draw from the noisy gradient . If the sequence of step sizes satisfies

(23)

then will converge to the optimal (if is convex) or a local optimum of (if not convex).444To find a local optimum, must be three-times differentiable and meet a few mild technical requirements (Bottou, 1998). The variational objective satisfies these criteria. The same results apply if we premultiply the noisy gradients by a sequence of positive-definite matrices

(whose eigenvalues are bounded)

(Bottou, 1998). The resulting algorithm is

As our notation suggests, we will use the Fisher metric for , replacing stochastic Euclidean gradients with stochastic natural gradients.

Stochastic variational inference.   We use stochastic optimization with noisy natural gradients to optimize the variational objective function. The resulting algorithm is in Figure 4. At each iteration we have a current setting of the global variational parameters. We repeat the following steps:

  1. Sample a data point from the set; optimize its local variational parameters.

  2. Form intermediate global variational parameters, as though we were running classical coordinate ascent and the sampled data point were repeated times to form the collection.

  3. Update the global variational parameters to be a weighted average of the intermediate parameters and their current setting.

We show that this algorithm is stochastic natural gradient ascent on the global variational parameters.

Our goal is to find a setting of the global variational parameters that maximizes the ELBO. Writing as a function of the global and local variational parameters, Let the function return a local optimum of the local variational parameters so that

Define the locally maximized ELBO to be the ELBO when is held fixed and the local variational parameters are set to a local optimum ,

We can compute the (natural) gradient of by first finding the corresponding optimal local parameters and then computing the (natural) gradient of , holding fixed. The reason is that the gradient of is the same as the gradient of the two-parameter ELBO ,

(24)
(25)

where is the Jacobian of and we use the fact that the gradient of with respect to is zero at .

Stochastic variational inference optimizes the maximized ELBO by subsampling the data to form noisy estimates of the natural gradient. First, we decompose into a global term and a sum of local terms,

(26)

Now consider a variable that chooses an index of the data uniformly at random, . Define to be the following random function of the variational parameters,

(27)

The expectation of is equal to the objective in Equation 26. Therefore, the natural gradient of with respect to each global variational parameter is a noisy but unbiased estimate of the natural gradient of the variational objective. This process—sampling a data point and then computing the natural gradient of —will provide cheaply computed noisy gradients for stochastic optimization.

We now compute the noisy gradient. Suppose we have sampled the th data point. Notice that Equation 27 is equivalent to the full objective of Equation 26 where the th data point is observed times. Thus the natural gradient of Equation 27—which is a noisy natural gradient of the ELBO—can be found using Equation 22,

where are a data set formed by replicates of observation and hidden variables .

We compute this expression in more detail. Recall the complete conditional from Equation 6. From this equation, we can compute the conditional natural parameter for the global parameter given replicates of ,

Using this in the natural gradient of Equation 22 gives a noisy natural gradient,

where gives the elements of associated with the th local context. While the full natural gradient would use the local variational parameters for the whole data set, the noisy natural gradient only considers the local parameters for one randomly sampled data point. These noisy gradients are cheaper to compute.

Finally, we use the noisy natural gradients in a Robbins-Monro algorithm to optimize the ELBO. We sample a data point at each iteration. Define the intermediate global parameter to be the estimate of that we would obtain if the sampled data point was replicated times,

This comprises the first two terms of the noisy natural gradient. At each iteration we use the noisy gradient (with step size ) to update the global variational parameter. The update is

This is a weighted average of the previous estimate of and the estimate of that we would obtain if the sampled data point was replicated times.

1:  Initialize randomly.
2:  Set the step-size schedule appropriately.
3:  repeat
4:     Sample a data point uniformly from the data set.
5:     Compute its local variational parameter,
6:     Compute intermediate global parameters as though is replicated times,
7:     Update the current estimate of the global variational parameters,
8:  until forever
Figure 4: Stochastic variational inference.

Figure 4 presents the full algorithm. At each iteration, the algorithm has an estimate of the global variational parameter . It samples a single data point from the data and cheaply computes the intermediate global parameter , i.e., the next value of if the data set contained replicates of the sampled point. It then sets the new estimate of the global parameter to be a weighted average of the previous estimate and the intermediate parameter.

We set the step-size at iteration as follows,

(28)

This satisfies the conditions in Equation 23. The forgetting rate controls how quickly old information is forgotten; the delay down-weights early iterations. In Section 4 we fix the delay to be one and explore a variety of forgetting rates. Note that this is just one way to parameterize the learning rate. As long as the step size conditions in Equation 23 are satisfied, this iterative algorithm converges to a local optimum of the ELBO.

2.5 Extensions

We now describe two extensions of the basic stochastic inference algorithm in Figure 4

: the use of multiple samples (“minibatches”) to improve the algorithm’s stability, and empirical Bayes methods for hyperparameter estimation.

Minibatches.   So far, we have considered stochastic variational inference algorithms where only one observation is sampled at a time. Many stochastic optimization algorithms benefit from “minibatches,” i.e., several examples at a time (Bottou and Bousquet, 2008; Liang et al., 2009; Mairal et al., 2010). In stochastic variational inference, we can sample a set of examples at each iteration (with or without replacement), compute the local variational parameters for each data point, compute the intermediate global parameters for each data point , and finally average the variables in the update

The stochastic natural gradients associated with each point have expected value equal to the gradient. Therefore, the average of these stochastic natural gradients has the same expectation and the algorithm remains valid.

There are two reasons to use minibatches. The first reason is to amortize any computational expenses associated with updating the global parameters across more data points; for example, if the expected sufficient statistics of are expensive to compute, using minibatches allows us to incur that expense less frequently. The second reason is that it may help the algorithm to find better local optima. Stochastic variational inference is guaranteed to converge to a local optimum but taking large steps on the basis of very few data points may lead to a poor one. As we will see in Section 4, using more of the data per update can help the algorithm.

Empirical Bayes estimation of hyperparameters.   In some cases we may want to both estimate the posterior of the hidden random variables and and obtain a point estimate of the values of the hyperparameters . One approach to fitting is to try to maximize the marginal likelihood of the data , which is also known as empirical Bayes (Maritz and Lwin, 1989) estimation. Since we cannot compute exactly, an approximate approach is to maximize the fitted variational lower bound over . In the non-stochastic setting, can be optimized by interleaving the coordinate ascent updates in Figure 3 with an update for that increases the ELBO. This is called variational expectation-maximization.

In the stochastic setting, we update simultaneously with . We can take a step in the direction of the gradient of the noisy ELBO (Equation 27) with respect to , scaled by the step-size ,

Here are the global parameters from the previous iteration and are the optimized local parameters for the currently sampled data point. We can also replace the standard Euclidean gradient with a natural gradient or Newton step.

3 Stochastic Variational Inference in Topic Models

We derived stochastic variational inference, a scalable inference algorithm that can be applied to a large class of hierarchical Bayesian models. In this section we show how to use the general algorithm of Section 2 to derive stochastic variational inference for two probabilistic topic models: latent Dirichlet allocation (LDA) (Blei et al., 2003) and its Bayesian nonparametric counterpart, the hierarchical Dirichlet process (HDP) topic model (Teh et al., 2006a).

Topic models are probabilistic models of document collections that use latent variables to encode recurring patterns of word use (Blei, 2012). Topic modeling algorithms are inference algorithms; they uncover a set of patterns that pervade a collection and represent each document according to how it exhibits them. These patterns tend to be thematically coherent, which is why the models are called “topic models.” Topic models are used for both descriptive tasks, such as to build thematic navigators of large collections of documents, and for predictive tasks, such as to aid document classification. Topic models have been extended and applied in many domains.

Topic models assume that the words of each document arise from a mixture of multinomials. Across a collection, the documents share the same mixture components (called topics). Each document, however, is associated with its own mixture proportions (called topic proportions). In this way, topic models represent documents heterogeneously—the documents share the same set of topics, but each exhibits them to a different degree. For example, a document about sports and health will be associated with the sports and health topics; a document about sports and business will be associated with the sports and business topics. They both share the sports topic, but each combines sports with a different topic. More generally, this is called mixed membership (Erosheva, 2003).

The central computational problem in topic modeling is posterior inference: Given a collection of documents, what are the topics that it exhibits and how does each document exhibit them? In practical applications of topic models, scale is important—these models promise an unsupervised approach to organizing large collections of text (and, with simple adaptations, images, sound, and other data). Thus they are a good testbed for stochastic variational inference.

More broadly, this section illustrates how to use the results from Section 2 to develop algorithms for specific models. We will derive the algorithms in several steps: (1) we specify the model assumptions; (2) we derive the complete conditional distributions of the latent variables; (3) we form the mean-field variational family; (4) we derive the corresponding stochastic inference algorithm. In Section 4, we will report our empirical study of stochastic variational inference with these models.

3.1 Notation

We follow the notation of Blei et al. (2003).

  • Observations are words, organized into documents. The th word in the th document is . Each word is an element in a fixed vocabulary of terms.

  • A topic is a distribution over the vocabulary. Each topic is a point on the -simplex, a positive vector of length that sums to one. We denote the th entry in the th topic as . In LDA there are topics; in the HDP topic model there are an infinite number of topics.

  • Each document in the collection is associated with a vector of topic proportions , which is a distribution over topics. In LDA is a point on the -simplex. In the HDP topic model, is a point on the infinite simplex. (We give details about this below in Section 3.3.) We denote the th entry of the topic proportion vector as .

  • Each word in each document is assumed to have been drawn from a single topic. The topic assignment indexes the topic from which is drawn.

The only observed variables are the words of the documents. The topics, topic proportions, and topic assignments are latent variables.

3.2 Latent Dirichlet allocation

Var Type Conditional Param Relevant Expectations
Multinomial
Dirichlet
Dirichlet
Figure 5: (Top) The graphical model representation of Latent Dirichlet allocation. Note that in practice each document may not have the same number of words . (Bottom) In LDA: hidden variables, complete conditionals, variational parameters, and expected sufficient statistics.

LDA is the simplest topic model. It assumes that each document exhibits topics with different proportions. The generative process is

  1. Draw topics for .

  2. For each document :

    1. Draw topic proportions .

    2. For each word :

      1. Draw topic assignment .

      2. Draw word .

Figure 5 illustrates LDA as a graphical model.

In LDA, each document exhibits the same shared topics but with different proportions. LDA assumes Dirichlet priors for and . Dirichlet distributions over the -simplex take parameters, but for simplicity we assume exchangeable Dirichlet priors; that is, we require that all of these parameters are set to the same value. (The prior on has parameter ; the prior on has parameter .). We note that Blei et al. (2003) and Wallach et al. (2009) found improved empirical performance with non-exchangeable priors.

LDA models an observed collection of documents , where each is a collection of words . Analyzing the documents amounts to posterior inference of . Conditioned on the documents, the posterior distribution captures the topics that describe them (), the degree to which each document exhibits those topics (), and which topics each word was assigned to (). We can use the posterior to explore large collections of documents. Figure 1 illustrates posterior topics found with stochastic variational inference.

The posterior is intractable to compute (Blei et al., 2003). Approximating the posterior in LDA is a central computational problem for topic modeling. Researchers have developed many methods, including Markov chain Monte Carlo methods (Griffiths and Steyvers, 2004), expectation propagation (Minka and Lafferty, 2002), and variational inference (Blei et al., 2003; Teh et al., 2006b; Asuncion et al., 2009). Here we use the results of Section 2 to develop stochastic inference for LDA. This scales the original variational algorithm for LDA to massive collections of documents.555The algorithm we present was originally developed in Hoffman et al. (2010a), which is a special case of the stochastic variational inference algorithm we developed in Section 2.

Figure 7 illustrates the performance of 100-topic LDA on three large collections—Nature contains 350K documents, New York Times contains 1.8M documents, and Wikipedia contains 3.8M documents. (Section 4 describes the complete study, including the details of the performance measure and corpora.) We compare two inference algorithms for LDA: stochastic inference on the full collection and batch inference on a subset of 100,000 documents. (This is the size of collection that batch inference can handle.) We see that stochastic variational inference converges faster and to a better model. It is both more efficient and lets us handle the full data set.

Indicator vectors and Dirichlet distributions.   Before deriving the algorithm, we discuss two mathematical details. These will be useful both here and in the next section.

First, we represent categorical variables like the topic assignments

and observed words with indicator vectors. An indicator vector is a binary vector with a single one. For example, the topic assignment can take on one of values (one for each topic). Thus, it is represented as a -vector with a one in the component corresponding to the value of the variable: if then the th word in document is assigned to the th topic. Likewise, implies that the th word in document is . In a slight abuse of notation, we will sometimes use and as indices—for example, if , then refers to the th topic .

Second, we review the Dirichlet distribution. As we described above, a -dimensional Dirichlet is a distribution on the -simplex, i.e., positive vectors over elements that sum to one. It is parameterized by a positive -vector ,

where is the Gamma function, which is a real-valued generalization of the factorial function. The expectation of the Dirichlet is its normalized parameter,

The expectation of its log uses , which is the first derivative of the log Gamma function,

(29)

This can be derived by putting the Dirichlet in exponential family form, noticing that is the vector of sufficient statistics, and computing its expectation by taking the gradient of the log-normalizer with respect to the natural parameter vector .

Complete conditionals and variational distributions.   We specify the global and local variables of LDA to place it in the stochastic variational inference setting of Section 2. In topic modeling, the local context is a document . The local observations are its observed words . The local hidden variables are the topic proportions and the topic assignments . The global hidden variables are the topics .

Recall from Section 2 that the complete conditional is the conditional distribution of a variable given all of the other variables, hidden and observed. In mean-field variational inference, the variational distributions of each variable are in the same family as the complete conditional.

We begin with the topic assignment . The complete conditional of the topic assignment is a multinomial,

(30)

Thus its variational distribution is a multinomial , where the variational parameter is a point on the -simplex. Per the mean-field approximation, each observed word is endowed with a different variational distribution for its topic assignment, allowing different words to be associated with different topics.

The complete conditional of the topic proportions is a Dirichlet,

(31)

Since is an indicator vector, the th element of the parameter to this Dirichlet is the sum of the hyperparameter and the number of words assigned to topic in document . Note that, although we have assumed an exchangeable Dirichlet prior, when we condition on the conditional is a non-exchangeable Dirichlet.

With this conditional, the variational distribution of the topic proportions is also Dirichlet , where is a -vector Dirichlet parameter. There is a different variational Dirichlet parameter for each document, allowing different documents to be associated with different topics in different proportions.

These are local hidden variables. The complete conditionals only depend on other variables in the local context (i.e., the document) and the global variables; they do not depend on variables from other documents.

Finally, the complete conditional for the topic is also a Dirichlet,

(32)

The th element of the parameter to the Dirichlet conditional for topic is the sum of the hyperparameter and the number of times that the term was assigned to topic . This is a global variable—its complete conditional depends on the words and topic assignments of the entire collection.

The variational distribution for each topic is a -dimensional Dirichlet,

As we will see in the next section, the traditional variational inference algorithm for LDA is inefficient with large collections of documents. The root of this inefficiency is the update for the topic parameter , which (from Equation 32) requires summing over variational parameters for every word in the collection.

Batch variational inference.

With the complete conditionals in hand, we now derive the coordinate ascent variational inference algorithm, i.e., the batch inference algorithm of Figure 3. We form each coordinate update by taking the expectation of the natural parameter of the complete conditional. This is the stepping stone to stochastic variational inference.

The variational parameters are the global per-topic Dirichlet parameters , local per-document Dirichlet parameters , and local per-word multinomial parameters . Coordinate ascent variational inference iterates between updating all of the local variational parameters (Equation 16) and updating the global variational parameters (Equation 15).

We update each document ’s local variational in a local coordinate ascent routine, iterating between updating each word’s topic assignment and the per-document topic proportions,

(33)
(34)

These updates derive from taking the expectations of the natural parameters of the complete conditionals in Equation 30 and Equation 31. (We then map back to the usual parameterization of the multinomial.) For the update on the topic assignment, we have used the Dirichlet expectations in Equation 29. For the update on the topic proportions, we have used that the expectation of an indicator is its probability,