Optimal transport theory is becoming an essential tool for modern machine learning research. The development of efficient optimal transport algorithms led to a wide range of machine learning applicationsPeyré and Cuturi (2017); Cuturi (2013); Solomon et al. (2014); Kloeckner (2015); Ho et al. (2017); Arjovsky et al. (2017); Patrini et al. (2018); Lee and Raginsky (2018); A. et al. (2018); Staib et al. (2017); Ambrogioni et al. (2018); Mi et al. (2018)2018); Patrini et al. (2018) and for general purpose variational Bayesian inference Ambrogioni et al. (2018). However, the field that has been most deeply influenced by optimal transport theory is arguably generative modeling Arjovsky et al. (2017); A. et al. (2018); Gulrajani et al. (2017); Adler and Lunz (2018); Genevay et al. (2017); Gemici et al. (2018). The introduction of the Wasserstein generative adversarial network (wGAN) Arjovsky et al. (2017) was a milestone as it provided a more stable form of adversarial training. Generative adversarial networks (GANs) greatly improved the state of the art in image generation. Nevertheless, GAN training often leads to mode collapse, where part of the data space is ignored by the generative model. A possible way to mitigate this phenomenon is to use a collection of GANs, each modeling a part of the data space Wang et al. (2016)
. However, existing ensembling techniques are often heuristic in nature and do not provide a principled way for ensuring that the different generators model non-overlapping parts of the data space. In this paper, we derive an ensemble of generative models algorithm from first principles using the theory of semi-discrete optimal transport. The basic idea is to jointly learn a series of elements (prototypes) of a discrete distribution and the optimal transportation functions from these prototypes to the data space. The notion of optimality is determined by a transportation cost that measures the dissimilarity between the prototypes and the data points. The resulting k-GANs algorithm has strong theoretical connections with the k-means and k-medoids algorithms. In the k-GANs algorithm we learnprototypes that implicitly define a partitioning of the data space into non-overlapping cells. The distribution of the data in each of these cells is generated by a stochastic transportation function that maps the prototype into the data space. These transportation functions are parameterized by deep networks and are trained as regular GANs within their cell. The prototypes and the transportation functions are learned jointly so that the boundary of the cells shifts during training as a consequence of the changes in the prototypes.
2 Related work
From a theoretical point of view, our algorithm has a strong connection with the traditional k-means and k-medoid clustering methods Forgy (1965); Graf and Luschgy (2007). This connection between k-means and semi-discrete optimal transport stems from the fact that semi-discrete transport problems implicitly define a Laguerre tessellation of the space, which reduces to the more familiar Voronoi tessellation in special cases Peyré and Cuturi (2017); Graf and Luschgy (2007). Recently, this connection has been exploited in a variational clustering algorithm which uses optimal transport theory in order to derive a more powerful clustering method Mi et al. (2018).
3 Background on optimal transport
In machine learning and statistics, optimal transport divergences are often used for comparing probability measures. Consider two probability measuresand . The optimal transport divergence between them is defined by the following optimization problem:
where is the cost of transporting probability mass from to and is the set of probability measures that have and as marginal measures over and respectively. The transportation nature of this problem can be seen by slightly reformulating the objective function by writing the joint measure as the product of a conditional measure and of the marginal measure :
where the integral of the conditional measures has to be equal to :
Therefore, the conditional measures can be interpreted as stochastic transportation functions that map each element of into a probability measure over .
4 Semi-discrete optimal transport as ensemble of generative models
One of the main advantages of using optimal transport divergences is that they can be defined between probability measures with very different support. An important case is semi-discrete transport, where is is absolutely continuous with respect to Lebesgue measure while is a Dirac measure:
Semi-discrete optimal transport has important machine learning applications. For our purposes, the minimization of a semi-discrete optimal transport divergences can be used for approximating the probability distribution of the data with a discrete distribution over a finite number of “prototypes”. The semi-discrete optimal transport divergence can be rewritten as follows:
with the following constraint:
Note that each conditional measure can be interpreted as a generative model that maps a prototype into a probability distribution over the data points . The optimization in Eq. 5 assures that these distributions are centered around their prototype (in a sense given by the cost function) while the marginalization constraint in Eq. 6 guarantees that the sum of all generative models is the real distribution of the data. In other words, the solution of this semi-discrete optimal transport problem provides an ensemble of local generative models.
5 Geometry of semi-discrete optimal transport problems
In this section we will summarize some known results of semi-discrete optimal transport that will provide the theoretical foundation for our work. Semi-discrete optimal transport has a deep connection with geometry as it can be proven that the transportation maps are piecewise constant and define a tessellation of the space . In order to show this, it is useful to introduce the unconstrained dual formulation of the optimal transport problem Peyré and Cuturi (2017):
denotes the c-transform of the vector of dual weights:
We can now reformulate the objective function in terms of a tessellation of :
where the sets are defined as
yielding a so-called Laguerre tessellation of . We can finally state the following important theorem, expressing the transportation maps in terms of the optimized Laguerre tessellation:
Since the problem is unconstrained, this implies that, for the optimal dual weights , we have
By plugging this result into Eq. 9, we obtain:
By comparing this expression with the primal formulation in Eq. 5, it follows immediately that the optimal transportation maps are given by the measures .
6 Simultaneous optimization of weights and transportation maps
In this section, we prove the main theoretical result behind the method. Consider the problem of finding the set of prototypes, weights and transportation maps that minimize the semi-discrete optimal transport divergence in Eq. 4. This results in the following joint optimization problem:
The solution of this optimization problem is an optimal ensemble of generative models.
From the previous section, we know that the solution of the semi-discrete optimal transport problem is given by a tessellation of the target space into Laguerre cells parameterized by a vector of dual weights . These cells are the support sets of the transportation maps. In the general case, these cells can be computed using computational geometry algorithms Peyré and Cuturi (2017). Fortunately, the problem can be solved in closed form if we simultaneously optimize the weights and the transportation maps, as stated in the following theorem:
Theorem 2 (Formal solution of the joint optimization problem).
The optimization problem given by
under the marginalization constraint given in Eq. 3 is solved by the following Voronoi tessellation:
where transportation maps are obtained by restricting the data distribution to each set of the tessellation
and the optimal weight are given by
It is easier to work with the dual optimization problem (see Eq. 9). Enforcing the fact that the weight vector should sum to one using Lagrange multipliers, we obtain the following unconstrained minimax optimization:
We can find the critical point by setting the gradient to zero:
The second equation implies that all the dual weights are equal to a constant. This implies that the Laguerre sets are Voranoi sets (Eq. 17). The first equation gives Eq. 2 and the transportation maps in Eq. 18 are a consequence of Theorem 1. Note that the resulting weights clearly respect the marginalization constraint.
Using Theorem 2, we can write a simple expression for the optimal prototypes:
In other words, the prototypes are the medoid of the Voronoi sets with respect to the cost function .
7 Learning the transportation maps
The solution given in Eq. 15
is purely formal and does not directly provide a useful algorithm since the distribution that generated the data is not available. However, a practical algorithm can be obtained by minimizing a statistical divergence between this formal solution and a parametric model, such as a deep generative network. We will denote the probability measure induced by passing a latent measurethough a deep network as (where the bottom star denotes the pushforward of a measure through a measurable function). We approximate each optimal transportation map as follows:
In practice, in the most naive implementation, this means that we reject samples that land outside . We train each network by minimizing a statistical divergence:
This is possible when the divergence solely requires the ability to sample from since we can sample from the dataset and reject all samples that do not land in
. For example, using the dual formulation of the Wasserstein distance, we can train the generators (Wasserstein GANs) by optimizing the following minimax problem using stochastic gradient descent (SGD)Arjovsky et al. (2017):
where is the space of Lipschitz continuous functions. In practice, we approximate the samples with samples from a finite dataset and we parameterized both and
as deep neural networks. Furthermore,is regularized by the following soft Lipschitz regularization term:
Using the trained generators, we can obtain a parameterized proxy for the loss in Eq. 25 that we will use to train the prototypes:
8 The k-GANs algorithm
We are finally ready to formulate the algorithm. The basic idea is to minimize Eq. 15
using a two step approach similar to the expectation maximization scheme used in the k-means algorithmForgy (1965). In the first step, we keep the generators fixed and we train the prototypes by minimizing Eq. 29 with SGD steps. In the second step, we keep the prototypes (and consequently the tessellation) fixed and we train the generators by minimizing Eq. 27 with SGD steps (cf. Algorithm 1).
We named this algorithm k-GANs since it can be interpreted as a parametric version of the well-known k-medoids method. Specifically, a stochastic version of k-medoids is obtained if we replace the trained deep generators with nonparametric generators that sample with uniform probability the elements of the training set that belongs to the set. This further reduces to a stochastic version of the k-means algorithm is we use the squared euclidean distance as cost function.
8.1 The k-generators algorithm
The theory outlined in this papar is not specific to GANs and can be directly applied to any generative model based on the minimization of a statistical divergence. For example, the approach can be used with variational autoencoders Kingma and Welling (2014)
and sequential generative models such as those used in natural language processingSundermeyer et al. (2012).
9 Choosing the cost function
The clustering behavior of the k-GANs algorithm depends on the choice of the cost function . The shape of the cost determines the boundaries between the sets of the Voronoi tessellation. These boundaries are in general curved, except when is a monotonic function of a quadratic form.
9.1 . Euclidean and feature costs
The simplest choice for the cost function is given by the -th power of a norm:
The boundaries induced by this family of norms are very well-studied and leads to different clustering behaviors Hathaway et al. (2000). The most common choice is of course the familiar norm which to the familiar (Euclidean) k-means clusters. However, norms can lead to sub-optimal clustering in highly structured data such as in natural images as the boundaries tend to be driven by low-level features and ignore semantic information. A possible way of basing the partitioning on more semantic feature is to consider a norm in an appropriate feature space:
where the feature map maps the raw data to a feature space. Usually, is chosen to been a deep network trained on a supervised task.
9.2 Semi-supervised costs
Another interesting way to insert semantic information into the cost function is to use labels on a subset of data. For example, we can have a cost of the following form:
where the function assign a discrete value based on whether the data-point is labeled and on its label. The function then scales the loss based on this label information. For example, can be equal to when two data-points have the same label, equal to
when datapoints have different labels and equal to one when one or both of the datapoints are unlabeled. Note that, in order to use this semi-supervise cost in the k-GANs algorithm we need to be able to assign a label on the prototypes. A possibility is to train a classifier on the labeled part of the dataset. Alternatively, we can simply select the label of the closest labeled data-point.
In this section we validate the k-GANs method on a clustered toy dataset and on two image datasets: MNIST and fashion MNIST Xiao et al. (2017); LeCun et al. (1998). We compare the performance of the k-GANs example against the performance of individual GANs. In all our experiments we used the Euclidean distance as cost function.
10.1 Toy dataset
We constructed several toy datasets (TD) comprised of randomly sampled coordinates on 2D plane, which were masked to create circular clusters. The first TD had two circular clusters of data points that fell within a radius of 0.25, centered on (-0.5, 0) and (0.5, 0); similarly the second TD had three circular clusters of data points, centered on (-0.5, -0.5), (0.5, -0.5) and (0, 0.5); and finally the third TD had four circular clusters centered on (-0.5, -0.5), (0.5, -0.5), (0.5, 0.5) and (-0.5, 0.5). We trained a Wasserstein k-GANs for
ranging from 1 (baseline) to 4. We repeated the experiment 10 times. We used a 10-dimensional latent space for each of our generators. The generator network architecture was constructed as follows: a fully connected input layer of 32 units with batch normalization and leaky ReLU activation function, followed by fully connected layers of 16, 8 and finally 2 units. First two layers had batch normalization and leaky ReLUs, while the last one had sigmoid activation. The discriminator network had two fully connected layers with 16 and 8 units. For optimization, we used Adam withfor the generator and the discriminator networks, and for the prototype. A burn-in parameter of 600 was introduced to the 60 000 iterations of training of each prototype and the corresponding generator/discriminator networks by minimizing Wasserstein distance. Figure 1
shows the resulting tessellation and the sampled produced by each generator. The case corresponding to k equal to one is the baseline Wasserstein GAN. We evaluated the performance of the methods using two metrics: coverage and precision. Coverage is quantified by binning the plane in a 2D grid and counting the fraction of bins inside the circular masks that contain a generated data-point. The precision metric is given by the fraction of generated datapoints that are inside the masks. We compared a GAN baseline with two k-GANs runs where k was set equal to the number of clusters in the dataset. In one of these two runs the prototypes were initialized using k-means on the generated data while in the other they were sampled randomly from uniform distributions ranging from -1 to 1. Figure1 shows the metrics for all methods. Both k-GANs methods reach significantly higher performance than the baseline.
10.2 MNIST and fashion MNIST
We applied the k-GANs algorithm on MNIST and Fashion MNIST. We trained a Wasserstein k-GANs for
ranging from 1 (baseline) to 4. Given our limited computational resources, we could train a single run on both models. Prototypes were initialized using k-means algorithm, and samples were assigned to the nearest prototype in batches of 100 during training. We used a 100-dimensional latent space for each of our generators. We used the following generator network architecture: a fully connected input layer of 12544 units with batch normalization and leaky ReLU activation function (output of which was reshaped to 256 x 7 x 7), followed by three deconvolution layers of 128, 64 and 1 units. First two had batch normalization and leaky ReLUs, while the last one had sigmoid activation. All of them had 5 x 5 kernels with a stride of 2 x 2 except for the first which had a stride of 1 x 1. The discriminator network had two convolutional layers with 64 and 128 units of size 5 x 5 and stride 2 x 2 and a linear layer of with a single output unit. Figure2 shows the results corresponding to . The figure shows the partition of the image space embedded into a 2D plane using t-SNE embedding. The images inside the sets are their prototypes. Figure 2B shows the real images and generated samples corresponding to each prototype. Figure 3 shows prototypes and samples on MNIST and fashion MNIST for the baseline and k-GANs with . The k-GANs produced diversified samples except for one of the generator in fashion MNIST that collapsed on its mode. On the other hand, both the baseline models suffered from severe mode collapse. While it is difficult to draw strong conclusions from a single run, the results suggest that the k-GANs approach improves the stability of the base model.
In this paper we introduce method for training an ensemble of generators based on semi-discrete optimal transport theory. Each generator of the ensemble is associated to a prototype. These prototypes induce a partition of the data space and each set of the partition is modeled by an individual generator. This protects the algorithm from mode collapse as each generator only needs to cover a localized portion of the data space.
- Peyré and Cuturi  G. Peyré and M. Cuturi. Computational Optimal Transport. Now Publishers, Inc., 2017.
- Cuturi  M. Cuturi. Sinkhorn distances: Lightspeed computation of optimal transport. Advances in Neural Information Processing Systems, 2013.
Solomon et al. 
J. Solomon, R. Rustamov, L. Guibas, and A. Butscher.
Wasserstein propagation for semi-supervised learning.International Conference on Machine Learning, 2014.
- Kloeckner  B. R. Kloeckner. A geometric study of Wasserstein spaces: Ultrametrics. Mathematika, 61(1):162–178, 2015.
- Ho et al.  N. Ho, X. Long Nguyen, Mikhail Y., Hung H. B., V. Huynh, and D. Phung. Multilevel clustering via Wasserstein means. International Conference on Machine Learning, 2017.
- Arjovsky et al.  M. Arjovsky, S. Chintala, and L. Bottou. Wasserstein generative adversarial networks. International Conference on Machine Learning, 2017.
- Patrini et al.  G. Patrini, M. Carioni, P. Forre, S. Bhargav, M. Welling, R. Berg, T. Genewein, and F. Nielsen. Sinkhorn autoencoders. arXiv preprint arXiv:1810.01118, 2018.
- Lee and Raginsky  J. Lee and M. Raginsky. Minimax statistical learning with wasserstein distances. Advances in Neural Information Processing Systems, 2018.
A. et al. 
Genevay A., Gabriel P., and M. Cuturi.
Learning generative models with sinkhorn divergences.
Proceedings of the International Conference on Artificial Intelligence and Statistics, 2018.
- Staib et al.  M. Staib, S. Claici, J. M. Solomon, and S. Jegelka. Parallel streaming wasserstein barycenters. Advances in Neural Information Processing Systems, pages 2647–2658, 2017.
- Ambrogioni et al.  L. Ambrogioni, U. Güçlü, Y. Güçlütürk, Maris E. Hinne, M., and M. A. J. van Gerven. Wasserstein variational inference. Advances in Neural Information Processing Systems, pages 2473–2482, 2018.
Mi et al. 
L. Mi, W. Zhang, X. Gu, and Y. Wang.
Variational Wasserstein clustering.
European Conference on Computer Vision, 2018.
- Tolstikhin et al.  I. Tolstikhin, O. Bousquet, S. Gelly, and Be. Schoelkopf. Wasserstein auto-encoders. International Conference on Machine Learning, 2018.
- Gulrajani et al.  I. Gulrajani, F. Ahmed, M. Arjovsky, V. Dumoulin, and A. C. Courville. Improved training of wasserstein gans. Advances in Neural Information Processing Systems, 2017.
- Adler and Lunz  J. Adler and S. Lunz. Banach Wasserstein gan. Advances in Neural Information Processing Systems, 2018.
- Genevay et al.  A. Genevay, G. Peyré, and M. Cuturi. Gan and vae from an optimal transport point of view. arXiv preprint arXiv:1706.01807, 2017.
- Gemici et al.  M. Gemici, Z. Akata, and M. Welling. Primal-dual Wasserstein gan. arXiv preprint arXiv:1805.09575, 2018.
- Wang et al.  Y. Wang, L. Zhang, and J. van de Weijer. Ensembles of generative adversarial networks. arXiv preprint arXiv:1612.00991, 2016.
- Forgy  E. W. Forgy. Cluster analysis of multivariate data: efficiency versus interpretability of classifications. Biometrics, 21:768–769, 1965.
- Graf and Luschgy  S. Graf and H. Luschgy. Foundations of Quantization for Probability Distributions. Springer, 2007.
- Kingma and Welling  D. P. Kingma and M. Welling. Auto-encoding variational Bayes. International Conference on Learning Representation, 2014.
- Sundermeyer et al.  M. Sundermeyer, R. Schlüter, and H. Ney. Lstm neural networks for language modeling. Annual conference of the international speech communication association, 2012.
- Hathaway et al.  R. J. Hathaway, J. C Bezdek, and Y. Hu. Generalized fuzzy c-means clustering strategies using l~ p norm distances. IEEE transactions on Fuzzy Systems, 8(5):576–582, 2000.
- Xiao et al.  H. Xiao, K. Rasul, and R. Vollgraf. Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747, 2017.
- LeCun et al.  Y. LeCun, L. Bottou, Y. Bengio, P. Haffner, et al. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.