1 Introduction
The Wasserstein gradient flow models the gradient dynamics over the space of probability densities with respect to the Wasserstein metric. It was first discovered by Jordan, Kinderlehrer, and Otto (JKO) in their seminal work (JorKinOtt98)
. They pointed out that the FokkerPlanck equation is in fact the Wasserstein gradient flow of the free energy, bringing tremendous physical insights to this type of partial differential equations (PDEs). Since then, the Wasserstein gradient flow has played an important role in optimal transport, PDEs, physics, machine learning, and many other areas
(AmbGigSav08; Ott01; AdaDirPelZim11; San17; CarDuvPeySch17; FroPog20).Despite the abundant theoretical results on the Wasserstein gradient flow established over the past decades (AmbGigSav08; San17), the computation of it remains a challenge. Most existing methods are either based on finite difference of the underlying PDEs or based on finite dimensional optimization; both require discretization of the underlying space (Pey15; benamou2016discretization; CarDuvPeySch17; LiLuWan20; CarCraWanWei21). The computational complexity of these methods scales exponentially as the problem dimension, making them unsuitable for the cases where probability densities over high dimensional spaces are involved.
Our goal is to develop a scalable method to compute the Wasserstein gradient flow without discretizing the underlying space. One target application we are specifically interested in is optimization over the space of probability densities. Many problems such as variational inference can be viewed as special cases of such optimization. We aim to establish a method for this type of optimization that is applicable to a large class of objective functions.
Our algorithm is based on the JKO scheme (JorKinOtt98)
, which is essentially a backward Euler time discretization method for the continuous time Wasserstein gradient flow. In each step of JKO scheme, one needs to find a probability density that minimizes a weighted sum of the Wasserstein distance (square) to the probability density at the previous step and the objective function. We reparametrize this problem in each step so that the optimization variable becomes the optimal transport map from the probability density at the previous step and the one we want to optimize, recasting the problem into a stochastic optimization framework. This transport map can either be modeled by a standard feedback forward network or the gradient of an input convex neural network. The latter is justified by the fact that the optimal transport map for the optimal transport problem with quadratic cost with any marginals is the gradient of a convex function. Another crucial ingredient of our algorithm is a variational form of the objective function leveraging
divergence, which has been employed in multiple machine learning applications, such as generative models (nowozin2016f), and Bayesian inference
(wan2020f). The variational form allows the evaluation of the objective with samples and without density estimation. At the end of the algorithm, a sequence of transport maps connecting an initial distribution and the target distribution are obtained. One can then sample from the target distribution by sampling from the initial distribution (often Gaussian) and then propagating these particles through the sequence of transport maps. When the transport map is modeled by the gradient of an input convex neural network, one can evaluate the target density at every point.
Our contributions can be summarized as follows.
i). We develop a neural network based algorithm to compute Wasserstein gradient flow without spatial discretization. Our algorithm is applicable to any objective function that has a variational representation.
ii). We specialize our algorithm to three important cases where the objective functions are the KullbackLeibler divergence, the generalized entropy, and the interaction energy .
iii). We apply our algorithm to several representative problems including sampling and aggregationdiffusion equation and obtain respectable performance.
Related works: Most existing methods to compute Wasserstein gradient flow are finite difference based (Pey15; benamou2016discretization; CarDuvPeySch17; LiLuWan20; CarCraWanWei21). These methods require spatial discretization and are thus not scalable to high dimensional settings. SalKorLui20 analyze the convergence for a forwardbackward scheme but leave the implementation of JKO an open question. There is a line of research that uses particlebased method to estimate the Wasserstein gradient flow (CarCraPat19; FroPog20). In these algorithms, the current density value is often estimated using kernel method whose complexity scales at least quadratically with the number of particles. More recently, several interesting neural network based methods (MokKorLiBur21; AlvSchMro21; YanZhaCheWan20; bunne2021jkonet; bonet2021sliced) were proposed for Wasserstein gradient flow. MokKorLiBur21 focuses on the special case with KullbackLeibler divergence as objective function. AlvSchMro21 uses a density estimation method to evaluate the objective function by backpropagating to the initial distribution, which could become a computational burden when the number of time discretization is large. YanZhaCheWan20 is based on a forward Euler time discretization of the Wasserstein gradient flow and is more sensitive to time stepsize. bunne2021jkonet utilizes JKO scheme to approximate a population dynamics given an observed trajectory, which finds application in computational biology. bonet2021sliced replaces Wasserstein distance in JKO by sliced alternative but its connection to the original Wasserstein gradient flow remains unclear. Over the past few years, many neural network based algorithms have been proposed to compute optimal transport map or Wasserstein barycenter (MakTagOhLee20; KorEgiAsaBur19; FanTagChe20; KorLiSolBur21). These can be viewed as special cases of Wasserstein gradient flows or optimizations over the space of probability densities.
2 Background
2.1 Optimal transport and Wasserstein distance
Given two probability distributions
over the Euclidean spacewith finite second moments, the optimal transport problem with quadratic cost reads
(1) 
where the minimization is over all the feasible transport maps that transport mass from distribution to distribution . The feasibility is characterized by the pushforward operator (Bog07) as . When the initial distribution admits a density, the above optimal transport problem (1) has a unique solution and it is the gradient of a convex function, that is,
for some convex function . In this paper, we assume probability measures admit densities and use the notation for the measure and the density interchangeably.
The squareroot of the minimum transport cost, namely, the minimum of (1), defines a metric on the space of probability distributions known as the Wasserstein2 distance (Vil03), denoted by . The Wasserstein distance has many nice geometrical properties compared with other distances such as distance for probability distributions, making it a popular choice in applications.
2.2 Wasserstein gradient flow
Given a function over the space of probability densities, the Wasserstein gradient flow describes the dynamics of the probability density when it follows the steepest descent direction of the function with respect to the Wasserstein metric . The Wasserstein gradient flow can be explicitly represented by the PDE
(2) 
where stands for the gradient of the function with respect to the standard metric (Vil03, Ch. 8)
Many important PDEs are the Wasserstein gradient flow for minimizing certain objective functions . For instance, when is the free energy the gradient flow is the FokkerPlanck equation (JorKinOtt98) When is the generalized entropy for some positive number , the gradient flow is the porous medium equation (Ott01; Vaz07)
3 Methods and algorithms
We are interested in solving the optimization problem
(3) 
over the space of probability densities . In particular, our objective is to develop a particlebased Wasserstein gradient flow algorithm to numerically solve (3).
The objective function could exhibit different form depending on the application. In this paper, we present our algorithm for the linear combination of the following three important cases:
Case I The functional is equal to the KLdivergence with respect to a given target distribution
(4) 
This is important for the problem of sampling from a target distribution.
Case II The objective functional is equal to the generalized entropy
(5) 
This case is important for modeling the porous medium.
Case III The objective functional is equal to the interaction energy
(6) 
This case is important for modeling the aggregation equation.
These functionals have been widely studied in the Wasserstein gradient flow literature (CarDuvPeySch17; San17; AmbGigSav08) due to their desirable properties. It can be shown that if is composed by the above functionals, under proper assumptions, Wasserstein gradient flow associated with converges to the unique solution to (3) (San17).
In Section 3.1, 3.2, we first assume doesn’t include interaction energy, and introduce JKO/backward scheme to solve (3). We then add into consideration and present a forwardbackward scheme in Section 3.3 and close by showing our Algorithm in Section 3.4.
3.1 JKO scheme and reparametrization
To realize the Wasserstein gradient flow, a discretization over time is needed. One such discretization is the famous JKO scheme (JorKinOtt98)
(7) 
This is essentially a backward Euler discretization or a proximal point method with respect to the Wasserstein metric. The solution to (7) converges to the continuoustime Wassrstein gradient flow when the step size .
Thanks to Brenier’s Theorem (brenier1991polar), (7) can be recast as an optimization in terms of the transport maps from to , i.e., by defining . The optimal is the optimal transport map from to and is thus the gradient of a convex function . Therefore, in view of the definition of Wasserstein distance (1), bunne2021jkonet; MokKorLiBur21; AlvSchMro21 propose to reparamterize as the gradient of Input convex neural network (ICNN) (AmoXuKol17) and express (7) as
(8) 
where CVX stands for the space of convex functions. In our method, we extend this idea and propose to reparametrize alternatively by a residual neural network. With this reparametrization, the JKO step (7) becomes
(9) 
We use the preceding two schemes (8) and (9) in our numerical method depending on the application.
3.2 and reformulation with variational formula
The main challenge in implementing the JKO scheme is to evaluate the functional in terms of samples from . We achieve this goal by using a variational formulation of . In order to do so, we use the notion of divergence between the two distributions and :
(10) 
where is a convex function. The divergence admits the variational formulation
(11) 
where is the convex conjugate of . The variational form has the special feature that it does not involve the density of and explicitly and can be approximated in terms of samples from and . The functionals and can both be expressed as divergence.
With the help of the divergence variational formula, when or , the JKO scheme (9) can be equivalently expressed as
(12) 
where , is a user designed distribution which is easy to sample from, and and are functionals whose form depends on the functional . The form of these two functionals for the KL divergence and the generalized entropy appears in Table 1. The details appear in Section 3.2.1 and 3.2.2.
3.2.1 KL divergence
The KL divergence is the special instance of the divergence obtained by replacing with in (10)
Proposition 1.
The variational formulation for reads
where is a user designed distribution which is easy to sample from. The optimal function is equal to the ratio between the densities of and .
The proof for Proposition 1 can be found in appendix A. It becomes practical when we have only access to unnormalized density of , which is the case for the sampling problem. Using this variational form in the JKO scheme (9) yields and
(13) 
In practice, we choose adaptively, where is the Gaussian with the same mean and covariance as . We noticed that this choice improves the numerical stability of the the algorithm.
Energy function  

Gaussian distribution  
Uniform distribution 
3.2.2 Generalized entropy
The generalized entropy can be also represented as divergence. In particular, let and let be the uniform distribution on a set which is the superset of the support of density and has volume . Then
Proposition 2.
The variational formulation for reads
(14) 
The optimal function is equal to the ratio between the densities of and .
3.3 Forward Backward (FB) scheme
When involves the interaction energy we add an additional forward step to solve the gradient flow:
(15)  
(16) 
where is the identity map, and is defined by replacing by in (12). In other words, the first gradient descent step (15) is a forward discretization of the gradient flow and the second JKO step (16) is a backward discretization. can be written as expectation , thus can also be approximated by samples. SalKorLui20 firstly propose this method to solve Wasserstein gradient flow and provide the theoretical convergence analysis. We make this scheme practical by giving a scalable implementation of JKO.
Since can be equivalently written as expectation , there exists another nonforwardbackward (nonFB) method , i.e., removing the first step and integrating into a single JKO step: and
In practice, we observe the FB scheme is more stable and gives more regular results however converge slower than nonFB scheme. The detailed discussion appears in the Appendix C.2, C.3.
Remark 1.
In principle, one can single out term from (13) and perform a similar forward step (SalKorLui20), but we don’t observe improved performance of doing this in sampling task.
3.4 Primaldual algorithm and parametrization of and
The two optimization variables and in our minimax formulation (12) can be both parameterized by neural networks, denoted by and . With this neural network parametrization, we can then solve the problem by iteratively updating and . This primaldual method to solve (3) is depicted in Algorithm 1.
In this work, we implemented two different architectures for the map . One way is to use a residual neural network to represent directly, and another way is to parametrize as the gradient of a ICNN . The latter has been widely used in optimal transport (MakTagOhLee20; FanTagChe20; KorLiSolBur21). Note that ICNN could be modified to be strictly convex and if the function is strictly convex, the gradient is invertible. However, recently several works (rout2021generative; KorEgiAsaBur19) find poor expressiveness of ICNN architecture and also propose to replace the gradient of ICNN by a neural network. In our experiments, we find that the first parameterization gives more regular results, which aligns with the result in bonet2021sliced. However, it would be very difficult to calculate the density of pushforward distribution. Therefore, with the first parametrization, our method becomes a particlebased method, i.e. we cannot query density directly. As we discuss in Section B, when density evaluation is needed, we adopt the ICNN parameterization since we need to compute .
3.5 Computational complexity
Per each update in Algorithm 1, the forward step (15) requires at most where is the total number of particles to pushforward. The backward step (16) requires where is the number of iterations per each JKO step, is the batch size, and is the size of the network. shows up in the bound because sampling requires us to pushforward through maps.
However, MokKorLiBur21 requires which has the cubic dependence on dimension because they need to query the in each iteration. We refer to MokKorLiBur21 for the complexity details of calculating the Hessian term. There exists fast approximation (huang2020convex) of and its gradient using Stochastic Lanczos Quadrature (ubaru2017fast) and Hutchinson trace estimator (hutchinson1989stochastic). AlvSchMro21 applies this, thus the cubic dependence on can be improved to quadratic dependence. Noneless, this is accompanied by an additional cost, which is the number of iterations to run conjugate gradient (CG) method. CG is guaranteed to converge exactly in steps in this setting. If one wants to obtain precisely, the cost is still , which is the same as calculating directly. If one uses an error stopping condition in CG, the complexity could be improved to , where is the conditional number of , but this would sacrifice on the accuracy. Thus our method has the advantage of independence on the dimension.
We provide training time details in Section 4.2 and Appendix C.4. Other than training and sampling time, the complexity for evaluating the density are the same as the above two methods due to the standard density evaluation process (see Section B).
4 Numerical examples
4.1 Sampling
We first consider the sampling problem to sample from a target distribution . Note that doesn’t have to be normalized. To this end, we consider the Wasserstein gradient flow with objective function , that is, the KL divergence between distributions and . When this objective is minimized,
. In our experiments, we consider two types of target distribution: the two moons distribution and the Gaussian mixture model (GMM) with spherical Gaussian components. In this set of experiments, the step size is set to be
and the initial measure is a spherical Gaussian .Two moons: The two moons distribution is a popular target distribution for sampling task. It is a 2D mixture model composed of 16 Gaussian components; each moon shape consists of 8 Gaussian components. The results are displayed in Figure 1, from which we see that our method is able to generate samples that match the target distribution.
GMM with spherical Gaussians: We also test our algorithm in sampling from GMM in higher dimensional space. The target GMM has 9 Gaussian components with equal weights and the same covariances. The results with dimension and are depicted in Figure 2. In Figure 2
, we not only display the samples as grey dots in the plot, but also the kernel density estimation of generated samples as level sets. As can be seen from the results, both the samples and densities obtained with our algorithm match the target distribution well.
4.2 OrnsteinUhlenbeck Process
We study the performance of our method in modeling the OrnsteinUhlenbeck Process as dimension grows. The gradient flow is affiliated with the free energy (4), where with a positive definite matrix and . Given an initial Gaussian distribution , the gradient flow at each time is a Gaussian distribution
with mean vector
and covariance (vatiwutipong2019alternative). We calculate with JKO step size and compare with the FokkerPlanck (FP) JKO (MokKorLiBur21). We quantify the error as the SymKL divergence between estimated distribution and the ground truth in Figure 3, whereWe believe that there are several reasons we have better performance. 1) The proposed distribution is Gaussian, which is consistent with for any . This is beneficial for the inner maximization to find a precise . 2) The map is linear, so it can promise our generated samples are always from a Gaussian distribution. Note that if one approximate map as gradient of ICNN, it’s not guaranteed that the mapped distribution is still Gaussian. We also compare the training time per every two JKO steps with FP JKO. The computation time for FP JKO is around 20 when and increases to 100 when . Our method’s training time remains at for all the dimensions . This is due to we fix the neural network size for both methods and our method’s computation complexity doesn’t depend on the dimension.
4.3 Porous media equation
We next consider the porous media equation with only diffusion: . This is the Wasserstein gradient flow associated with the energy function . A representative closedform solution of the porous media equation is the Barenblatt profile (Gi52; Vaz07)
and is the starting time and is a free parameter.
In principle, our algorithm should match the analytical solution when the step size is sufficiently small. When is not that small, time discretization is inevitable. To account for the timediscretization error of JKO scheme, we consider the porous media equation in 1D space and use the solution via finite difference method as a reference. The details appear in appendix D.3.
In the experiments, we set the stepsize for the JKO scheme to be and the initial time to be . Other parameters are chosen as . We parametrize the transport map as the gradient of an ICNN and thus we can evaluate the density following Section B. In Figure 4, we observe that the gap between the density computed using our algorithm and the ground truth density is dominated by the timediscretization error of the JKO scheme. Our result matches the discrete time solution nicely.
4.4 Aggregation–Diffusion Equation
We finally simulate the evolution of solutions to the following aggregationdiffusion equation:
This corresponds to the energy function . We use the same parameters in CarCraWanWei21. The initial distribution is a uniform distribution supported on and the JKO step size . In Figure 5, we utilize FB scheme to simulate the gradient flow for this equation with on space. With this choice , is equal to in the gradient descent step (15). And we estimate with samples from .
Throughout the process, the aggregation term and the diffusion adversarially exert their effects and cause the probability measure split to four pulses and converge to a single pulse in the end (carrillo2019nonlinear).
5 Conclusion
In this paper we presented a novel neural network based algorithm to compute the Wasserstein gradient flow. Our algorithm follows the JKO time discretization scheme. We reparametrize the problem so that the optimization variable becomes the transport map between the consecutive steps. By utilizing a variational formula of the objective function, we further reformulate the problem in every step as a minmax problem over map and dual function respectively. This formulation doesn’t require density estimation using samples and can be optimized using stochastic optimization. It also shows advantages with dimensionindependent computation complexity. Our method can also be extended to minimize other objective functions that can be written as divergence. Our limitation is the accuracy is not satisfying in sampling tasks with high dimension complex density.
References
Appendix A Details about variational formula in Section 3.2
a.1 KL divergence
The KL divergence is the special instance of the divergence obtained by replacing with in (10)
which, according to (11), admits the variational formulation
(17) 
where the convex conjugate is used.
The variational formulation can be approximated in terms of samples from and . For the case where we have only access to unnormalized density of , which is the case for the sampling problem, we use the following change of variable: where is a user designed distribution which is easy to sample from. Under such a change of variable, the variational formulation reads
Note that the optimal function is equal to the ratio between the densities of and .
Remark 2.
The DonskerVaradhan formula
is another variational representation of KL divergence and it’s a stronger than (17) because it’s a upper bound of (17) for any fixed
. However, we cannot get an unbiased estimation of the objective using samples.
a.2 Generalized entropy
The generalized entropy can be also represented as divergence. In particular, let and let be the uniform distribution on a set which is the superset of the support of density and has volume . Then
As a result, the generalized entropy can be expressed in terms of divergence according to
Upon using the variational representation of the divergence with
the generalized entropy admits the following variational formulation
In practice, we find it numerically useful to let so that
(18) 
With such a change of variable, the optimal function .
Appendix B Evaluation of the density
In this section, we assume the solving process doesn’t use forwardbackward scheme, i.e. all the probability measures are obtained by performing JKO one by one. Otherwise, the map includes an expectation term and becomes intractable to pushbackward particles to compute density.
If is invertible, these exists a standard approach to evaluate the density of (AlvSchMro21; MokKorLiBur21) through the change of variables formula. More specifically. we assume is parameterized by the gradient of an ICNN that is assumed to be strictly convex. To evaluate the density at point , we back propagate through the sequence of maps to get
The inverse map can be obtained by solving the convex optimization
(19) 
Appendix C Additional experiment results and discussions
c.1 Sampling using ICNN parameterization
c.2 Aggregation equation
AlvSchMro21 proposes using the neural network based JKO, i.e. the backward method, to solve (21). They parameterize as the gradient of the ICNN. In this section, we use two cases to compare the forward method and backward when . This could help explain the FB and nonFB scheme performance difference later in Section C.3.
We study the gradient flow associated with the aggregation equation
(21) 
The forward method is
The backward method or JKO is
Example 1
We follow the setting in CarCraWanWei21. The interaction kernel is , and the initial measure is a Gaussian . In this case, becomes . We use step size for both methods and show the results in Figure 7.




Example 2
We follow the setting in CarCraWanWei21. The interaction kernel is , and the initial measure is . The unique steady state for this case is
The reader can refer to AlvSchMro21 for the backward method performance. As for the forward method, becomes . Because the kernel enforces repulsion near the origin and is concentrated around origin, will easily blow up. So the forward method is not suitable for this kind of interaction kernel.
Through the above two examples, if is smooth, we can notice the backward method converges faster, but is not stable when solving (21). This shed light on the FB and nonFB scheme performance in Section 4.4, C.3. However, if has bad modality such as Example 2, the forward method loses the competitivity.
c.3 Aggregationdiffusion equation with nonFB scheme
In Figure 8, we show the nonFB solutions to Aggregationdiffusion equation in Section 4.4. FB scheme should be independent with the implementation of JKO, but in the following context, we assume FB and nonFB are both neural network based methods discussed in Section 3. NonFB scheme reads
where is represented by the variational formula (14). We use the same step size and other PDE parameters as in Section 4.4.
Comparing the FB scheme results in Figure 5 and the nonFB scheme results in Figure 8, we observe nonFB converges slower than the finite difference method (CarCraWanWei21), and FB converges slower than the finite difference method. This may because splitting one JKO step to the forwardbackward two steps removes the aggregation term effect in the JKO, and the diffusion term is too weak to make a difference in the loss. Note at the first several , both and are nearly the same uniform distributions, so is nearly a constant and exerts little effect in the variational formula of . Another possible reason is a single forward step for aggregation term converges slower than integrating aggregation in the backward step, as we discuss in Section C.2 and Figure 7.
However, FB generates more regular measures. We can tell the four pulses given by FB are more symmetric. We speculate this is because gradient descent step in FB utilizes the geometric structure of directly, but integrating
in neural network based JKO losses the geometric meaning of
.c.4 Computational time
Our experiments are conducted on GeForce RTX 3090. The forward step (15) takes about 14 seconds to pushforward one million points.
Assume each JKO step involves 500 iterations and the number of inner iteration , then each JKO step (16) takes 100 seconds if the energy function contains the generalized energy and 25 seconds if the energy function contains the KL divergence .
Appendix D Implementation details
Our code is written in PytorchLightning
(falcon2020framework). For some parts of plotting in Section 4.1 and 4.2, we adopt the code given by MokKorLiBur21.Without further specification, we use the following parameters:
1) The number of iterations of the outer loop is 600.
2) The number of iterations of the inner loop is 3.
3) The batch size is fixed to be .
4) The learning rate is fixed to be .
5) All the activation functions are set to be PReLu.
6)
has 4 layers and 16 neurons in each layer.
7) has 5 layers and 16 neurons in each layer.
The transport map can be parametrized in different ways. We use a residual MLP network for it in Section 4.1, 4.2, C.2 and the gradient of a strongly convex ICNN in Section 4.3, 4.4, C.1, C.3. The dual test function is always a MLP network with a dropout layer before each layer.
d.1 Sampling (Section 4.1 and c.1)
Two moons
We run JKO steps with inner iterations. has 6 layers. has 5 layers.
Gmm
8D example trains for JKO steps. has 6 layers and 64 neurons in each layer. has 3 layers and 128 neurons in each layer.
13D example trains for JKO steps. has 8 layers and 64 neurons in each layer. has 9 layers and 64 neurons in each layer.
d.2 OrnsteinUhlenbeck Process (Section 4.2)
For FokkerPlanck JKO, we use the implementation provided by the authors and the default parameters given in MokKorLiBur21. We also estimate the SymKL using Monte Carlo according to the author’s instructions. It’s straightforward to use MC estimate for FokkerPlanck JKO because it has access to the density of pushforward distributions. Our method uses parameterization so we cannot query density of the estimated exactly. Nevertheless, our estimated is guaranteed to be a Gaussian distribution since our map
is a linear transformation. Thus we use the density of
as a surrogate, where are empirical mean and empirical covariance computed using samples.We use nearly all the same hyperparameters as MokKorLiBur21, including learning rate, hidden layer width, and the number of iterations per JKO step. Specifically, we use a linear residual feedforward NN to work as , i.e. without activation function. and both have 3 layers and 64 hidden neurons per layer for all dimensions. We also train them for iterations per each JKO with learning rate . The batch size is .
d.3 Porous media equation (Section 4.3)
In the experiment, and both have 10 neurons in each layer.
To account for the timediscretization error of JKO scheme, we consider the porous media equation in 1D space and use the solution via finite difference method as a reference. More specifically, in the 1D space , we discretize the density over a fixed grid with grid size and grid resolution . With this discretization, the probability densities become (scaled) probability vectors and the problem (7) can be converted into a convex optimization
(22) 
where is the discretized unit transport cost, is the probability vector at the previous step, is the allones vector and the optimization variable
is the joint distribution between
and . This is a standard convex optimization and can be solved with generic solvers. When an optimal is obtained, can be computed as . We adopt the library CVXOPT^{1}^{1}1http://cvxr.com/cvx/ to solve the convex programming problem (22). In so doing, we arrive at a reference solution for our algorithm.d.4 Aggregationdiffusion equation (Section 4.4 and c.3)
Each JKO step contains iterations. The batch size is .