Scalable-Wasserstein-Barycenter
https://arxiv.org/abs/2007.04462
view repo
Wasserstein Barycenter is a principled approach to represent the weighted mean of a given set of probability distributions, utilizing the geometry induced by optimal transport. In this work, we present a novel scalable algorithm to approximate the Wasserstein Barycenters aiming at high-dimensional applications in machine learning. Our proposed algorithm is based on the Kantorovich dual formulation of the 2-Wasserstein distance as well as a recent neural network architecture, input convex neural network, that is known to parametrize convex functions. The distinguishing features of our method are: i) it only requires samples from the marginal distributions; ii) unlike the existing semi-discrete approaches, it represents the Barycenter with a generative model; iii) it allows to compute the barycenter with arbitrary weights after one training session. We demonstrate the efficacy of our algorithm by comparing it with the state-of-art methods in multiple experiments.
READ FULL TEXT VIEW PDFhttps://arxiv.org/abs/2007.04462
The Wasserstein barycenter is concerned with the (weighted) average of multiple given probability distributions. It is based on the natural geometry over the space of distributions induced by optimal transport (Villani, 2003) theory and serves as a counterpart of arithmetic mean/average for data of distribution-type. Compared to other methods, Wasserstein barycenter provides a principled approach to average probability distributions, fully utilizing the underlying geometric structure of the data (Agueh and Carlier, 2011). During the past few years, it has found applications in several machine learning problems. For instance, in sensor fusion, Wasserstein barycenter is used to merge/average datasets collected from multiple sensors to generate a single collective result (Elvander et al., 2018). The advantage of Wasserstein barycenter is its ability to preserves the modality of the different datasets, a highly desirable property in practice (Jiang et al., 2012). Wasserstein Barycenter has also been observed to be effective in removing batch effects of the sensor measurements (Yang and Tabak, 2019)
. It has also found application in large scale Bayesian inference for averaging the results from Markov chain Monte Carlo (MCMC) Bayesian inference carried out over subsets of the observations
(Srivastava et al., 2015; Staib et al., 2017; Srivastava et al., 2018). It has also been useful in image processing for texture mixing (Rabin et al., 2011)and shape interpolation
(Solomon et al., 2015).The bottleneck of Wasserstein barycenter in machine learning applications remains to be computational complexity. Indeed, when the data is discrete, namely, the given probability distributions are over discrete space (e.g., grid), the Wasserstein barycenter problem can be solved using linear programming
(Anderes et al., 2016). This has been greatly accelerated by introducing an entropy term (Cuturi and Doucet, 2014; Solomon et al., 2015) as in Sinkhorn algorithm (Cuturi, 2013). However, these methods are not suitable for machine learning applications involving distributions over continuous space. First of all, it requires discretization of the continuous space to implement these methods and thus doesn’t scale to high dimensional settings. In addition, in some applications such as MCMC Bayesian inference (Andrieu et al., 2003; Srivastava et al., 2018) the explicit formulas of the distributions are not accessible, which precludes these discretization-based algorithms.In this work we propose a scalable algorithm for computing the Wasserstein barycenter of probability distributions over continuous spaces using only samples from them. Our method is based on a Kantorovich-type dual characterization of the Wasserstein barycenter, which optimizes over convex functions. The recent discovered input convex neural networks (Amos et al., 2017; Chen et al., 2018a) parametrize convex functions and make such optimization possible. The (weighted) barycenter is modeled by a generative model, which samples from the barycenter through the reparametrization trick (Goodfellow et al., 2014; Arjovsky et al., 2017). Thus, the barycenter we obtain of continuous distributions is itself a continuous distribution. Even though the explicit expression of the barycenter density is not accessible, its distribution is easy to sample from, generating as many samples as one likes.
Contribution: We present a scalable algorithm for calculating the Wasserstein barycenter of continuous distributions. Our algorithm has the following advantages: i) It is a sample-based method which only requires samples generated from the marginal distributions; ii) The generative model representation of the barycenter characterizes a continuous distribution and allows fast sampling of the barycenter once trained; iii) The method is adaptable to weighted setting where Wasserstein barycenter with different weights can be achieve through a single training process. We demonstrate the performance of our algorithm through extensive evaluations over various examples and comparisons with several state-of-art algorithms: Convolutional Wasserstein method (Solomon et al., 2015) and stochastic Wasserstein barycenter algorithm (Claici et al., 2018). We also observed that the performance of our algorithm is less sensitive to the dimensionality of the distribution compared to other approaches.
Related work: Our proposed algorithm is most related to the stochastic Wasserstein barycenter method (Claici et al., 2018), which also aims at calculating barycenters for continuous distributions using samples. One major difference between the two is that Claici et al. (2018) adopts a semi-discrete approach that models the barycenter with a finite set of points. That is, even though the marginal distributions are continuous, the barycenter is discrete. In contrast, our algorithm models the barycenter using a generative model, which indeed yields a continuous distribution. Several other sample-based algorithms (Staib et al., 2017; Kuang and Tabak, 2019; Mi et al., 2020) are also of semi-discrete-type. Most other Wasserstein barycenter algorithms are for discrete distributions and require discretization if applied to continuous distributions. An incomplete list includes (Cuturi and Doucet, 2014; Benamou et al., 2015; Solomon et al., 2015).
The subject of this work is also related to vast amount of literature on estimating the optimal transport map and Wasserstein distance (see
Peyré et al. (2019) for a complete list). Closely related to this paper are the recent works that aim to extend the optimal transport map estimation to large-scale machine learning settings (Genevay et al., 2016; Seguy et al., 2017; Liu et al., 2018; Chen et al., 2018b; Leygonie et al., 2019; Xie et al., 2019). In particular, our algorithm is inspired by the recent advances in estimation of optimal transport map and -Wasserstein distance using input convex neural networks (Taghvaei and Jalali, 2019; Makkuva et al., 2019; Korotin et al., 2019).Given two probability distributions over Euclidean space
with finite second moments, the optimal transport
(Villani, 2003)(OT) problem with quadratic unit cost seeks an optimal joint distribution of
that minimizes the total transport cost. More specifically, it is formulated as , where denotes the set of all joint distributions of and . The square-root of the minimum transport cost defines the celebrated Wasserstein-2 distance , which is known to enjoy many nice geometrical properties compared to other distance functions for probability distributions, and endows the space of probability distribution with a Riemannain like structure (Ambrosio et al., 2008).The Kantorovich dual (Villani, 2003) of the OT problem reads
(1) |
where is defined as . Let , then (1) can be rewritten as
(2) |
where CVX stands for the set of convex functions, , and the is the convex conjugate (Rockafellar, 1970) function of . The formulation (2) is known as the semi-dual formulation of OT.
Wasserstein barycenter is OT-based average of probability distributions. Given a set of probability distributions
and a weight vector
( and ), the associated Wasserstein barycenter is defined as the minimizer of(3) |
This barycenter problem (3) can be reformulated a linear programming (Agueh and Carlier, 2011). However, the linear programming-base algorithms don’t scale well for high dimensional problems. A special case that can be solved efficiently is when the marginal distributions are Gaussian. Denote the mean and covariance of as and
respectively, then their Wasserstein barycenter is a Gaussian distribution with mean being
and covariance being the unique solution to the fixed-point equation . In Álvarez-Esteban et al. (2016), a simple however efficient algorithm was proposed to get .Input Convex Neural Network (ICNN) is a type of deep neural networks architecture that characterize convex functions (Amos et al., 2017). A fully ICNN (FICNN) leads to a function that is convex with respect to all inputs. A partially ICNN (PICNN) models a function that is convex with respect to parts of its inputs.
The FICNN architecture is shown in Fig. 0(a). It is a -layer feedforward neural network propagating following, for
(4) |
where are weight matrices (with the convention that ), are the bias terms, and
denotes the entry-wise activation function at the layer
. Denote the total set of parameters by , then this network defines a map from input to . This map is convex in provided 1) are non-negative; 2) are convex; 3) are non-decreasing (Makkuva et al., 2019). We remark that FICNN has the ability to approximate any convex function over a compact domain with a desired accuracy (Chen et al., 2018a), which makes FICNN an ideal candidate for modeling convex functions.PICNN is an extension of FICNN that is capable of modeling functions that are convex with respect to parts of the variable. The architecture of PICNN is depicted in Fig. 0(b). It is a -layer architecture with inputs . Under some proper assumptions on the weights (the feed-forward weights for are non-negative) and activation functions of the network, the map is convex over . We refer the reader to (Amos et al., 2017) for more details.
We study the Wasserstein barycenter problem (3) for a given set of marginal distributions
. We consider the setting where the analytic forms of the marginals are not available. Instead, we only have access to independent samples from them. It can be either the cases a fix set of samples is provided a prior like in supervised learning, or the cases where one can keep sampling from the marginals like in the MCMC Bayesian
(Srivastava et al., 2018). Our goal is to recover the true continuous Barycenter distribution .For a fixed , the objective function of (3) is simply a (scaled) summation of the Wasserstein cost between and . Thus, we utilize the semi-dual formulation (2) of OT to evaluate the objective function of (3). Since the convex conjugate function is characterized by
(5) |
with the maximum being achieved at , the semi-dual formulation (2) can be rewritten as
(6) |
where is a functional of and defined as
(7) |
This formulation (6) has been utilized in conjugation with FICNN to solve OT problem in (Makkuva et al., 2019) and proved to be advantageous.
Plugging (6) into the Wasserstein barycenter problem (3), we obtain the following reformulation
(8) |
Note that we have used different functions to estimate . The first minimization is over all the possible probability distributions to search for the Wasserstein barycenter. This min-max-min formulation enjoys the following property, whose proof is in the appendix.
Obtaining convergence rate for first-order optimization algorithms solving (8) is challenging even in the ideal setting that the optimization is carried out in the function space and space of probability distributions. The difficulty arises because of the optimization over . While the inner optimization problems over functions and are concave and convex respectively, the outer optimization problem over is not convex. Precisely, it is not geodesically convex on the space of probability distributions with -Wasserstein metric (Ambrosio et al., 2008). However, it is possible to obtain guarantees in a restricted setting by establishing a Polyak- Lojasiewicz type inequality. In particular, assuming all are Gaussian distributions with positive-definite covariance matrices, it is shown that the gradient-descent algorithm admits a linear convergence rate (Chewi et al., 2020).
Consider the Wasserstein barycenter problem for a fixed weight vector . Following (Makkuva et al., 2019) we use FICNN architecture to represent convex functions and . We now use a generator to model the distribution , by transforming samples from a simple distribution (e.g., Gaussian, uniform) to a complicated distribution, thereby we recover a continuous Barycenter distribution. Thus, using this network parametrization and discarding constant terms, we arrive at the following optimization problem
(9) |
We propose Neural Wasserstein Barycenter-I (NWB-I) algorithm (Algorithm 1) to solve this three-loop min-max-min problem by alternatively updating , and using stochastic optimization algorithms. This pipeline is illustrated by the block diagram (Figure 2). We remark that the objective function in (9) can be estimated using samples from . Thus, we just need access to the samples generated by the marginal distributions instead of their analytic form to compute their Wasserstein barycenter. In practice, we found it more effective to replace the convexity constraints for with a convexity penalty, that is, the negative values of the weight matrices in FICNN (4). Denote the parameters of by respectively, we arrive at the batch estimation of the objective
(10) |
where , represents the sample generated by , are samples from , and is a hyper-parameter weighing the intensity of regularization.
It is tempting to combine the two minimization steps over and into one and reduce (9) into a min-max saddle point problem. The resulting algorithm only alternates between updates and updates instead of the three-way alternating in Algorithm 1. However, in our implementations, we observed that this strategy is highly unstable.
We next consider a more challenging Wasserstein barycenter problem with free weights. More specifically, given a set of marginal distribution , we aim to compute their Wasserstein barycenter for all the possible weights. Of course, we can utilize Algorithm 1 to solve fixed weight Wasserstein barycenter problem (9) for different weight separately. However, this will be extremely expensive if the number of weights is large. It turns out that Algorithm 1 can be adapted easily to obtain the barycenters for all weights in one shot. To this end, we include the weight as an input to all the neural networks and , rendering maps . For each fixed weight , the networks and with this as an input solves the Barycenter problem with this weight. Apparently, are only required to be convex with respect to samples, not the weight . Therefore, we use PICNN instead of FICNN for as network architectures. The problem then becomes
(11) | |||||
where is a probability distribution on the probability simplex, from which the weight
is sampled. In our experiment, we used uniform distribution, but it can be any distribution that is simple to sample from, e.g., Dirichlet distribution. Effectively, the objective function in (
11) amounts to the total Wasserstein cost over all the possible weights. Our formulation makes it ideal to implement stochastic gradient descent/ascent algorithm and solve the problem jointly in one training. As in the fixed weights setting, the (partial) convexity constraints of
can be replaced by a penalty term. For batch implementation, in each batch, we randomly choose one and samples from and from . The unbiased batch estimation of the objective in (11) reads(12) |
where is , and . By alternatively updating we establish Neural Wasserstein Barycenter-II (NWB-II) (Algorithm 2).
In Section 4.1, we present numerical experiments on several two-dimensional datasets which serve as proof of concept and qualitatively illustrate the performance of our approaches in comparison with the existing state of the art algorithms. In Section 4.2, we test the ability of our algorithm to recover sharp distributions. In Section 4.3, we numerically study the effect of problem dimension and demonstrate the scalability of our algorithms to high-dimensional problems. In Section 4.4, we illustrate the performance of our proposed free weighted Wasserstein barycenter algorithm 2 in its ability to learn the barycenter for arbitrary weights after one training session. The implementation details of our algorithm and further experiments are included in the supplementary materials.
For comparison, we choose the following state of the art algorithms: (i) convolutional Wasserstein barycenter (CWB) (Solomon et al., 2015); (ii) and stochastic Wasserstein barycenter (SWB) (Claici et al., 2018). The CWB serves as a baseline algorithm for the approaches that are based on adding an entropic regularization to the problem and discretization of the space. The SWB serves as the baseline for semi-discrete based approaches that represent the barycenter with finite number of Dirac delta distributions. For implementation of the two approaches, we used the available code given by Flamary and Courty (2017) and Claici (2018) respectively.
The Wasserstein Barycenter of four Gaussian marginal distributions is computed and depicted in Figure 3. The marginals are chosen to be Gaussian because in this case, the exact Wasserstein barycenter is can be computed (see Section 2.2). The resulting Wasserstein barycenter computed with NWB-I (Algorithm 1) and the baseline approaches, CWB and SWB, are depicted in panel (f)-(g)-(h) respectively. The density of the barycenter for NWB-I is computed based on
samples using a kernel density estimation with a Gaussian kernel of bandwidth
. The density for the CWB method is readily available because the algorithm outputs a density on a grid. The grid size in CWB method is and the regularization parameter is (smaller values introduced numerical instability). The SWB method terminates after outputting samples to represent the barycenter.It is qualitatively observed from Figure 3 that our proposed approach performs as well as existing approaches in approximating the exact barycenter. The KL-divergence error of density between the exact barycenter and the barycenter computed using NWB-I, SWB, and CWB are , and respectively; our algorithm exhibits smaller error compared to the CWB and SWB method. This maybe due to the regularization effect in CWB method, which introduced bias in the barycenter estimation, and finite number of samples in SWB method.
The Wasserstein barycenter of two mixture of Gaussians marginal distributions is computed and depicted in Figure 4. The marginal distributions are depicted in Figure 4-(a)-(b). The exact Wasserstein barycenter is not known in this setting. The resulting Wasserstein barycenter computed with NWB-I and the state of the art algorithms are depicted in Figure 4-(c)-(d)-(e). The densities are computed using the same parameters as in Section 4.1.1. It is observed that our proposed algorithm performs as well as the state of the art algorithms for the mixture of Gaussians problem. We also observed that the SWB method takes 2x more time to run compared to our approach.
We illustrate the performance of NWB-I in learning the Wasserstein barycenter when the marginal distributions are sharp. We use the example reported in Claici et al. (2018, Figure 4), where the marginal distributions are uniform distributions on random two-dimensional lines as shown in Figure 5. We present the result using NWB-I and SWB (the SWB is reported to outperform the CWB method for this problem (Claici et al., 2018)). It is observed that our algorithm is able to learn the sharp barycenter. In comparison to the SWB algorithm, NWB-I is able to represent the barycenter as a continuous distribution by generating samples from the learned generator, while the SWB method is approximating the barycenter using 13 samples. We also tested NWB-I on another example (Claici et al., 2018, Figure 6) to learn the barycenter of 20 uniform marginals supported on ellipses and obtained excellent results.
We study the performance of our proposed algorithm in learning the barycenter of two Gaussian marginal distributions as the dimension grows. The Gaussian marginal distributions have zero mean and a random covariance matrix. We choose Gaussian distribution because the exact barycenter is explicitly available to serve as a baseline. We implemented NWB-I and the SWB method for comparison (it is practically impossible implement the CWB method for high-dimensional problems because it is based on discretization of the space). The resulting error in estimating the exact barycenter is depicted in Figure 5(b). The error is defined to be the error in estimating the covariance matrix in Frobenius norm, i.e. where is the exact covariance and is the output of the algorithm.
It is observed from Figure 5(b) that the error of NWB-I exhibits a slow rate of growth with respect to dimension, while that of SWB grows rapidly. To gain better insight, we depicted a two-dimensional projection of the 128-dimensional barycenter into the first two coordinates, in Figure 5(a). It is observed that the samples produced by the SWB method tend to collapse in one direction.
We present the experimental result of implementing NWB-II (Algorithm 2) to compute the Wasserstein barycenter for all combinations of weights with single training. The result for the case of Gaussian marginal distributions, and 12 combination of weight values, is depicted in Figure 11. For comparison, we have included the exact barycenter. It is qualitatively observed that our approach is able to compute the Wasserstein barycenter for the selected weight combinations in comparison to exact barycenter.
To quantitatively verify the performance of NWB-II, we compare the barycenters to ground truth with several different weight in terms of KL-divergence. The resulting error is respectively for , for , and for . The error of results using NWB-II is consistently small among different weight combinations.
International Conference on Scale Space and Variational Methods in Computer Vision
, pages 435–446. Springer, 2011.Optimal transport for gaussian mixture models.
IEEE Access, 7:6269–6278, 2018b.Proof. For fixed that is absolutely continuous with respect to the Lebesgue measure, and , the solution to the inner-loop minimization problems over are clearly . The problem (8) then becomes
In view of (2), it boils down to
which is exactly the Wasserstein barycenter problem (3). Since all the marginal distributions are absolutely continuous with respect to the Lebesgue measure, their barycenter exists and is unique. This completes the proof.
The block diagram for Neural Wasserstein Barycenter-II (Algorithm 2) is shown in Figure 8.
In this section, we provide the experiment details as well as more supportive experimental results. Some common experiment setup is:
1) All and networks use CELU activation function while the network uses PReLU activation function.
2) The weight for the regularizer .
3) All optimizers are Adam.
4) We use 60000 training samples for each epoch.
All used in this section are clean feedforward networks.
The four marginal Gaussian distributions are given by
and the exact Wasserstein barycenter is
The networks and each has 3 layers and
has 2 layers. All networks have 10 neurons for each hidden layer. The initial learning rate is 0.001 and the learning rate drops
percent every 20 epochs. The inner loop iteration numbers are and . Batch size is .This is for the results displayed in Figure 4. Each of the two marginals has 4 Gaussian components. The first marginal is a uniform combination of the Gaussian distributions
The second marginal is a uniform combination of the Gaussian distributions
For NWB-I, the networks and each has 3 layers and the generative network has 2 layers. All the networks have 10 neurons for each hidden layer. The initial learning rate is 0.001 and the learning rate drops percent every 20 epochs. The inner loop iteration numbers are and . Batch size is . For CWB, the regularization intensity is set to 0.004. The SWB algorithm generates 115 samples.
We further test NWB-I with 3 marginals of Gaussian mixtures. The first marginal is a uniform combination of 4 Gaussian components
The second marginal is a uniform combination of 3 Gaussian components
The third marginal is a uniform combination of 3 Gaussian components
For NWB-I, the networks and each has 3 layers and the generative network has 4 layers. All networks have 10 neurons for each hidden layer. The initial learning rate is 0.001 and the learning rate drops percent every 20 epochs. The inner loop iteration numbers are and . The batch size is . For CWB, the entropy regularization intensity is 0.004. The SWB algorithm generates 110 samples. The experiment results are depicted in Figure 9. The performance of NWB-I is on a par with CWB, both are better than SWB.
In cases when , that is, only samples from a single distribution is given, our algorithm NWB-I behaves like a Generative Adversarial Network (GAN); it learns a generative model from samples from an underlying distribution. Learning a Gaussian mixture distribution is a simple yet challenging task in GAN, due to mode collapse. Here we compare the performance of our algorithm NWB-I and the orignal GAN, the Wasserstein GAN (WGAN) and an improvement of WGAN (WGAN-GP) in learning a Gaussian mixture model with 10 Gaussian components as shown in Figure 10. The Gaussian mixture distribution is a uniform combination of 10 Gaussians evenly distributed on a circle. It can be seen that NWB-I avoids mode collapse, and achieve comparable results as the state of art.
For NWB-I, the networks and each has 4 layers and the generative network has 4 layers. All networks have 10 neurons for each hidden layer. The initial learning rate is 0.001 and the learning rate drops percent every 20 epochs. The inner loop iteration numbers are and . The batch size is .
For GAN, WGAN and WGAN-GP, to be fair, they all use the same network structures: fully-connected linear layers and ReLU activation function. All discriminators and generators have 4 layers and 512 neurons for each hidden layer. Learning rate is 0.0001. The batch size is 256.
In the experiment with marginals concentrating on lines, for NWB-I, the network and each has 3 layers and has 2 layers. All networks have 6 neurons for each hidden layer. network is linear. Learning rate is 0.0001. The inner loop iteration numbers are and . The batch size is . The SWB algorithm generates 13 samples.
In the experiment with marginals concentrating on ellipses, for NWB-I, the network and each has 3 layers and has 3 layers. All networks have 6 neurons for each hidden layer. The initial learning rate is 0.001 and the learning rate drops percent every 15 epochs. The inner loop iteration numbers are and . The batch size is . The SWB algorithm generates 30 samples.
The experiment results are shown in Figure 5.
To test the scalability of our algorithm NWB-I, we carry out experiments to calculate the barycenters of two Gaussian distributions of dimension , , , and . All the Gaussian distributions are of zero mean and the covariances are randomly generated. (For generating random covariances: we first generate random positive integer vectors and then use these vectors to generate diagonal covariance matrices.) We compare its performance with SWB and the result is shown in Figure 6.
In this part, we evaluate the performance of NWB-II which is an algorithm to calculate the Wasserstein barycenter of a given set of marginals for all weights in one shot. Departing from NWB-I, the networks and are of PICNN structure. We carry out 3 sets of experiments when the marginal distributions are Gaussian, Gaussian mixtures and sharp distributions.
The three marginals are
The networks and each has 2 layers and the generative network has 2 layers. All networks have 12 neurons for each hidden layer. Learning rate is 0.001. The inner loop iteration numbers are and . The batch size is . The results are shown in Figure 11.
The marginal distributions are the same as the ones associated with Figure 4. We apply NWB-II to obtain the Wasserstein barycenter for all weights in one shot. The networks and each has 4 layers and the generative network
has 4 layers. All networks have 12 neurons for each hidden layer. Batch normalization is used in
. Learning rate is 0.001. The inner loop iteration numbers are and . The batch size is .The experiment results are depicted in Figure 11 in comparison with CWB. We remark that this is not a fair comparison since NWB-II obtained all the barycenters with different weights in one shot while CWB has to be run separately for each weight. Nevertheless, NWB-II generates reasonable results.
Given two marginals supported on two lines, we apply NWB-II to obtain the Wasserstein barycenter for all weights in one shot. The networks and each has 3 layers and the generative network has 2 layers. All networks have 12 neurons for each hidden layer. Batch normalization is used in . Learning rate is 0.001. The inner loop iteration numbers are and . The batch size is . The experiment results are depicted in Figure 12 in comparison with ground truth results.