Wasserstein Variational Inference

05/29/2018 ∙ by Luca Ambrogioni, et al. ∙ Radboud Universiteit University of Amsterdam 0

This paper introduces Wasserstein variational inference, a new form of approximate Bayesian inference based on optimal transport theory. Wasserstein variational inference uses a new family of divergences that includes both f-divergences and the Wasserstein distance as special cases. The gradients of the Wasserstein variational loss are obtained by backpropagating through the Sinkhorn iterations. This technique results in a very stable likelihood-free training method that can be used with implicit distributions and probabilistic programs. Using the Wasserstein variational inference framework, we introduce several new forms of autoencoders and test their robustness and performance against existing variational autoencoding techniques.



There are no comments yet.


page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Variational Bayesian inference is gaining a central role in machine learning. Modern stochastic variational techniques can be easily implemented using differentiable programming frameworks

Hoffman et al. (2013); Ranganath et al. (2014); Rezende et al. (2014)

. As a consequence, complex Bayesian inference is becoming almost as user friendly as deep learning

Kucukelbir et al. (2017); Tran et al. (2016). This is in sharp contrast with old-school variational methods that required model-specific mathematical derivations and imposed strong constraints on the possible family of models and variational distributions. Given the rapidness of this transition it is not surprising that modern variational inference research is still influenced by some legacy effects from the days when analytical tractability was the main concern. One of the most salient examples of this is the central role of the (reverse) KL divergence Blei et al. (2017); Zhang et al. (2017). While several other divergence measures have been suggested Li and Turner (2016); Ranganath et al. (2016); Dieng et al. (2017); Bamler et al. (2017), the reverse KL divergence still dominates both research and applications. Recently, optimal transport divergences such as the Wasserstein distance Villani (2003); Peyré G. (2018) have gained substantial popularity in the generative modeling literature as they can be shown to be well-behaved in several situations where the KL divergence is either infinite or undefined Arjovsky et al. (2017); Gulrajani et al. (2017); Genevay et al. (2018); Montavon et al. (2016). For example, the distribution of natural images is thought to span a sub-manifold of the original pixel space Arjovsky et al. (2017). In these situations Wasserstein distances are considered to be particularly appropriate because they can be used for fitting degenerate distributions that cannot be expressed in terms of densities Arjovsky et al. (2017).

In this paper we introduce the use of optimal transport methods in variational Bayesian inference. To this end, we define the new c-Wasserstein family of divergences, which includes both Wasserstein metrics and all f-divergences (which have both forward and reverse KL) as special cases. Using this family of divergences we introduce the new framework of Wasserstein variational inference, which exploits the celebrated Sinkhorn iterations Sinkhorn and Knopp (1967); Cuturi (2013) and automatic differentiation. Wasserstein variational inference provides a stable gradient-based black-box method for solving Bayesian inference problems even when the likelihood is intractable and the variational distribution is implicit Huszár (2017); Tran et al. (2017). Importantly, as opposed to most other implicit variational inference methods Huszár (2017); Tran et al. (2017); Dumoulin et al. (2017); Mescheder et al. (2017), our approach does not rely on potentially unstable adversarial training Arjovsky and Bottou (2017).

1.1 Background on joint-contrastive variational inference

We start by briefly reviewing the framework of joint-contrastive variational inference Dumoulin et al. (2017); Huszár (2017). For notational convenience we will express distributions in terms of their densities. Note however that those densities could be degenerate. For example, the density of a discrete distribution can be expressed in terms of delta functions. The posterior distribution of the latent variable given the observed data is

While the joint probability

is usually tractable, the evaluation of often involves an intractable integral or summation. The central idea of variational Bayesian inference is to minimize a divergence functional between the intractable posterior

and a tractable parametrized family of variational distributions. This form of variational inference is sometimes referred to as posterior-contrastive. Conversely, in joint-contrastive inference the divergence to minimize is defined between two structured joint distributions. For example, using the reverse KL we have the following cost functional:


where is the product between the variational posterior and the sampling distribution of the data. Usually is approximated as the re-sampling distribution of a finite training set, as in the case of variational autoencoders (VAE) Kingma and Welling (2013). The advantage of this joint-contrastive formulation is that it does not require the evaluation of the intractable distribution . Joint-contrastive variational inference can be seen as a generalization of amortized inference Huszár (2017).

1.2 Background on optimal transport

Intuitively speaking, optimal transport divergences quantify the distance between two probability distributions as the cost of transporting probability mass from one to the other. Let

be the set of all bivariate probability measures on the product space whose marginals are and respectively. An optimal transport divergence is defined by the following optimization:


where is the cost of transporting probability mass from to . When the cost is a metric function the resulting divergence is a proper distance and it is usually referred to as the Wasserstein distance. We will denote the Wasserstein distance as .

The computation of the optimization problem in Eq. 2 suffers from a super-cubic complexity. Recent work showed that this complexity can be greatly reduced by adopting entropic regularization Cuturi (2013). We begin by defining a new set of joint distributions:


These distributions are characterized by having the mutual information between the two variables bounded by the regularization parameter . Using this family of distributions we can define the entropy regularized optimal transport divergence:


This regularization turns the optimal transport into a strictly convex problem. When and are discrete distributions the regularized optimal transport cost can be efficiently obtained using the Sinkhorn iterations Sinkhorn and Knopp (1967); Cuturi (2013). The -regularized optimal transport divergence is then given by:

where the function gives the output of the -th Sinkhorn iteration. The pseudocode of the Sinkhorn iterations is given in Algorithm 1. Note that all the operations in this algorithm are differentiable.

1:procedure Sinkhorn()
4:     while  do
5:          Juxtaposition denotes matrix product
6:          "/" denotes entrywise division
7:                "*" denotes entrywise product
8:     return  
Algorithm 1 Sinkhorn Iterations. : Cost matrix, : Number of iterations, : Regularization strength

2 Wasserstein variational inference

We can now introduce the new framework of Wasserstein variational inference for general-purpose approximate Bayesian inference. We begin by introducing a new family of divergences that includes both optimal transport divergences and f-divergences as special cases. Subsequently, we develop a black-box and likelihood-free variational algorithm based on automatic differentiation through the Sinkhorn iterations.

2.1 c-Wasserstein divergences

Traditional divergence measures such as the KL divergence depend explicitly on the distributions and . Conversely, optimal transport divergences depend on and only through the constraints of an optimization problem. We will now introduce the family of c-Wasserstein divergences that generalize both forms of dependencies. A c-Wasserstein divergence has the following form:


where the real-valued functional depends both on the two scalars and and on the two distributions and . Note that we are writing this dependency in terms of the densities only for notational convenience and that this dependency should be interpreted in terms of distributions. The cost functional is assumed to respect the following requirements:

where denotes the support of the distribution . From these requirements we can derive the following theorem:

Theorem 1.

The functional is a (pseudo-)divergence, meaning that for all and and for all .


From property and property it follows that, when is equal to , is a non-negative function of and that vanishes when . In this case, the optimization in Eq. 5 is optimized by the diagonal transport . In fact:


This is a global minimum since property implies that is always non-negative. ∎

All optimal transport divergences are part of the c-Wasserstein family, where reduces to a non-negative valued function independent from and .

Proving property for an arbitrary cost functional can be a challenging task. The following theorem provides a criterion that is often easier to verify:

Theorem 2.

Let be a convex function such that . The cost functional respects property when for all .


The result follows directly from Jensen’s inequality. ∎

2.2 Stochastic Wasserstein variational inference

We can now introduce the general framework of Wasserstein variational inference. The loss functional is a c-Wasserstein divergence between

and :


From Theorem 1 it follows that this variational loss is always minimized when is equal to . Note that we are allowing members of the c-Wasserstein divergence family to be pseudo-divergences, meaning that could be even if . It is sometimes convenient to work with pseudo-divergences when some features of the data are not deemed to be relevant.

We can now derive a black-box Monte Carlo estimate of the gradient of Eq. 

7 that can be used together with gradient-based stochastic optimization methods Fouskakis and Draper (2002). A Monte Carlo estimator of Eq 7 can be obtained by computing the discrete c-Wasserstein divergence between two empirical distributions:


where and are sampled from and respectively. In the case of the Wasserstein distance, we can show that this estimator is asymptotically unbiased:

Theorem 3.

Let be the Wasserstein distance between two empirical distributions and . For tending to infinity, there is a positive number such that


Using the triangle inequality and the linearity of the expectation we obtain:

In Weed and Bach (2017) it was proven that for any distribution :

when is larger than the upper Wasserstein dimension (see definition in Weed and Bach (2017)). The result follows with . ∎

Unfortunately the Monte Carlo estimator is biased for finite values of . In order to eliminate the bias when is equal to , we use the following modified loss:


It is easy to see that the expectation of this new loss is zero when is equal to . Furthermore:

As we discussed in Section 1.2, the entropy-regularized version of the optimal transport cost in Eq. 8 can be approximated by truncating the Sinkhorn iterations. Importantly, the Sinkhorn iterations are differentiable and consequently we can compute the gradient of the loss using automatic differentiation Genevay et al. (2018). The approximated gradient of the -regularized loss can be written as


where the function is the output of steps of the Sinkhorn algorithm with regularization and cost function . Note that the cost is a functional of and and consequently the gradient contains the term . Also note that this approximation converges to the real gradient of Eq. 7 for and (however the Sinkhorn algorithm becomes unstable when ).

3 Examples of c-Wasserstein divergences

We will now introduce two classes of c-Wasserstein divergences that are suitable for deep Bayesian variational inference. Moreover, we will show that the KL divergence and all f-divergences are part of the c-Wasserstein family.

3.1 A metric divergence for latent spaces

In order to apply optimal transport divergences to a Bayesian variational problem we need to assign a metric, or more generally a transport cost, to the latent space of the Bayesian model. The geometry of the latent space should depend on the geometry of the observable space since differences in the latent space are only meaningful as far as they correspond to differences in the observables. The simplest way to assign a geometric transport cost to the latent space is to pull back a metric function from the observable space:


where is a metric function in the observable space and is a deterministic function that maps to the expected value of . In our notation the subscript in denotes the fact that the distribution and the function depend on a common set of parameters which are optimized during variational inference. The resulting pullback cost function is a proper metric if is a diffeomorphism (i.e. a differentiable map with differentiable inverse) Burago et al. (2001).

3.2 Autoencoder divergences

Another interesting special case of c-Wasserstein divergence can be obtained by considering the distribution of the residuals of an autoencoder. Consider the case where the expected value of is given by the deterministic function . We can define the latent autoencoder cost functional as the transport cost between the latent residuals of the two models:


where is a distance function. It is easy to check that this cost functional defines a proper c-Wasserstein divergence since it is non-negative valued and it is equal to zero when is equal to and are equal to . Similarly, we can define the observable autoencoder cost functional as follows:


where again gives the expected value of the generator. In the case of a deterministic generator, this expression reduces to

Note that the transport optimization is trivial in this special case since the cost does not depend on and . In this case, the resulting divergence is just the average reconstruction error:


As expected, this is a proper (pseudo-)divergence since it is non-negative valued and is always equal to zero when and are sampled from .

3.3 f-divergences

We can now show that all f-divergences are part of the c-Wasserstein family. Consider the following cost functional:

where is a convex function such that . From Theorem 2 it follows that this cost functional defines a valid c-Wasserstein divergence. We can now show that the c-Wasserstein divergence defined by this functional is the -divergence defined by . In fact


since is the marginal of all in .

4 Wasserstein variational autoencoders

We will now use the concepts developed in the previous sections in order to define a new form of autoencoder. VAEs are generative deep amortized Bayesian models where the parameters of both the probabilistic model and the variational model are learned by minimizing a joint-contrastive divergence

Kingma and Welling (2013); Pu et al. (2016); Makhzani et al. (2015). Let and be parametrized probability distributions and and be the outputs of deep networks determining the parameters of these distributions. The probabilistic model (decoder) of a VAE has the following form:

The variational model (encoder) is given by:

We can define a large family of objective functions of VAEs by combining the cost functionals defined in the previous section. The general form is given by the following total autoencoder cost functional:



is a vector of non-negative valued weights,

is a metric on the observable space and is a convex function.

5 Connections with related methods

In the previous sections we showed that variational inference based on f-divergences is a special case of Wasserstein variational inference. We will discuss several theoretical links with some recent variational methods.

5.1 Operator variational inference

Wasserstein variational inference can be shown to be a special case of a generalized version of operator variational inference Ranganath et al. (2016). The (amortized) operator variational objective is defined as follows:


where is a set of test functions and is a positive valued function. The dual representation of the optimization problem in the c-Wasserstein loss (Eq. 5) is given by the following expression:



Converting the expectation over to an expectation over using importance sampling, we obtain the following expression:

which has the same form as the operator variational loss in Eq. 17 with and . Note that the fact that is not positive valued is irrelevant since the optimum of Eq. 18 is always non-negative. This is a generalized form of operator variational loss where the functional family can now depend on and . In the case of optimal transport divergences, where , the resulting loss is a special case of the regular operator variational loss.

5.2 Wasserstein autoencoders

The recently introduced Wasserstein autoencoder uses a regularized optimal transport divergence between and in order to train a generative model Tolstikhin et al. (2018). The regularized loss has the following form:

where does not depend on and and is an arbitrary divergence. This loss was not derived from a variational Bayesian inference problem. Instead, the Wasserstein autoencoder loss is derived as a relaxation of an optimal transport loss between and :

When is a c-Wasserstein divergence, we can show that the is a Wasserstein variational inference loss and consequently that Wasserstein autoencoders are approximate Bayesian methods. In fact:

In the original paper the regularization term is either the Jensen-Shannon divergence (optimized using adversarial training) or the maximum mean discrepancy (optimized using a reproducing kernel Hilbert space estimator). Our reformulation suggests another way of training the latent space using a metric optimal transport divergence and the Sinkhorn iterations.

6 Experimental evaluation

We will now demonstrate experimentally the effectiveness and robustness of Wasserstein variational inference. We focused our analysis on variational autoecoding problems. We decided to use simple deep architectures and to avoid any form of structural and hyper-parameter optimization for three main reasons. First and foremost, our main aim is to show that Wasserstein variational inference works off-the-shelf without user tuning. Second, it allows us to run a large number of analyses and consequently to systematically investigate the performance of several variants of the Wasserstein autoencoder on several datasets. Finally, it minimizes the risk of inducing a bias that disfavors the baselines.

In our first experiment, we assessed the performance and the robustness of our Wasserstein variation autoencoder against a conventional VAE and ALI, a more recent (adversarial) likelihood-free alternative Dumoulin et al. (2017)

. We used the same neural architecture for all models. The generative models were parametrized by three-layered fully connected networks (100-300-500-1568) with Relu nonlinearities in the hidden layers. Similarly, the variational models were parametrized by three-layered ReLu networks (784-500-300-100). The cost functional of our Wasserstein variational autoencoder (see Eq. 

16) had the weights , , and different from zero. Conversely, in this experiment was set to zero, meaning that we did not use a f-divergence component. We refer to this model as . We trained using Sinkhorn iterations and . We assessed the robustness of the methods by running re-runs of the experiment for each method. In each of these re-runs, the parameters of the networks were re-initialized and the weights of the losses ( weights for , weights for VAE and weights for ALI) were randomly sampled from the interval . We evaluated three performance metrics: 1) mean squared deviation in the latent space, 2) pixelwise mean squared reconstruction error in the image space and 3) sample quality estimated as the smallest Euclidean distance with an image in the validation set. The results are reported in Table 1. Our model outperforms both VAE and ALI in all metrics. Furthermore, the performance of

with respect to all metrics is very stable to perturbations in the weights, as demonstrated by the small standard deviations. Note that the maximum error of

is lower than the minimum errors of the other methods in four different comparisons. Figure 1 shows the reconstruction of several real images and some generated images for all methods in a randomly chosen run. In this run, the reconstructions from ALI collapsed into only s and s. In our setup, this phenomenon was observed in all runs with the reconstructions collapsing on different digits. This explains the high observable reconstruction error of ALI in Table 1.

In our second experiment we tested several other forms of Wasserstein variational autoencoders on three different datasets. We denote different versions of our autoencoder with a binary string denoting which weight was set to either zero or one. For example, we denote the purely metric version without autoencoder divergences as . We also included two hybrid models obtained by combining our loss () with the VAE and the ALI losses. These methods are special cases of Wasserstein variational autoencoders with non-zero weight and where the function is chosen to give either the reverse KL divergence or the Jansen-Shannon divergence respectively. Note that this fifth component of the loss was not obtained from the Sinkhorn iterations. As can be seen in Table 2, most versions of the Wasserstein variational autoencoder perform better than both VAE and ALI on all datasets. The has good reconstruction errors but significantly lower sample quality as it does not explicitly train the marginal distribution of . Interestingly, the purely metric version has a small reconstruction error even if the cost functional is solely defined in terms of the marginals over and . Also interestingly, the hybrid methods h-VAE and h-ALI have high performances. This result is promising as it suggests that the Sinkhorn loss can be used for stabilizing adversarial methods.

Latent Observable Sample
mean std min, max mean std min, max mean std min, max
ALI 1.040 0.070 1.003, 1.307 0.168 0.047 0.105, 0.332 0.057 0.002 0.051, 0.059
VAE 3.670 3.630 0.965, 16.806 0.042 0.003 0.036, 0.049 0.241 0.124 0.033, 0.473
1111 0.877 0.026 0.811, 0.938 0.029 0.007 0.022, 0.060 0.041 0.002 0.035, 0.045
Table 1: Perturbation analysis on MNIST.
MNIST Fashion-MNIST Quick Sketch
Latent Observable Sample Latent Observable Sample Latent Observable Sample
ALI 1.0604 0.1419 0.0631 1.0179 0.1210 0.0564 1.0337 0.3477 0.1157
VAE 1.1807 0.0406 0.1766 1.7671 0.0214 0.0567 0.9445 0.0758 0.0687
1001 0.9256 0.0710 0.0448 0.9453 0.0687 0.0277 0.9777 0.1471 0.0654
0110 1.0052 0.0227 0.0513 1.4886 0.0244 0.0385 0.8894 0.0568 0.0743
0011 1.0030 0.0273 0.0740 1.0033 0.0196 0.0447 1.0016 0.0656 0.1204
1100 1.0145 0.0268 0.0483 1.3748 0.0246 0.0291 1.0364 0.0554 0.0736
1111 0.8991 0.0293 0.0441 0.9053 0.0258 0.0297 0.8822 0.0642 0.0699
h-ALI 0.8865 0.0289 0.0462 0.9026 0.0260 0.0300 0.8961 0.0674 0.0682
h-VAE 0.9007 0.0292 0.0442 0.9072 0.0227 0.0306 0.8983 0.0638 0.0677
Table 2: Detailed analysis on MNIST, fashion MNIST and Quick Sketches.
Figure 1: Observable reconstructions (A) and samples (B).

7 Conclusions

In this paper we showed that Wasserstein variational inference offers an effective and robust method for black-box (amortized) variational Bayesian inference. Importantly, Wasserstein variational inference is a likelihood-free method and can be used together with implicit variational distributions and differentiable variational programs Tran et al. (2017); Huszár (2017). These features make Wasserstein variational inference particularly suitable for probabilistic programming, where the aim is to combine declarative general purpose programming and automatic probabilistic inference.