Forward Amortized Inference for Likelihood-Free Variational Marginalization

05/29/2018 ∙ by Luca Ambrogioni, et al. ∙ Radboud Universiteit University of Amsterdam UMC Utrecht 0

In this paper, we introduce a new form of amortized variational inference by using the forward KL divergence in a joint-contrastive variational loss. The resulting forward amortized variational inference is a likelihood-free method as its gradient can be sampled without bias and without requiring any evaluation of either the model joint distribution or its derivatives. We prove that our new variational loss is optimized by the exact posterior marginals in the fully factorized mean-field approximation, a property that is not shared with the more conventional reverse KL inference. Furthermore, we show that forward amortized inference can be easily marginalized over large families of latent variables in order to obtain a marginalized variational posterior. We consider two examples of variational marginalization. In our first example we train a Bayesian forecaster for predicting a simplified chaotic model of atmospheric convection. In the second example we train an amortized variational approximation of a Bayesian optimal classifier by marginalizing over the model space. The result is a powerful meta-classification network that can solve arbitrary classification problems without further training.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 6

page 7

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Bayesian inference is a principled statistical framework for estimating the probability of latent factors given a set of observations. Unfortunately, most complex Bayesian models are intractable since computing the posterior distribution involves the solution of integrals over high-dimensional spaces. Variational inference (VI) is a family of approximation methods that reframes Bayesian inference as an optimization problem that can be solved using stochastic optimization techniques Jordan et al. (1999)

. Recent developments in stochastic VI have scaled Bayesian inference to massive datasets and paved the way for the integration of deep learning and Bayesian statistics

Hoffman et al. (2013); Ranganath et al. (2014); Rezende et al. (2014); Kucukelbir et al. (2017); Tran et al. (2016). In many applications, VI is made more efficient by optimizing a whole family of variational distributions at once Kingma and Welling (2013); Huszár (2017); Ritchie et al. (2016). This approach is usually referred to as amortized inference. Amortized inference can be seen as a special case of the larger framework of joint-contrastive variational inference Huszár (2017); Dumoulin et al. (2017).

In this paper we introduce forward amortized variational inference (FAVI) as a flexible and tractable new form of likelihood-free VI. FAVI is obtained by using the forward KL divergence on a joint-contrastive variational loss. One of the most important features of FAVI is that it can be used for marginalizing over a large space of nuisance variables without explicitly modeling their joint density. Marginalization of nuisance variables is important in many real-world problems such as weather forecasting Gneiting and Raftery (2005). FAVI is particularly suitable for model-based problems such as weather forecasting because it is trained on samples from the generative model. However, the applicability of FAVI goes far beyond model-based problems. As an example of a model-free problem, we use FAVI to obtain a meta-classifier as a variational approximation of the Bayes optimal classifier of an infinite ensemble of classification models. The resulting variational meta-classifier is algorithmically similar to the meta-learning methods introduced in Prokhorov et al. (2002) and recently expanded in Santoro et al. (2016); Vinyals et al. (2016).

2 Related work

In spite of its theoretical advantages, the intractability of the expectation in the forward KL divergence limits its applicability in the conventional VI framework Bishop (2006). The forward KL is adopted by expectation propagation (EP) methods Minka (2001); Barthelmé and Chopin (2011), but EP is not a form of VI since it does not minimize a global divergence between the two distributions. Likelihood-free Bayesian inference is often based on approximate Bayesian computation (ABC) Tavaré et al. (1997); Pritchard et al. (1999). Recently the ABC approach has been applied to both VI Tran et al. (2017a) and EP Barthelmé and Chopin (2011). However, despite its success in many applications, ABC has some important limitations. In particular, the efficiency of rejection based ABC methods tends to sharply degrade as the dimensionality grows and the use of low-dimensional summary statistics can severely affect the performance. Similarly, methods based on some form of density estimation such as S. et al. (2018)

are strongly affected by the curse of dimensionality since high-dimensional density estimation is notoriously challenging. An alternative approach, which is algorithmically similar to our method, is to treat Bayesian inference as a nonlinear regression problem. This approach was first introduced in

Blum and François (2010) and recently extended in Papamakarios and Murray (2016). In this latter work, a loss similar to our FAVI loss was iteratively optimized using an importance sampling scheme so that the simulator ( in our notation) gradually narrows down to the distribution of the observed data. Note that this work does not draw any connection with VI and their importance sampling scheme is explicitly designed to avoid inference amortization. In general, the FAVI approach offers a theoretical foundation to several previous works based on training deep networks on simulated data Le et al. (2017); Jaderberg et al. (2014, 2016); Gupta et al. (2016); Stark et al. (2015); Güçlütürk et al. (2016); Ambrogioni et al. (2017a). Most of the recent literature about likelihood-free approximate Bayesian inference is based on adversarial training. This line of research was initiated by adversarially learned inference (ALI) which can be shown to minimize the Jensen-Shannon divergence at the limit of an optimal discriminator Dumoulin et al. (2017). Several other adversarial VI methods have recently been introduced Mescheder et al. (2017); Tran et al. (2017b); Huszár (2017). These variational methods share some of the flexibility of FAVI, but they usually require the samples from to be differentiable. A drawback of adversarial methods is that the adversarial minimax problem is equivalent to the minimization of a divergence only in the nonparametric limit Goodfellow et al. (2014); Mescheder et al. (2017). From a practical perspective, variational methods tend to generate very realistic samples, but often suffer from instability during training and mode collapse A. et al. (2018); Arora et al. (2017).

3 Background on joint-contrastive variational inference

Joint-contrastive variational inference was first introduced in the context of ALI Dumoulin et al. (2017) and more explicitly outlined in Huszár (2017)

. The loss functional of joint-contrastive variational inference is a divergence between the model joint distribution and a joint variational distribution:

(1)

Without further constraints the minimization of this loss functional is not particularly useful as the model joint is usually tractable and it does not need to be approximated. The key idea for approximating the intractable posterior by minimizing 1 is to factorize the variational joint as the product of a variational posterior and the sampling distribution of the data:

(2)

Usually

is a re-sampling distribution of a training set as in the case of variational autoencoders 

Kingma and Welling (2013). Given this factorization, the minimization of 1 with respect to both and simultaneously approximates the model posterior with and the real-word distribution with . Importantly, we can usually sample from both and and this implies that we can stochastically optimize 1 for a large class of divergence measures.

3.1 Amortized inference

If we adopt the KL divergence in Eq. 2, the joint-contrastive variational inference loss decomposes into an evidence loss and an amortized inference loss term:

(3)

The result suggests that conventional amortized inference is a special case of joint-contrastive variational inference. We can see this by studying the gradients of Eq. 3. In the following, denotes the functional gradient with respect to the density . We use this functional notation in order to avoid referring to an explicit parametrization. Since the term corresponding to the entropy of in Eq. 3 does not depend on , this divergence has the same functional gradient as the (negative) amortized ELBO:

(4)

Therefore, amortized variational inference is a special case of joint-contrastive variational inference.

4 Forward amortized variational inference

The reverse KL divergence has a central position in the classical (posterior-contrastive) variational framework because it leads to a tractable variational lower bound. Conversely, the forward KL divergence is intractable in a posterior-contrastive sense as it requires computation of an expectation with respect to the true posterior. We will now show that the forward KL is tractable when used in a joint-contrastive loss. In this case we obtain the following divergence:

(5)

Note that in this expression there is only one term that depends on . Therefore, by ignoring the constant terms, we can define the FAVI loss as follows:

(6)

The resulting functional gradient is given by

(7)

Note that the computation of this gradient requires neither reparametrization tricks nor black-box methods, since the expectation is taken with respect to while the gradient is taken with respect to .

The FAVI variational loss can also be derived as an amortized form of posterior-contrastive variational inference. The forward KL posterior-contrastive variational loss is given by:

(8)

It is challenging to obtain unbiased samples from the gradient of this expression as the expectation is taken with respect to the intractable . We can recover the FAVI loss (up to a term constant in ) if we amortize the problem with respect to the model probability:

(9)

FAVI has several advantages over reverse amortized inference. First of all, it is very simple to obtain Monte Carlo samples of the gradients of the stochastic loss in Eq. 6, since the expectation is taken with respect to

. This avoids the use of methods such as the reparametrization trick, which limits the family of possible probability distributions and lengthens the computational graph since the loss needs to be back-propagated through the samples

Rezende et al. (2014). Another important advantage is that the model joint probability does not need to be evaluated explicitly. This implies that FAVI can be used when the likelihood is intractable, in situations where ABC methods are usually adopted Csilléry et al. (2010); Blum and François (2010); Marin et al. (2012); Tran et al. (2017a). A downside of FAVI is that Eq. 4 cannot be directly minimized with respect to since cannot be expressed in closed form. There are several possible ways for dealing with this problem. In Appendix A we outline an adversarial method that only requires the differentiability of the samples from . Note that the optimization of is not strictly speaking part of Bayesian inference. Therefore we will focus on the case where the generative model is known a priori in the rest of the paper. One of the most interesting features of FAVI is that its loss is optimized by the exact marginals even when the variational approximation is fully factorized, as we shall demonstrate in the next section.

4.1 Marginalization properties of FAVI

In the fully factorized mean field approximation the FAVI loss is minimized by the exact marginals of the true posterior, as stated in the following theorem:

Theorem 1 (Exact marginals).

Consider a joint distribution and a fully factorized variational posterior . The functional is minimized when for all in the support of . Furthermore, the minimizer is unique when all values of are in the support of .

Proof.

In the fully factorized case, the FAVI loss can be rewritten as follows:

(10)

The conditional entropy term on the right side of the final expression does not depend on and can therefore be ignored. Since the KL divergence is always non-negative and vanishes only when the two distributions are identically equal, the expectations in the remaining term are equal to zero if and only if for all and for all in the support of . ∎

The situation is radically different in reverse KL VI where the factorized approximation can lead to a severe underestimation of the uncertainty of the marginals Murphy (2012); Bishop (2006).

Theorem 1 straightforwardly generalizes to variational models that are factorized into two blocks. From this, an important result follows:

Theorem 2 (Consistent marginalization).

Consider a joint distribution and the (nonparametric) conditionally independent variational model . The following equality holds:

(11)
Proof.
(12)

where the first equality is a direct consequence of Theorem 1. ∎

Therefore, there is no need to explicitly model the conditional dependencies between and when the aim is to estimate . In practice, it is straightforward to obtain Monte Carlo estimates of the marginalized variational loss , since a sample from is obtained by ‘ignoring’ from a sample from the full joint distribution. Conversely, marginalization in the reverse KL approach requires to either perform the challenging integration or to explicitly model the conditional dependencies between and and marginalize out from the resulting variational distribution. Note that Theorem 2 does not hold in the case of reverse KL inference and, consequently, assuming conditional independence could severely bias the resulting marginalized posterior.

5 Applications

We begin this section with a direct comparison between amortized reverse VI and FAVI. In this comparison we approximate the variational posterior of a variational autoencoder and we compare the accuracy of the two variational posteriors. Subsequently, we discuss two applications where the reverse KL approach is not easily applicable. These applications involve large-scale likelihood-free marginalization of latent variables. In the first application we use FAVI to obtain a variational forecaster of chaotic time series. This is an example of a model-based problem since the dynamic equations are assumed to reliably describe the dynamics of real-world systems such as the earth’s atmosphere. In the second application we apply FAVI to the model-free problem of meta-classification. In this case, predictive performance is obtained by marginalizing over the posterior distribution of a weakly structured ensemble of random classification models that span a very large space of possible classification problems.

5.1 Comparison between amortized inference methods

In order to compare FAVI with other amortized inference approaches, we approximated the posterior distribution of the generative model

where

is a vector containing the intensity of the pixels of a black-and-white image and

is a vector of latent variables. The functions and

are the two outputs of a pre-trained deep neural network. The network has a three-layered fully connected architecture with ReLu nonlinearities in the hidden layers and was trained on the MNIST dataset using a variational autoencoder

Kingma and Welling (2013). We decided to use a common pre-trained generator in order to have a clean comparison between the performances of the approximate Bayesian inference methods. The variational posterior was parametrized by a three-layered fully connected architecture with ReLu nonlinearities. Both models trained with Adam Kingma and Ba (2014) for epochs with batch size . The reverse KL inference network was trained by re-sampling MNIST images while FAVI was trained on simulated samples. This difference follows from the fact that the former is amortized with respect to while the latter is amortized with respect to . We also included ALI Dumoulin et al. (2017) in this comparison as an example of an adversarial likelihood-free method.

5.1.1 Results

Figure 1: Comparison of the performance of the variational inference methods on the MNIST dataset. A. Reconstruction error of the latent variables. B. Reconstruction error of the images. C. Example of variational distributions given the (synthetic) image shown in the upper right corner. The black line denotes the real value of the latent variable.

The latent reconstruction error was quantified as respectively

(13)

where is the dimension of the latent space and

is the number of pixels. We tested the statistical difference between the errors using two-sample t-tests. Figure 

1A shows the reconstruction error of the latent variable given a generated image. As we can see, FAVI has a remarkably lower latent reconstruction error when compared with reverse KL VI (p < 0.001). The latent error of ALI is slightly smaller than the error of reverse KL VI (p < 0.001). The superior performance of FAVI could have been expected since FAVI is trained on generated images while the reverse KL method is trained directly on real data. However, FAVI also has a slightly lower and less variable observable reconstruction error (p < 0.05). This can be seen in Fig. 1B. Conversely, ALI has a very high reconstruction error. Figure 1B shows the two variational distributions of the first component of the latent vector for an example image.

5.2 Bayesian variational forecaster

Forecasting the future of a dynamical system based on past noisy measurements and a system of dynamic equations is crucial for many scientific applications West (1996). The most well-known of these applications is arguably weather forecasting Gneiting and Raftery (2005). FAVI is particularly appropriate for dynamic forecasting problems for three main reasons. First, in these problems the generator is known with good accuracy and this benefits approaches like FAVI where the training samples are sampled from the generator. Second, it is often difficult to obtain analytic expressions for the probability densities of the dynamic and the noise models. Third, forecasting highly benefits from the marginalization of nuisance variables and unknown parameters Gneiting and Raftery (2005). We validated our FAVI forecaster on a simulated dataset. We generated chaotic time series using a very simplified model of atmospheric convection: the Lorentz dynamical system Lorenz (1963). The system is given by the following differential equations:

where the dot denotes a derivative with respect to time. In our case, the task is to estimate the probability of the value of at the future time point given a set of noise-corrupted observations where

(14)

Note that the variables and are not observed and need to be marginalized out. The graphical model of the complete and marginalized joint is given in Fig. 2A. The FAVI loss is given by:

(15)

We parametrized

using a dilated convolutional neural network

Yu and Koltun (2015) with a kernel mixture network output Ambrogioni et al. (2017b), the details of the architecture are given in Appendix B.

5.2.1 Results

We compared the performance of our variational Bayesian forecaster against the extended Kalman filter (EKF), one of the most popular off-the-shelf dynamic forecasting methods 

Evensen (2009)

. Specifically, we used the EKF for obtaining the joint posterior probability density of each variable at the last time point

given the observations. By construction of the EKF approximation, this probability is a multivariate normal distribution. We made a forecast by numerically integrating

time series from to , where the initial conditions were sampled from the EKF posterior density at . Figure 2B shows the forecast of a randomly sampled example trial together with the ground truth. The predictive distribution of the Bayesian variational forecaster is tightly tracking the ground truth. Interestingly, the variational posterior bifurcates into the two possible ‘wings’ of the Lorentz attractor. For each validation trial the performances of the EKF and the variational Bayesian forecaster were quantified as the probability of being inside a symmetric interval centered around the ground truth with radius . In the EKF case this probability was obtained by counting the number of samples inside the interval and dividing by the total number of samples, while in the case of the Bayesian variational forecaster the probability was obtained by integrating the variational posterior probability density inside the interval. Figure 2C shows the scatter plot of these probabilities for 500 validation trials. On average the performance of the variational Bayesian forecaster is times higher than the performance of the EKF.

Figure 2: A. The total and marginalized generative models for forward autoencoders. B. EKF (top panel) and variational (bottom panel) forecast of a time series sampled from the Lorentz system. The blue dots denote the noise-corrupted observations. C. Forecast of a Lorentz dynamical system. The blue dots are individual simulated trials and the red dot denote the mean (center of mass).

5.3 Bayesian variational meta-classifier

We now introduce a real-wold application that showcases the flexibility and scalability of FAVI when the real generative model is unknown. Our aim is to construct a Bayesian meta-classifier as an amortized variational approximation of the Bayes optimal classifier of an ensemble. Conventional variational methods are not suited for this task as they would need to introduce a variational distribution over the potentially infinite and unstructured model space and explicitly marginalize over the resulting posterior. Furthermore, the model likelihood is very often non-differentiable and even impossible to evaluate in closed form. The lack of differentiability would rule out adversarial variational methods. We begin by giving a brief introduction to ensemble methods and Bayes optimal classifiers.

5.3.1 Bayesian ensembles

Figure 3: A. The total and marginalized generative models for the Bayesian meta-classifier. B. The classification accuracy of the Bayes variational meta-classifier versus three common alternatives on three data sets.

In a classification task the aim is to estimate the probability of the target class assignments given a set of predictors . In an ensemble learning setting we assume that the classification task is sampled from a predefined family of classification models . In our notation we consider two models with the same parametric form, but different parameter values as different models. An ensemble classifier has the following form:

(16)

where denotes a new vector of predictors, denotes the corresponding label and denotes the training data. Different ensemble models use different techniques for setting the weights . The optimal way of setting the weights can be obtained formally using Bayes’ rule. The posterior probability of each model given the training data is given by:

(17)

where is a training set of predictors and target class assignments . Assuming that we know the prior over the family of classification models, the optimal solution to the classification problem is given by marginalizing the posterior distribution over all models Mitchell (1997). This is known as the Bayesian optimal classifier:

(18)

In practice, computing the Bayesian optimal classifier is intractable as it involves a sum (or an integral) over the whole (usually infinite) ensemble of models.

5.3.2 Variational meta-classifier

A Bayesian variational meta-classifier can be obtained by approximating the Bayesian optimal classifier using FAVI. The model is amortized with respect to whole training sets consisting of feature/label pairs, which are assumed to be generated by one (and only one) of the models in the ensemble. The resulting amortized posterior model is a meta-classifier, as it takes as input a training set and it outputs the predictive distribution over the label of an arbitrary new data-point. The forward amortized loss is given by:

where

The graphical models of both the total and the marginalized joint are shown in Fig. 3A. We trained a RNN using the FAVI loss in order to approximate the predictive distribution . Our variational posterior is given by:

(19)

where is a recurrent architecture that has received as input a training set of training pairs . The details of our RNN architecture are given in Appendix C.

5.3.3 Results

We trained a Bayes variational meta-classifier model on the ensemble of generative models described in Appendix D. The network was trained on binary classification with predictors. The Chainer deep learning framework Tokui et al. (2015) was used for model training. After training the model was tested separately on three public real-world datasets: the Boston house-prices dataset, the diabetes dataset and the breast cancer Wisconsin dataset Harrison and Rubinfeld (1978); Efron et al. (2004); Street et al. (1993). In all datasets only the first predictors were used. The Boston dataset is a regression problem but we converted it into a classification problem by replacing the value of the output variable with label if it was less than the total median or with label otherwise. The datasets contained , and data points, respectively. However, in order to reliably evaluate the model performance on small data, in each dataset we sampled data subsets of length (from to ) at random. The model was tested by making a prediction for the -th sample. The sampling and testing was repeated

times for different re-samplings of the full dataset and the model performance scores were averaged. The model performance was compared to three other models: random forest, AdaBoost and decision trees

Freund et al. (1999); Breiman (2001); Quinlan (1986); Dietterich (2000). Our experiments show that the Bayesian variational meta-classifier is competitive when compared to other ensemble approaches, achieving the best performances in diabetes and breast cancer datasets (Fig. 3B). In the house pricing dataset the Bayesian variational meta-classifier has competitive performance when the training set is smaller than data-points, but the performance degrades for higher number of training samples. This decline in performance is likely to be caused by the limitations of our recurrent architecture. Note that the variational meta-classifier is applied to each dataset without any further training, while the other methods are trained separately on each dataset.

6 Conclusions

In this paper we introduced a likelihood-free variational method based on the minimization of the forward KL divergence between the model joint distribution and a factorized variational joint distribution. We focused our exposition on variational marginalization problems where a Bayesian predictive distribution is obtained by marginalizing over a large space of latent variables.

References