Learning transport cost from subset correspondence

by   Ruishan Liu, et al.

Learning to align multiple datasets is an important problem with many applications, and it is especially useful when we need to integrate multiple experiments or correct for confounding. Optimal transport (OT) is a principled approach to align datasets, but a key challenge in applying OT is that we need to specify a transport cost function that accurately captures how the two datasets are related. Reliable cost functions are typically not available and practitioners often resort to using hand-crafted or Euclidean cost even if it may not be appropriate. In this work, we investigate how to learn the cost function using a small amount of side information which is often available. The side information we consider captures subset correspondence---i.e. certain subsets of points in the two data sets are known to be related. For example, we may have some images labeled as cars in both datasets; or we may have a common annotated cell type in single-cell data from two batches. We develop an end-to-end optimizer (OT-SI) that differentiates through the Sinkhorn algorithm and effectively learns the suitable cost function from side information. On systematic experiments in images, marriage-matching and single-cell RNA-seq, our method substantially outperform state-of-the-art benchmarks.



page 1

page 2

page 3

page 4


The Gene Mover's Distance: Single-cell similarity via Optimal Transport

This paper introduces the Gene Mover's Distance, a measure of similarity...

Learning Cost Functions for Optimal Transport

Learning the cost function for optimal transport from observed transport...

𝒲_∞-transport with discrete target as a combinatorial matching problem

In this short note, we show that given a cost function c, any coupling π...

Large-Scale Optimal Transport via Adversarial Training with Cycle-Consistency

Recent advances in large-scale optimal transport have greatly extended i...

Learning Hypergraph Labeling for Feature Matching

This study poses the feature correspondence problem as a hypergraph node...

Lower Bounds on Adversarial Robustness from Optimal Transport

While progress has been made in understanding the robustness of machine ...

Look Wider to Match Image Patches with Convolutional Neural Networks

When a human matches two images, the viewer has a natural tendency to vi...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In many applications, we have multiple related datasets from different sources or domains, and learning efficient computational mappings between these datasets is an important problem LongJordan17, zamir2018taskonomy. For example, we might have single-cell RNA-Seq datasets generated for the same tissue type from two different labs. Since data come from the same type of tissue, we would like to map cells between the two datasets to merge them, so that we could analyze them jointly. However, there are often complex nonlinear batch artifacts generated by the different labs. Moreover the cells are not paired—for each cell measured in the first lab, there is not an identical clone in the second lab. How to integrate or align these two datasets is therefore a challenging problem.

Optimal transport (OT) is an principled analytical framework to align heterogeneous datasets Santambrogio2015. It has been increasingly applied to problems in domain adaptation and transfer learning seguy2017large, genevay2017learning, courty2017optimal, li2019learning. Optimal transport is an approach for taking two datasets, and computing a mapping between them in the form of a "transport plan"

. The mapping is optimal in the sense that among all reasonable mappings (precisely defined in Section 2), it minimizes the cost of aligning the two datasets. The transport cost is given by the user and encodes expert knowledge about how datasets relate to each other. For example, if the expert believes that one data is essentially data with added Gaussian noise, then Euclidean cost could be natural. If the cost is correctly specified, then there are powerful methods for finding the global optimal transport Villani2008. A major challenge in practice, e.g. for single-cell RNA-seq, is that we and experts do not know what cost is appropriate. Users often resort to using Euclidean or other hand-crafted cost functions, which could give misleading mappings.

Our contributions.

We propose a novel approach to automatically learn good transport costs by leveraging side information we may have about the data. The side information that we model is in the form of knowledge that a certain subset of points in dataset should be mapped to another subset of points in dataset . In the single-cell example, we often have cell type annotations that certain points are, say, T cells. Then we can deduce that T cells from lab one should be at least mapped to T cells from lab two. We only need T cells to be crudely annotated in both datasets, which is reasonable; we don’t need to know that a particular T cell should be mapped to another specific T cell.

We present the first algorithm, OT-SI, to leverage subset correspondence as a general form of side information. In contrast, previous works mainly focus on pair matching problems li2019learning, galichon2010matching — the extreme case of subset correspondence when the subset sizes are 1. In practice, exact one-to-one matching labels are often expensive to obtain or even intractable. OT-SI is an end-to-end framework that learns the transport cost. The intuition is to optimize over a parametrized family of transport costs, to identify the cost under which the annotated subsets are naturally mapped to each other via optimal transport. OT-SI efficiently leverages even a small amount of side information and it generalizes well to new, unannotated data. The learned transport cost is also interpretable. We demonstrate in extensive experiments across image, single-cell, marriage and synthetic datasets that our method OT-SI substantially outperform state-of-the-art methods for mapping datasets.

Related Work

Optimal transport been well studied in the mathematics, statistics and optimization literature Villani2008, courty2017learning, li2019learning, courty2017optimal. OT can be used to define a distance metric between distributions (e.g. Wasserstein distance) or to produce an explict mapping across two datasets. The latter is the focus of our paper. In machine learning, there has been significant work on developing fast algorithms for efficient computation of the optimal transport plan cuturi2013sinkhorn, altschuler2017near, staib2017parallel, and analyzing the properties of the transport plan under various structures and constraints on the optimization problem alvarez2018structured, titouan2019optimal. The previous work on learning the transport cost is done on a very different setting from ours – learning feature histogram distances between many pairs of datapoints cuturi2014ground. Some classical clustering and alignment methods xing2003distance, HLS05, WM08, WM09 have realized benefits by including side information, but these nonparametric methods differ from our explicit parametrization and optimization of the transport cost function.

Separately, there have been recent efforts to directly map between datasets, without learning a transport cost. The standard alignment methods can be divided into two categories: GANs-based zhu2017unpaired, choi2018stargan and OT-based grave2019unsupervised, alvarez2019towards. GAN-based approaches have been used to align single-cell RNA-seq data when pairs of cells are known to be related amodio2018magan. However the exact pairing of individual cells is always not readily available or even intractable. To address this issue, our method OT-SI allows for more general correspondence between subsets, i.e., clusters, cell types and also individual cells. In the meantime, the OT-based methods always rely on Procrustes analysis rangarajan1997softassign — a linear transformation between the datasets is assumed, which lacks the flexibility to handle nonlinear artifacts and the side information cannot be utilized. In contrast, a major benefit of our approach is its graceful adaptation to partial subset correspondence information, where we frame the problem as semi-supervised.

2 Learning Cost Metrics

A good choice of the cost function for optimal transport is the key to a successful mapping between two datasets. In this section, we present the algorithm OT-SI, which parametrizes the cost function with weight and adaptively learns using side information about the training data. The side information we consider is subset correspondence — a common situation when some subsets of training points are known to be related; pair matching is included as an extreme case. The learned cost function is further evaluated on the unseen test data to prove generalizability.

2.1 Optimal Transport

Consider learning a mapping between two datasets and . Here we use and to denote the number of datapoints; each sample or

could be a vector as well. We briefly recall the optimal transport framework in this setting. Given probability vectors

and , the transport polytope is defined as


where () is the () dimensional vector of ones. Here the probability vector () is in the simplex for (

). For two random variables with distribution

and , the transport polytope represents the set of all possible joint probabilities of the two variables. In this paper, we consider and to represent the empirical distributions of the samples X and Y, respectively, and set .

Given a cost matrix , the classical optimal transport plan between and is defined as , where denotes the Frobenius inner product.

is also called a coupling. Despite its intuitive formulation, the computation of this linear program quickly becomes prohibitive especially in the common situation when

and , the sizes of the datasets, exceed a few hundred. For computational efficiency, Sinkhorn-Knopp iteration is widely used to compute the optimal transport cuturi2013sinkhorn. Sinkhorn-Knopp is a fast iterative algorithm for approximately solving the optimization problem with entropy regularization [Santambrogio2015]:


where is a regularization parameter and denotes the entropy. The regularized solution converges to the classical one when the regularization diminishes, i.e., , with exponential convergence rate cominetti1994asymptotic. The transport treats and symmetrically.

The cost matrix is retrieved from the cost function . A good choice of the cost function is the key to influencing the learned mapping . However, reliable cost functions are typically not available and Euclidean cost is mostly used. In this paper, the representation of the cost function is adaptively learned using side information about the data.

2.2 Side Information

Subset correspondence describes a common situation when certain subsets of points are known to be related. For example, images with the same objects should always be mapped together in domain adaptation tasks courty2017optimal, while cells in a single-cell dataset need to be aligned to those with the same cell type, where the cell type annotation is available.

Given corresponding subsets, we write and , , to denote the sets of data indices in the corresponding subsets, i.e., and . Note that and could have different probability mass. If , we take this side information to be that should be mapped into . In other words, all the other entries of the transport matrix that maps to outside of should be 0. Everything is swapped if . Mathematically, this side information corresponds to the constraint that , where


Note that pair matching is an extreme case of subset correspondence with subset size 1 — that is, the exact pairwise relation is known. The pair matching problem has been addressed in literature li2019learning, galichon2010matching. However, in practice, exact one-to-one matching labels are often expensive to obtain or even intractable. In this paper, we show that subset correspondence, as a small amount of side information, can significantly aid cost function learning.

We investigate how to learn the cost function using subset correspondence information between training datasets and . The learned cost function is evaluated via the mapping quality on the test datasets and . Note that the training and test sets are not necessarily under the same distribution. We demonstrate the power of OT-SI in generalizing to new subsets that were not seen in the training process.

2.3 The OT-SI Algorithm

Our ultimate goal is to learn a cost function , such that the computed optimal transport satisfies the side information given in Eq.  (3) as faithfully as possible.

Cost function parametrization.

When the cost function is Euclidean, the entry of the cost matrix is computed as , where is the data dimension. To learn the cost function systematically, we parametrize it as with weight . Here the function form can be chosen by users. To illustrate the improvement over the commonly-used Euclidean cost, we parametrize as a polynomial in with coefficients and degree 2 for low-dimensional data. The Euclidean cost is equivalent to a specific choice of

which is set as the initialization; see Appendix for more discussions. For high-dimensional data, the memory required to store the second order polynomials becomes too large and we use a fully connected neural network to parametrize

with input and weights . Throughout this paper, the polynomial parametrization is used if not specified.

The optimal transport solution is characterized by as


Then the problem can be formulated as optimizing to make the transport approximately satisfy the conditions defined in Eq. (3), penalizing deviation of the solution from these constraints with the loss

Theorem 1.

For any , the optimal transport plan : is in the interior of its domain.

The infinite differentiability of the Sinkhorn distance is previously-known luise2018differential; Thm. 1 proves that the Sinkhorn transport plan also has this desirable property. Because Thm. 1 guarantees that is infinitely differentiable, we are able to optimize in Eq. (5) by gradient descent. In practice, we iterate Sinkhorn’s update a sufficient number of times to converge to . Each iteration is a matrix operation involving the cost matrix , and when the number of iterations is fixed, we can propagate the gradient

through all of the iterations using the chain rule. Updating

by one forward and backward pass has complexity of up to logarithmic terms. Hence OT-SI has the same complexity as the Procrustes-based OT methods which alternatively optimize over coupling matrix and linear transformation matrix grave2019unsupervised, alvarez2019towards. To further boost the performance, we propose to use a mimic learning method for initialization, which does not need to propagate the gradient. The pseudocode for OT-SI is in Algorithm 1 and details are in the Appendix. The proof of Thm. 1, mimic learning algorithm, details, and discussions about convergence are also in the Appendix.

1:training datasets and , corresponding subsets index and (), step size , total training steps , weights from initialization procedure, Sinkhorn regularization parameter and number of Sinkhorn iterations .
2:, , ,
4:for  to  do
5:     Compute cost matrix with entries .
6:     Solve with Sinkhorn’s update with iterations.
7:     Derive

by backpropagating the gradient through all Sinkhorn-Knopp iterations.

8:     Update weights
9:end for
Algorithm 1 OT-SI

2.4 Experiments Setup

The OT-SI algorithm is carried out in Pytorch paszke2017automatic and trained with GPU. The model is fitted on training set, and evaluated on test set. We use validation set for hyperparameter selection and early stopping. We evaluate OT-SI with different types of data and correspondence information.

Comparison methods.

We use optimal transport with Euclidean cost function as a baseline for comparison, referred as "OT-baseline". We also compare our result with state-of-the-art GAN-based data alignment methods, MAGAN amodio2018magan and CycleGAN zhu2017unpaired, as well as the OT-based methods, RIOT li2019learning which is developed for specific pair matching applications and the Procrustes-based OT grave2019unsupervised, referred as "OT-Procrustes". For MAGAN and CycleGAN, the matching point for a source sample is set as its nearest neighbor in the target after mapping. Among the five comparison methods, OT-baseline, OT-Procrustes and CycleGAN do not use any side information; MAGAN makes use of matching pairs; RIOT requires the one-to-one matching labels for all the datapoints. Because MAGAN and RIOT requires pairwise correspondence, they are not applied in some experiments and these are marked as N/A. We use the same settings and hyperparameters for the comparison methods as in their original implementations.

Evaluation metrics.

When the subset correspondence is known on the test set (not shown to the algorithm), we evaluate a transport plan by how much it satisfies the correspondence. Mathematically, we define subset matching accuracy:


From the definition, gives the probability of mapping to the correct corresponding subsets. When all the test datapoints are mapped into the correct subsets, the accuracy is 1; when all the data are matched to the wrong subsets, accuracy is 0. As an extreme example, pair matching is equivalent to subset correspondence with subset sizes 1, referred to as pair matching accuracy. In the next few sections, we thoroughly evaluate OT-SI and several state-of-the-art methods in extensive and diverse experiments—aligning single-cell RNA-seq data to correct for batch effects, aligning single-cell gene expression and protein abundances, a marriage data, an image dataset, and the synthetic twin-moon data for illustration.

3 Benchmark on Synthetic Datasets

Subsets N/A N/A
1 Pair N/A
10 Pairs
Subsets 2% N/A 0% N/A
1 Pair 2% 0% N/A
10 Pairs 2% 0% 0%
Table 1:

Subset matching and pair matching accuracy on test data for two moon datasets. Here the subset (pair) matching accuracy corresponds to the proportion of the data points that are aligned to the correct moon (data points) on the test set. Higher is better. We generated 10 independent datasets and the standard deviation is shown.

(a) Data
(b) OT-baseline.
(c) OT-SI.
(d) OT-Procrustes.
Figure 1: Illustration of the two moon datasets and the optimal transport result. The target domain (blue) is built by adding noise to the source (orange) and rotating by 60 degrees. Corresponding subsets are denoted by circles and crossings. (b-e) Optimal transport plan under (b) Euclidean cost (OT-baseline) (c) learned cost function by OT-SI and (d) Procrustes-based OT. Points learned to be matched are connected by solid curves. When a datapoint is matched to the wrong subset, i.e, to the other moon, the connection curve is colored by red.

We first experiment with the benchmark toy example for domain adaptation — two moon datasets—to illustrate the challenges of data alignment, before we move onto complex real-world data germain2013pac, courty2017optimal. The dataset is simulated with two domains, source and target. As shown in Fig. 0(a), each domain contains two standard entangled moons. The two moons are associated with two different classes, denoted by circle and crossing respectively. The target (colored in orange), is built by adding noise to the source (colored in orange) and rotating by 60 degrees. In the experiment, we generate the training, test and validation datasets with 100, 100, and 50 samples of each moon. We set the parameter and the number of Sinkhorn-Knopp iterations

. The algorithm is run for 100 epochs with step size 1. There are two types of side information available for OT tasks: (i) subset correspondence — datapoints are known to be mapped into the corresponding moon class; (ii) pair matching — known matched datapoints after rotation. The result is averaged over 10 (50) independent runs when the side information is subset correspondence (pair matching).

Baseline performance.

The optimal transport plan under Euclidean cost function is depicted in Fig. 0(b). Datapoints learned to be matched are connected by solid curves. The red curves indicate wrong transports which map the data into the wrong subset, i.e., the other moon. As shown by Fig. 0(b), most wrong transports are between the points at the edges of the moons. In the euclidean space, the edge of one moon becomes “closer" to the other moon after rotation, which leads to a small cost between datapoints in different moon classes. A new cost function which captures the rotation property is expected. Quantitatively, only of the data are mapped into the correct moon it belongs to and only of the data are matched to their corresponding pairs, as given in Table 1.

Subset correspondence.

We first evaluate our method OT-SI when the side information is only that the data on the corresponding moons are known to be related—i.e. and are the two moons in dataset one, and and are the two moons in dataset two. With the learned cost function, almost all the datapoints are mapped into the corresponding moon, i.e., the subset matching accuracy on the test data achieves for both methods, as shown in Table 1. The results are averaged for 10 independent runs. Interestingly, although there is no pair matching information provided during training, the learned cost function significantly improves the matching performance. As shown in Table 1, datapoints are transported into their exact matching points after rotation. The learned mapping of OT-SI is depicted in Fig. 0(c). The rotation property of the datasets is correctly captured. In contrast, OT-Procrustes sometimes learns as good as OT-SI, similar to Fig. 0(c), but sometimes mistakenly learns rotation as flipping, indicated by Fig. 0(d)

. This results in overall worse accuracy and large variance for OT-Procrustes.

Pair matching.

OT-SI demonstrates substantial improvement when the pair matching information is provided—only 1 and 10 pairs are known out of the total 100 training pairs. The matching pairs are randomly selected from the training data. OT-SI significantly outperforms all four bench methods particularly when the number of known pairs is very limited. Even when only 1 matching pair is provided, the learned cost function greatly improves the OT performance, as given in Table 1. The improvement here is largely attributable to the unlabeled data, i.e., the datapoints without any pair matching information. For comparison, we carry out another experiment with only 3 unlabeled data, with all other settings unchanged. The algorithms are not able to learn the right cost function anymore — the test pair matching accuracy is only after learning, even worse than the Euclidean baseline. With 199 unlabeled data, the accuracy achieves . In contrast, the competing methods MAGAN and RIOT learn barely any patterns, because too few labeled datapoints are available and the unlabeled ones are wasted.

4 Biological Manifold Alignment

In this section, we implement our method OT-SI to learn a cost function that aligns biological manifolds with partial supervision — annotations of some cell types or clusters, which is the common situation in biological studies. The pair matching methods, MAGAN and RIOT, are not compared here because the the cell-to-cell matching information is not available. Similar to Sec. 3, the CycleGAN does not learn correctly for these data types and is not presented here. We show that OT-SI has substantial improvement for aligning datasets with different data types and aligning data from different batches.

4.1 Alignment of Protein and RNA sequencing data

How to align datasets with different data types has been a major topic in many fields. For example, in single-cell studies, RNA and protein sequencing can both be done at cellular resolution. How to map between those two types of data, i.e., map cells with certain mRNA level to cells with certain protein level, becomes critical for downstream studies such as RNA and protein co-expression analysis.

We demonstrate the power of OT-SI in learning cost function that aligns two different data types — RNA and protein expression in the CITE-seq cord blood mononuclear cells (CBMCs) experiments

stoeckius2017simultaneous. The dataset is subset to 8,005 human cells with the expression of 5,001 highly variable genes and 13 proteins. Fifteen clusters are identified using the Louvain modularity clustering. The CITE-seq technology has enabled the simultaneous measurement of RNA and protein expression at single-cell level, hence the ground truth about the cell pairing is available. To emulate the common situation, we only use the information of cluster correspondence in the training and report the performance of both subset (cluster) matching and pair matching for test. We randomly sampled 500 and 500 cells for validation and test purpose. When OT-SI is learned in the original data space with the expression of 5001 mRNA and 13 proteins for each cell, we use a fully connected neural network to parametrize the cost function, with two hidden layers of 100 and 5 neurons.

We align the RNA and protein expression datasets in two scenarios: i) the embedding space where we use the first 10 PCs for both datasets; ii) the original expression space. Table 2 shows that OT-SI substantially outperforms other methods. OT-SI is able to learn good cost function when the dimensions of the two datasets are highly unbalanced as 5001:13. The learned cost metrics in the original expression space can be used for future biological analysis on the effect and relation between RNA and protein expressions. Although the single-cell sequencing data are noisy and the algorithm is not informed of any matching pair during training, the accuracy of test pair matching is improved.

Embedding Space (10:10) Original Space (5001:13)
OT-Baseline OT-SI OT-Procrustes OT-Baseline OT-SI OT-Procrustes
Subset Matching 9.9% 56.1% 44.3% N/A 43.9% N/A
Pair Matching 0 % 3.2% 1.0% 0.8%
Table 2: Subset matching and pair matching test accuracy for the alignment of protein and RNA expression data in CITE-seq CBMCs experiment. Here subset (pair) matching accuracy denotes the proportion of the cells that are aligned to the correct cluster (cells). Higher is better. For alignment in the original space, the expression of 5001 mRNAs are mapped to the expression of 13 proteins. For alignment in the embedding space, the first 10 principal components are used for both RNA and protein expression datasets.

4.2 Batch Alignment

#Training Cell Types 2 5 8 0 (OT-Baseline)
OT-SI 70.0% 75.0% 83.8% 70.0%
OT-Procrustes 36.9% 56.8% 61.5%
Table 3: Subset matching accuracy on the two held-out cell types, T cell and immature T cell, for aligning FACS and droplet data. The algorithms OT-SI and OT-Procrustes are trained on 2, 5 and 8 other cell types. Higher is better.

Another fascinating biological application of optimal transport is to align data from different batches. In biological studies, the samples processed or measured in different batches usually result in non-biological variations, known as batch effect chen2011removing. Here we use OT to align two batches of data111https://github.com/czbiohub/tabula-muris-vignettes/blob/master/data — data collected with fluorescence activated cell sorting (FACS) and droplet methods. In this case, the cell types are used as the subset correspondence information. For illustration purposes, we subsample the top 10 celltypes with 400 samples of each. There are 1,682 genes after filtering and the first 10 principal components are used for analysis. The dataset is split into training, validation and test sets with ratio 50%, 20% and 30%. The experiment setting is the same as in Sec. 3.

We demonstrate the power of our learned metric in generalizing to entirely new cell types that were not used to train the cost. This is a hard task (a zero-shot learning task), and is more realistic because in most settings we only have partial annotations for cell types and we would like the mapping to generalize to all of the data. To do this, we choose two cell types — T cell and immature T cell — as held-out and train on the rest. Among all the ground truth expert annotations, T cell and immature T cell are most difficult to be aligned. With OT-baseline, only cells are mapped into the correct cell types. Substantial improvement is achieved, as shown in Table 3. From a small number of annotated cells types, OT-SI is able to learn a transport cost that captures the batch artifacts between FACS and droplet which generalized to mapping these two new cell types. We anticipate future uses of our formulation to further investigate the cost function, particularly in biological discovery applications like isolating genes that mark single-cell heterogeneity.

5 Experiments on Images and Tabular Data Alignment

5.1 Image Alignment for New Digits

original watering swirl sphere flip
OT-baseline 70.0% 62.5% 57.5% 35.0%
OT-SI 57.5%
OT-Procrustes 48.6% 39.4%
Table 4: Subset matching accuracy on the two held-out digits 3 and 5 when trained our model on the rest eight digits for MNIST dataset. OT is used to align the original images with the perturbed ones.

To illustrate the use of OT-SI in image alignment, we use it to learn cost metrics in aligning images with partial annotation on the MNIST dataset which contains 28 28 images of handwritten digits. We subsample 200 images from each digit class and split them into training, validation, and test sets with ratio 50%, 20% and 30%. For illustration purposes, we use the first ten principal components for the alignment analysis. We generate four different types of perturbations to the original images, as plotted in Table 4. Then we align the original images with the perturbed ones using OT, respectively. The experiment setting is the same as in Sec. 3. Similar to Sec. 4, we test how well our algorithm generalizes to new classes that were not used in learning the cost function. In the experiment, we hold out digits 3 and 5, and demonstrate the metric learned on the other eight digits can help the alignment of digits 3 and 5. We achieve consistent improvement over the baseline on all four distribution types, as indicated by Table 4.

5.2 Marriage Dataset for Pair Matching

While OT-SI is designed for the more general form of side information — subset correspondence, it can also be used for pair matching purpose. We finally benchmark it from the comparison with other state-of-the-art pair matching methods, including RIOT, factorization machine model (FM) rendle2012factorization, probabilistic matrix factorization model (PMF) mnih2008probabilistic, item-based collaborative filtering model (itemKNN) cremonesi2010performance, classical SVD model koren2009matrix and baseline random predictor model. These methods were also used as comparisons in [li2019learning]. We follow the same experimental protocol as in [li2019learning] for the Dutch Household Survey (DHS) dataset. The exact matching matrix between 50 datapoints with 11 features are known. We note that the coupling matrix is treated as continuous and OT-Procrustes is not applicable.

The performance is evaluated by the root mean square error (RMSE) and the mean absolute error (MAE) for the predicted matching matrix, as given in Table 5. When used for pair matching purposes, OT-SI report comparable performance to state-of-the-art matching algorithms. Note that this marriage dataset was a primary motivating dataset used to design RIOT ([li2019learning]), and therefore we expect RIOT to perform very well for this task.

RMSE 54.7 77.8 109.0 2.4 2.4 9.5 2.4
MAE 36.5 36.1 62.0 1.6 1.5 7.5 1.5
Table 5: Root mean square error (RMSE) and mean absolute error (MAE) of pair matching algorithms for marriage-matching dataset. Lower is better.

6 Discussion

In this paper, we study the problem of learning the transport cost using side information in the form of a small number of corresponding subsets between the two datasets. This is a new problem formulation, to the best of our knowledge. Previous works rely on more restricted information such as that specific pairs of points should be aligned. In settings such as genomics and images, it is often difficult to say that a single point in dataset one should be mapped onto a particular point in dataset two. It is more common to have partial annotation of subsets of points—e.g. T cells are annotated in two single-cell RNA-seq datasets—which motivates our generalization.

We propose a flexible and principled method to learn the transport cost with side information. Experiments demonstrate that they work significantly better than state-of-the-art methods when the side-information is very limited, which is often the case. We compare against state-of-the-art methods for the special case when the side information consists of matching pairs, since we are not aware of other published OT methods that deal with the more general subset correspondence. One interesting reason for the improved performance is that by learning the transport cost directly, our algorithms are more efficiently using all of the unannotated datapoints that are not in any pairs or subsets. These unannotated data act as regularization (similar to in semi-supervised 1learning), which enables the model to avoid overfitting to the limited side information. An interesting direction of future work is to interpret the learned cost function for insights on how the datasets differ.


Appendix A Cost function

In Sec. 2.3, the cost function is parametrized as with weight . For low-dimensional data, we choose as a polynomial in with coefficients and degree 2:


Here we do not require the dimension of the two datasets and to be the same. When the two datasets and have different features, the learned weights indicate the coupling between different features in the data mapping. In general, the function form can be chosen by users. We have also investigated parametrizing as a small fully connected neural network and achieved very similar performance.

The Euclidean cost is equivalent to a specific choice of , when the two datasets have the same feature space . The entry of the cost matrix is computed as .

Appendix B OT-SI Algorithm

The Lagrangian dual of Eq. (4) is


By Sinkhorn’s scaling theorem sinkhorn1967concerning, the optimal transport plan is computed as


where and are the solutions to the dual problem in Eq. (8).

b.1 Proof of Theorem 1

The analysis for Theorem 1 follows the strategy of the proofs of Theorem 2 in luise2018differential.

Theorem 1.

For any , the optimal transport plan : is in the interior of its domain.

Proof. Based on Eq. (9), the optimal transport plan is a smooth function when , and are smooth. In the meantime, the cost is a linear function of , as indicated by Eq. (7). Thus to prove the smoothness of , we only need to demonstrate and are both smooth in .

Here we define


The dual problem in Eq. (8) becomes . From the definition, is smooth and strictly convex in . Note that is linear in . Then for any fixed in the interior of , there exits such that . The function due to the smoothness of . Now we fix such that . The strict convexity of ensures that is invertible.

From implicit function theorem, we can always find a function and a subset such that i) ; ii) for any ; iii) . That is, is a stationary point of the function for any in . Together with the strict convexity of , we derive . Recalling , we prove that is in the interior of its domain.

b.2 Convergence properties

The gradient of in Eq. (9) is computed as


The convergence of scaling factors and is linear (i.e. exponential in ), with the bounded rate given by franklin1989scaling; our experiments use a linear , and we do not find Sinkhorn to bottleneck convergence. Besides this point, convergence is determined by the loss landscape of . Analyzing convergence to the global optimum, and the role of primal regularization of , is a relevant open question.

Appendix C Mimic Learning for Initialization

In this section, we derive a mimic learning as an fast initialization method to boost the performance and accelerate the learning. While Algorithm 1 requires to differentiate through Sinkhorn updates, the mimic learning approach does not need to propagate the gradient through all the iterations and is applicable for any kind of OT algorithm. Here we take the classical optimal transport plan as an example.

As discussed in Sec. 2.3, our ultimate goal is to learn a cost function, such that the optimal transport satisfies the side information defined in Eq. (3) as faithfully as possible. From another perspective, we force an additional set of constraints on the transport plan to fulfill the condition in Eq. (3):


To quantify how much the learned in Eq. (2) follows the side information, we compare it with


Here is interpreted as the optimal transport plan when the side information is completely satisfied, and is the smallest transport distance under the constraint. With the cost function parametrized as , The optimal transport solution in Eq. (13) is characterized by as .

Then we also expect a good cost function to make the distance under constraint to be as small as the lowest one as possible — optimize to minimize the loss


We refer to this method as mimic learning, because its objective is to make the mimic the cost performance of .

Note that is the optimal solution for any transport matrix in . That is, the optimal distance for any . The equality holds true only when for the convex transport problem. In the meantime, we have . Thus the loss is always larger or equal to 0. When zero loss is achieved, we have , coinciding with the optimal solution for in Eq. (4).

Equation (14) describes the absolute difference between the two transport distances, but a relative difference is more desirable in practice to adjust for the scale of the objective function around . For example, scaling the cost matrix by a constant does not change the solutions and , but does scale the loss defined in Eq. (14) by the same constant. We modify the loss to be invariant to such scaling:


Here is a uniform matrix used to stand for the averaged performance of random transport plans. Eq. (15) captures how close the distance under constraint is to the best one, compared to other random transports.

The mimic learning is approximately solved by alternating minimization. As described in Algorithm 2, we iterate over two steps: (i) compute the value of and while fixing ; (ii) take one gradient step with respect to with fixed and . The computation for optimal transport plans and the optimization of are carried out in alternating fashion.

1:training datasets and , corresponding subsets index and (), step size , total steps , optimal transport solver OTSolver.
2:, , ,
3:Define transport polytopes and constraints from Eq. (12).
4:Initialize such that the cost function is equivalent to the Euclidean cost.
5:for  to  do
6:     Compute cost matrix as
7:     ,
9:     Update weights
10:end for
Algorithm 2 Mimic Learning for Initialization

The OT solver is used only to estimate the value of

and in the first step, requiring no gradient propagation. Given such estimates of transport mappings and , the second step can be interpreted as learning a cost function which equates their transport costs, i.e., makes the behavior of mimic that of . In the experiments, we set and for initialization purpose.