k-GANs: Ensemble of Generative Models with Semi-Discrete Optimal Transport

07/09/2019 ∙ by Luca Ambrogioni, et al. ∙ Radboud Universiteit 0

Generative adversarial networks (GANs) are the state of the art in generative modeling. Unfortunately, most GAN methods are susceptible to mode collapse, meaning that they tend to capture only a subset of the modes of the true distribution. A possible way of dealing with this problem is to use an ensemble of GANs, where (ideally) each network models a single mode. In this paper, we introduce a principled method for training an ensemble of GANs using semi-discrete optimal transport theory. In our approach, each generative network models the transportation map between a point mass (Dirac measure) and the restriction of the data distribution on a tile of a Voronoi tessellation that is defined by the location of the point masses. We iteratively train the generative networks and the point masses until convergence. The resulting k-GANs algorithm has strong theoretical connection with the k-medoids algorithm. In our experiments, we show that our ensemble method consistently outperforms baseline GANs.



There are no comments yet.


page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Optimal transport theory is becoming an essential tool for modern machine learning research. The development of efficient optimal transport algorithms led to a wide range of machine learning applications

Peyré and Cuturi (2017); Cuturi (2013); Solomon et al. (2014); Kloeckner (2015); Ho et al. (2017); Arjovsky et al. (2017); Patrini et al. (2018); Lee and Raginsky (2018); A. et al. (2018); Staib et al. (2017); Ambrogioni et al. (2018); Mi et al. (2018)

. A notable example is approximate Bayesian inference, where optimal transport techniques have been used for constructing probabilistic autoencoders

Tolstikhin et al. (2018); Patrini et al. (2018) and for general purpose variational Bayesian inference Ambrogioni et al. (2018). However, the field that has been most deeply influenced by optimal transport theory is arguably generative modeling Arjovsky et al. (2017); A. et al. (2018); Gulrajani et al. (2017); Adler and Lunz (2018); Genevay et al. (2017); Gemici et al. (2018). The introduction of the Wasserstein generative adversarial network (wGAN) Arjovsky et al. (2017) was a milestone as it provided a more stable form of adversarial training. Generative adversarial networks (GANs) greatly improved the state of the art in image generation. Nevertheless, GAN training often leads to mode collapse, where part of the data space is ignored by the generative model. A possible way to mitigate this phenomenon is to use a collection of GANs, each modeling a part of the data space Wang et al. (2016)

. However, existing ensembling techniques are often heuristic in nature and do not provide a principled way for ensuring that the different generators model non-overlapping parts of the data space. In this paper, we derive an ensemble of generative models algorithm from first principles using the theory of semi-discrete optimal transport. The basic idea is to jointly learn a series of elements (prototypes) of a discrete distribution and the optimal transportation functions from these prototypes to the data space. The notion of optimality is determined by a transportation cost that measures the dissimilarity between the prototypes and the data points. The resulting k-GANs algorithm has strong theoretical connections with the k-means and k-medoids algorithms. In the k-GANs algorithm we learn

prototypes that implicitly define a partitioning of the data space into non-overlapping cells. The distribution of the data in each of these cells is generated by a stochastic transportation function that maps the prototype into the data space. These transportation functions are parameterized by deep networks and are trained as regular GANs within their cell. The prototypes and the transportation functions are learned jointly so that the boundary of the cells shifts during training as a consequence of the changes in the prototypes.

2 Related work

From a theoretical point of view, our algorithm has a strong connection with the traditional k-means and k-medoid clustering methods Forgy (1965); Graf and Luschgy (2007). This connection between k-means and semi-discrete optimal transport stems from the fact that semi-discrete transport problems implicitly define a Laguerre tessellation of the space, which reduces to the more familiar Voronoi tessellation in special cases Peyré and Cuturi (2017); Graf and Luschgy (2007). Recently, this connection has been exploited in a variational clustering algorithm which uses optimal transport theory in order to derive a more powerful clustering method Mi et al. (2018).

3 Background on optimal transport

In machine learning and statistics, optimal transport divergences are often used for comparing probability measures. Consider two probability measures

and . The optimal transport divergence between them is defined by the following optimization problem:


where is the cost of transporting probability mass from to and is the set of probability measures that have and as marginal measures over and respectively. The transportation nature of this problem can be seen by slightly reformulating the objective function by writing the joint measure as the product of a conditional measure and of the marginal measure :


where the integral of the conditional measures has to be equal to :


Therefore, the conditional measures can be interpreted as stochastic transportation functions that map each element of into a probability measure over .

4 Semi-discrete optimal transport as ensemble of generative models

One of the main advantages of using optimal transport divergences is that they can be defined between probability measures with very different support. An important case is semi-discrete transport, where is is absolutely continuous with respect to Lebesgue measure while is a Dirac measure:

Semi-discrete optimal transport has important machine learning applications. For our purposes, the minimization of a semi-discrete optimal transport divergences can be used for approximating the probability distribution of the data with a discrete distribution over a finite number of “prototypes”. The semi-discrete optimal transport divergence can be rewritten as follows:


with the following constraint:


Note that each conditional measure can be interpreted as a generative model that maps a prototype into a probability distribution over the data points . The optimization in Eq. 5 assures that these distributions are centered around their prototype (in a sense given by the cost function) while the marginalization constraint in Eq. 6 guarantees that the sum of all generative models is the real distribution of the data. In other words, the solution of this semi-discrete optimal transport problem provides an ensemble of local generative models.

5 Geometry of semi-discrete optimal transport problems

In this section we will summarize some known results of semi-discrete optimal transport that will provide the theoretical foundation for our work. Semi-discrete optimal transport has a deep connection with geometry as it can be proven that the transportation maps are piecewise constant and define a tessellation of the space . In order to show this, it is useful to introduce the unconstrained dual formulation of the optimal transport problem Peyré and Cuturi (2017):



denotes the c-transform of the vector of dual weights



We can now reformulate the objective function in terms of a tessellation of :


where the sets are defined as

yielding a so-called Laguerre tessellation of . We can finally state the following important theorem, expressing the transportation maps in terms of the optimized Laguerre tessellation:

Theorem 1.

The optimal transportation maps in the optimization problem in Eq. 4 are given by the following formula:


where denotes the renormalized restriction of the probability measure to the set and is the solution of the optimization in Eq. 7.


The derivative of Eq. 9 with respect to is given by Peyré and Cuturi (2017):


Since the problem is unconstrained, this implies that, for the optimal dual weights , we have


By plugging this result into Eq. 9, we obtain:


By comparing this expression with the primal formulation in Eq. 5, it follows immediately that the optimal transportation maps are given by the measures .

6 Simultaneous optimization of weights and transportation maps

In this section, we prove the main theoretical result behind the method. Consider the problem of finding the set of prototypes, weights and transportation maps that minimize the semi-discrete optimal transport divergence in Eq. 4. This results in the following joint optimization problem:


The solution of this optimization problem is an optimal ensemble of generative models.

From the previous section, we know that the solution of the semi-discrete optimal transport problem is given by a tessellation of the target space into Laguerre cells parameterized by a vector of dual weights . These cells are the support sets of the transportation maps. In the general case, these cells can be computed using computational geometry algorithms Peyré and Cuturi (2017). Fortunately, the problem can be solved in closed form if we simultaneously optimize the weights and the transportation maps, as stated in the following theorem:

Theorem 2 (Formal solution of the joint optimization problem).

The optimization problem given by


under the marginalization constraint given in Eq. 3 is solved by the following Voronoi tessellation:


where transportation maps are obtained by restricting the data distribution to each set of the tessellation


and the optimal weight are given by


It is easier to work with the dual optimization problem (see Eq. 9). Enforcing the fact that the weight vector should sum to one using Lagrange multipliers, we obtain the following unconstrained minimax optimization:


We can find the critical point by setting the gradient to zero:


The second equation implies that all the dual weights are equal to a constant. This implies that the Laguerre sets are Voranoi sets (Eq. 17). The first equation gives Eq. 2 and the transportation maps in Eq. 18 are a consequence of Theorem 1. Note that the resulting weights clearly respect the marginalization constraint.

Using Theorem 2, we can write a simple expression for the optimal prototypes:


In other words, the prototypes are the medoid of the Voronoi sets with respect to the cost function .

7 Learning the transportation maps

The solution given in Eq. 15

is purely formal and does not directly provide a useful algorithm since the distribution that generated the data is not available. However, a practical algorithm can be obtained by minimizing a statistical divergence between this formal solution and a parametric model, such as a deep generative network. We will denote the probability measure induced by passing a latent measure

though a deep network as (where the bottom star denotes the pushforward of a measure through a measurable function). We approximate each optimal transportation map as follows:

In practice, in the most naive implementation, this means that we reject samples that land outside . We train each network by minimizing a statistical divergence:


This is possible when the divergence solely requires the ability to sample from since we can sample from the dataset and reject all samples that do not land in

. For example, using the dual formulation of the Wasserstein distance, we can train the generators (Wasserstein GANs) by optimizing the following minimax problem using stochastic gradient descent (SGD)

Arjovsky et al. (2017):


where is the space of Lipschitz continuous functions. In practice, we approximate the samples with samples from a finite dataset and we parameterized both and

as deep neural networks. Furthermore,

is regularized by the following soft Lipschitz regularization term:


Using the trained generators, we can obtain a parameterized proxy for the loss in Eq. 25 that we will use to train the prototypes:


8 The k-GANs algorithm

We are finally ready to formulate the algorithm. The basic idea is to minimize Eq. 15

using a two step approach similar to the expectation maximization scheme used in the k-means algorithm

Forgy (1965). In the first step, we keep the generators fixed and we train the prototypes by minimizing Eq. 29 with SGD steps. In the second step, we keep the prototypes (and consequently the tessellation) fixed and we train the generators by minimizing Eq. 27 with SGD steps (cf. Algorithm 1).

We named this algorithm k-GANs since it can be interpreted as a parametric version of the well-known k-medoids method. Specifically, a stochastic version of k-medoids is obtained if we replace the trained deep generators with nonparametric generators that sample with uniform probability the elements of the training set that belongs to the set. This further reduces to a stochastic version of the k-means algorithm is we use the squared euclidean distance as cost function.

1:procedure k-GANs()
2:     Initialize k generators
3:     Initialize k discriminators
4:     Initialize the prototypes
5:     for n from 1 to N do

loop over epochs

6:         for j from 1 to k do loop over GANs
7:              for m from 1 to M do loop over iterations
9:                   reject outside the set
10:                  Train discriminator and generator using (Eq. 27, 28)
11:                  Train prototypes using samples from the generator (Eq. 29)                             
Algorithm 1 k-GANs. k: Number of GANs, N: Number of epochs, M: Number of iterations per epoch.

8.1 The k-generators algorithm

The theory outlined in this papar is not specific to GANs and can be directly applied to any generative model based on the minimization of a statistical divergence. For example, the approach can be used with variational autoencoders Kingma and Welling (2014)

and sequential generative models such as those used in natural language processing

Sundermeyer et al. (2012).

9 Choosing the cost function

The clustering behavior of the k-GANs algorithm depends on the choice of the cost function . The shape of the cost determines the boundaries between the sets of the Voronoi tessellation. These boundaries are in general curved, except when is a monotonic function of a quadratic form.

9.1 . Euclidean and feature costs

The simplest choice for the cost function is given by the -th power of a norm:


The boundaries induced by this family of norms are very well-studied and leads to different clustering behaviors Hathaway et al. (2000). The most common choice is of course the familiar norm which to the familiar (Euclidean) k-means clusters. However, norms can lead to sub-optimal clustering in highly structured data such as in natural images as the boundaries tend to be driven by low-level features and ignore semantic information. A possible way of basing the partitioning on more semantic feature is to consider a norm in an appropriate feature space:


where the feature map maps the raw data to a feature space. Usually, is chosen to been a deep network trained on a supervised task.

9.2 Semi-supervised costs

Another interesting way to insert semantic information into the cost function is to use labels on a subset of data. For example, we can have a cost of the following form:


where the function assign a discrete value based on whether the data-point is labeled and on its label. The function then scales the loss based on this label information. For example, can be equal to when two data-points have the same label, equal to

when datapoints have different labels and equal to one when one or both of the datapoints are unlabeled. Note that, in order to use this semi-supervise cost in the k-GANs algorithm we need to be able to assign a label on the prototypes. A possibility is to train a classifier on the labeled part of the dataset. Alternatively, we can simply select the label of the closest labeled data-point.

10 Experiments

In this section we validate the k-GANs method on a clustered toy dataset and on two image datasets: MNIST and fashion MNIST Xiao et al. (2017); LeCun et al. (1998). We compare the performance of the k-GANs example against the performance of individual GANs. In all our experiments we used the Euclidean distance as cost function.

Figure 1: Results of the experiments in the toy dataset. A) Generated samples and Voronoi partition induced by the prototypes. The top row shows the result of the Wasserstein GAN baseline while the bottom shows the results for the k-GANs. B) Coverage and precision of the generated samples ensembled over the three toy datasets.

10.1 Toy dataset

We constructed several toy datasets (TD) comprised of randomly sampled coordinates on 2D plane, which were masked to create circular clusters. The first TD had two circular clusters of data points that fell within a radius of 0.25, centered on (-0.5, 0) and (0.5, 0); similarly the second TD had three circular clusters of data points, centered on (-0.5, -0.5), (0.5, -0.5) and (0, 0.5); and finally the third TD had four circular clusters centered on (-0.5, -0.5), (0.5, -0.5), (0.5, 0.5) and (-0.5, 0.5). We trained a Wasserstein k-GANs for

ranging from 1 (baseline) to 4. We repeated the experiment 10 times. We used a 10-dimensional latent space for each of our generators. The generator network architecture was constructed as follows: a fully connected input layer of 32 units with batch normalization and leaky ReLU activation function, followed by fully connected layers of 16, 8 and finally 2 units. First two layers had batch normalization and leaky ReLUs, while the last one had sigmoid activation. The discriminator network had two fully connected layers with 16 and 8 units. For optimization, we used Adam with

for the generator and the discriminator networks, and for the prototype. A burn-in parameter of 600 was introduced to the 60 000 iterations of training of each prototype and the corresponding generator/discriminator networks by minimizing Wasserstein distance. Figure 1

shows the resulting tessellation and the sampled produced by each generator. The case corresponding to k equal to one is the baseline Wasserstein GAN. We evaluated the performance of the methods using two metrics: coverage and precision. Coverage is quantified by binning the plane in a 2D grid and counting the fraction of bins inside the circular masks that contain a generated data-point. The precision metric is given by the fraction of generated datapoints that are inside the masks. We compared a GAN baseline with two k-GANs runs where k was set equal to the number of clusters in the dataset. In one of these two runs the prototypes were initialized using k-means on the generated data while in the other they were sampled randomly from uniform distributions ranging from -1 to 1. Figure 

1 shows the metrics for all methods. Both k-GANs methods reach significantly higher performance than the baseline.

Figure 2: Results on MNIST with k = 4. A) Partition induced by the prototypes (black and white figures) in the t-SNE space. B) real (top row) and generated (bottom row) images corresponding to each prototype. The color surrounding the images matches the color scheme of the partition.

10.2 MNIST and fashion MNIST

We applied the k-GANs algorithm on MNIST and Fashion MNIST. We trained a Wasserstein k-GANs for

ranging from 1 (baseline) to 4. Given our limited computational resources, we could train a single run on both models. Prototypes were initialized using k-means algorithm, and samples were assigned to the nearest prototype in batches of 100 during training. We used a 100-dimensional latent space for each of our generators. We used the following generator network architecture: a fully connected input layer of 12544 units with batch normalization and leaky ReLU activation function (output of which was reshaped to 256 x 7 x 7), followed by three deconvolution layers of 128, 64 and 1 units. First two had batch normalization and leaky ReLUs, while the last one had sigmoid activation. All of them had 5 x 5 kernels with a stride of 2 x 2 except for the first which had a stride of 1 x 1. The discriminator network had two convolutional layers with 64 and 128 units of size 5 x 5 and stride 2 x 2 and a linear layer of with a single output unit. Figure 

2 shows the results corresponding to . The figure shows the partition of the image space embedded into a 2D plane using t-SNE embedding. The images inside the sets are their prototypes. Figure 2B shows the real images and generated samples corresponding to each prototype. Figure 3 shows prototypes and samples on MNIST and fashion MNIST for the baseline and k-GANs with . The k-GANs produced diversified samples except for one of the generator in fashion MNIST that collapsed on its mode. On the other hand, both the baseline models suffered from severe mode collapse. While it is difficult to draw strong conclusions from a single run, the results suggest that the k-GANs approach improves the stability of the base model.

Figure 3: Samples of k-GANs and baselines for MNIST and fashion MNIST.

11 Discussion

In this paper we introduce method for training an ensemble of generators based on semi-discrete optimal transport theory. Each generator of the ensemble is associated to a prototype. These prototypes induce a partition of the data space and each set of the partition is modeled by an individual generator. This protects the algorithm from mode collapse as each generator only needs to cover a localized portion of the data space.


  • Peyré and Cuturi [2017] G. Peyré and M. Cuturi. Computational Optimal Transport. Now Publishers, Inc., 2017.
  • Cuturi [2013] M. Cuturi. Sinkhorn distances: Lightspeed computation of optimal transport. Advances in Neural Information Processing Systems, 2013.
  • Solomon et al. [2014] J. Solomon, R. Rustamov, L. Guibas, and A. Butscher.

    Wasserstein propagation for semi-supervised learning.

    International Conference on Machine Learning, 2014.
  • Kloeckner [2015] B. R. Kloeckner. A geometric study of Wasserstein spaces: Ultrametrics. Mathematika, 61(1):162–178, 2015.
  • Ho et al. [2017] N. Ho, X. Long Nguyen, Mikhail Y., Hung H. B., V. Huynh, and D. Phung. Multilevel clustering via Wasserstein means. International Conference on Machine Learning, 2017.
  • Arjovsky et al. [2017] M. Arjovsky, S. Chintala, and L. Bottou. Wasserstein generative adversarial networks. International Conference on Machine Learning, 2017.
  • Patrini et al. [2018] G. Patrini, M. Carioni, P. Forre, S. Bhargav, M. Welling, R. Berg, T. Genewein, and F. Nielsen. Sinkhorn autoencoders. arXiv preprint arXiv:1810.01118, 2018.
  • Lee and Raginsky [2018] J. Lee and M. Raginsky. Minimax statistical learning with wasserstein distances. Advances in Neural Information Processing Systems, 2018.
  • A. et al. [2018] Genevay A., Gabriel P., and M. Cuturi. Learning generative models with sinkhorn divergences.

    Proceedings of the International Conference on Artificial Intelligence and Statistics

    , 2018.
  • Staib et al. [2017] M. Staib, S. Claici, J. M. Solomon, and S. Jegelka. Parallel streaming wasserstein barycenters. Advances in Neural Information Processing Systems, pages 2647–2658, 2017.
  • Ambrogioni et al. [2018] L. Ambrogioni, U. Güçlü, Y. Güçlütürk, Maris E. Hinne, M., and M. A. J. van Gerven. Wasserstein variational inference. Advances in Neural Information Processing Systems, pages 2473–2482, 2018.
  • Mi et al. [2018] L. Mi, W. Zhang, X. Gu, and Y. Wang. Variational Wasserstein clustering.

    European Conference on Computer Vision

    , 2018.
  • Tolstikhin et al. [2018] I. Tolstikhin, O. Bousquet, S. Gelly, and Be. Schoelkopf. Wasserstein auto-encoders. International Conference on Machine Learning, 2018.
  • Gulrajani et al. [2017] I. Gulrajani, F. Ahmed, M. Arjovsky, V. Dumoulin, and A. C. Courville. Improved training of wasserstein gans. Advances in Neural Information Processing Systems, 2017.
  • Adler and Lunz [2018] J. Adler and S. Lunz. Banach Wasserstein gan. Advances in Neural Information Processing Systems, 2018.
  • Genevay et al. [2017] A. Genevay, G. Peyré, and M. Cuturi. Gan and vae from an optimal transport point of view. arXiv preprint arXiv:1706.01807, 2017.
  • Gemici et al. [2018] M. Gemici, Z. Akata, and M. Welling. Primal-dual Wasserstein gan. arXiv preprint arXiv:1805.09575, 2018.
  • Wang et al. [2016] Y. Wang, L. Zhang, and J. van de Weijer. Ensembles of generative adversarial networks. arXiv preprint arXiv:1612.00991, 2016.
  • Forgy [1965] E. W. Forgy. Cluster analysis of multivariate data: efficiency versus interpretability of classifications. Biometrics, 21:768–769, 1965.
  • Graf and Luschgy [2007] S. Graf and H. Luschgy. Foundations of Quantization for Probability Distributions. Springer, 2007.
  • Kingma and Welling [2014] D. P. Kingma and M. Welling. Auto-encoding variational Bayes. International Conference on Learning Representation, 2014.
  • Sundermeyer et al. [2012] M. Sundermeyer, R. Schlüter, and H. Ney. Lstm neural networks for language modeling. Annual conference of the international speech communication association, 2012.
  • Hathaway et al. [2000] R. J. Hathaway, J. C Bezdek, and Y. Hu. Generalized fuzzy c-means clustering strategies using l~ p norm distances. IEEE transactions on Fuzzy Systems, 8(5):576–582, 2000.
  • Xiao et al. [2017] H. Xiao, K. Rasul, and R. Vollgraf. Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747, 2017.
  • LeCun et al. [1998] Y. LeCun, L. Bottou, Y. Bengio, P. Haffner, et al. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.