TrajectoryNet: A Dynamic Optimal Transport Network for Modeling Cellular Dynamics

02/09/2020 ∙ by Alexander Tong, et al. ∙ Yale University 14

It is increasingly common to encounter data from dynamic processes captured by static cross-sectional measurements over time, particularly in biomedical settings. Recent attempts to model individual trajectories from this data use optimal transport to create pairwise matchings between time points. However, these methods cannot model continuous dynamics and non-linear paths that entities can take in these systems. To address this issue, we establish a link between continuous normalizing flows and dynamic optimal transport, that allows us to model the expected paths of points over time. Continuous normalizing flows are generally under constrained, as they are allowed to take an arbitrary path from the source to the target distribution. We present TrajectoryNet, which controls the continuous paths taken between distributions. We show how this is particularly applicable for studying cellular dynamics in data from single-cell RNA sequencing (scRNA-seq) technologies, and that TrajectoryNet improves upon recently proposed static optimal transport-based models that can be used for interpolating cellular distributions.



There are no comments yet.


page 8

page 13

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In data science we are often confronted with cross-sectional samples of time-varying phenomena, especially in biomedical data. Examples include health measurements of different age cohorts 

Oeppen and Vaupel (2002), or disease measurements at different stages of disease progression Waddington (1942). In these measurements we consider data that is sampled at multiple timepoints, but at each timepoint we have access only to a distribution (cross-section) of the population at that time. Extracting the longitudinal dynamics of development or disease from static snapshot measurements can be challenging as there are few methods of interpolation between distributions. Further exacerbating this problem is the fact that the same entities are often not measured at each time, resulting in a lack of point-to-point correspondences. Here, we propose to formulate this problem as one of unbalanced dynamic transport, where the goal is to transport entities from one cross sectional measurement to the next using efficient and smooth paths. Our main contribution is to establish a link between continuous normalizing flows (CNF) Grathwohl et al. (2019) and dynamic optimal transport Benamou and Brenier (2000), allowing us to efficiently solve the transport problem using a Neural ODE framework Chen et al. (2018). To our knowledge, TrajectoryNet is the first method to consider the specific paths taken by a CNF between distributions.

We add penalties to the continuous normalizing flow such that it performs unbalanced dynamic optimal transport between discrete timepoints with paths optionally following data density, or known directions. Our penalties include:

  1. An energy regularization encouraging lower energy paths, as measured by kinetic energy over euclidean space, which we show results in dynamic optimal transport  Benamou and Brenier (2000).

  2. A growth rate regularization that accommodates unbalanced transport.

  3. A density-based penalty which encourages interpolations that lie on dense regions of the data. Often data lies on a low-dimensional manifold, and it is desirable for paths to follow this manifold at the cost of higher energy.

  4. A velocity regularization

    where we enforce local estimates of velocity at measured datapoints to match the first time derivative of cell state change.

Although widely applicable, we focus our results on the problem of understanding cellular differentiation from snapshot scRNA-seq measurements.

Single-cell RNA sequencing Macosko et al. (2015)

is a relatively new technology that has made it possible for scientists to randomly sample the entire transcriptome, i.e, 20-30 thousand species of mRNA molecules representing transcribed genes of the cell. This technology can reveal very precise information about the identity of individual cells based on transcription factors, surface marker expression, cell cycle and many other facets of cellular behavior. In particular, this technology can be used to learn how cells differentiate from one state to another: for example, from embryonic stem cells to specified lineages such as neuronal or cardiac. However, hampering this understanding is the fact that scRNA-seq only offers static snapshots of data, since all cells are destroyed upon measurement. Thus it is impossible to monitor how an individual cell changes over time. Moreover, due to the expensive nature of this technology, generally only a handful of discrete timepoints are collected in measuring any transition process. TrajectoryNet is especially well suited to this data modality. Existing methods attempt to infer a trajectory within one timepoint 

Saelens et al. (2019); La Manno et al. (2018), or interpolate linearly between two timepoints Yang and Uhler (2019); Schiebinger et al. (2019), but TrajectoryNet can interpolate non-linearly using information from more than two timepoints. TrajectoryNet has advantages over existing methods in that it:

  1. can interpolate by following the manifold of observed entities between measured timepoints, thereby solving the static-snapshot problem,

  2. can create time-continuous trajectories of individual entities, giving researchers the ability to follow an entity in time,

  3. forms a deep representational model of system dynamics, which can then be used to understand drivers of dynamics (gene logic in the cellular context), via perturbation of this deep model.

While our experiments apply this work specifically to cellular dynamics, these penalties can be used in many other situations where we would like to model dynamics based on cross-sectional population level data.

Figure 1: TrajectoryNet uses a Neural ODE to learn the derivative of the dynamics function. To find the output at time for a given input at time we integrate times letting the ODE solver choose the integration timepoints.

2 Background and Related Work

Optimal Transport.

Introduced originally by Monge (1781) and in modern form by Kantorovich (1942)

, the linear program formulation of static optimal transport (OT) has the relatively high cost of

for discrete measures. Recently, there have been a number of fast approximations using entropic regularization. Cuturi (2013) presented a parallel algorithm for the discrete case as an application of Sinkhorn’s algorithm Sinkhorn (1964). Recent effort approximates OT on subspaces Muzellec and Cuturi (2019) or even a single dimension Kolouri et al. (2019). These efforts emphasis the importance to the field of obtaining fast OT algorithms. Another direction that has recently received increased attention is in unbalanced optimal transport where the goal is to relax the problem to add and remove mass Benamou (2003); Chizat et al. (2018); Liero et al. (2018); Schiebinger et al. (2019). While many efficient static optimal transport algorithms exist, and recently for the unbalanced case Yang and Uhler (2019), much less attention has focused on dynamic optimal transport, the focus of this work.

Dynamic Optimal Transport.

Another formulation of optimal transport is known as dynamic optimal transport. Benamou and Brenier (2000) showed how the addition of a natural time interpolation variable gives an alternative interpretation with links to fluid dynamics that surprisingly leads to a convex optimization problem. However, while solvers for the discretized dynamic OT problem are effective in low dimensions and for small problems they require a discretization of space into grids giving cost exponential in the dimension (See Peyré and Cuturi (2019) Chap. 7 for a good overview of this problem). One of our main contributions is to provide an approximate solver for high dimensional smooth problems using a neural network.

Single-cell Trajectories from a Static Snapshot.

Temporal interpolation in single-cell data started with solutions that attempt to infer an axis within one single time point of data cell “pseudotime” – used as a proxy for developmental progression – using known markers of development and the asynchronous nature of cell development Trapnell et al. (2014); Bendall et al. (2014). An extensive comparison of 45 methods for this type of analysis gives method recommendations based on prior assumptions on the general structure of the data Saelens et al. (2019). However, these methods can be biased and fail in a number of circumstances Weinreb et al. (2018); Lederer and La Manno (2020) and do not take into account experimental time.

Matching Populations from Multiple Time Points

Recent methods get around some of these challenges using multiple timepoints Hashimoto et al. (2016); Schiebinger et al. (2019); Yang and Uhler (2019). However, these methods generally resort to matching populations between coarse-grained timepoints, but do not give much insight into how they move between measured timepoints. Often paths are assumed to minimize total Euclidean cost, which is not realistic in this setting. In contrast, the methods that estimate dynamics from single timepoints  La Manno et al. (2018); Bergen et al. (2019); Erhard et al. (2019); Hendriks et al. (2019) have the potential to give relatively accurate estimation of local direction, but cannot give accurate global estimation of distributional shift.

With TrajectoryNet, we aim to unite these two approaches into a single model combining in inferring continuous time trajectories from multiple timepoints, globally, while respecting local dynamics within a single timepoint.

3 Preliminaries

We provide an overview of static optimal transport, dynamic optimal transport Benamou and Brenier (2000), and continuous normalizing flows.

3.1 The Monge-Kantorovich Problem

We adopt notation from the standard text Villani (2008)

. For two probability measures

defined on , let denote the set of all joint probability measures on whose marginals are and . Then the p-Kantorovich distance (or Wasserstein distance of order )between and is


where . This formulation has led to many useful interpretations both in GANs and biological networks. For the entropy regularized problem, the Sinkhorn algorithm Sinkhorn (1964) provides a fast and parallelizable numerical solution in the discrete case. Recent work tackles computationally efficient solutions to the exact problem Jambulapati et al. (2019) for the discrete case. However, for the continuous case solutions to the discrete problem in high dimensional spaces do not scale well. As the rate of convergence of the empirical Wasserstein metric between empirical measures and with bounded support is shown in Dudley (1969) to be


where is the ambient dimension. However, recent work shows that in high dimensions a more careful treatment that the rate depends on the intrinsic dimension not the ambient dimension Weed and Bach (2019). As long as data lies in a low dimensional manifold in ambient space, then we can reasonably approximate the Wasserstein distance. In this work we approximate the support of this manifold using a neural network.

3.2 Dynamic Optimal Transport

Benamou and Brenier (2000)

defined and explored a dynamic version of Kantorovich distance. Their work linked optimal transport distances with dynamics and partial differential equations (PDEs). For a fixed time interval

with smooth enough, time dependent density and velocity fields, , subject to the continuity equation


for and , and the conditions


we can relate the squared Wasserstein distance to in the following way


In other words, a velocity field with minimum norm that transports mass at to mass at when integrated over the time interval is the optimal plan for an Wasserstein distance. The continuity equation is applied over all points of the field and asserts that no point is a source or sink for mass. The solution to this flow can be shown to follow a pressureless flow on a time-dependent potential function. Mass moves with constant velocity that linearly interpolates between the initial and final measures. For problems where interpolation of the density between two known densities is of interest, this formulation is very attractive. Existing computational methods for solving the dynamic formulation for continuous measures approximate the flow using a discretization of space-time Papadakis et al. (2014). This works well in low dimensions, but scales poorly to high dimensions as the complexity is exponential in the input dimension . We next give background on continuous normalizing flows, which we show can provide a solution with computational complexity polynomial in .

3.3 Continuous normalizing flows

A normalizing flow Rezende and Mohamed (2015) transforms a parametric (usually simple) distribution to a more complicated one. Using an invertible transrmation

applied to an initial latent random variable

with density , We define as the output of the flow. Then by the change of variables formula, we can compute the density of the output :


A large effort has gone into creating architectures where the log determinant of the Jacobian is efficient to compute Rezende and Mohamed (2015); Kingma et al. (2016); Papamakarios et al. (2017).

Now consider a continuous-time transformation, where the derivative of the transformation is parameterized by , thus at any timepoint , . At the initial time , is drawn from a distribution which we also denote for clarity, and it’s continuously transformed to by following the differential equation :


where at any time associated with every through the flow can be found by following the inverse flow. This model is referred as continuous normalizing flows (CNFs). It can be likened to the dynamic version of optimal transport, where we model the measure over time rather than the mapping from to

Unsurprisingly, there is a deep connection between CNFs and dynamic optimal transport. In the next section we exploit this connection and show how CNFs can be used to provide a high dimensional solution to the dynamic optimal transport problem with TrajectoryNet.

Figure 2: Transporting a Gaussian (a) to an S-curve (b) via (c) static optimal transport, (d) Base TrajectoryNet (e) TrajectoryNet with energy regularization. TrajectoryNet can follow density or direct paths.

4 TrajectoryNet: Efficient Dynamic Optimal Transport

In this section, we describe our TrajectoryNet model for general usage (Figure 1). We first adapt the optimization of dynamic optimal transport to that of continuous normalizing flows, then show how we use this to optimize over flows consisting of multiple timepoints.

Architecture and Training.

The neural network architecture of TrajectoryNet consists of three fully connected layers of 64 nodes with leaky ReLU activations. It takes as input a cell state and time and outputs the derivative with respect to time at that point. This is represented by

. For a dataset with measured samples at timepoints we denote the measured observations from each timepoint as . To train a continuous normalizing flow we need access to the density function of the source distribution. Since this is not accessible for an empirical distribution we use an additional Gaussian at , defining

, the standard Gaussian distribution, where

is the density function at time .

For a training step we draw samples . Our full loss is


While there are a number of ways to computationally approximate these quantities we use a parallel method to iteratively calculate the based on . To make a backward pass through all timepoints we start at the final timepoint, integrate the batch to the second to last timepoint, concatenate these points to the samples from the second to last timepoint, and continue till , where the density is known for each sample. We note that this can compound the error especially for later timepoints if is large or if the learned system is stiff, but gives significant speedup during training.

To sample from we first sample then use the adjoint method to perform the integration .

Approximation of Dynamic OT.

In TrajectoryNet we consider a maximum likelihood model. The well known equivalence between maximum likelihood and KL divergence allows us to consider a Kullback-Leibler (KL) divergence penalty on measured timepoints  Murphy (2013). Our TrajectoryNet loss in eq. (8) can alternatively be expressed as: where


From a numerical standpoint, it is difficult to satisfy the constraint in eq. (4). We relax this constraint to a KL divergence and the dynamic optimal transport equation then becomes


For sufficiently large under constraint (3) this converges to the optimal solution in (5). This is encapsulated in the following theorem,

Theorem 4.1.

The squared Wasserstein distance is equal to the infimum of

among all when subject to (3) and for a sufficiently large ;

See Appendix A.1 for full proof.

We now show how continuous normalizing flows can be used to solve this optimization. As noted by Chen et al. (2018) it follows from the Liouville equation Hilbert (1902) that for a continuous normalizing flow eq. (3) holds i.e. for all . We can then modify eq. (4) to


We can approximate the first part of this equation using a Riemann sum as . This requires a forward integration through TrajectoryNet to compute. If we consider the case where models the data well then we can combine this with the backward integration for very little added computation. Using the maximum likelihood and KL divergence equivalence, we obtain a loss


Where the integral above is computed using a Riemann sum. In practice, both a penalty on the Jacobian or additional training noise helped to get straight paths with a lower energy regularization . We found that a value of

large enough to encourage straight paths unsurprisingly also shortens the paths undershooting the target distribution. Adding these additional penalties encourage locally consistent paths as in denoising and contractive autoencoders 

Vincent et al. (2010); Rifai et al. (2011). Since represents the derivative of the path, this discourages paths with high local curvature. Our energy loss is then


Where is the Frobenius norm of the Jacobian of . Figure 2 shows the effect of this regularization. Without energy regularization, TrajectoryNet tends to have paths that follow the data density. However, with energy regularization we approach the paths of the optimal map. TrajectoryNet solution biases towards undershooting the target distribution. Our energy loss gives control over how much to penalize indirect, high energy paths.

OT extended to a series of distributions over time.

Optimal transport is traditionally performed between a source and target distribution. Extensions to a series of distributions is normally done by performing optimal transport between successive pairs of distributions as in Schiebinger et al. (2019). This creates flows which have discontinuities at the sampled times, which may be undesirable when the underlying system is smooth in time as in biological systems. The dynamic model approximates dynamic OT for two timepoints, but by using a single smooth function to model the whole series the flow becomes the minimal cost smooth flow over time.

Unbalanced optimal transport.

We use a simple and computationally efficient method that adapts discrete static unbalanced optimal transport to our framework in the continuous setting. This is a necessary extension, but is by no means a focus of our work. While we could also apply this adversarial framework, we choose to avoid the instabilities of adversarial training and use a simple network trained from the solution to the discrete problem. We train a network , which takes as input a cell state and time pair and produces a growth rate of a cell at that time. This is trained to match the result from discrete optimal transport. For further specification see Appendix B. We then fix weights of this network and modify the way we integrate mass over time to


We note that adding growth rate regularization in this way does not guarantee conservation of mass. We could normalize

to be a probability distribution during training, however, this now requires an integration over

, let which is too computationally costly. Instead we use the equivalence of the maximum likelihood formulation over a fixed growth function and normalize at after the network is trained.

Encoding the density assumption.

Methods that display or perform computations on the cellular manifold often include an implicit or explicit way of normalizing for density and model data geometry. PHATE Moon et al. (2019) uses scatterplots and adaptive kernels to display the geometry. SUGAR Lindenbaum et al. (2018) explicitly models the data geometry. We would like to constrain our flows to the manifold geometry but not to its density. We penalize the flow such that it is always close to at least a few measured points across all timepoints.


This can be thought of as a loss that penalizes points until they are within Euclidean distance of their nearest neighbors. We use and in all of our experiments. We evaluate on an interpolated time every batch.

Encoding Velocity.

Often it is the case where it is easy to measure direction of change in a short time horizon, but not have good predictive power at the scale of measured timesteps. In health data, we can often collect data from a few visits over a short time horizon estimating the direction of a single patient in the near future. In single-cell data RNA-velocity La Manno et al. (2018); Bergen et al. (2019) provides an estimates at every measured cell. We use these measurements to regularize the direction of flow at every measured point. Our regularization requires evaluating periodically at every measured cell adding the regularization:


This encourages the direction of the flow at a measured point to be similar to the direction of local velocity. This ignores the magnitude of the estimate, and only heeds the direction. While RNA-velocity provides some estimate of relative speed, the vector length is considered not as informative, as it is unclear how to normalize these vectors in a system specific way 

La Manno et al. (2018); Bergen et al. (2019). We note that while current estimates of velocity can only estimate direction, this does not preclude future methods that can give accurate magnitude estimates. can easily be adapted to take magnitudes into account by considering similarity for instance.

5 Experiments

All experiments were performed with the TrajectoryNet framework with a network of consisting of three layers with LeakyReLU activations. Optimization was performed on 10,000 iterations of batches of size 1,000 using the dopri5 solver Dormand and Prince (1980) with both absolute and relative tolerances set to and the ADAM optimizer Kingma and Ba (2014) with learning rate 0.001, and weight decay as in Grathwohl et al. (2019). We evaluate using three TrajectoryNet models with different regularization terms. The Base model refers to a standard normalizing flow. +E adds , +D adds ,+V adds , and +G adds .

Comparison to Existing Methods.

Since there are no ground truth methods to calculate the trajectory of a single cell we evaluate our model using interpolation of held-out timepoints. We leave out an intermediary timepoint and measure the Kantorovich-distance also known as the earth mover’s distance (EMD) between the predicted and held-out distributions. For EMD lower is more accurate. We compare the distribution interpolated by Trajectorynet with four other distributions. The previous timepoint, the next timepoint, a random timepoint and the Mccan interpolant in the discrete OT solution as used in Schiebinger et al. (2019).

Figure 3: Density regularization or velocity regularization can be used to follow a 1D manifold in 2D.
Figure 4: A 1D distribution of data over time embedded in two dimensions along a smooth manifold. On a single branch (left), with a tree structure (center), and circle (right).
Arch Cycle Tree Arch Cycle Tree
Base 0.691 0.037 0.490 0.300 0.190 0.218
Base + D 0.607 0.049 0.373 0.236 0.191 0.145
Base + V 0.243 0.033 0.143 0.107 0.068 0.098
Base + D + V 0.415 0.034 0.252 0.156 0.081 0.132
OT 0.644 0.032 0.492 0.252 0.192 0.196
prev 1.086 0.035 1.092 0.652 0.192 0.666
next 1.090 0.035 1.068 0.659 0.192 0.689
rand 0.622 0.406 0.420 0.243 0.346 0.161
Table 1: Shows the Wasserstein distance EMD and MSE for artificial datasets between the left out timepoint and the predicted points for our two generated datasets. Mean over 3 seeds.
Figure 5: Cell growth model learned on Embryoid Body Data Moon et al. (2019)

5.1 Artificial Data

For artificial data where we have known paths, we can measure the mean squared error (MSE) predicted by the model based on the first timepoint. Here we leave out the middle timepoint for training then calculate the MSE between the predicted point at time and the true point at for 5000 sampled trajectories. This gives a measure of how accurately we can model simple dynamical systems.

We first test TrajectoryNet on two datasets where points lie on a 1D manifold in 2D with Gaussian noise (See Figure 4). First two half Gaussians are sampled with means zero and one in one dimension. These progressions are then lifted onto curved manifolds in two dimensions either an arch or a tree mimicking a differentiating system where we have two sampled timepoints that have some overlap. Table 1 shows the Wasserstein distance (EMD) and the mean squared error for different interpolation methods between the interpolated distribution at and the true interpolated distribution at . Because optimal transport considers the shortest Euclidean distance, the base model and OT methods follow the lowest energy path, which is straight across. With density regularization or velocity regularization TrajectoryNet learns paths that follow the density manifold. Figure 3 and Figure S2 demonstrate how TrajectoryNet with density or velocity regularization learns to follow the manifold.

A third artificial dataset shows the necessity of using velocity estimates for some data. Here we have an unchanging distribution of points distributed uniformly over the unit circle, but are traveling counterclockwise at radians per unit time. This is similar to the cell-cycle process in adult systems. Without velocity estimates it is impossible to pick up this type of dynamical system. This is illustrated by the MSE of the cycle dataset using velocity regularization in Table 1.

5.2 Single-Cell Data

We run our model on 5D PCA due to computational constraints, but note that computation time scales roughly linearly with dimension for our test cases (See Appendix C), which is consistent to what was found in Grathwohl et al. (2019). Since there are no ground truth trajectories in real data, we can only evaluate using distributional distances. We do leave-one-out validation, training the model on all but one of the intermediate timepoints then evaluating the EMD between the validation data and the model’s predicted distribution. We evaluate and compare our method on two single-cell RNA sequencing datasets.

Figure 6: Shows the first 2 PCs of the mouse cortex dataset. (a-c) show the distributions for the first three timepoints. (d) shows the distribution of cells over PC1. the interpolated points for E14.5 using (e) static OT, and (f) TrajectoryNet with density regularization. (g-i) shows expression of three markers of early (Pax6) mid (Eomes) and late (Tbr1) stage neurons.
rep1 rep2 mean
Base 0.888 0.07 0.905 0.06 0.897 0.06
Base + D 0.882 0.03 0.895 0.03 0.888 0.03
Base + V 0.900 0.09 0.898 0.10 0.899 0.10
Base + D + V 0.851 0.08 0.866 0.07 0.859 0.07
OT 1.098 1.095 1.096
prev 1.628 1.573 1.600
next 1.324 1.391 1.357
rand 1.333 1.288 1.311
Table 2:

Shows the Wasserstein distance between the left out timepoint and the predicted distribution for various methods on a 4 timepoint mouse embryo cortex dataset. Mean and standard deviation over 3 seeds.

Mouse Cortex Data.111For videos of the dynamics learned by TrajectoryNet see

The first dataset has structure similar to the Arch toy dataset. It consists of cells collected from mouse embryos at days E12.5, E14.5, E16, and E17.5. In Figure 6(d) we can see at this time in development of the mouse cortex the distribution of cells moves from a mostly neural stem cell population at E12.5 to a fairly developed and differentiated neuronal population at E17.5 Cotney et al. (2015); Katayama et al. (2016). The major axis of variation is neuron development. Over the 4 timepoints we have 2 biological replicates that we can use to evaluate variation between animals. In Table 2, we can see that TrajectoryNet outperforms baseline models, especially when adding density and velocity information. The curved manifold structure of this data, and gene expression data in general means that methods that interpolate with straight paths cannot fully capture the structure of the data. Since TrajectoryNet models full paths between timepoints, adding density and velocity information can bend the cell paths to follow the manifold utilizing all available data rather than two timepoints as in standard optimal transport.

Figure 7: Shows the Embryoid body dataset projected into 2D with PHATE Moon et al. (2019)

with paths and densities imputed using TrajectoryNet.

Embryoid body Data.

Next, we evaluate on a differentiating Embryoid body scRNA-seq time course. Figure 7 shows this data projected into two dimensions using a non-linear dimensionality reduction method called PHATE Moon et al. (2019). This data consists of 5 timepoints of single cell data collected in a developing human embryo system (Day 0-Day 24). See Figure 5 for a depiction of the growth rate. Initially, cells start as a single stem cell population, but differentiate into roughly 4 cell precurser types. This gives a branching structure similar to our artificial tree dataset. In Table 3 we show results when each of the three intermediate timepoints are left out. In this case velocity regularization does not seem to help, we hypothesis this has to do with the low unspliced RNA counts present in the data (See Figure S3). We find that energy regularization and growth rate regularization help only on the first timepoint, and that density regularization helps the most overall.

We can also project trajectories back to gene space. This gives insights into when populations might be distinguishable. In Figure 8, we demonstrate how TrajectoryNet can be projected back to the gene space. We sample cells from the end of the four main branches, then integrate TrajectoryNet backwards to get their paths through gene space. This recapitulates known biology in Moon et al. (2019). See appendix D for a more in-depth treatment.

t=1 t=2 t=3 mean
Base 0.764 0.811 0.863 0.813
Base + D 0.759 0.783 0.811 0.784
Base + V 0.816 0.839 0.865 0.840
Base + D + V 0.930 0.806 0.810 0.848
Base + E . 0.737 0.896 0.842 0.825
Base + G 0.700 0.913 0.829 0.814
OT 0.791 0.831 0.841 0.821
prev 1.715 1.400 0.814 1.309
next 1.400 0.814 1.694 1.302
rand 0.872 1.036 0.998 0.969
Table 3: Shows the Wasserstein distance (EMD) between the left out timepoint and the predicted distribution for various methods on the 5 timepoint Embryoid body dataset.
Figure 8: For curated endpoints, shows location on PHATE dimensions, TrajectoryNet paths projected into PCA space, and trajectories for 4 genes.

6 Conclusion

TrajectoryNet computes dynamic optimal transport between distributions of samples at discrete times to model realistic paths of samples continuously in time. In the single-cell case, TrajectoryNet ”reanimates,” cells which are destroyed by measurement to recreate a continuous-time trajectory. This is also relevant when modeling any underlying system that is high-dimensional, dynamic, and non-linear. In this case, existing static OT methods are under-powered and do not interpolate well to intermediate timepoints between measured ones. Existing dynamic OT methods (non-neural network based) are computationally infeasible for this task.

In this work we integrate multiple priors and assumptions into one model to bias TrajectoryNet towards more realistic dynamic optimal transport solutions. We demonstrated how this gives more power to discover hidden and time specific relationships between features. In future work, we would like to consider stochastic dynamics Liu et al. (2019) and learning the growth term together with the dynamics.


  • J. Benamou and Y. Brenier (2000) A computational fluid mechanics solution to the Monge-Kantorovich mass transfer problem. Numerische Mathematik 84 (3), pp. 375–393 (en). External Links: ISSN 0029-599X, 0945-3245, Document Cited by: Appendix C, item 1, §1, §2, §3.2, §3.
  • J. Benamou (2003) Numerical resolution of an “unbalanced” mass transport problem. ESAIM: Mathematical Modelling and Numerical Analysis 37 (5), pp. 851–868 (en). External Links: ISSN 0764-583X, 1290-3841, Document Cited by: §2.
  • S. C. Bendall, K. L. Davis, E. D. Amir, M. D. Tadmor, E. F. Simonds, T. J. Chen, D. K. Shenfeld, G. P. Nolan, and D. Pe’er (2014) Single-Cell Trajectory Detection Uncovers Progression and Regulatory Coordination in Human B Cell Development. Cell 157 (3), pp. 714–725 (en). External Links: ISSN 00928674, Document Cited by: §2.
  • V. Bergen, M. Lange, S. Peidli, F. A. Wolf, and F. J. Theis (2019) Generalizing RNA velocity to transient cell states through dynamical modeling. Preprint Bioinformatics (en). External Links: Document Cited by: Appendix D, §E.2, §2, §4.
  • R. T. Q. Chen, Y. Rubanova, J. Bettencourt, and D. Duvenaud (2018)

    Neural Ordinary Differential Equations

    In NeurIPS, External Links: 1806.07366 Cited by: §1, §4.
  • L. Chizat, G. Peyré, B. Schmitzer, and F. Vialard (2018) Unbalanced optimal transport: Dynamic and Kantorovich formulations. Journal of Functional Analysis 274 (11), pp. 3090–3123 (en). External Links: ISSN 00221236, Document Cited by: §2.
  • J. Cotney, R. A. Muhle, S. J. Sanders, L. Liu, A. J. Willsey, W. Niu, W. Liu, L. Klei, J. Lei, J. Yin, S. K. Reilly, A. T. Tebbenkamp, C. Bichsel, M. Pletikos, N. Sestan, K. Roeder, M. W. State, B. Devlin, and J. P. Noonan (2015) The autism-associated chromatin modifier CHD8 regulates other autism risk genes during human neurodevelopment. Nature Communications 6 (1), pp. 6404 (en). External Links: ISSN 2041-1723, Document Cited by: §5.2.
  • M. Cuturi (2013) Sinkhorn Distances: Lightspeed Computation of Optimal Transport. NeurIPS, pp. 9 (en). Cited by: §2.
  • J.R. Dormand and P.J. Prince (1980) A family of embedded Runge-Kutta formulae. Journal of Computational and Applied Mathematics 6 (1), pp. 19–26 (en). External Links: ISSN 03770427, Document Cited by: §5.
  • R. M. Dudley (1969) The Speed of Mean Glivenko-Cantelli Convergence. The Annals of Mathematical Statistics 40 (1), pp. 40–50. Cited by: §3.1.
  • F. Erhard, M. A. P. Baptista, T. Krammer, T. Hennig, M. Lange, P. Arampatzi, C. S. Jürges, F. J. Theis, A. Saliba, and L. Dölken (2019) scSLAM-seq reveals core features of transcription dynamics in single cells. Nature 571 (7765), pp. 419–423 (en). External Links: ISSN 0028-0836, 1476-4687, Document Cited by: §2.
  • W. Grathwohl, R. T. Q. Chen, J. Bettencourt, I. Sutskever, and D. Duvenaud (2019) FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models. In ICLR, External Links: 1810.01367 Cited by: Appendix C, §1, §5.2, §5.
  • T. B. Hashimoto, D. K. Gifford, and T. S. Jaakkola (2016) Learning Population-Level Diffusions with Generative Recurrent Networks. ICML, pp. 10 (en). Cited by: §2.
  • G. Hendriks, L. A. Jung, A. J. M. Larsson, M. Lidschreiber, O. Andersson Forsman, K. Lidschreiber, P. Cramer, and R. Sandberg (2019) NASC-seq monitors RNA synthesis in single cells. Nature Communications 10 (1), pp. 3138 (en). External Links: ISSN 2041-1723, Document Cited by: §2.
  • D. Hilbert (1902) Mathematical Problems. Bulletin of the American Mathematical Society 8 (10), pp. 437–479. Cited by: §4.
  • A. Jambulapati, A. Sidford, and K. Tian (2019) A Direct $\tilde{}O{}(1/\epsilon)$ Iteration Parallel Algorithm for Optimal Transport. NeurIPS, pp. 12 (en). Cited by: §3.1.
  • S. Kanton, M. J. Boyle, Z. He, M. Santel, A. Weigert, F. Sanchís-Calleja, P. Guijarro, L. Sidow, J. S. Fleck, D. Han, Z. Qian, M. Heide, W. B. Huttner, P. Khaitovich, S. Pääbo, B. Treutlein, and J. G. Camp (2019) Organoid single-cell genomic atlas uncovers human-specific features of brain development. Nature 574 (7778), pp. 418–422 (en). External Links: ISSN 0028-0836, 1476-4687, Document Cited by: Appendix D.
  • L. V. Kantorovich (1942) On the Translocation of Masses. Doklady Akademii Nauk, pp. 7–8 (en). Cited by: §2.
  • Y. Katayama, M. Nishiyama, H. Shoji, Y. Ohkawa, A. Kawamura, T. Sato, M. Suyama, T. Takumi, T. Miyakawa, and K. I. Nakayama (2016) CHD8 haploinsufficiency results in autistic-like phenotypes in mice. Nature 537 (7622), pp. 675–679 (en). External Links: ISSN 0028-0836, 1476-4687, Document Cited by: §5.2.
  • D. P. Kingma and J. Ba (2014) Adam: A Method for Stochastic Optimization. arXiv:1412.6980 [cs]. External Links: 1412.6980 Cited by: §5.
  • D. P. Kingma, T. Salimans, R. Jozefowicz, X. Chen, I. Sutskever, and M. Welling (2016) Improving Variational Inference with Inverse Autoregressive Flow. In NeurIPS, External Links: 1606.04934 Cited by: §3.3.
  • S. Kolouri, K. Nadjahi, U. Simsekli, R. Badeau, and G. Rohde (2019) Generalized Sliced Wasserstein Distances. NeurIPS, pp. 12 (en). Cited by: §2.
  • G. La Manno, R. Soldatov, A. Zeisel, E. Braun, H. Hochgerner, V. Petukhov, K. Lidschreiber, M. E. Kastriti, P. Lönnerberg, A. Furlan, J. Fan, L. E. Borm, Z. Liu, D. van Bruggen, J. Guo, X. He, R. Barker, E. Sundström, G. Castelo-Branco, P. Cramer, I. Adameyko, S. Linnarsson, and P. V. Kharchenko (2018) RNA velocity of single cells. Nature 560 (7719), pp. 494–498 (en). External Links: ISSN 0028-0836, 1476-4687, Document Cited by: Figure S3, Appendix D, §E.2, §1, §2, §4.
  • A. R. Lederer and G. La Manno (2020) The emergence and promise of single-cell temporal-omics approaches. Current Opinion in Biotechnology 63, pp. 70–78 (en). External Links: ISSN 09581669, Document Cited by: §2.
  • M. Liero, A. Mielke, and G. Savaré (2018) Optimal Entropy-Transport problems and a new Hellinger–Kantorovich distance between positive measures. Inventiones mathematicae 211 (3), pp. 969–1117 (en). External Links: ISSN 0020-9910, 1432-1297, Document Cited by: §2.
  • O. Lindenbaum, J. Stanley, G. Wolf, and S. Krishnaswamy (2018) Geometry Based Data Generation. NeurIPS, pp. 12 (en). Cited by: §4.
  • X. Liu, T. Xiao, S. Si, Q. Cao, S. Kumar, and C. Hsieh (2019) Neural SDE: Stabilizing Neural ODE Networks with Stochastic Noise. arXiv:1906.02355 [cs, stat]. External Links: 1906.02355 Cited by: §6.
  • E. Z. Macosko, A. Basu, R. Satija, J. Nemesh, K. Shekhar, M. Goldman, I. Tirosh, A. R. Bialas, N. Kamitaki, E. M. Martersteck, J. J. Trombetta, D. A. Weitz, J. R. Sanes, A. K. Shalek, A. Regev, and S. A. McCarroll (2015) Highly Parallel Genome-wide Expression Profiling of Individual Cells Using Nanoliter Droplets. Cell 161 (5), pp. 1202–1214 (en). External Links: ISSN 00928674, Document Cited by: §1.
  • G. Monge (1781) Mémoire sur la théorie des déblais et des remblais. Histoire de l’Académie Royale des Science. Cited by: §2.
  • K. R. Moon, D. van Dijk, Z. Wang, S. Gigante, D. B. Burkhardt, W. S. Chen, K. Yim, A. van den Elzen, M. J. Hirn, R. R. Coifman, N. B. Ivanova, G. Wolf, and S. Krishnaswamy (2019) Visualizing Structure and Transitions for Biological Data Exploration. bioRxiv, pp. 120378 (en). External Links: Document Cited by: Appendix D, §4, Figure 5, Figure 7, §5.2, §5.2.
  • K. P. Murphy (2013) Machine Learning: a Probabilistic Perspective. The MIT Press. External Links: ISBN 978-0-262-01802-9 Cited by: §4.
  • B. Muzellec and M. Cuturi (2019) Subspace Detours: Building Transport Plans that are Optimal on Subspace Projections. NeurIPS, pp. 12 (en). Cited by: §2.
  • J. Oeppen and J. W. Vaupel (2002) Broken Limits to Life Expectancy. Science 296 (5570), pp. 1029–1031. External Links: Document Cited by: §1.
  • N. Papadakis, G. Peyré, and E. Oudet (2014) Optimal Transport with Proximal Splitting. SIAM Journal on Imaging Sciences 7 (1), pp. 212–238 (en). External Links: ISSN 1936-4954, Document, 1304.5784 Cited by: Appendix C, §3.2.
  • G. Papamakarios, T. Pavlakou, and I. Murray (2017) Masked Autoregressive Flow for Density Estimation. arXiv:1705.07057 [cs, stat]. External Links: 1705.07057 Cited by: §3.3.
  • A. Paszke, S. Gross, F. Massa, A. Lerer, J. Bradbury, G. Chanan, T. Killeen, Z. Lin, N. Gimelshein, L. Antiga, A. Desmaison, A. Kopf, E. Yang, Z. DeVito, M. Raison, A. Tejani, S. Chilamkurthy, B. Steiner, L. Fang, J. Bai, and S. Chintala (2019) PyTorch: An Imperative Style, High-Performance Deep Learning Library. NeurIPS, pp. 12 (en). Cited by: Appendix C.
  • G. Peyré and M. Cuturi (2019) Computational Optimal Transport. arXiv:1803.00567 [stat] (en). External Links: 1803.00567 Cited by: Appendix C, §2.
  • D. J. Rezende and S. Mohamed (2015) Variational Inference with Normalizing Flows. arXiv:1505.05770 [cs, stat]. External Links: 1505.05770 Cited by: §3.3, §3.3.
  • S. Rifai, P. Vincent, X. Muller, X. Glorot, and Y. Bengio (2011)

    Contractive Auto-Encoders: Explicit Invariance During Feature Extraction

    In Proceedings of the 29th International Conference on Machine Learning, pp. 8 (en). Cited by: §4.
  • W. Saelens, R. Cannoodt, H. Todorov, and Y. Saeys (2019) A comparison of single-cell trajectory inference methods. Nature Biotechnology 37 (5), pp. 547–554 (en). External Links: ISSN 1087-0156, 1546-1696, Document Cited by: §1, §2.
  • G. Schiebinger, J. Shu, M. Tabaka, B. Cleary, V. Subramanian, A. Solomon, J. Gould, S. Liu, S. Lin, P. Berube, L. Lee, J. Chen, J. Brumbaugh, P. Rigollet, K. Hochedlinger, R. Jaenisch, A. Regev, and E. S. Lander (2019) Optimal-Transport Analysis of Single-Cell Gene Expression Identifies Developmental Trajectories in Reprogramming. Cell 176 (4), pp. 928–943.e22 (en). External Links: ISSN 00928674, Document Cited by: §1, §2, §2, §4, §5.
  • R. Sinkhorn (1964) A relationship between arbitrary positive matrices and doubly stochastic matrices. Cited by: §2, §3.1.
  • C. Trapnell, D. Cacchiarelli, J. Grimsby, P. Pokharel, S. Li, M. Morse, N. J. Lennon, K. J. Livak, T. S. Mikkelsen, and J. L. Rinn (2014) The dynamics and regulators of cell fate decisions are revealed by pseudotemporal ordering of single cells. Nature Biotechnology 32 (4), pp. 381–386 (en). External Links: ISSN 1087-0156, 1546-1696, Document Cited by: §2.
  • C. Villani (2008) Optimal transport, old and new. Springer. External Links: ISBN 978-3-540-71050-9 Cited by: §3.1.
  • P. Vincent, H. Larochelle, I. Lajoie, Y. Bengio, and P. Manzagol (2010)

    Stacked Denoising Autoencoders: Learning Useful Representations in a Deep Network with a Local Denoising Criterion

    Journal of Machine Learning Research, pp. 3371–3408 (en). Cited by: §4.
  • C. H. Waddington (1942) The epigenotype. Endeavour 1, pp. 18–20. Cited by: §1.
  • J. Weed and F. Bach (2019) Sharp asymptotic and finite-sample rates of convergence of empirical measures in Wasserstein distance. Bernoulli 25 (4A), pp. 2620–2648 (en). External Links: ISSN 1350-7265, Document Cited by: §3.1.
  • C. Weinreb, S. Wolock, B. K. Tusi, M. Socolovsky, and A. M. Klein (2018) Fundamental limits on dynamic inference from single-cell snapshots. Proceedings of the National Academy of Sciences 115 (10), pp. E2467–E2476 (en). External Links: ISSN 0027-8424, 1091-6490, Document Cited by: §2.
  • K. D. Yang and C. Uhler (2019) Scalable Unbalanced Optimal Transport Using Generative Adversarial Networks. ICLR, pp. 20 (en). Cited by: §1, §2, §2.

Appendix A Technical details

a.1 Proof of Theorem 4.1

First, we apply the Lagrange multiplier method by introducing the variable to the minimization problem of (5) subject to constraints (4). As we always begin with the base distribution, .


Since the KL divergences are non-negative, . The optimal solution of the min-max problem is the optimal solution to the original problem. Consider the true minimal loss given by the optimal solution to be , we know that , where


In order to show that the solution of the max-min problem converges to , we first show that it is monotonic in . For easier reading, set , . Both and are functions of . For any pair of values, if , . Thus the maximum and minimum of the function is also monotonic in , and it will converge to the supremum.

Next, we show that the divergence term decreases monotonically as increases, and it converges to as goes to infinity. For a given , let . By definition, , and . Thus . If ,then . The sequence decreases monotonically as increases. Because is upper bounded by and , converges to zero as goes to infinity.

Now we have shown that is a monotone sequence and is upper bounded by , and that the divergence term converges to zero, we next show that converges to , and that the optimal solution of the max-min problem in (18) is the optimal solution of the original problem. Since the divergence term is non-negative, we have a lower bound for as


Because is monotonically increasing, and is monotonically decreasing, increases as increases.

Lemma A.1.

where is the diameter of the probability space and is the transformation completion time.


For a certain , starting from the base distribution , at time the distribution is transformed, by following , to , and KL. Now consider a different transformation , which is composed of two part: the first part is an accelerated , so that is achieved by time , and the second part is transforming to in the remaining time of . Thus at time , by following , we achieve zero divergence, . The new transformation has an increased , from the acceleration and the additional transformation.


where has distribution , has distribution , and is a mapping from to . The first part of is just . The second part is upper bounded by , where is the total variation between and , which is in turn upper bounded by . We choose . ∎

By definition, , as is the infimum at zero divergence. Assuming converges to , and , . Then by Lemma A.1, , and we have a contradiction. Now we complete the proof that converges to and that for a large enough , the solution of the max-min problem (18) is the solution of (5) when subject to conditions (4).

Appendix B Growth Rate Model Training

Our growth network is trained to match the discrete unbalanced optimal transport problem with entropic regularization:


Where is the Frobenius norm of elementwise matrix multiplication of the transportation matrix and the cost matrix , and where are regularization constants on the source and target unbalanced distributions . Then the growth rate of each cell in to is then


In our experiments we set and tune

for reasonable growth rates. This gives a growth rate at every observed cell; however, our model needs a growth rate defined continuously at every measured timepoint. For this, we learn a neural network that is trained to match the growth rate at measured cells and equal to one at negative sampled points. We use a simple form of negative sampling, for each batch of real points we sample an equal sized batch of points from a uniform distribution over the [-1,1] hypercube, where these negative points are given a growth rate value of 1. The network is trained with mean squared error loss to match these growth rates at all measured times.

Appendix C Scaling with Dimension

Runtime Considerations.

Existing numerical methods for solving dynamic OT rely on proximal splitting methods over a discretized staggered grid Benamou and Brenier (2000); Papadakis et al. (2014); Peyré and Cuturi (2019). This results in a non-smooth but convex optimization problem over these grid points. However, the number of grid points scales exponentially with the dimension, so these methods are only applicable in low dimensions. TrajectoryNet scales polynomially with the dimension. See Figure S1 for empirical measurements.

To test the computation time with dimension we run TrajectoryNet for 100 batches of 1000 points on the mouse cortex dataset over different dimensionalities. For hardware we use a single machine with An AMD Ryzen Threadripper 2990WX 32-core Processor, 128GB of memory, and three Nvidia TITAN RTX GPUs. Our model is coded in the Pytorch framework Paszke et al. (2019). We count the total number of function evaluations (both forward and backward) divide the total time by this. In Figure S1, you can see the seconds per evaluation is roughly linear with the dimensionality of the data. This does not imply convergence of the model is linear in dimension, only that computation per iteration is linear. As suggested in Grathwohl et al. (2019), number of iterations until convergence is a function of how complicated the distributions are, and less dependent on the ambient dimension itself. By learning flows along a manifold with , our method may scale closer to the intrinsic dimensionality of the data rather than the ambient.

Figure S1: The computation per evaluation is roughly linear in terms of dimension.
Figure S2: Density regularization or velocity regularization can be used to follow a 1D manifold in 2D.

Appendix D Biological Considerations

Quality control and normalization is important when estimating RNA-velocity from existing single cell measurements. We suspect that the RNA-velocity measurements from the Embryoid body data may be suspect given the low number of unspliced RNA counts present. In Figure S3 we can see that each timepoint consists of around - of unspliced RNA. This is relatively low relative to numbers in other recent works La Manno et al. (2018); Bergen et al. (2019); Kanton et al. (2019). Low unspliced RNA counts leads to more noise in the estimates of RNA velocity and lower quality.

Figure S3: Shows the ratio of spliced, ambiguous, and unspliced RNA counts over the 5 timepoints in the Embryoid body dataset. Mean unspliced here is around - of total counts, in other systems this is near  La Manno et al. (2018).

In Figure 8 we showed how TrajectoryNet can be projected back to the gene space. These projections can be used to infer the differences much earlier in time than they can be identified in the gene space. Here we have four populations that are easily identified by marker genes or clustering at the final timepoint. Since all four populations emerge from a single relatively uniform stem cell population, the question becomes how early can we identify the features of progenitor cells, the cells leading to these differentiated populations. Since TrajectoryNet models cells as probabilities continuously over time, we can find the path for each differentiated cell in earlier timepoints. This allows inferences such as the fact that HAND1, a gene that is generally high in cardiac cells, is high at earlier timepoints, and may even start to distinguish the population as early as day 6. A gene like ONECUT2 is only starts to distinguish neuronal populations at later timepoints. For further information on this particular system see Moon et al. (2019) Figure 6.

Appendix E Reproducibility

To foster reproducibility, we provide as many details as possible on the experiments in the main paper. Code is available at

e.1 2D Examples

In Figure 2

we transport a Gaussian to an s-curve. The Gaussian consists of 10000 points sampled from a standard normal distribution. The s-curve is generated using the sklearn function

sklearn.datasets.make_s_curve with noise of 0.05, and 10000 samples. We then take the first and third dimension, and multiply by 1.5 for the proper scaling. To generate the OT subplot we used the Mccann interpolant from 200 points sampled from the Gaussian. To generate panel (d), we used the procedure detailed in the beginning of Section 5 to train TrajectoryNet, then sampled 200 points from a Gaussian distribution and used the adjoint with these points as the initial state at time to generate points at time . For panel (e) we added an energy regularization with and . These were found by experimentation, although parameters in the range of and were largely visually similar.

To generate the arch and tree datasets we started with two half Gaussians at mean zero and one (as pictured in Figure 4) with 5000 points each, then found the Mccann interpolant at as the test distribution. We then lift these into 2d by embedding on the half circle of radius 1 and adding noise to the radius. To generate velocity, we add a velocity tangent to the circle for each point. For the tree dataset we additionally flip (randomly) half of the points with over the line .

For the Cycle dataset, we start with 5000 uniformly sampled points around the circle, with radius as We then add an arrow tangent to the circle with magnitude , Thus in one time unit the points should move of the way around the circle.

e.2 Single Cell Datasets

Both single cell datasets were sequenced using 10X sequencing. The Embryoid body data can be found here222 and consists of roughly 30,000 cells unfiltered, and 16,000 cells after filtering. The mouse cortex dataset is not currently publicly available, but consists of roughly 20,000 cells after filtering. For both datasets no batch correction was used. Raw sequences were processed with CellRanger. We then used velocyto La Manno et al. (2018) to produce the unspliced and spliced count matrices. We then used the default parameters in ScVelo Bergen et al. (2019) to generate velocity arrows on a PCA embedding. These include count normalization across rows, selection of the 3000 most variable genes, filtering of low quality genes, and smoothing counts between cells.

For parameters we did a grid search over , . For Base+E we did a search of , A more extensive search could lead to better results. We intended to show how these regularizations can be used and demonstrate the viability of this approach rather than fully explore parameter space.

e.3 Software Versioning

The following software versions were used.

scvelo==0.1.24, torch==1.3.1, torchdiffeq==0.0.1, velocyto==0.17.17, scprep==1.0.3, scipy==1.4.1, scikit-learn==0.22, scanpy==1.4.5