Approximate statistical inference is fundamentally important in statistics and probabilistic modelling in machine learning research. In Bayesian statistics, posterior distributions are in general intractable to compute, because the partition function requires computing the integral over all possible model parameters. For this reason, approximate inference for distributions with unnormalised probability function has always been an active research topic in Bayesian inference. In frequentist inference, approximate inference is also essential, particularly the maximum likelihood estimation on models involved latent variables. Expectation-Maximisation (EM) is one of the classic training algorithms for latent-variable models (LVMs), in which approximate inference is critical in the expectation step for marginalising out latent variables in the gradient of model parameters. Markov chain Monte Carlo (MCMC) methods and variational inference (VI), which are important inference methodologies originally developed in statistics and physics, are two most important approximate inference methods in machine learning. In spite of the success of MCMC and VI on many classic probabilistic models, like latent topic modelsBlei et al. (2003)Hinton (2002b); Salakhutdinov and Larochelle (2010)
, Bayesian non-parametric modelsNeal (2000); Kurihara et al. (2007), they are facing great challenges in the recent research on deep generative models.
Deep Generative Models (DGM) are a new type of LVMs. The innovation of DGMs is to use deep neural networks (NNs) to transform the simple latent random variables into complex distribution, which has been proved to be a technique that can greatly boost the presentation power of LVMs. The recent advances in DGMs show impressive progress on many challenging unsupervised learning tasks. This makes DGMs one of the most active research topics in both deep learning and probabilistic modelling. However, the powerful DGMs also bring great challenges to classic inference methods. This motivates many recent advances on enhancing classic inference methods using also NNs. In particular, variational autoencodersKingma and Welling (2014) and normalising flows Rezende and Mohamed (2015)
are two most influential works in this direction. Many recent state-of-the-art training algorithms of DGMs are based on the variants of these two methods. However, the limitations of NN-based methods have been noticed in recent work. In particular, the success of NN-enhanced inference methods is determined by the the design of inference NNs, which is mainly based on heuristic tricks and engineering expertise on the specific model. This makes it difficult to generalise these methods to DGMs with new type of NNs or other probabilistic models without NNs, like traditional LVMs and Bayesian NNs.
In this work, we propose a novel approximate inference method based on the classic inference theory of MCMC. Our method is inspired by the idea of parallel simulations of MCMC and the recent advances in NN-based approximate inference. From the computational perspective, our method is highly scalable like recent NN-based inference. In particular, it is straightforward to accelerate the computation of our method using parallelised simulations on Graphical Processing Units. More importantly, the proposed method has solid theoretical foundations in the theory of MCMC, that gives the guarantee of asymptotic strong convergence to any distributions of interest. It is a great advantage over NN-based inference, because there is a better theoretical understanding of our methods. More importantly, like classic inference methods, our methods can be applied to general inference tasks in a wide range of probabilistic models in both Bayesian and frequentist statistics.
2.1 Deep Generative Models
We introduce the basic concepts and notations in inference on generative models. Generally speaking, generative models often refer to certain probabilistic models. The goal of generative models is to extract the intrinsic structures of data as stochastic latent representations that can be used to generate synthetic data by simulating samples from the probabilistic models.
Formally, let be a dataset formed by a collection of data points . Assume that the data points in follow the distribution
that is unobserved. In many popular machine learning applications, like natural language processing and computer vision, data contains rich and complex structure, which means the data distributions are typically high-dimensional and multi-modal. One important modelling technique to handle complex data is known as latent variables. Intuitively, a latent variablerepresents some intrinsic representation of data point and the conditional distribution specifies the data generation distribution given specific latent representation . To regularise the space of latent variable, we define a prior distribution
. By the rule of product of probability, the joint distribution of data and latent representation is defined as
To fit a LVM to the data distribution , we marginalise out in the model, that gives
Although the joint probability function is often known and straightforward to compute, the marginal probability above is not in general MacKay (2002). Approximation of the marginal probability is one of the key research topics in LVMs.
Recent advances in deep generative models (DGMs) Kingma and Welling (2014); Goodfellow et al. (2014) show that deep neural networks can be used to parametrise and greatly boost the representation power of LVMs. In classic LVMs, the distribution is typically specified as some parametric family with parameter , like generalized linear models Murphy (2012); Bishop (2006). In DGMs, such a parametric family is constructed by NNs and the parameter is the weights of NNs Rezende et al. (2014); Kingma and Welling (2014). In particular, there can be multiple layers of latent representations
in DGMs, where adjacent layers are connected by deterministic non-linear transformations constructed by NNs. For example, given the observed, follows the distribution with the parameter and is a NN with parameter . In powerful DGMs, the NNs often have multiple hidden layers to create highly non-linear mapping between the stochastic layers.
Maximum likelihood estimation (MLE) is the most straightforward way to train a generative model is to fit the distribution . The optimal model parameters by MLE is specified as
where is the marginal likelihood
Because does not have closed form, expectation-maximisation (EM) MacKay (2002) is the classic training algorithm for latent generative models, which requires to compute the expectation of under the distribution . Unfortunately, the expectation with respect to the distribution is also intractable to compute, but it can be approximated by variational inference or Markov chain Monte Carlo.
2.2 Variational Auto-Encoders and Normalizing Flows
Variational inference (VI) is a popular inference method to approximate distributions with intractable partition function. The idea of VI is to approximate the distribution of interest, like in DGMs, by another distribution from a parametric family with closed-form probability function. To find the distribution in the family that is closest to the target distribution, VI methods optimise the lower bound of the normalising constant of target distribution. In the case of DGMs, that is
In general, the more flexible the family is the tighter the variational lower bound can be. Rainforth et al. Rainforth et al. (2018) argue that a tighter variational lower bound is not necessary helpful for some specific variational inference methods. However, it is worth to clarify that, if the family guaranteed to approximate the target arbitrarily well, a tighter variational lower bound is always desired. We provide both theoretical and empirical evidence in later sections. In classic mean-field variational methods, the proposal distribution is often in a factorized form, that is often not flexible enough to produce good approximations on complex posteriors, like in DGMs.
Because in VI can be any normalized probability function, one can use NNs to construct expressive . There are two popular ways to do this. One option is to use another deep generative model to approximate , where the NN architecture in is often chosen to mirror the NN in . This is well known as variational auto-encoders (VAEs) proposed by Kingma and Welling Kingma and Welling (2014). The variational parameter in VAEs is essentially the weights of the NN in . Rezende and Mohamed Rezende and Mohamed (2015) proposed an alternative flexible family of variational distribution called normalizing flows (NFs). Unlike VAEs, NFs are constructed by composition of invertible non-linear deterministic transformations of a random variable , that is
where is often a simple distribution, e.g. uniform or normal. In NFs, the variational parameter is the collection of parameters in . Because the mapping from to is a deterministic smooth transformation, following the rule of changing variables in integral, the density of is defined as
If the transformations are volume preserving transformations, where the Jacobian terms are equal to 1, then (3) can be simplified as
We can rewrite the variational lower bound as
It is straightforward to train normalizing flows by minimizing the variational lower bound using stochastic gradient descent (SGD).
Although both VAEs and NFs are more powerful than mean-field variational methods, they still suffer from the fundamental limitation of VI: there is no guarantee of the existence of optimal that is sufficiently close to the target distribution . Because of this limitation, the choice of the architecture of NNs in is significant to the success of VI. However, there is lack of theoretical understanding of the convergence of NN to distribution. For this reason, NN-based inference methods often rely on engineering skills and knowledge on specific architecture of NNs.
The limitation of NN in inference is more fundamental than the engineering complexity. Hoffman Hoffman (2017) pointed out that NN-based VI may not be sufficiently flexible because the variational distribution
is restricted to the parametric family with closed-form density functions. Although NNs in theory can approximate any function arbitrarily well in supervised learning problems, however, this is not the case for posterior inference. From the perspective of function approximate, the goal of statistic inference is to approximate the function. However, the partition functions of posteriors are functions in general intractable to compute, so the target function is essentially unobserved. Therefore, this raises a question on how useful NNs are for general inference.
2.3 Markov chain Monte Carlo
Markov chain Monte Carlo (MCMC) is an alternative way to approximate a posteriors. MCMC generates correlated but asymptotically unbiased samples from the distributions of interest by simulating a stationary Markov chains. In particular, MCMC only requires the unnormalized probability density function of the distributionwe want to simulate to construct a stationary Markov chain. Formally, a Markov chain is a sequence of random variables or shortly . The transitional probability from state to , denoted by given . Given the distribution of initial state , the joint probability of all states of a Markov chain is defined as
The Markov chains in MCMC are special Markov chains that have a stationary distribution . Intuitively, that means, given n samples from , if we apply the MCMC transition kernel to these samples, the output samples of the kernel also follow distribution . Formally, that is
If the Markov chain is not mixed well, which means the distribution of is different from the stationary distribution , the distribution of
is guaranteed to be closer to than . So, irrespective to the initial distribution , the samples of the MCMC chain is essentially exact samples from with sufficient long simulation of the chain. Because of the guarantee of convergence, MCMC methods are very popular in statistics and physics.
However, there is much less interest in MCMC than VI in research on DGMs, because of the following pitfalls of MCMC. First, it is very difficult to have theoretical analysis on the convergence rate of MCMC methods in general case. Even the diagnosis of the convergence of MCMC chains is very challenging in practice. MCMC methods are very sensitive to the choice of the parameters in the kernel. Unfortunately, the theoretical results of the convergence of MCMC chains are not useful to construct tractable loss functions to optimise the kernel parameters to accelerate the convergence of MCMC. In practice, the parameters are tuned manually based on trial simulation or heuristic adaptive strategy. Second, even after the chain mixes well, the samples from MCMC still can be strongly correlated. As a consequence, it requires more samples from MCMC chains to reduce the variance of estimation than i.i.d. samples, especially in high dimensional spaceRobert and Casella (2005). Third, the computational time of simulation increases linearly with the number of samples. This greatly limits the use of MCMC in the problems with complex model with big amount of data. Trivially running parallel simulations of MCMC chain often does not improve the computational efficiency, because each single simulation may still take long to mix well.
3 Measure Preserving Flows
We are interested in an approximate inference framework that avoids the problems of MCMC and VI mentioned in previous section. The idea of our method is very simple: we use by the composition of from MCMC chain as the approximation distribution in VI and optimise a variational lower bound w.r..t. the parameters of the Markov kernels . By formulating MCMC chains variationally, we can avoid manually tuning the kernel parameters. Moreover, since converges asymptotically to the target, we know that our variational lower bound can become arbitrarily tight. Salimans et al. Salimans et al. (2015) have attempted to propose a method following this idea. However, it is very challenging because the density function is intractable to compute, that makes it impossible to optimise variational lower bound directly. Instead, they proposed a NN-based approximation to in the variational lower bound, which limits the potential of their method. Here, we propose a novel solution based on transformations from ergodic theory, which allows us to avoid NN-based approximation and optimise the kernel parameters w.r.t. tractable loss function on the convergence rate of to the target.
3.1 Probability Measure Basics
We provide the relevant basics of probability theory in this section. A probability spaceincludes an arbitrary set , a collection of its subsets of denoted by and probability measure that maps an element of to a real number in . A random variable is a deterministic function of denoted by
. A random vector is a function mapping fromto denoted by . Lebesgue measure is an important probability measure defined in Euclidean space . It is specified by the requirement that bounded rectangles have measure The probability density function of a probability measure of following Radon-Nikodym derivative with respect to Lebesgue measure as
where denotes the preimage of in . Lebesgue measure is preserved by any linear transformation that also preserves the volume of a set in Euclidean space Billingsley (1986). In particular, we present the following theorem from Billingsley (1986):
If is linear and non-singular, then implies that and
Shear transformations are one of most known Lebesgue measure preserving transformations. In particular, shear transformations in are defined as , where can be arbitrary function. It is straightforward to verify that shear transformations have the determinant of Jacobian equal to 1. By the rule of changing variables, the Jacobian of composition of transformations is simply the product of the Jacobian of individual transformations. So, the composition of shear transformations also preserve Lebesgue measure.
Measure preserving transformations are transformation which preserves a given measure. For example, any transformations with the determinant of Jacobian equal to 1 preserves Lebesgue measure. In the book of Billingsley Billingsley (1986), the definition of measure preserving transformations is as following.
Measure Preserving Transformations. Let () be a probability space and be a consistent measure with . A mapping is a measure preserving transformation if is measurable in both the input filed and the output field and for all , the measurable subsets under . If is a one-to-one mapping onto , then T preserves : .
Consider a -dimensional random vector with probability measure . The are the marginal distribution of , if has a density function in , then has the marginal density function
This can be generalised to the marginal distribution over any subsets of variables in . The definition of marginal density function implies that the preservation of joint probability distribution is sufficient condition for the preservation of marginal distribution.
Given the target distribution with unnormalised density function , we define an approximate distribution by a mixture of sequential deterministic transformations that preserve the measure . We call such approximate distributions measure preserving flows (MPFs). Formally, let be a random variable with distribution in Euclidean -space . Following the definition of measure preserving transformation (Definition 3.1), it is straightforward to show that the following three conditions are sufficient conditions for measure preserving transformations:
Bijection: is invertible,
Preservation of density function: for all ,
Preservation of the reference measure: in the case of Lebesgue measure, that means the determinant of Jacobian .
In probability theory, the composition of measure preserving transformations constructs an ergodic stochastic process, which, under reasonable conditions, converges to an invariant measure as the number of transformations grows. This convergence applies irrespective of the initial state’s distribution. Markov chains are an example of ergodic processes. In the book of Robert and Casella Robert and Casella (2005), MCMC methods are formally defined as
A Markov chain Monte Carlo method for the simulation of a distribution is any method producing an ergodic Markov chain whose stationary distribution is .
In this work, we are interested in the MPFs that are equivalent to MCMC. Formally, a MCMC chain with steps have the stochastic states , where is sampled from a initial distribution and the distribution of given the previous state denoted by . is also known as the transition kernel. By the product rule of probability, the joint probability of all states of the MCMC chain is
Integrating out the history of the chain , we have the distribution of the last state of the chain as
where denotes the composition of transition kernels
Recall that by (6) we know that MCMC kernel preserves the stationary distribution .
Let be the distribution we want to sample and be a measure-preserving transformation (MPT) that preserves the probability measure with , where can be any distribution. We assume that satisfy the three conditions of measure preserving transformations we mentioned earlier. The projection of in the space of , denoted by , is a stochastic transition from to with . Assume that the auxiliary variable follows distribution and we can reformulate MCMC kernel as . It is not hard to see that it is possible to reformulate each state of MCMC chain by applying measure-preserving transformations to in a sequence, that is
In the joint space of , (8) is essentially
where each preserves and given each output auxiliary variables and the state , we can revert this process by the reverse of , that is .
If we are only interested in , we denote (8) simply by the composition . Following the rule of changing variables, it is straightforward to derive the density of
where denotes dirac delta function. Obviously, MPFs constructed by reparameterising ergodic Markov chains as above enjoy exactly the same ergodicity as the Markov chains. This is important because the ergodicity guarantees the convergence of to the invariant distribution as grows.
3.3 Understanding Measure Preserving Conditions
It is worth to address some common confusion of measure preservation conditions in Section 3.2 before further discussion on measure preserving flows. In particular, a common misunderstanding is that the preservation of volume is equivalent to condition (iii) in Section 3.2 on the preservation of Lebesgue measure. This is not true in a general sense. Notice that, by the construction of MPFs, we are interested in sampling the random variable , but the measure preserving transformation defined in the previous section preserves the joint measure . Following the conditions of measure preserving transformations in Section 3.2, it seems necessary to show that given specific , the projection of in the space of also satisfy the measure preserving conditions to the marginal . In particular, it is necessary to include a correction term if the Jacobian of with respect to is not equal to 1.
However, it is not the case. In particular, the volume preservation in the space of (the Jacobian of is equal to 1) is NOT necessary for measure preserving flows. It is important to understand that the measure preservation conditions in the augmented space is sufficient to the preservation of the marginal distribution in the space of . Formally, we have the following proposition
Given a transformation preserves the distribution , if is sampled from the marginal , then the marginal distribution
is also preserved by the projection of in the space of , that is .
Because preserves the probability measure , then for any measurable set in Borel set , we have
Given a set , the set generated by , is measurable under with the measure
that is essentially a probability measure in the space of , also known as the marginal probability
Because preserves the joint measure, applying on gives
where denotes the projection of in the space of . Follow the definition of measure preserving transformations, we have
where denotes the preimage of under . Because is essentially the marginal probability , , we know the marginal distribution is preserved by because
This implies that the marginal distribution is preserved by the stochastic mapping
where . If is invertible for any , we have
where denotes the projection of in the space of . Therefore, we know that if we sample from and apply on , then the marginal distribution is preserved. ∎
Proposition 1 gives us some important insight to understand the difference between MPFs and normalising flows (NFs). Similar to MPFs, NFs also use a sequence of deterministic transformations to approximate the distribution of interest, but the parameter in NFs is not stochastic variable and optimised by maximising the variational lower bound. Because the value of is fixed, following the rule of changing variables, it is necessary to consider the Jacobian of . In contrast, the projection transformation in MPFs is fundamentally different from the transformations in NFs. Given a fixed value of , the transformation in MPFs is deterministic like the transformations in NFs. However, in MPFs, is a projection of the transformation which preserves a joint probability measure , which means is a random variable rather than a fixed parameter. More importantly, by Proposition 1 we know that is not required to preserve the volume in the space of to preserve the marginal if is sampled from . In actuality, given a fixed , the Jacobian of can be very complex, that may not even be in closed form and the Jacobian of is intractable to compute in general. For this reason, the family of transformations used in MPFs is much more general than the family of NN used in NFs.
3.4 Ergodicity and Convergence of Measure Preserving Flows
As mentioned in Section 3.2, we are interested in the MPFs that can be reformulated as MCMC chains, because the ergodicity provides the guarantee of convergence of MPFs. Ergodicity and measure preserving transformations are closely related in probability theory. The background of the ergodicity can be found in many textbooks, so we will not cover the basics. However, we introduce some important results in ergodicity theory that will be used in later sections.
First, recall that we only consider MPFs that are equivalent to MCMC chains, there is also no need to prove the ergodicity of such MPFs. Ergodicity is important because it can establish the convergence of random variables in total variation distance. In Vaart (2000), the convergence in total variation is defined as following.
A sequence of random variables converges in total variation to a variable if
where the supremum is taken over all measurable sets .
The total variation is also known as the distance metric between two distributions and . We denote the total variation distance by . Convergence in total variation is stronger than convergence in distribution, because it requires that the sequence converges for every Borel set and the convergence must also be uniform in . A simple sufficient condition for convergence in total variation is pointwise convergence of densities. Unsurprisingly, it is generally very difficult to establish the convergence in total variation distance from scratch.
The convergence in total variation distance of an ergodic Markov chain to its stationary distribution is the foundation of the theory of MCMC. Formally, this is described as Theorem 6.51 in Robert and Casella (2005), which is the following theorem.
If a Markov chain with kernel is ergodic with stationary distribution , then
for every initial distribution .
Furthermore, we know that the convergence of ergodic Markov chains is monotonic as the following proposition from Robert and Casella (2005) (Proposition 6.52).
If is an invariant distribution for the ergodic Markov chain, then
is decreasing in .
Because MPFs we consider in this work are equivalent to MCMC chains, MPFs enjoy the convergence to the invariant distribution in total variation distance. Following Theorem 2 and Proposition 2, we have the following theorem on the convergence of measure preserving flows.
If is an invariant distribution of a measure preserving flow and is the marginal distribution of the final state of the flow with measure preserving transformations,
for every initial distribution and is decreasing in .
see appendix ∎
3.5 Variational Inference in MPFs
As we discussed in Section 2.2, one important application of variational inference (VI) is to train latent variable models (LVMs) by maximising the marginal likelihood of data. In particular, VI uses the KL divergence as the distance metric between model distribution and approximate distribution . By Jensen’s inequality, KL divergence is a lower bound of the marginal likelihood
This lower bound is often known as the evidence lower bound (ELBO). ELBO cannot be evaluated exactly, but it is straightforward to approximate ELBO by Monte Carlo if we can generate samples from
VI allows us to fit the model to the data by optimising the ELBO with respect to model parameter using stochastic gradient descent (SGD). In this section, we will explore the use of measure preserving flows as variational proposal for training LVMs.
Given the model distribution , we construct a MPF with the stationary measure of as , where can be any distribution with tractable density function and simulation of samples. Let to be the th measure preserving transformation (MPT) and to be the parameter of . We transform to by sequentially applying MPFs , that is
The composition of the transformations above forms a mapping from to . We use the shorthand notation for the composition of transformations . By the preservation of density function of MPTs (the condition (ii) in Section 3.2), we have the following equality
where the transformation from to is given by as (12).
By the preservation of Lebesgue measure by MPTs (the condition (iii) in Section 3.2), we have the following equality of density function
where is the parameter of simply because as (12). It is important to clarify the following understanding of this equality (14). First, equation (14) implies that the density values of the initial proposal and transformed are identical for any and its image under the transformation . But, this does not implies the distribution and are identical. In particular, and in general can be arbitrarily different because can be arbitrary transformation with the determinant Jacobian equal to 1. Second, the equality (14) is only applicable to the joint density of . The marginal probability density of any subset of can be arbitrarily different from the initial marginal density and . Finally, the marginal probability of any subset of is intractable to compute in general, but we are free to choose be some simple distribution, e.g., Gaussian. Because the density of flow distribution is preserved, the entropy is also preserved, that is
Following (11), it is straightforward to derive the ELBO of initial distribution as
We call the ELBO above the simple ELBO. Multiplying the density of auxiliary variables in both the top and bottom of the log ratio in (16), we have the equivalent form of as
By the reparameterisation trick Kingma and Welling (2014), we can rewrite the ELBO using MPTs (12). In particular, we reparameterise with in (17) by the sequential transformation. By the preservation of density in the stationary distribution (13) and the flow distribution (14), we have the ELBO after reparameterization as
where we denote the reparameterization of by because and are determined by the parameter of MPTs . We call the ELBO in (18) by reparameterised ELBO.
To optimise the ELBO (18), we need to compute the gradient of with respect to model parameter and flow parameter . It is straightforward to derive the gradient with respect to as
Notice that the gradient term is discarded, that is simply because of the preservation of entropy of the flow (15) and . The gradient of ELBO with respect to the flow parameter is given by
Notice that depends on , so this gradient estimator can have high variance. This can be solved by the reparameterization trick. In particular, we reparameterize by in the density by the inverse of (12). By the preservation of the flow density (14), the gradient (20) is equivalent to
where denotes the projection of in space and denotes the projection of in space. Because and do not depend on after reparameterisation, we can move the operator of derivative inside of integral, that gives
Obviously, the reparameterised ELBO (18) can only be as tight as the ELBO (16) with initial distribution . However, optimising the reparameterised ELBO may lead to faster convergence than optimising . Recall that by the ergodicity of MPFs, we know that the total variational distance between and decreases in the flow length . In other words, is guaranteed to be closer to the target than the initial .
3.6 Ergodic Lower Bound and Ergodic Inference
We proposed a reparameterised ELBO of initial distribution involved MPF parameters in last section. However, it is not hard to see that this reparameterised ELBO (18) is of limited use, because it can only be as tight as the ELBO (16) with initial approximate distribution . Moreover, the reparameterised ELBO can be less favoured than the simple ELBO, as it is more expensive to compute than the simple ELBO with . More importantly, the reparameterised ELBO cannot be arbitrarily tight because the simple ELBO cannot be. This seems to erase the benefits of using an ergodic MPF, which we know will converge to the invariant distribution, given a sufficiently long flow.
To overcome the fundamental limitation of the reparameterised ELBO, we propose a special ELBO variant which is tailored to the MPFs setting, and allows for a variational lower bound which becomes arbitrarily tight as the length of the flow grows. We call such an ELBO ergodic lower bound (ERLBO). Formally, we define ergodic lower bound as following.
Given an ergodic measure preserving flow with the invariant measure , the ergodic lower bound is an asymptotically tight lower bound of the integral
which means can be arbitrarily tight if the flow is sufficiently long, that is
The most important difference between ERLBO and the ELBO with or its reparametrisation with the transforms is that ERLBO can be arbitrarily tight.
To derive ERLBO, we first rewrite the reparameterised ELBO (18) as