Variational Wasserstein gradient flow

The gradient flow of a function over the space of probability densities with respect to the Wasserstein metric often exhibits nice properties and has been utilized in several machine learning applications. The standard approach to compute the Wasserstein gradient flow is the finite difference which discretizes the underlying space over a grid, and is not scalable. In this work, we propose a scalable proximal gradient type algorithm for Wasserstein gradient flow. The key of our method is a variational formulation of the objective function, which makes it possible to realize the JKO proximal map through a primal-dual optimization. This primal-dual problem can be efficiently solved by alternatively updating the parameters in the inner and outer loops. Our framework covers all the classical Wasserstein gradient flows including the heat equation and the porous medium equation. We demonstrate the performance and scalability of our algorithm with several numerical examples.


The back-and-forth method for Wasserstein gradient flows

We present a method to efficiently compute Wasserstein gradient flows. O...

A Proximal-Gradient Algorithm for Crystal Surface Evolution

As a counterpoint to recent numerical methods for crystal surface evolut...

Fisher information regularization schemes for Wasserstein gradient flows

We propose a variational scheme for computing Wasserstein gradient flows...

Data driven gradient flows

We present a framework enabling variational data assimilation for gradie...

A Wasserstein Minimum Velocity Approach to Learning Unnormalized Models

Score matching provides an effective approach to learning flexible unnor...

Nonlocal-interaction equation on graphs: gradient flow structure and continuum limit

We consider dynamics driven by interaction energies on graphs. We introd...

Primal Dual Interpretation of the Proximal Stochastic Gradient Langevin Algorithm

We consider the task of sampling with respect to a log concave probabili...

1 Introduction

The Wasserstein gradient flow models the gradient dynamics over the space of probability densities with respect to the Wasserstein metric. It was first discovered by Jordan, Kinderlehrer, and Otto (JKO) in their seminal work (JorKinOtt98)

. They pointed out that the Fokker-Planck equation is in fact the Wasserstein gradient flow of the free energy, bringing tremendous physical insights to this type of partial differential equations (PDEs). Since then, the Wasserstein gradient flow has played an important role in optimal transport, PDEs, physics, machine learning, and many other areas

(AmbGigSav08; Ott01; AdaDirPelZim11; San17; CarDuvPeySch17; FroPog20).

Despite the abundant theoretical results on the Wasserstein gradient flow established over the past decades (AmbGigSav08; San17), the computation of it remains a challenge. Most existing methods are either based on finite difference of the underlying PDEs or based on finite dimensional optimization; both require discretization of the underlying space (Pey15; benamou2016discretization; CarDuvPeySch17; LiLuWan20; CarCraWanWei21). The computational complexity of these methods scales exponentially as the problem dimension, making them unsuitable for the cases where probability densities over high dimensional spaces are involved.

Our goal is to develop a scalable method to compute the Wasserstein gradient flow without discretizing the underlying space. One target application we are specifically interested in is optimization over the space of probability densities. Many problems such as variational inference can be viewed as special cases of such optimization. We aim to establish a method for this type of optimization that is applicable to a large class of objective functions.

Our algorithm is based on the JKO scheme (JorKinOtt98)

, which is essentially a backward Euler time discretization method for the continuous time Wasserstein gradient flow. In each step of JKO scheme, one needs to find a probability density that minimizes a weighted sum of the Wasserstein distance (square) to the probability density at the previous step and the objective function. We reparametrize this problem in each step so that the optimization variable becomes the optimal transport map from the probability density at the previous step and the one we want to optimize, recasting the problem into a stochastic optimization framework. This transport map can either be modeled by a standard feedback forward network or the gradient of an input convex neural network. The latter is justified by the fact that the optimal transport map for the optimal transport problem with quadratic cost with any marginals is the gradient of a convex function. Another crucial ingredient of our algorithm is a variational form of the objective function leveraging

-divergence, which has been employed in multiple machine learning applications, such as generative models (nowozin2016f)

, and Bayesian inference


. The variational form allows the evaluation of the objective with samples and without density estimation. At the end of the algorithm, a sequence of transport maps connecting an initial distribution and the target distribution are obtained. One can then sample from the target distribution by sampling from the initial distribution (often Gaussian) and then propagating these particles through the sequence of transport maps. When the transport map is modeled by the gradient of an input convex neural network, one can evaluate the target density at every point.

Our contributions can be summarized as follows.
i). We develop a neural network based algorithm to compute Wasserstein gradient flow without spatial discretization. Our algorithm is applicable to any objective function that has a variational representation.

ii). We specialize our algorithm to three important cases where the objective functions are the Kullback-Leibler divergence, the generalized entropy, and the interaction energy .

iii). We apply our algorithm to several representative problems including sampling and aggregation-diffusion equation and obtain respectable performance.

Related works: Most existing methods to compute Wasserstein gradient flow are finite difference based (Pey15; benamou2016discretization; CarDuvPeySch17; LiLuWan20; CarCraWanWei21). These methods require spatial discretization and are thus not scalable to high dimensional settings. SalKorLui20 analyze the convergence for a forward-backward scheme but leave the implementation of JKO an open question. There is a line of research that uses particle-based method to estimate the Wasserstein gradient flow (CarCraPat19; FroPog20). In these algorithms, the current density value is often estimated using kernel method whose complexity scales at least quadratically with the number of particles. More recently, several interesting neural network based methods (MokKorLiBur21; AlvSchMro21; YanZhaCheWan20; bunne2021jkonet; bonet2021sliced) were proposed for Wasserstein gradient flow. MokKorLiBur21 focuses on the special case with Kullback-Leibler divergence as objective function. AlvSchMro21 uses a density estimation method to evaluate the objective function by back-propagating to the initial distribution, which could become a computational burden when the number of time discretization is large. YanZhaCheWan20 is based on a forward Euler time discretization of the Wasserstein gradient flow and is more sensitive to time stepsize. bunne2021jkonet utilizes JKO scheme to approximate a population dynamics given an observed trajectory, which finds application in computational biology. bonet2021sliced replaces Wasserstein distance in JKO by sliced alternative but its connection to the original Wasserstein gradient flow remains unclear. Over the past few years, many neural network based algorithms have been proposed to compute optimal transport map or Wasserstein barycenter (MakTagOhLee20; KorEgiAsaBur19; FanTagChe20; KorLiSolBur21). These can be viewed as special cases of Wasserstein gradient flows or optimizations over the space of probability densities.

2 Background

2.1 Optimal transport and Wasserstein distance

Given two probability distributions

over the Euclidean space

with finite second moments, the optimal transport problem with quadratic cost reads


where the minimization is over all the feasible transport maps that transport mass from distribution to distribution . The feasibility is characterized by the pushforward operator (Bog07) as . When the initial distribution admits a density, the above optimal transport problem (1) has a unique solution and it is the gradient of a convex function, that is,

for some convex function . In this paper, we assume probability measures admit densities and use the notation for the measure and the density interchangeably.

The square-root of the minimum transport cost, namely, the minimum of (1), defines a metric on the space of probability distributions known as the Wasserstein-2 distance (Vil03), denoted by . The Wasserstein distance has many nice geometrical properties compared with other distances such as distance for probability distributions, making it a popular choice in applications.

2.2 Wasserstein gradient flow

Given a function over the space of probability densities, the Wasserstein gradient flow describes the dynamics of the probability density when it follows the steepest descent direction of the function with respect to the Wasserstein metric . The Wasserstein gradient flow can be explicitly represented by the PDE


where stands for the gradient of the function with respect to the standard metric (Vil03, Ch. 8)

Many important PDEs are the Wasserstein gradient flow for minimizing certain objective functions . For instance, when is the free energy the gradient flow is the Fokker-Planck equation (JorKinOtt98) When is the generalized entropy for some positive number , the gradient flow is the porous medium equation (Ott01; Vaz07)

3 Methods and algorithms

We are interested in solving the optimization problem


over the space of probability densities . In particular, our objective is to develop a particle-based Wasserstein gradient flow algorithm to numerically solve (3).

The objective function could exhibit different form depending on the application. In this paper, we present our algorithm for the linear combination of the following three important cases:

Case I The functional is equal to the KL-divergence with respect to a given target distribution


This is important for the problem of sampling from a target distribution.

Case II The objective functional is equal to the generalized entropy


This case is important for modeling the porous medium.

Case III The objective functional is equal to the interaction energy


This case is important for modeling the aggregation equation.

These functionals have been widely studied in the Wasserstein gradient flow literature (CarDuvPeySch17; San17; AmbGigSav08) due to their desirable properties. It can be shown that if is composed by the above functionals, under proper assumptions, Wasserstein gradient flow associated with converges to the unique solution to (3) (San17).

In Section 3.1, 3.2, we first assume doesn’t include interaction energy, and introduce JKO/backward scheme to solve (3). We then add into consideration and present a forward-backward scheme in Section 3.3 and close by showing our Algorithm in Section 3.4.

3.1 JKO scheme and reparametrization

To realize the Wasserstein gradient flow, a discretization over time is needed. One such discretization is the famous JKO scheme (JorKinOtt98)


This is essentially a backward Euler discretization or a proximal point method with respect to the Wasserstein metric. The solution to (7) converges to the continuous-time Wassrstein gradient flow when the step size .

Thanks to Brenier’s Theorem (brenier1991polar), (7) can be recast as an optimization in terms of the transport maps from to , i.e., by defining . The optimal is the optimal transport map from to and is thus the gradient of a convex function . Therefore, in view of the definition of Wasserstein distance (1), bunne2021jkonet; MokKorLiBur21; AlvSchMro21 propose to reparamterize as the gradient of Input convex neural network (ICNN) (AmoXuKol17) and express (7) as


where CVX stands for the space of convex functions. In our method, we extend this idea and propose to reparametrize alternatively by a residual neural network. With this reparametrization, the JKO step (7) becomes


We use the preceding two schemes (8) and (9) in our numerical method depending on the application.

3.2 and reformulation with variational formula

The main challenge in implementing the JKO scheme is to evaluate the functional in terms of samples from . We achieve this goal by using a variational formulation of . In order to do so, we use the notion of -divergence between the two distributions and :


where is a convex function. The -divergence admits the variational formulation


where is the convex conjugate of . The variational form has the special feature that it does not involve the density of and explicitly and can be approximated in terms of samples from and . The functionals and can both be expressed as -divergence.

With the help of the -divergence variational formula, when or , the JKO scheme (9) can be equivalently expressed as


where , is a user designed distribution which is easy to sample from, and and are functionals whose form depends on the functional . The form of these two functionals for the KL divergence and the generalized entropy appears in Table 1. The details appear in Section 3.2.1 and 3.2.2.

3.2.1 KL divergence

The KL divergence is the special instance of the -divergence obtained by replacing with in (10)

Proposition 1.

The variational formulation for reads

where is a user designed distribution which is easy to sample from. The optimal function is equal to the ratio between the densities of and .

The proof for Proposition 1 can be found in appendix A. It becomes practical when we have only access to un-normalized density of , which is the case for the sampling problem. Using this variational form in the JKO scheme (9) yields and


In practice, we choose adaptively, where is the Gaussian with the same mean and covariance as . We noticed that this choice improves the numerical stability of the the algorithm.

Energy function
Gaussian distribution
Uniform distribution
Table 1: Variational formula for and

3.2.2 Generalized entropy

The generalized entropy can be also represented as -divergence. In particular, let and let be the uniform distribution on a set which is the superset of the support of density and has volume . Then

Proposition 2.

The variational formulation for reads


The optimal function is equal to the ratio between the densities of and .

The proof for Proposition 2 is postponed to appendix A. Using this in the JKO scheme yields , and

where is the volume of a set large enough to contain the support of for any that is not too far away from the identity map.

Input: Objective function , initial distribution , step size , number of JKO steps , number of outer loop , number of inner loop , batch size .
Initialization: Parameterized and
for   do
    if includes
    if {// use last iteration as a warm-up}
   for   do
      Sample from . Sample from .
      for   do
         Apply Adam to to maximize .
      end for
      Apply Adam to to minimize .
   end for
end for
Algorithm 1 Primal-dual gradient flow

3.3 Forward Backward (FB) scheme

When involves the interaction energy we add an additional forward step to solve the gradient flow:


where is the identity map, and is defined by replacing by in (12). In other words, the first gradient descent step (15) is a forward discretization of the gradient flow and the second JKO step (16) is a backward discretization. can be written as expectation , thus can also be approximated by samples. SalKorLui20 firstly propose this method to solve Wasserstein gradient flow and provide the theoretical convergence analysis. We make this scheme practical by giving a scalable implementation of JKO.

Since can be equivalently written as expectation , there exists another non-forward-backward (non-FB) method , i.e., removing the first step and integrating into a single JKO step: and

In practice, we observe the FB scheme is more stable and gives more regular results however converge slower than non-FB scheme. The detailed discussion appears in the Appendix C.2, C.3.

Remark 1.

In principle, one can single out term from (13) and perform a similar forward step (SalKorLui20), but we don’t observe improved performance of doing this in sampling task.

3.4 Primal-dual algorithm and parametrization of and

The two optimization variables and in our minimax formulation (12) can be both parameterized by neural networks, denoted by and . With this neural network parametrization, we can then solve the problem by iteratively updating and . This primal-dual method to solve (3) is depicted in Algorithm 1.

In this work, we implemented two different architectures for the map . One way is to use a residual neural network to represent directly, and another way is to parametrize as the gradient of a ICNN . The latter has been widely used in optimal transport (MakTagOhLee20; FanTagChe20; KorLiSolBur21). Note that ICNN could be modified to be strictly convex and if the function is strictly convex, the gradient is invertible. However, recently several works (rout2021generative; KorEgiAsaBur19) find poor expressiveness of ICNN architecture and also propose to replace the gradient of ICNN by a neural network. In our experiments, we find that the first parameterization gives more regular results, which aligns with the result in bonet2021sliced. However, it would be very difficult to calculate the density of pushforward distribution. Therefore, with the first parametrization, our method becomes a particle-based method, i.e. we cannot query density directly. As we discuss in Section B, when density evaluation is needed, we adopt the ICNN parameterization since we need to compute .

3.5 Computational complexity

Per each update in Algorithm 1, the forward step (15) requires at most where is the total number of particles to push-forward. The backward step (16) requires where is the number of iterations per each JKO step, is the batch size, and is the size of the network. shows up in the bound because sampling requires us to pushforward through maps.

However, MokKorLiBur21 requires which has the cubic dependence on dimension because they need to query the in each iteration. We refer to MokKorLiBur21 for the complexity details of calculating the Hessian term. There exists fast approximation (huang2020convex) of and its gradient using Stochastic Lanczos Quadrature (ubaru2017fast) and Hutchinson trace estimator (hutchinson1989stochastic). AlvSchMro21 applies this, thus the cubic dependence on can be improved to quadratic dependence. Noneless, this is accompanied by an additional cost, which is the number of iterations to run conjugate gradient (CG) method. CG is guaranteed to converge exactly in steps in this setting. If one wants to obtain precisely, the cost is still , which is the same as calculating directly. If one uses an error stopping condition in CG, the complexity could be improved to , where is the conditional number of , but this would sacrifice on the accuracy. Thus our method has the advantage of independence on the dimension.

We provide training time details in Section 4.2 and Appendix C.4. Other than training and sampling time, the complexity for evaluating the density are the same as the above two methods due to the standard density evaluation process (see Section B).

(a) ground truth
(b) ours
Figure 1: The left figure shows samples from the target 16-GMM distribution and the right figure shows samples obtained by our method. Each plot contains 4000 points.

4 Numerical examples

4.1 Sampling

We first consider the sampling problem to sample from a target distribution . Note that doesn’t have to be normalized. To this end, we consider the Wasserstein gradient flow with objective function , that is, the KL divergence between distributions and . When this objective is minimized,

. In our experiments, we consider two types of target distribution: the two moons distribution and the Gaussian mixture model (GMM) with spherical Gaussian components. In this set of experiments, the step size is set to be

and the initial measure is a spherical Gaussian .

Two moons: The two moons distribution is a popular target distribution for sampling task. It is a 2D mixture model composed of 16 Gaussian components; each moon shape consists of 8 Gaussian components. The results are displayed in Figure 1, from which we see that our method is able to generate samples that match the target distribution.

(a) Dimension
(b) Dimension
Figure 2: Comparison between the target GMM and fitted measure of generated samples by our method. Samples are projected onto 2D plane by performing PCA. We refer the reader to MokKorLiBur21 for the performance of another algorithm in similar setup.

GMM with spherical Gaussians: We also test our algorithm in sampling from GMM in higher dimensional space. The target GMM has 9 Gaussian components with equal weights and the same covariances. The results with dimension and are depicted in Figure 2. In Figure 2

, we not only display the samples as grey dots in the plot, but also the kernel density estimation of generated samples as level sets. As can be seen from the results, both the samples and densities obtained with our algorithm match the target distribution well.

4.2 Ornstein-Uhlenbeck Process

We study the performance of our method in modeling the Ornstein-Uhlenbeck Process as dimension grows. The gradient flow is affiliated with the free energy (4), where with a positive definite matrix and . Given an initial Gaussian distribution , the gradient flow at each time is a Gaussian distribution

with mean vector

and covariance (vatiwutipong2019alternative). We calculate with JKO step size and compare with the Fokker-Planck (FP) JKO (MokKorLiBur21). We quantify the error as the SymKL divergence between estimated distribution and the ground truth in Figure 3, where

(a) Time
(b) Time
Figure 3: We repeat the experiments for 15 times in dimensions .

We believe that there are several reasons we have better performance. 1) The proposed distribution is Gaussian, which is consistent with for any . This is beneficial for the inner maximization to find a precise . 2) The map is linear, so it can promise our generated samples are always from a Gaussian distribution. Note that if one approximate map as gradient of ICNN, it’s not guaranteed that the mapped distribution is still Gaussian. We also compare the training time per every two JKO steps with FP JKO. The computation time for FP JKO is around 20 when and increases to 100 when . Our method’s training time remains at for all the dimensions . This is due to we fix the neural network size for both methods and our method’s computation complexity doesn’t depend on the dimension.

4.3 Porous media equation

Figure 4: Comparison among exact density, finite difference method solution given by CVXOPT, and the density given by our method. To better visualize the distributed particles from each distribution, we also plot the histograms of our method as the blue shadow.

We next consider the porous media equation with only diffusion: . This is the Wasserstein gradient flow associated with the energy function . A representative closed-form solution of the porous media equation is the Barenblatt profile (Gi52; Vaz07)

and is the starting time and is a free parameter.

In principle, our algorithm should match the analytical solution when the step size is sufficiently small. When is not that small, time discretization is inevitable. To account for the time-discretization error of JKO scheme, we consider the porous media equation in 1D space and use the solution via finite difference method as a reference. The details appear in appendix D.3.

In the experiments, we set the stepsize for the JKO scheme to be and the initial time to be . Other parameters are chosen as . We parametrize the transport map as the gradient of an ICNN and thus we can evaluate the density following Section B. In Figure 4, we observe that the gap between the density computed using our algorithm and the ground truth density is dominated by the time-discretization error of the JKO scheme. Our result matches the discrete time solution nicely.

4.4 Aggregation–Diffusion Equation

Figure 5: Histogram for simulated measures by FB scheme at different .

We finally simulate the evolution of solutions to the following aggregation-diffusion equation:

This corresponds to the energy function . We use the same parameters in CarCraWanWei21. The initial distribution is a uniform distribution supported on and the JKO step size . In Figure 5, we utilize FB scheme to simulate the gradient flow for this equation with on space. With this choice , is equal to in the gradient descent step (15). And we estimate with samples from .

Throughout the process, the aggregation term and the diffusion adversarially exert their effects and cause the probability measure split to four pulses and converge to a single pulse in the end (carrillo2019nonlinear).

5 Conclusion

In this paper we presented a novel neural network based algorithm to compute the Wasserstein gradient flow. Our algorithm follows the JKO time discretization scheme. We reparametrize the problem so that the optimization variable becomes the transport map between the consecutive steps. By utilizing a variational formula of the objective function, we further reformulate the problem in every step as a min-max problem over map and dual function respectively. This formulation doesn’t require density estimation using samples and can be optimized using stochastic optimization. It also shows advantages with dimension-independent computation complexity. Our method can also be extended to minimize other objective functions that can be written as -divergence. Our limitation is the accuracy is not satisfying in sampling tasks with high dimension complex density.


Appendix A Details about variational formula in Section 3.2

a.1 KL divergence

The KL divergence is the special instance of the -divergence obtained by replacing with in (10)

which, according to (11), admits the variational formulation


where the convex conjugate is used.

The variational formulation can be approximated in terms of samples from and . For the case where we have only access to un-normalized density of , which is the case for the sampling problem, we use the following change of variable: where is a user designed distribution which is easy to sample from. Under such a change of variable, the variational formulation reads

Note that the optimal function is equal to the ratio between the densities of and .

Remark 2.

The Donsker-Varadhan formula

is another variational representation of KL divergence and it’s a stronger than (17) because it’s a upper bound of (17) for any fixed

. However, we cannot get an unbiased estimation of the objective using samples.

a.2 Generalized entropy

The generalized entropy can be also represented as -divergence. In particular, let and let be the uniform distribution on a set which is the superset of the support of density and has volume . Then

As a result, the generalized entropy can be expressed in terms of -divergence according to

Upon using the variational representation of the -divergence with

the generalized entropy admits the following variational formulation

In practice, we find it numerically useful to let so that


With such a change of variable, the optimal function .

Appendix B Evaluation of the density

In this section, we assume the solving process doesn’t use forward-backward scheme, i.e. all the probability measures are obtained by performing JKO one by one. Otherwise, the map includes an expectation term and becomes intractable to push-backward particles to compute density.

If is invertible, these exists a standard approach to evaluate the density of (AlvSchMro21; MokKorLiBur21) through the change of variables formula. More specifically. we assume is parameterized by the gradient of an ICNN that is assumed to be strictly convex. To evaluate the density at point , we back propagate through the sequence of maps to get

The inverse map can be obtained by solving the convex optimization


Then, by the change of variables formula, we obtain


where is the Hessian of and is its determinant. By iteratively solving (19) and plugging the resulting into (20), we can recover the density at any point.

Appendix C Additional experiment results and discussions

c.1 Sampling using ICNN parameterization

(a) Dimension
(b) Dimension
Figure 6: Sampling Gaussian mixture models by parameterizing the map by .

In Figure 6, we present the sampling results with parameterized map where is a ICNN neural network. The experiment setting is the same as Section 4.1 and we can observe a MLP network map gives better fitted measures.

c.2 Aggregation equation

AlvSchMro21 proposes using the neural network based JKO, i.e. the backward method, to solve (21). They parameterize as the gradient of the ICNN. In this section, we use two cases to compare the forward method and backward when . This could help explain the FB and non-FB scheme performance difference later in Section C.3.

We study the gradient flow associated with the aggregation equation


The forward method is

The backward method or JKO is

Example 1

We follow the setting in CarCraWanWei21. The interaction kernel is , and the initial measure is a Gaussian . In this case, becomes . We use step size for both methods and show the results in Figure 7.

(a) Forward method
(b) Forward method
(c) Backward method
(d) Backward method
Figure 7: The steady state is supported on a ring of radius 0.5. Backward converges faster to the steady rate but is unstable. As goes large, it cannot keep the regular ring shape and will collapse after .
Example 2

We follow the setting in CarCraWanWei21. The interaction kernel is , and the initial measure is . The unique steady state for this case is

The reader can refer to AlvSchMro21 for the backward method performance. As for the forward method, becomes . Because the kernel enforces repulsion near the origin and is concentrated around origin, will easily blow up. So the forward method is not suitable for this kind of interaction kernel.

Through the above two examples, if is smooth, we can notice the backward method converges faster, but is not stable when solving (21). This shed light on the FB and non-FB scheme performance in Section 4.4, C.3. However, if has bad modality such as Example 2, the forward method loses the competitivity.

c.3 Aggregation-diffusion equation with non-FB scheme

In Figure 8, we show the non-FB solutions to Aggregation-diffusion equation in Section 4.4. FB scheme should be independent with the implementation of JKO, but in the following context, we assume FB and non-FB are both neural network based methods discussed in Section 3. Non-FB scheme reads

where is represented by the variational formula (14). We use the same step size and other PDE parameters as in Section 4.4.

Figure 8: Histograms for simulated measures by non-FB scheme at different .

Comparing the FB scheme results in Figure 5 and the non-FB scheme results in Figure 8, we observe non-FB converges slower than the finite difference method (CarCraWanWei21), and FB converges slower than the finite difference method. This may because splitting one JKO step to the forward-backward two steps removes the aggregation term effect in the JKO, and the diffusion term is too weak to make a difference in the loss. Note at the first several , both and are nearly the same uniform distributions, so is nearly a constant and exerts little effect in the variational formula of . Another possible reason is a single forward step for aggregation term converges slower than integrating aggregation in the backward step, as we discuss in Section C.2 and Figure 7.

However, FB generates more regular measures. We can tell the four pulses given by FB are more symmetric. We speculate this is because gradient descent step in FB utilizes the geometric structure of directly, but integrating

in neural network based JKO losses the geometric meaning of


c.4 Computational time

Our experiments are conducted on GeForce RTX 3090. The forward step (15) takes about 14 seconds to pushforward one million points.

Assume each JKO step involves 500 iterations and the number of inner iteration , then each JKO step (16) takes 100 seconds if the energy function contains the generalized energy and 25 seconds if the energy function contains the KL divergence .

Appendix D Implementation details

Our code is written in Pytorch-Lightning

(falcon2020framework). For some parts of plotting in Section 4.1 and 4.2, we adopt the code given by MokKorLiBur21.

Without further specification, we use the following parameters:
1) The number of iterations of the outer loop is 600.
2) The number of iterations of the inner loop is 3.
3) The batch size is fixed to be .
4) The learning rate is fixed to be .

5) All the activation functions are set to be PReLu.


has 4 layers and 16 neurons in each layer.

7) has 5 layers and 16 neurons in each layer.

The transport map can be parametrized in different ways. We use a residual MLP network for it in Section 4.1, 4.2, C.2 and the gradient of a strongly convex ICNN in Section 4.3, 4.4, C.1, C.3. The dual test function is always a MLP network with a dropout layer before each layer.

d.1 Sampling (Section 4.1 and c.1)

Two moons

We run JKO steps with inner iterations. has 6 layers. has 5 layers.


8D example trains for JKO steps. has 6 layers and 64 neurons in each layer. has 3 layers and 128 neurons in each layer.

13D example trains for JKO steps. has 8 layers and 64 neurons in each layer. has 9 layers and 64 neurons in each layer.

d.2 Ornstein-Uhlenbeck Process (Section 4.2)

For Fokker-Planck JKO, we use the implementation provided by the authors and the default parameters given in MokKorLiBur21. We also estimate the SymKL using Monte Carlo according to the author’s instructions. It’s straightforward to use MC estimate for Fokker-Planck JKO because it has access to the density of pushforward distributions. Our method uses parameterization so we cannot query density of the estimated exactly. Nevertheless, our estimated is guaranteed to be a Gaussian distribution since our map

is a linear transformation. Thus we use the density of

as a surrogate, where are empirical mean and empirical covariance computed using samples.

We use nearly all the same hyper-parameters as MokKorLiBur21, including learning rate, hidden layer width, and the number of iterations per JKO step. Specifically, we use a linear residual feed-forward NN to work as , i.e. without activation function. and both have 3 layers and 64 hidden neurons per layer for all dimensions. We also train them for iterations per each JKO with learning rate . The batch size is .

d.3 Porous media equation (Section 4.3)

In the experiment, and both have 10 neurons in each layer.

To account for the time-discretization error of JKO scheme, we consider the porous media equation in 1D space and use the solution via finite difference method as a reference. More specifically, in the 1D space , we discretize the density over a fixed grid with grid size and grid resolution . With this discretization, the probability densities become (scaled) probability vectors and the problem (7) can be converted into a convex optimization


where is the discretized unit transport cost, is the probability vector at the previous step, is the all-ones vector and the optimization variable

is the joint distribution between

and . This is a standard convex optimization and can be solved with generic solvers. When an optimal is obtained, can be computed as . We adopt the library CVXOPT111 to solve the convex programming problem (22). In so doing, we arrive at a reference solution for our algorithm.

d.4 Aggregation-diffusion equation (Section 4.4 and c.3)

Each JKO step contains iterations. The batch size is .