Graph Optimal Transport for Cross-Domain Alignment

06/26/2020 ∙ by Liqun Chen, et al. ∙ 9

Cross-domain alignment between two sets of entities (e.g., objects in an image, words in a sentence) is fundamental to both computer vision and natural language processing. Existing methods mainly focus on designing advanced attention mechanisms to simulate soft alignment, with no training signals to explicitly encourage alignment. The learned attention matrices are also dense and lacks interpretability. We propose Graph Optimal Transport (GOT), a principled framework that germinates from recent advances in Optimal Transport (OT). In GOT, cross-domain alignment is formulated as a graph matching problem, by representing entities into a dynamically-constructed graph. Two types of OT distances are considered: (i) Wasserstein distance (WD) for node (entity) matching; and (ii) Gromov-Wasserstein distance (GWD) for edge (structure) matching. Both WD and GWD can be incorporated into existing neural network models, effectively acting as a drop-in regularizer. The inferred transport plan also yields sparse and self-normalized alignment, enhancing the interpretability of the learned model. Experiments show consistent outperformance of GOT over baselines across a wide range of tasks, including image-text retrieval, visual question answering, image captioning, machine translation, and text summarization.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 7

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Cross-domain Alignment (CDA), which aims to associate related entities across different domains, plays a central role in a wide range of deep learning tasks, such as image-text retrieval (Karpathy and Fei-Fei, 2015; Lee and others, 2018), visual question answering (VQA) (Malinowski and Fritz, 2014; Antol and others, 2015), and machine translation (Bahdanau et al., 2015; Vaswani et al., 2017). Take VQA as an example. In order to understand the contexts in the image and the question, a model needs to interpret the latent alignment between regions in the input image and words in the question. Specifically, a good model should: () identify entities of interest in both the image (e.g., objects/regions) and the question (e.g., words/phrases), () quantify both intra-domain (within the image or sentence) and cross-domain relations between these entities, and then () design good metrics for measuring the quality of cross-domain alignment drawn from these relations, in order to optimize towards better results.

CDA is particularly challenging as it constitutes a

weakly supervised learning

task. That is, only paired spaces of entity are given (e.g., an image paired with a question), while the ground-truth relations between these entities are not provided (e.g., no supervision signal for a “dog” region in an image aligning with the word “dog” in the question). State-of-the-art methods mostly focus on designing advanced attention mechanisms to simulate soft alignment (Bahdanau et al., 2015; Xu et al., 2015; Yang et al., 2016b, a; Vaswani et al., 2017). For example, Lee and others (2018); Kim et al. (2018); Yu et al. (2019) have shown that learned co-attention can model dense interactions between entities and infer cross-domain latent alignments for vision-and-language tasks. Graph attention has also been applied to relational reasoning for image captioning (Yao et al., 2018) and VQA (Li et al., 2019a), such as graph attention network (GAT) (Veličković et al., 2018) for capturing relations between entities in a graph via masked attention, and graph matching network (GMN) (Li et al., 2019b) for graph alignment via cross-graph soft attention. However, conventional attention mechanisms are guided by task-specific losses, with no training signal to explicitly encourage alignment. And the learned attention matrices are often dense and uninterpretable, thus inducing less effective relational inference.

Is there a more principled approach to scalable discovery of cross-domain relations? To explore this, we present Graph Optimal Transport (GOT), a new framework for cross-domain alignment that germinates from recent advances in Optimal Transport (OT). OT-based learning aims to optimize for distribution matching via minimizing the cost of transporting one distribution to another. We extend this to CDA (here domain can be language, images, videos, etc.). The transport plan is thus redefined as transporting the distribution of embedding from one domain (e.g., language) to another (e.g., images). By minimizing the cost of the learned transport plan, we explicitly minimize the embedding distance between the domains, i.e., optimizing towards better cross-domain alignment.

Specifically, we convert entities (e.g., objects, words) in each domain (e.g.

, image, sentence) into a graph, where each entity is represented by a feature vector, and the graph representations are recurrently updated via graph propagation. Cross-domain alignment can then be formulated into a graph matching problem, and be addressed by calculating matching scores based on graph distance. In our GOT framework, we utilize two types of OT distance: (

) Wasserstein distance (WD) (Peyré et al., 2019) is applied to node (entity) matching, and () Gromov-Wasserstein distance (GWD) (Peyré et al., 2016) is adopted for edge (structure) matching. WD only measures the distance between node embeddings across domains, without considering topological information encoded in the graphs. GWD, on the other hand, compares graph structures by measuring the distance between a pair of nodes within each graph. When fused together, the two distances allow the proposed GOT framework to effectively take into account both node and edge information for better graph matching.

The main contributions are summarized as follows. () We propose Graph Optimal Transport (GOT), a new framework that tackles cross-domain alignment by adopting Optimal Transport for graph matching. () GOT is compatible to existing neural network models, acting as an effective drop-in regularizer to the original objective. () To demonstrate the versatile generalization ability of the proposed approach, we conduct extensive experiments on five diverse tasks: image-text retrieval, visual question answering, image captioning, machine translation, and text summarization. Results show that GOT can provide consistent performance lift over strong baselines across all the tasks.

2 Graph Optimal Transport Framework

In this section, we first introduce the problem formulation of Cross-domain Alignment in Sec. 2.1, then present the proposed Graph Optimal Transport (GOT) framework in Sec. 2.22.4.

2.1 Problem Formulation

Assume we have two sets of entities from two different domains (denoted as and ). For each set, every entity is represented by a feature vector, i.e., and , where and are the number of entities in each domain, respectively. The scope of this paper mainly focuses on tasks involving images and text, thus entities here correspond to objects in an image or words in a sentence. An image can be represented as a set of detected objects, each associating with a feature vector (e.g., from a pre-trained Faster RCNN (Anderson et al., 2018)). With a word embedding layer, a sentence can be represented as a sequence of word feature vectors.

A deep neural network can be designed to take both and as initial inputs, and generate contextualized representations:

(1)

where , , and advanced attention mechanisms (Bahdanau et al., 2015; Vaswani et al., 2017) can be applied to to simulate soft alignment. The final supervision signal is then used to learn , i.e., the training objective is defined as:

(2)

Here are some instantiations for different tasks: () Image-text Retrieval. and are image and text features, respectively. is the binary label, indicating whether the input image and sentence are paired or not. can be the SCAN model (Lee and others, 2018), and corresponds to ranking loss (Faghri et al., 2018; Chechik et al., 2010). () VQA. Here denotes the ground-truth answer, can be BUTD or BAN model (Anderson et al., 2018; Kim et al., 2018), and is cross-entropy loss. () Machine Translation. and are textual features from the source and target sentences, respectively. can be an encoder-decoder Transformer model (Vaswani et al., 2017), and corresponds to cross-entropy loss that models the conditional distribution of . is not needed. To simplify future discussions, all the tasks are abstracted into and .

In most previous work, the learned attention can be interpreted as a soft alignment between and . However, only the final supervision signal is used for model training, thus lacking an objective explicitly encouraging cross-domain alignment. To enforce alignment and cast a regularizing effect on model training, we propose a new objective for Cross-domain Alignment:

(3)

where is a regularization term that encourages alignments explicitly, and is a hyper-parameter that balances the two terms. Through gradient back-propagation, the learned supports more effective relational inference. In Section 2.4 we describe in detail.

2.2 Dynamic Graph Construction

Image and text data inherently contain rich sequential/spatial structures. By representing them as graphs and performing graph alignment, not only cross-domain relations can be modeled, but also intra-domain relations are exploited (e.g., semantic/spatial relations among detected objects in an image (Li et al., 2019a)).

Figure 1: Illustration of the Wasserstein Distance (WD) and the Gromov-Wasserstein Distance (GWD) used for node and structure matching, respectively. WD: is calculated between node and across two domains; GWD: is calculated between edge and . See Sec. 2.3 for details.

Given , we aim to construct a graph , where each node is represented by a feature vector . To add edges , we first calculate the similarity between a pair of entities inside a graph: . Further, we define , where is a threshold hyper-parameter for the graph cost matrix. Empirically, is set to . If , an edge is added between node and . Given , another graph can be similarly constructed. Since both and are evolving through the update of parameters during training, this graph construction process is considered “dynamic”. By representing the entities in both domains as graphs, cross-domain alignment is naturally formulated into a graph matching problem.

In our proposed framework, we use Optimal Transport (OT) for graph matching, where a transport plan is learned to optimize the alignment between and . OT possesses several idiosyncratic characteristics that make it a good choice for solving CDA problem: () Self-normalization: all the elements of sum to 1 (Peyré et al., 2019). () Sparsity: when solved exactly, OT yields a sparse solution containing non-zero elements at most, where , leading to a more interpretable and robust alignment (De Goes and others, 2011). () Efficiency

: compared with conventional linear programming solvers, our solution can be readily obtained using iterative procedures that only require matrix-vector products 

(Xie et al., 2018), hence readily applicable to large deep neural networks.

1:  Input: ,,
2:  ,
3:  ,
4:  for  do
5:       // is Hadamard product
6:      for  do
7:          ,
8:      end for
9:      
10:  end for
11:  
12:  Return , // is the Frobenius dot-product
Algorithm 1 Computing Wasserstein Distance.
1:  Input: ,

, probability vectors

,
2:  Compute intra-domain similarities:
3:   , ,
4:  Compute cross-domain similarities:
5:   
6:  for  do
7:      // Compute the pseudo-cost matrix
8:      
9:      Apply Algorithm 1 to solve transport plan
10:  end for
11:  
12:  Return ,
Algorithm 2 Computing Gromov-Wasserstein Distance.
Figure 2: Schematic computation graph of the Graph Optimal Transport (GOT) distance used for cross-domain alignment. WD is short for Wasserstein Distance, and GWD is short for Gromov-Wasserstein Distance. See Sec. 2.1 and 2.4 for details.

2.3 Optimal Transport Distances

As illustrated in Figure 1, two types of OT distance are adopted for our graph matching: Wasserstein distance for node matching, and Gromov-Wasserstein distance for edge matching.

Wasserstein Distance

Wasserstein distance (WD) is commonly used for matching two distributions (e.g., two sets of node embeddings). In our setting, discrete WD can be used as a solver for network flow and bipartite matching (Luise et al., 2018). The definition of WD is described as follows.

Definition 2.1.

Let denote two discrete distributions, formulated as and , with as the Dirac function centered on .

denotes all the joint distributions

, with marginals and . The weight vectors and belong to the - and -dimensional simplex, respectively (i.e., ), where both and

are probability distributions. The Wasserstein distance between the two discrete distributions

is defined as:

(4)

where , denotes an -dimensional all-one vector, and is the cost function evaluating the distance between and . For example, the cosine distance is a popular choice. The matrix is denoted as the transport plan, where represents the amount of mass shifted from to .

defines an optimal transport distance that measures the discrepancy between each pair of samples across the two domains. In our graph matching, this is a natural choice for node (entity) matching.

Gromov-Wasserstein Distance

Instead of directly calculating distances between two sets of nodes as in WD, Gromov-Wasserstein distance (GWD) (Peyré et al., 2016; Chowdhury and Mémoli, 2019) can be used to calculate distances between pairs of nodes within each domain, as well as measuring how these distances compare to those in the counterpart domain. GWD in the discrete matching setting can be formulated as follows.

Definition 2.2.

Following the same notation as in Definition 2.1, Gromov-Wasserstein distance between is defined as:

(5)

where is the cost function evaluating the intra-graph structural similarity between two pairs of nodes and , i.e., , where are functions that evaluate node similarity within the same graph (e.g.

, the cosine similarity).

Similar to WD, in the GWD setting, and (corresponding to the edges) can be viewed as two nodes in the dual graphs (Van Lint et al., 2001), where edges are projected into nodes. The learned matrix now becomes a transport plan that helps aligning the edges in different graphs. Note that, the same and are also used for graph construction in Sec. 2.2.

2.4 Graph Matching via OT Distances

Though GWD is capable of capturing edge similarity between graphs, it cannot be directly applied to graph alignment, since only the similarity between and is considered, without taking into account node representations. For example, the word pair (“boy”, “girl”) has similar cosine similarity as the pair (“football”, “basketball”), but the semantic meanings of the two pairs are completely different, and should not be matched.

On the other hand, WD can match nodes in different graphs, but fails to capture the similarity between edges. If there are duplicated entities represented by different nodes in the same graph, WD will treat them as identical and ignore their neighboring relations. For example, given a sentence “there is a red book on the blue desk” paired with an image containing several desks and books in different colors, it is difficult to correctly identity which book in the image the sentence is referring to, without understanding the relations among the objects in the image.

To best alloy WD and GWD together and unify these two distances in a mutually-benefiting way, we propose a transport plan shared by both WD and GWD. Compared with naively employing two different transport plans, we observe that this joint plan works better (see Table 8), and faster, since we only need to solve once (instead of twice). Intuitively, with a shared transport plan, WD and GWD can enhance each other effectively, as utilizes both node and edge information simultaneously. Formally, the proposed GOT distance is defined as:

(6)

We apply the Sinkhorn algorithm (Cuturi, 2013; Cuturi and Peyré, 2017) to solve WD (2.1) with an entropic regularizer (Benamou et al., 2015):

(7)

where , and is the hyper-parameter controlling the importance of the entropy term. Details are provided in Algorithm 1. The solver for GWD can be readily developed based on Algorithm 1, where

are defined as uniform distributions (as shown in Algorithm

2), following Alvarez-Melis and Jaakkola (2018)

. With the help of Sinkhorn algorithm, GOT can be efficiently implemented in popular deep learning libraries, such as PyTorch and TensorFlow.

To obtain a unified solver for the GOT distance, we define the unified cost function as:

(8)

where is the hyper-parameter for controlling the importance of different cost functions. Instead of using projected gradient descent or conjugated gradient descent as in Xu et al. (2019b, a); Vayer et al. (2018), we can approximate the transport plan by adding back in Algorithm 2, so that Line 9 in Algorithm 2 helps solve for both WD and GWD at the same time, effectively matching both nodes and edges simultaneously. The solver for calculating the GOT distance is illustrated in Figure 2, and the detailed algorithm is summarized in Algorithm 3. The calculated GOT distance is used as the cross-domain alignment loss in (3), as a regularizer to update parameters .

1:  Input: ,, hyper-parameter
2:  Compute intra-domain similarities:
3:   , ,
4:  , // denote two MLPs
5:  Compute cross-domain similarities:
6:   
7:  if  is shared: then
8:      Update in Algorithm 2 (Line 8) with:
9:       
10:      Plug in back to Algorithm 2 and solve new
11:      Compute
12:  else
13:      Apply Algorithm 1 to obtain
14:      Apply Algorithm 2 to obtain
15:      
16:  end if
17:  Return
Algorithm 3 Computing GOT Distance.

3 Related Work

Optimal Transport

Wasserstein distance (WD), a.k.a. Earth Mover’s distance, has been widely applied to machine learning tasks. In computer vision,

Rubner et al. (1998) uses WD to discover the structure of color distribution for image search. In natural language processing, WD has been applied to document retrieval (Kusner et al., 2015) and sequence-to-sequence learning (Chen et al., 2019a). There are also studies adopting WD in Generative Adversarial Network (GAN) (Goodfellow et al., 2014; Salimans et al., 2018; Chen et al., 2018; Mroueh et al., 2018; Zhang et al., 2020) to alleviate the mode collapsing issue. Recently, it has also been used for vision-and-language pre-training to encourage word-region alignment (Chen et al., 2019b). Besides WD, Gromov-Wassersten distance (Peyré et al., 2016) has been proposed for distributional metric matching and applied to unsupervised machine translation (Alvarez-Melis and Jaakkola, 2018).

There are different ways to solve the OT distance, such as linear programming. However, this solver is not differentiable, thus cannot be applied in deep learning frameworks. Recently, WGAN (Arjovsky and others, 2017) proposes to approximate the dual form of WD by imposing a 1-Lipschitz constraint on the discriminator. Note that the duality used for WGAN is restricted to the W-1 distance, , . Sinkhorn algorithm is first proposed in Cuturi (2013) as a solver for calculating an entropic regularized OT distance. Thanks to the Envelop Theorem (Cuturi and Peyré, 2017), Sinkhorn algorithm can be efficiently calculated and readily applied to neural networks. More recently, Vayer et al. (2018) proposes the fused GWD for graph matching. Our proposed GOT framework enjoys the benefits of both Sinkhorn algorithm and fused GWD: () capable of capturing more structured information via marrying both WD and GWD; and () scalable to large datasets and trainable with deep neural networks.

Sentence Retrieval Image Retrieval
Method R@1 R@5 R@10 R@1 R@5 R@10 Rsum
VSE++ (ResNet) (Faghri et al., 2018) 52.9 87.2 39.6 79.5
DPC (ResNet) (Zheng et al., 2020) 55.6 81.9 89.5 39.1 69.2 80.9 416.2
DAN (ResNet) (Nam et al., 2017) 55.0 81.8 89.0 39.4 69.2 79.1 413.5
SCO (ResNet) (Huang et al., 2018) 55.5 82.0 89.3 41.1 70.5 80.1 418.5
SCAN (Faster R-CNN, ResNet) (Lee and others, 2018) 67.7 88.9 94.0 44.0 74.2 82.6 452.2
Ours (Faster R-CNN, ResNet):
SCAN + WD 70.9 92.3 95.2 49.7 78.2 86.0 472.3
SCAN + GWD 69.5 91.2 95.2 48.8 78.1 85.8 468.6
SCAN + GOT 70.9 92.8 95.5 50.7 78.7 86.2 474.8
VSE++ (ResNet) (Faghri et al., 2018) 41.3 81.2 30.3 72.4
DPC (ResNet) (Zheng et al., 2020) 41.2 70.5 81.1 25.3 53.4 66.4 337.9
GXN (ResNet) (Gu et al., 2018) 42.0 84.7 31.7 74.6
SCO (ResNet) (Huang et al., 2018) 42.8 72.3 83.0 33.1 62.9 75.5 369.6
SCAN (Faster R-CNN, ResNet)(Lee and others, 2018) 46.4 77.4 87.2 34.4 63.7 75.7 384.8
Ours (Faster R-CNN, ResNet):
SCAN + WD 50.2 80.1 89.5 37.9 66.8 78.1 402.6
SCAN + GWD 47.2 78.3 87.5 34.9 64.4 76.3 388.6
SCAN + GOT 50.5 80.2 89.8 38.1 66.8 78.5 403.9
Table 1: Results on image-text retrieval evaluated on Recall@ (R@). Upper panel: Flickr30K; lower panel: COCO.

Graph Neural Network

Neural network operated on graph data was first introduced in Gori et al. (2005)

using recurrent neural networks. Later,

Duvenaud et al. (2015)

proposed a convolutional neural network over graphs for classification tasks. However, these methods suffer from scalability issue, because they need to learn node-degree-specific weight matrices for large graphs. To alleviate this issue,

Kipf and Welling (2016) proposes to use a single weight matrix per layer in the neural network, which is capable of handling varying node degrees through an appropriate normalization of the adjacency matrix of the data. To further improve the classification accuracy, graph attention network (GAT) (Veličković et al., 2018) is proposed by using a learned weight matrix instead of the adjacency matrix, with masked attention to aggregate node neighbourhood information.

Recently, graph neural network has been extended to other tasks beyond classification. Li et al. (2019b) proposes graph matching network (GMN) for learning similarities between graphs. Similar to GAT, masked attention is applied to aggregate information from each node within a graph, and cross-graph information is further exploited via soft attention. Task-specific losses are then used to guide model training. In this setting, adjacency matrix can be directly obtained from the data and soft attention is used to induce alignment. In contrast, our GOT framework does not rely on explicit graph structures in the data, and uses OT for graph alignment.

Figure 3: (a) A comparison of the inferred transport plan from GOT (top chart) and the learned attention matrix from SCAN (bottom chart). Both serve as a lens to visualize cross-domain alignment. The horizontal axis represents image regions, and the vertical axis represents word tokens. (b) The original image.

4 Experiments

To validate the effectiveness of the proposed GOT framework, we evaluate on a selection of diverse tasks: vision-and-language understanding, including: () image-text retrieval, and (

) visual question answering; and text generation tasks, including: (

) image captioning, () machine translation, and () abstractive text summarization.

4.1 Vision-and-Language Tasks

Image-Text Retrieval

For image-text retrieval task, we use pre-trained Faster R-CNN (Ren et al., 2015) to extract bottom-up-attention features (Anderson et al., 2018) as the image representation. A set of 36 features is created for each image, each feature represented by a 2048-dimensional vector. For captions, a bi-directional GRU (Schuster and Paliwal, 1997; Bahdanau et al., 2015) is used to obtain textual features.

We evaluate our model on the Flickr30K (Plummer and others, 2015) and COCO (Lin et al., 2014) datasets. Flickr30K contains , images, with five human-annotated captions per image. We follow previous work (Karpathy and Fei-Fei, 2015; Faghri et al., 2018) for the data split: ,, , and , images are used for training, validation and test, respectively. COCO contains , images, each image also accompanied with five captions. We follow the data split in Faghri et al. (2018), where ,, , and , images are used for training, validation and test, respectively.

We measure the performance of image retrieval and sentence retrieval on Recall at (R@(Karpathy and Fei-Fei, 2015), defined as the percentage of queries retrieving the correct images/sentences within the top highest-ranked results. In our experiment, , and Rsum (Huang et al., 2017) (summation over all R@) is used to evaluate the overall performance. Results are summarized in Table 1. Both WD and GWD can boost the performance of the SCAN model, while WD achieves a larger margin than GWD. This indicates that when used alone, GWD may not be a good metric for graph alignment. When combining the two distances together, GOT achieves the best performance.

Figure 3 provides visualization on the learned transport plan in GOT and the learned attention matrix in SCAN. Both serve as a proxy to lend insights into the learned alignment. As shown, the attention matrix from SCAN is much denser and noisier than the transport plan inferred by GOT. This shows our model can better discover cross-domain relations between image-text pairs, since the inferred transport plan is more interpretable and has less ambiguity. For example, both the words “sidewalk” and “skateboard” match the corresponding image regions very well.

Because of the Envelope Theorem (Cuturi and Peyré, 2017), GOT needs to be calculated only during the forward phase of model training. Therefore, it will not cost much extra time. For example, when using the same machine for image-text retrieval experiments, SCAN required 6hr 34min for training and SCAN+GOT 6hr 57min.

Visual Question Answering

Model BAN BAN+GWD BAN+WD BAN+GOT
Score 66.00 66.21 66.26 66.44
Table 2: Results (accuracy) on VQA 2.0 validation set, using BAN (Kim et al., 2018) as baseline.
Model BUTD BAN-1 BAN-2 BAN-4 BAN-8
w/o GOT 63.37 65.37 65.61 65.81 66.00
w/ GOT 65.01 65.68 65.88 66.10 66.44
Table 3: Results (accuracy) of applying GOT to BUTD (Anderson et al., 2018) and BAN- (Kim et al., 2018) on VQA 2.0. denotes the number of glimpses.

We also evaluate on the VQA 2.0 dataset (Goyal et al., 2017), which contains human-annotated QA pairs on COCO images (Lin et al., 2014). For each image, an average of questions are collected, with candidate answers per question. The most frequent answer from the annotators is selected as the correct answer. Following previous work (Kim et al., 2018), we take the answers that appear more than 9 times in the training set as candidate answers, which results in

candidates. Classification accuracy is used as the evaluation metric, defined as

.

The BAN model (Kim et al., 2018) is used as baseline, with the original codebase used for fair comparison. Results are summarized in Table 2. Both WD and GWD improve the BAN model on the validation set, and GOT achieves further performance lift.

We also investigate whether different architecture designs affect the performance gain. We consider BUTD (Anderson et al., 2018) as an additional baseline, and apply different number of glimpses to the BAN model, denoted as BAN-. Results are summarized in Table 3. Here are our observations: () When the number of parameters in the tested model is small, such as BUTD, the improvement brought by GOT is more significant. () BAN-4, a simpler model than BAN-8, when combined with GOT, can outperform BAN-8 without using GOT (66.10 v.s. 66.00). () For complex models such as BAN-8 that might have limited space for improvement, GOT is still able to achieve performance gain.

Method CIDEr BLEU-4 BLUE-3 BLEU-2 BLEU-1 ROUGE METEOR
Soft Attention (Xu et al., 2015) - 24.3 34.4 49.2 70.7 - 23.9
Hard Attention (Xu et al., 2015) - 25.0 35.7 50.4 71.8 - 23.0
Show & Tell (Vinyals et al., 2015) 85.5 27.7 - - - - 23.7
ATT-FCN (You et al., 2016) - 30.4 40.2 53.7 70.9 - 24.3
SCN-LSTM (Gan et al., 2017) 101.2 33.0 43.3 56.6 72.8 - 25.7
Adaptive Attention (Lu et al., 2017) 108.5 33.2 43.9 58.0 74.2 - 26.6
MLE 106.3 34.3 45.3 59.3 75.6 55.2 26.2
MLE + WD 107.9 34.8 46.1 60.1 76.2 55.6 26.5
MLE + GWD 106.6 33.3 45.2 59.1 75.7 55.0 25.9
MLE + GOT 109.2 35.1 46.5 60.3 77.0 56.2 26.7
Table 4: Results of image captioning on the COCO dataset.
Model EN-VI uncased EN-VI cased EN-DE uncased EN-DE cased
Transformer (Vaswani et al., 2017)
Transformer + WD
Transformer + GWD
Transformer + GOT
Table 5:

Results of neural machine translation on EN-DE and EN-VI.

Method ROUGE-1 ROUGE-2 ROUGE-L
ABS+ (Rush et al., 2015)
LSTM (Hu et al., 2018)
LSTM + GWD
LSTM + WD
LSTM + GOT
Table 6: Results of abstractive text summarization on the English Gigawords dataset.

4.2 Text Generation Tasks

Image Captioning

We conduct experiments on image captioning using the same COCO dataset. The same bottom-up-attention features (Anderson et al., 2018) used in image-text retrieval are adopted here. The text decoder is one-layer LSTM with 256 hidden units. The word embedding dimension is set to 256. Results are summarized in Table 4. Similar performance gain by GOT can be observed. The relative performance boost from WD to GOT over CIDEr score is: . This attributes to the additional GWD introduced in GOT that can help model implicit intra-domain relationships in images and captions, leading to more accurate caption generation.

Machine Translation

In machine translation (and abstractive summarization), the word embedding spaces of the source and target sentences are different, which can be considered as different domains. Therefore, GOT can be used to align those words with similar semantic meanings between the source and target sentences for better translation/summarization. We choose two machine translation benchmarks for experiments: () English-Vietnamese TED-talks corpus, which contains K pairs of sentences from the IWSLT Evaluation Campaign (Cettolo et al., 2015); and () a large-scale English-German parallel corpus with M pairs of sentences, from the WMT Evaluation Campaign (Vaswani et al., 2017). Texar codebase (Hu et al., 2018) is used in our experiments.

We apply GOT to the Transformer model (Vaswani et al., 2017) and use BLEU score (Papineni et al., 2002) as the evaluation metric. Results are summarized in Table 5. As also observed in Chen et al. (2019a), using WD can improve the performance of Transformer for sequence-to-sequence learning. However, if only GWD is used, the test BLEU score drops. Since GWD can only match the edges, it ignores supervision signals from node representations. This serves as empirical evidence to support our hypothesis that using GWD alone may not be enough to improve performance. However, GWD serves as a complimentary method for capturing graph information that might be missed by WD. Therefore, when combining the two together, GOT achieves the best performance. Example translations are provided in Table 7.

Figure 4: Inferred transport plan for aligning source and output sentences in abstractive summarization.
Reference: India’s new prime minister, Narendra Modi, is meeting his Japanese counterpart, Shinzo Abe, in Tokyo to discuss
economic and security ties, on his first major foreign visit since winning May’s election.
MLE: India ‘ s new prime minister , Narendra Modi , meets his Japanese counterpart , Shinzo Abe , in Tokyo , during his
first major foreign visit in May to discuss economic and security relations .
GOT: India ’ s new prime minister , Narendra Modi , is meeting his Japanese counterpart Shinzo Abe in Tokyo in his first
major foreign visit since his election victory in May to discuss economic and security relations.
Reference: Chinese leaders presented the Sunday ruling as a democratic breakthrough because it gives Hong Kongers a direct
vote, but the decision also makes clear that Chinese leaders would retain a firm hold on the process through a
nominating committee tightly controlled by Beijing.
MLE: The Chinese leadership presented the decision of Sunday as a democratic breakthrough , because it gives Hong
Kong citizens a direct right to vote , but the decision also makes it clear that the Chinese leadership maintains the
expiration of a nomination committee closely controlled by Beijing .
GOT: The Chinese leadership presented the decision on Sunday as a democratic breakthrough , because Hong Kong
citizens have a direct electoral right , but the decision also makes it clear that the Chinese leadership remains
firmly in hand with a nominating committee controlled by Beijing.
Table 7: Comparison of German-to-English translation examples. For each example, we show the human translation (reference) and the translation from MLE and GOT. We highlight the key-phrase differences between reference and translation outputs in blue and red, and denote the error in translation in bold. In the first example, GOT correctly maintains all the information in “since winning May’s election” by translating to “since his election victory in May”, whereas MLE only generate “in May”. In the second example, GOT successfully keeps the information “Beijing”, whereas MLE generates wrong words “expiration of”.

Abstractive Summarization

We evaluate abstractive summarization on the English Gigawords benchmark (Graff et al., 2003). A basic LSTM model as implemented in Texar (Hu et al., 2018) is used in our experiments. ROUGE-1, -2 and -L scores (Lin, 2004) are reported. Table 6 shows that both GWD and WD can improve the performance of LSTM. The transport plan for source and output sentences alignment is illustrated in Figure 4. The learned alignment is sparse and interpretable. For instance, the words “largest” and “projects” in the source sentence matches the words “more” and “investment” in the output summary very well.

4.3 Ablation study

We conduct additional ablation study on the EN-VI and EN-DE datasets for machine translation.

Shared Transport Plan

As discussed in Sec. 2.4, we use a shared transport plan to solve the GOT distance. An alternative is not to share this matrix. The comparison results are provided in Table 8. GOT with a shared transport plan achieves better performance than the alternative. Since we only need to run the iterative Sinkhorn algorithm once, it also saves training time than the unshared case.

Model EN-VI uncased EN-DE uncased
GOT (shared)
GOT (unshared)
Table 8: Ablation study on transport plan in machine translation. Both models were run 5 times with the same hyper-parameter setting.
0 0.1 0.3 0.5 0.8 1.0
BLEU 29.92
Table 9: Ablation study of the hyper-parameter on the EN-VI machine translation dataset.

Hyper-parameter

We perform ablation study on the hyper-parameter in Eqn. (2.4). We select from and report results in Table 9. When , EN-VI translation performs the best, which indicates that the weight on WD needs to be larger than the weight on GWD, since intuitively node matching is more important than edge matching for machine translation. However, both WD and GWD contribute to GOT achieving the best performance.

5 Conclusions

We propose Graph Optimal Transport, a principled framework for cross-domain alignment. With the Wasserstein and Gromov-Wasserstein distances, both intra-domain and cross-domain relations are captured for better alignment. Empirically, we observe that enforcing alignment can serve as an effective regularizer for model training. Extensive experiments show that the proposed method is a generic framework that can be applied to a wide range of cross-domain tasks. For future work, we plan to apply the proposed framework to self-supervised representation learning.

Acknowledgements

The authors would like to thank the anonymous reviewers for their insightful comments. The research was supported in part by DARPA, DOE, NIH, NSF and ONR.

References

  • D. Alvarez-Melis and T. S. Jaakkola (2018) Gromov-wasserstein alignment of word embedding spaces. arXiv:1809.00013. Cited by: §2.4, §3.
  • P. Anderson, X. He, C. Buehler, D. Teney, M. Johnson, S. Gould, and L. Zhang (2018) Bottom-up and top-down attention for image captioning and visual question answering. In CVPR, Cited by: §2.1, §2.1, §4.1, §4.1, §4.2, Table 3.
  • S. Antol et al. (2015) Vqa: visual question answering. In ICCV, Cited by: §1.
  • M. Arjovsky et al. (2017) Wasserstein generative adversarial networks. In ICML, Cited by: §3.
  • D. Bahdanau, K. Cho, and Y. Bengio (2015) Neural machine translation by jointly learning to align and translate. In ICLR, Cited by: §1, §1, §2.1, §4.1.
  • J. Benamou, G. Carlier, M. Cuturi, L. Nenna, and G. Peyré (2015) Iterative bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing. Cited by: §2.4.
  • M. Cettolo, J. Niehues, S. Stüker, L. Bentivogli, R. Cattoni, and M. Federico (2015) The IWSLT 2015 evaluation campaign. In International Workshop on Spoken Language Translation, Cited by: §4.2.
  • G. Chechik, V. Sharma, U. Shalit, and S. Bengio (2010) Large scale online learning of image similarity through ranking. Journal of Machine Learning Research. Cited by: §2.1.
  • L. Chen, S. Dai, C. Tao, H. Zhang, Z. Gan, D. Shen, Y. Zhang, G. Wang, R. Zhang, and L. Carin (2018) Adversarial text generation via feature-mover’s distance. In NeurIPS, Cited by: §3.
  • L. Chen, Y. Zhang, R. Zhang, C. Tao, Z. Gan, H. Zhang, B. Li, D. Shen, C. Chen, and L. Carin (2019a) Improving sequence-to-sequence learning via optimal transport. arXiv preprint arXiv:1901.06283. Cited by: §3, §4.2.
  • Y. Chen, L. Li, L. Yu, A. E. Kholy, F. Ahmed, Z. Gan, Y. Cheng, and J. Liu (2019b) Uniter: learning universal image-text representations. arXiv preprint arXiv:1909.11740. Cited by: §3.
  • S. Chowdhury and F. Mémoli (2019) The gromov–wasserstein distance between networks and stable network invariants. Information and Inference: A Journal of the IMA. Cited by: §2.3.
  • M. Cuturi and G. Peyré (2017) Computational optimal transport. Cited by: §2.4, §3, §4.1.
  • M. Cuturi (2013) Sinkhorn distances: lightspeed computation of optimal transport. In NeurIPS, Cited by: §2.4, §3.
  • F. De Goes et al. (2011) An optimal transport approach to robust reconstruction and simplification of 2d shapes. In Computer Graphics Forum, Cited by: §2.2.
  • D. K. Duvenaud, D. Maclaurin, J. Iparraguirre, R. Bombarell, T. Hirzel, A. Aspuru-Guzik, and R. P. Adams (2015) Convolutional networks on graphs for learning molecular fingerprints. In NeurIPS, Cited by: §3.
  • F. Faghri, D. J. Fleet, J. R. Kiros, and S. Fidler (2018) Vse++: improved visual-semantic embeddings. In BMVC, Cited by: §2.1, Table 1, §4.1.
  • Z. Gan, C. Gan, X. He, Y. Pu, K. Tran, J. Gao, L. Carin, and L. Deng (2017) Semantic compositional networks for visual captioning. In CVPR, Cited by: Table 4.
  • I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, and Y. Bengio (2014) Generative adversarial nets. In NeurIPS, Cited by: §3.
  • M. Gori, G. Monfardini, and F. Scarselli (2005) A new model for learning in graph domains. In IEEE International Joint Conference on Neural Networks, Cited by: §3.
  • Y. Goyal, T. Khot, D. Summers-Stay, D. Batra, and D. Parikh (2017) Making the v in vqa matter: elevating the role of image understanding in visual question answering. In CVPR, Cited by: §4.1.
  • D. Graff, J. Kong, K. Chen, and K. Maeda (2003) English gigaword. Linguistic Data Consortium, Philadelphia. Cited by: §4.2.
  • J. Gu, J. Cai, S. R. Joty, L. Niu, and G. Wang (2018) Look, imagine and match: improving textual-visual cross-modal retrieval with generative models. In cvpr, Cited by: Table 1.
  • Z. Hu, H. Shi, Z. Yang, B. Tan, T. Zhao, J. He, W. Wang, X. Yu, L. Qin, D. Wang, et al. (2018) Texar: a modularized, versatile, and extensible toolkit for text generation. arXiv preprint arXiv:1809.00794. Cited by: §4.2, §4.2, Table 6.
  • Y. Huang, W. Wang, and L. Wang (2017) Instance-aware image and sentence matching with selective multimodal lstm. In CVPR, Cited by: §4.1.
  • Y. Huang, Q. Wu, C. Song, and L. Wang (2018) Learning semantic concepts and order for image and sentence matching. In CVPR, Cited by: Table 1.
  • A. Karpathy and L. Fei-Fei (2015) Deep visual-semantic alignments for generating image descriptions. In CVPR, Cited by: §1, §4.1, §4.1.
  • J. Kim, J. Jun, and B. Zhang (2018) Bilinear attention networks. In NeurIPS, Cited by: §1, §2.1, §4.1, §4.1, Table 2, Table 3.
  • T. N. Kipf and M. Welling (2016) Semi-supervised classification with graph convolutional networks. arXiv:1609.02907. Cited by: §3.
  • M. Kusner, Y. Sun, N. Kolkin, and K. Weinberger (2015) From word embeddings to document distances. In ICML, Cited by: §3.
  • K. Lee et al. (2018) Stacked cross attention for image-text matching. In eccv, Cited by: §1, §1, §2.1, Table 1.
  • L. Li, Z. Gan, Y. Cheng, and J. Liu (2019a) Relation-aware graph attention network for visual question answering. In ICCV, Cited by: §1, §2.2.
  • Y. Li, C. Gu, T. Dullien, O. Vinyals, and P. Kohli (2019b) Graph matching networks for learning the similarity of graph structured objects. In ICML, Cited by: §1, §3.
  • C. Lin (2004) Rouge: a package for automatic evaluation of summaries. Text Summarization Branches Out. Cited by: §4.2.
  • T. Lin, M. Maire, S. Belongie, J. Hays, P. Perona, D. Ramanan, P. Dollár, and C. L. Zitnick (2014) Microsoft COCO: common objects in context. In ECCV, Cited by: §4.1, §4.1.
  • J. Lu, C. Xiong, D. Parikh, and R. Socher (2017) Knowing when to look: adaptive attention via a visual sentinel for image captioning. In CVPR, Cited by: Table 4.
  • G. Luise, A. Rudi, M. Pontil, and C. Ciliberto (2018) Differential properties of sinkhorn approximation for learning with wasserstein distance. arXiv:1805.11897. Cited by: §2.3.
  • M. Malinowski and M. Fritz (2014) A multi-world approach to question answering about real-world scenes based on uncertain input. In NeurIPS, Cited by: §1.
  • Y. Mroueh, C. Li, T. Sercu, A. Raj, and Y. Cheng (2018) Sobolev GAN. In ICLR, Cited by: §3.
  • H. Nam, J. Ha, and J. Kim (2017) Dual attention networks for multimodal reasoning and matching. In CVPR, Cited by: Table 1.
  • K. Papineni, S. Roukos, T. Ward, and W. Zhu (2002) BLEU: a method for automatic evaluation of machine translation. In ACL, Cited by: §4.2.
  • G. Peyré, M. Cuturi, et al. (2019) Computational optimal transport. Foundations and Trends® in Machine Learning. Cited by: §1, §2.2.
  • G. Peyré, M. Cuturi, and J. Solomon (2016) Gromov-wasserstein averaging of kernel and distance matrices. In ICML, Cited by: §1, §2.3, §3.
  • B. A. Plummer et al. (2015) Flickr30k entities: collecting region-to-phrase correspondences for richer image-to-sentence models. In ICCV, Cited by: §4.1.
  • S. Ren, K. He, R. Girshick, and J. Sun (2015) Faster r-cnn: towards real-time object detection with region proposal networks. In NeurIPS, Cited by: §4.1.
  • Y. Rubner, C. Tomasi, and L. J. Guibas (1998) A metric for distributions with applications to image databases. In ICCV, Cited by: §3.
  • A. M. Rush, S. Chopra, and J. Weston (2015)

    A neural attention model for abstractive sentence summarization

    .
    In EMNLP, Cited by: Table 6.
  • T. Salimans, H. Zhang, A. Radford, and D. Metaxas (2018) Improving GANs using optimal transport. In ICLR, Cited by: §3.
  • M. Schuster and K. K. Paliwal (1997) Bidirectional recurrent neural networks. Transactions on Signal Processing. Cited by: §4.1.
  • J. H. Van Lint, R. M. Wilson, and R. M. Wilson (2001) A course in combinatorics. Cambridge university press. Cited by: §2.3.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin (2017) Attention is all you need. In NeurIPS, Cited by: §1, §1, §2.1, §4.2, §4.2, Table 5.
  • T. Vayer, L. Chapel, R. Flamary, R. Tavenard, and N. Courty (2018) Optimal transport for structured data with application on graphs. arXiv:1805.09114. Cited by: §2.4, §3.
  • P. Veličković, G. Cucurull, A. Casanova, A. Romero, P. Lio, and Y. Bengio (2018) Graph attention networks. In ICLR, Cited by: §1, §3.
  • O. Vinyals, A. Toshev, S. Bengio, and D. Erhan (2015) Show and tell: a neural image caption generator. In CVPR, Cited by: Table 4.
  • Y. Xie, X. Wang, R. Wang, and H. Zha (2018) A fast proximal point method for Wasserstein distance. In arXiv:1802.04307, Cited by: §2.2.
  • H. Xu, D. Luo, and L. Carin (2019a) Scalable gromov-wasserstein learning for graph partitioning and matching. In NeurIPS, Cited by: §2.4.
  • H. Xu, D. Luo, H. Zha, and L. Carin (2019b) Gromov-wasserstein learning for graph matching and node embedding. In ICML, Cited by: §2.4.
  • K. Xu, J. Ba, R. Kiros, K. Cho, A. C. Courville, R. Salakhutdinov, R. S. Zemel, and Y. Bengio (2015) Show, attend and tell: neural image caption generation with visual attention.. In ICML, Cited by: §1, Table 4.
  • Z. Yang, X. He, J. Gao, L. Deng, and A. Smola (2016a) Stacked attention networks for image question answering. In CVPR, Cited by: §1.
  • Z. Yang, D. Yang, C. Dyer, X. He, A. Smola, and E. Hovy (2016b) Hierarchical attention networks for document classification. In NAACL, Cited by: §1.
  • T. Yao, Y. Pan, Y. Li, and T. Mei (2018) Exploring visual relationship for image captioning. In ECCV, Cited by: §1.
  • Q. You, H. Jin, Z. Wang, C. Fang, and J. Luo (2016) Image captioning with semantic attention. In CVPR, Cited by: Table 4.
  • Z. Yu, J. Yu, Y. Cui, D. Tao, and Q. Tian (2019) Deep modular co-attention networks for visual question answering. In CVPR, Cited by: §1.
  • R. Zhang, C. Chen, Z. Gan, Z. Wen, W. Wang, and L. Carin (2020)

    Nested-wasserstein self-imitation learning for sequence generation

    .
    arXiv:2001.06944. Cited by: §3.
  • Z. Zheng, L. Zheng, M. Garrett, Y. Yang, M. Xu, and Y. Shen (2020) Dual-path convolutional image-text embeddings with instance loss. TOMM. Cited by: Table 1.