Wasserstein Embedding for Graph Learning

06/16/2020 ∙ by Soheil Kolouri, et al. ∙ 0

We present Wasserstein Embedding for Graph Learning (WEGL), a novel and fast framework for embedding entire graphs in a vector space, in which various machine learning models are applicable for graph-level prediction tasks. We leverage new insights on defining similarity between graphs as a function of the similarity between their node embedding distributions. Specifically, we use the Wasserstein distance to measure the dissimilarity between node embeddings of different graphs. Different from prior work, we avoid pairwise calculation of distances between graphs and reduce the computational complexity from quadratic to linear in the number of graphs. WEGL calculates Monge maps from a reference distribution to each node embedding and, based on these maps, creates a fixed-sized vector representation of the graph. We evaluate our new graph embedding approach on various benchmark graph-property prediction tasks, showing state-of-the-art classification performance, while having superior computational efficiency.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Many exciting and practical machine learning applications involve learning from graph-structured data. While images, videos, and temporal signals (e.g., audio or biometrics) are instances of data that are supported on grid-like structures, data in social networks, cyber-physical systems, communication networks, chemistry, and bioinformatics often live on irregular structures [backstrom2011supervised, sadreazami2017distributed, naderializadeh2020wireless, jin2017predicting, agrawal2018large]

. One can represent such data as (attributed) graphs, which are universal data structures. Efficient and generalizable learning from graph-structured data opens the door to a vast number of applications, which were beyond the reach of classic machine learning (ML) and, more specifically, deep learning (DL) algorithms.

Analyzing graph-structured data has received significant attention from the ML, network science, and signal processing communities over the past few years. On the one hand, there has been a rush toward extending the success of deep neural networks to graph-structured data, which has led to a variety of graph neural network (GNN) architectures. On the other hand, the research on kernel approaches 

[gartner2003graph], perhaps most notably the random walk kernel [kashima2003marginalized] and the Weisfeiler-Lehman (WL) subtree kernel [shervashidze2011weisfeiler], remains an active field of study and the methods developed therein provide competitive performance in various graph representation tasks (See the recent survey by Kriege et al. [kriege2020survey]).

To learn graph representations, GNN-based frameworks make use of three generic modules, which provide i) feature aggregation, ii) graph pooling (i.e., readout), and iii) classification [Hu*2020Strategies]

. The feature aggregator provides a vector representation for each node of the graph, referred to as a node embedding. The graph pooling module creates a representation for the graph from its node embeddings, whose dimensionality is fixed regardless of the underlying graph size, and which can then be analyzed using a downstream classifier of choice. On the graph kernel side, one leverages a kernel to measure the similarities between pairs of graphs, and uses conventional kernel methods to perform learning on a set of graphs 

[hofmann2008kernel]. A recent example of such methods is the framework provided by Togninalli et al. [togninalli2019wasserstein], in which the authors propose a novel node embedding inspired by the WL kernel, and combine the resulting node embeddings with the Wasserstein distance [villani2008optimal, kolouri2017optimal] to measure the dissimilarity between two graphs. Afterwards, they leverage conventional kernel methods based on the pairwise-measured dissimilarities to perform learning on graphs.

Considering the ever-increasing scale of graph datasets, which may contain tens of thousands of graphs or millions to billions of nodes per graph, the issue of scalability and algorithmic efficiency becomes of vital importance for graph learning methods [hu2020open, hernandez2020measuring]. However, both of the aforementioned paradigms of GNNs and kernel methods suffer in this sense. On the GNN side, acceleration of the training procedure is challenging and scales poorly as the graph size grows [mlg2019_50]. On the graph kernel side, the need for calculating the matrix of all pairwise similarities can be a burden in datasets with a large number of graphs, especially if calculating the similarity between each pair of graphs is computationally expensive. For instance, in the method proposed in [togninalli2019wasserstein], the computational complexity of each calculation of the Wasserstein distance is cubic in the number of nodes (or linearithmic for the entropy-regularized distance).

To overcome these issues, inspired by the linear optimal transport framework of Wang et al. [wang2013linear], we propose a linear Wasserstein Embedding for Graph Learning, which we refer to as WEGL. Our proposed approach embeds a graph into a Hilbert space, where the distance between two embedded graphs provides a true metric between the graphs that approximates their 2-Wasserstein distance. For a set of graphs, the proposed method provides:

  1. Reduced computational complexity of estimating the graph Wasserstein distance 

    [togninalli2019wasserstein] for a dataset of graphs from a quadratic complexity in the number of graphs, i.e., calculations, to linear complexity, i.e., calculations of the Wasserstein distance; and

  2. An explicit Hilbertian embedding for graphs, which is not restricted to kernel methods, and therefore can be used in conjunction with any downstream classification framework.

We show that compared to multiple GNN and graph kernel baselines, WEGL achieves either state-of-the-art or competitive results on benchmark graph-level classification tasks, including classical graph classification datasets [KKMMN2016] and the recent molecular property-prediction benchmarks [hu2020open]. We also compare the algorithmic efficiency of WEGL with two baseline GNN and graph kernel methods and demonstrate that it is much more computationally efficient relative to those algorithms.

2 Background and Related Work

In this section, we provide a brief background on different methods for deriving representations for graphs and an overview on Wasserstein distances by reviewing the related work in the literature.

2.1 Graph Representation Methods

Let denotes a graph, comprising a set of nodes and a set of edges , where two nodes are connected to each other if and only if . For each node , we define its set of neighbors as . The nodes of the graph may have categorical labels and/or continuous attribute vectors. We use a unified notation of to denote the label and/or attribute vector of node , where denotes the node feature dimensionality. Moreover, we use to denote the edge feature vector for any edge , where denotes the edge feature dimensionality. Node and edge features may be present depending on the graph dataset under consideration.

To learn graph properties from the graph structure and its node/edge features, one can use a function to map any graph in the space of all possible graphs to an embedding in a Hilbert space . Kernel methods have been among the most popular ways of creating such graph embeddings. A graph kernel is defined as a function , where for two graphs and , represents the inner product of the embeddings and over the Hilbert space .

Kashima et al. [kashima2003marginalized] introduced graph kernels based on random walks on labeled graphs. Subsequently, shortest-path kernels were introduced in [borgwardt2005shortest]. These works have been followed by graphlet and Weisfeiler-Lehman subtree kernel methods [shervashidze2009efficient, shervashidze2011weisfeiler, morris2017glocalized]. More recently, kernel methods using assignment-based approaches [kriege2016valid, nikolentzos2017matching], spectral approaches [kondor2016multiscale], and graph decomposition algorithms [nikolentzos2018degeneracy] have also been proposed in the literature.

Despite being successful for many years, kernel methods often fail to leverage the explicit continuous features that are provided for the graph nodes and/or edges, making them less adaptable to the underlying data distribution. To alleviate these issues, and thanks in part to the prominent success of deep learning in many domains, including computer vision and natural language processing, techniques based on

graph neural networks (GNNs) have emerged as an alternative paradigm for learning representations from graph-based data. In its most general form, a GNN consists of hidden layers, where at the th layer, each node aggregates and combines messages from its 1-hop neighboring nodes , resulting in the feature vector


where denotes a parametrized and differentiable combining function.

At the input layer, each node starts with its initial feature vector , and the sequential application of GNN layers, as in (1), computes intermediate feature vectors . At the GNN output, the feature vectors of all nodes from all layers go through a global pooling (i.e., readout) function , resulting in the final graph embedding


Kipf and Welling [kipf2016semi] proposed a GNN architecture based on a graph convolutional network (GCN) framework. This work, alongside other notable works on geometric deep learning [defferrard2016convolutional], initiated a surge of interest in GNN architectures, which has led to several architectures, including the Graph ATtention network (GAT) [velivckovic2017graph], Graph SAmple and aggreGatE (GraphSAGE) [hamilton2017inductive], and the Graph Isomorphism Network (GIN) [xu2018how]. Each of these architectures modifies the combine and readout functions and in (1) and (2), respectively, demonstrating state-of-the-art performance in a variety of graph representation learning tasks.

2.2 Wasserstein Distances


denote a Borel probability measure with finite

thmoment defined on

, with corresponding probability density function

, i.e., . The 2-Wasserstein distance between and defined on is the solution to the optimal mass transportation problem with transport cost [villani2008optimal]:


where is the set of all transportation plans such that and for any Borel subsets and . Due to Brenier’s theorem [brenier1991polar], for absolutely continuous probability measures and (with respect to the Lebesgue measure), the -Wasserstein distance can be equivalently obtained from


where and represents the pushforward of measure , characterized as


The mapping is referred to as a transport map [kolouri2017optimal], and the optimal transport map is called the Monge map. For discrete probability measures, when the transport plan is a deterministic optimal coupling, such transport plan is referred to as a Monge coupling [villani2008optimal].

Recently, Togninalli et al. [togninalli2019wasserstein] proposed a Wasserstein kernel for graphs that involves pairwise calculation of the Wasserstein distance between graph representations. Pairwise calculation of the Wasserstein distance, however, could be expensive, especially for large graph datasets. In what follows, we apply the linear optimal transportation framework [wang2013linear] to define a Hilbertian embedding, in which the distance provides a true metric between the probability measures that approximates . We show that in a dataset containing graphs, this framework reduces the computational complexity from calculating linear programs to .

3 Linear Wasserstein Embedding

Figure 1:

Graphical representation of the linear Wasserstein embedding framework, where the probability distributions are mapped to the tangent space with respect to a fixed reference distribution. The figure is adopted from 


Wang et al. [wang2013linear] and the follow-up work by Kolouri et al. [kolouri2016continuous] describe a framework for an isometric Hilbertian embedding of 2D images (treated as probability measures) such that the Euclidean distance between the embedded images approximates

. Going beyond pattern recognition and image analysis, the framework can be used for any set of probability measures to provide a linear Wasserstein embedding.

3.1 Theoretical Foundation

We adhere to the definition of the linear Wasserstein embedding for continuous measures. However, all derivations hold for discrete measures as well. More precisely, let be a reference probability measure defined on , with a positive probability density function , s.t. and for . Let denote the Monge map that pushes into , i.e.,


Define , where is the identity function. In cartography, such a mapping is known as the equidistant azimuthal projection, while in differential geometry, it is called the logarithmic map. The mapping has the following characteristics (partially illustrated in Figure 1):

  1. provides an isometric embedding for probability measures, i.e., using the Jacobian equation , where .

  2. , i.e., the reference is mapped to zero.

  3. , i.e., the mapping preserves distances to .

  4. , i.e., the distance between and , while being a true metric between and , is an approximation of .

Embedding probability measures via involves calculation of Monge maps. The fourth characteristic above states that provides a linear embedding for the probability measures. Therefore, we call it the linear Wasserstein embedding. In practice, and for discrete distributions, the Monge coupling is used, which could be approximated from the Kantorovich plan (i.e., the transport plan) via the so-called barycenteric projection (see [wang2013linear]). A detailed description of the capabilities of the linear Wasserstein embedding framework is included in Appendix A.1. Below, we provide the numerical details of the barycenteric projection [ambrosio2008gradient, wang2013linear].

3.2 Numerical Details

Consider a set of probability distributions , and let be an array containing i.i.d. samples from distribution , i.e., . Let us define to be a reference distribution, with , where and . The optimal transport plan between and , denoted by , is the solution to the following linear program,


where . The Monge map is then approximated from the optimal transport plan by barycentric projection via


Finally, the embedding can be calculated by . With a slight abuse of notation, we use and interchangeably throughout the paper. Due to the barycenteric projection, here, is only pseudo-invertible.

4 WEGL: A Linear Wasserstein Embedding for Graphs

The application of the optimal transport problem to graphs is multifaceted. For instance, some works focus on solving the “structured” optimal transport concerning an optimal probability flow, where the transport cost comes from distances on an often unchanging underlying graph [leonard2016lazy, essid2018quadratically, titouan2019]. Here, we are interested in applying optimal transport to measure the dissimilarity between two graphs [maretic2019got, dong2020copt, togninalli2019wasserstein]. Our work significantly differs from [maretic2019got, dong2020copt], which measure the dissimilarity between non-attributed graphs based on distributions defined by their Laplacian spectra and is closer to [togninalli2019wasserstein].

Our proposed graph embedding framework, termed Wasserstein Embedding for Graph Learning (WEGL), combines node embedding methods for graphs with the linear Wasserstein embedding explained in Section 3. More precisely, let denote a set of individual graphs, each with a set of possible node features and a set of possible edge features . Let be an arbitrary node embedding process, where . Having the node embeddings , we can then calculate a reference node embedding (see Section 4.2 for details), which leads to the linear Wasserstein embedding with respect to , as described in Section 3. Therefore, the entire embedding for each graph , is obtained by composing and , i.e., . Figure 2 visualizes this process.

Figure 2: Our proposed graph embedding framework, WEGL, combines node embedding methods with the linear Wasserstein embedding framework described in Section 3. Given a graph , we first embed the graph nodes into a -dimensional Hilbert space and obtain an array of node embeddings, denoted by . We then calculate the linear Wasserstein embedding of with respect to a reference , i.e., , to derive the final graph embedding.

4.1 Node Embedding

There are many choices for node embedding methods [chami2020machine]. These methods in general could be parametric or non-parametric, for instance, as in propagation/diffusion-based embeddings. Parametric embeddings are often implemented via a GNN encoder. The encoder can capture different graph properties depending on the type of supervision (e.g., supervised or unsupervised). Among the recent work on this topic, self-supervised embedding methods have been shown to be promising [Hu*2020Strategies].

In this paper, for our node embedding process , we follow a similar non-parametric propagation/ diffusion-based encoder as in [togninalli2019wasserstein]. One of the appealing advantages of this framework is its simplicity, as there are no trainable parameters involved. In short, given a graph with node features and scalar edge features , we use the following instantiation of (1) to define the combining function as


where for any node , its degree is defined as its number of neighbors in augmented with self-connections, i.e., . Note that the normalization of the messages between graph nodes by the (square root of) the two end-point degrees in (9) have also been used in other architectures, including GCN [kipf2016semi]. For the cases where the edge weights are not available, including self-connection weights , we set them to one. In Appendix A.2, we show how we use an extension of (9) to treat graphs with multiple edge features/labels. Finally, we let represent the resultant embedding for each node , where is a local pooling process on a single node (not a global pooling), e.g., concatenation or averaging.

4.2 Calculation of the Reference Distribution

To calculate the reference distribution, we use the -means clustering algorithm on with centeroids. Alternatively, one can calculate the Wasserstein barycenter [cuturi2014fast] of the node embeddings or simply use

samples from a normal distribution. In Appendix 

A.3, we show that WEGL is not sensitive to the choice of the reference.

5 Experimental Evaluation

In this section, we discuss the evaluation results of our proposed algorithm on multiple benchmark graph classification datasets. We used the PyTorch Geometric framework 


for implementing WEGL. In all experiments, we used the scikit-learn implementation of random forest as our downstream classifier on the embedded graphs 

[sklearn_api, breiman2001random]. Moreover, for a graph with -dimensional initial node features, we report our results using the following three types of node embedding:

  • Concat.: , where denotes concatenation.

  • Average: .

  • Final: .

5.1 TUD Benchmark Datasets

We first consider a set of social network, bioinformatics and molecule graph datasets [KKMMN2016]. The social network datasets (IMDB-BINARY, IMDB-MULTI, COLLAB, REDDIT-BINARY, REDDIT-MULTI-5K, and REDDIT-MULTI-12K) lack both node and edge features. Therefore, in these datasets, for node embedding type “Concat.,” we use the actual node degrees as their initial scalar features, while for node embedding types of “Average” and “Final,” we use a one-hot representation of the node degrees as their initial feature vectors, as also used in prior work, e.g., [xu2018how]. To handle the large scale of the REDDIT-MULTI-5K and REDDIT-MULTI-12K datasets, we clip the node degrees at 500 and we use scalar node degree features for all the three aforementioned node embedding types.

Moreover, for the molecule (MUTAG, PTC-MR, and NCI1) and bioinformatics (ENZYMES, PROTEINS, and D&D) datasets, we use the readily-provided node labels in [KKMMN2016] as the initial node feature vectors. Besides, for datasets with edge labels (MUTAG and PTC-MR), as explained in Appendix A.2, we use an extension of (9

) to use the one-hot encoded edge features in the diffusion process.

To evaluate the performance of WEGL, we follow the methodology used in [yanardag2015deep, niepert2016learning, xu2018how]

, where for each dataset, we perform 10-fold cross-validation with random splitting on the entire dataset, conducting a grid search over a set of random forest hyperparameters and the number of diffusion layers, i.e.,

in (9). The complete list of hyperparameters can be found in Appendix A.2

. Using the best set of hyperparameters returned by grid search, we then report the mean and standard deviation of the validation accuracies achieved during cross-validation.

Tables 1 and 2 show the classification accuracies achieved by WEGL on the aforementioned datasets as compared with several GNN and graph kernel baselines, whose results are extracted from the corresponding original papers. As the tables demonstrate, our proposed algorithm achieves either state-of-the-art or competitive results across all the datasets, and in particular, it is among the top-three performers in the majority of them. This shows the effectiveness of the proposed linear Wasserstein embedding for learning graph-level properties across different domains.

=0ex =0ex

Method IMDB-B IMDB-M COLLAB RE-B RE-M5K RE-M12K DGCNN222The results for DGCNN [zhang2018end] are reported from [Errica2020A]. [zhang2018end] 69.2 3.0 45.6 3.4 71.2 1.9 87.8 2.5 49.2 1.2 - GraphSAGE [hamilton2017inductive] 68.8 4.5 47.6 3.5 73.9 1.7 84.3 1.9 50.0 1.3 - GIN [xu2018how] 75.1 5.1 52.3 2.8 80.2 1.9 92.4 2.5 57.5 1.5 - GNTK [du2019graph] 76.9 3.6 52.8 4.6 83.6 1.0 - - - GNN CapsGNN [xinyi2019capsule] 73.1 4.8 50.3 2.6 79.6 0.9 - 52.9 1.5 46.6 1.9 GNN DGK [yanardag2015deep] 67.0 0.6 44.6 0.5 73.1 0.3 78.0 0.4 41.3 0.2 32.2 0.1 WL [shervashidze2011weisfeiler] 73.8 3.9 49.8 0.5 74.8 0.2 68.2 0.2 51.2 0.3 32.6 0.3 RetGK [zhang2018retgk] 71.0 0.6 46.7 0.6 73.6 0.3 90.8 0.2 54.2 0.3 45.9 0.2 AWE [ivanov2018anonymous] 74.5 5.8 51.5 3.6 73.9 1.9 87.9 2.5 50.5 1.9 39.2 2.1 GK WWL [togninalli2019wasserstein] 74.4 0.8 - - - - - WEGL - Concat. 74.9 6.3 50.8 4.0 79.8 1.5 92.0 0.8 55.1 2.5 47.8 0.8 WEGL - Average 74.4 5.6 52.0 4.1 78.8 1.4 87.8 3.0 53.3 2.0 45.1 0.9 Ours WEGL - Final 75.4 5.0 51.7 4.6 79.1 1.1 87.5 2.0 53.2 1.8 45.3 1.4

Table 1: Accuracy (%) results of our method and comparison with the state-of-the-art in GNNs and graph kernels (GKs) on social network datasets. The top-three performers on each dataset are shown in bold.

=0ex =0ex

Method MUTAG PTC-MR ENZYMES PROTEINS D&D NCI1 DGCNN1 [zhang2018end] 85.8 58.6 38.9 5.7 72.9 3.5 76.6 4.3 76.4 1.7 GraphSAGE [hamilton2017inductive] 85.1 7.6 63.9 7.7 - 75.9 3.2 - 77.7 1.5 GIN [xu2018how] 89.4 5.6 64.6 7.0 - 76.2 2.8 - 82.7 1.7 GNTK [du2019graph] 90.0 8.5 67.9 6.9 - 75.6 4.2 - 84.2 1.5 GNN CapsGNN [xinyi2019capsule] 86.7 6.9 - 54.7 5.7 76.3 3.6 75.4 4.2 78.4 1.6 GNN DGK [yanardag2015deep] 82.7 1.5 57.3 1.1 27.1 0.8 71.7 0.5 - 62.5 0.3 WL [shervashidze2011weisfeiler] 80.7 3.0 57.0 2.0 53.2 1.1 72.9 0.6 79.8 0.4 80.1 0.5 RetGK [zhang2018retgk] 90.1 1.0 67.9 1.4 59.1 1.1 75.2 0.3 81.0 0.5 83.5 0.2 AWE [ivanov2018anonymous] 87.9 9.8 - 35.8 5.9 - 71.5 4.0 - GK WWL [togninalli2019wasserstein] 87.3 1.5 66.3 1.2 59.1 0.8 74.3 0.6 79.7 0.5 85.8 0.3 WEGL - Concat. 88.3 5.1 64.6 7.4 60.5 5.9 76.1 3.3 78.6 2.8 76.8 1.7 WEGL - Average 86.2 5.8 67.5 7.7 58.7 6.9 75.8 3.6 78.5 4.4 76.6 1.1 Ours WEGL - Final 86.2 9.5 66.3 6.5 57.5 5.5 76.5 4.2 77.9 4.3 75.9 1.2

Table 2: Accuracy (%) results of our method and comparison with the state-of-the-art in GNNs and graph kernels (GKs) on molecule and bioinformatics datasets. The top-three performers on each dataset are shown in bold.

5.2 Molecular Property Prediction on the Open Graph Benchmark

We also tested our algorithm on the molecular property prediction task on the ogbg-molhiv dataset [hu2020open]. This dataset is part of the recently-released Open Graph Benchmark [hu2020open], which involves node-level, link-level, and graph-level learning and prediction tasks on multiple datasets spanning diverse problem domains. The ogbg-molhiv dataset, in particular, is a molecular tree-like dataset, consisting of graphs, with an average number of nodes and edges per graph. Each graph is a molecule, with nodes representing atoms and edges representing bonds between them, and it includes both node and edge attributes, characterizing the atom and bond features. The goal is to predict a binary label indicating whether or not a molecule inhibits HIV replication.

To train and evaluate our proposed method, we use the scaffold split provided by the dataset, and report the mean and standard deviation of the results across 10 different random seeds. Aside from searching over the set of hyperparameters used in Section 5.1, we also optimize the initial node feature dimensionality, derived through the atom feature encoder module provided in [hu2020open]. The complete implementation details can be found in Appendix A.2.

Table 3 shows the evaluation results of WEGL on the ogbg-molhiv dataset in terms of the ROC-AUC (i.e., Receiver Operating Characteristic Area Under the Curve), alongside baseline GCN and GIN results, extracted from [hu2020open]. The virtual node variants of the algorithms correspond to cases where each molecular graph is augmented with a master virtual node that is connected to all the nodes in the graph. This node serves as a shortcut for message passing among the graph nodes, bringing any pair of nodes within at most two hops of each other. While achieving full training accuracy, our proposed method achieves state-of-the-art test results on this dataset, showing the high expressive power of WEGL in large-scale graph datasets.

=0ex =0ex

Method Virtual Node Training Validation Test 88.7 2.2 82.0 1.4 76.1 1.0 GCN 90.1 4.7 83.8 0.9 76.0 1.2 88.6 2.5 82.3 0.9 75.6 1.4 GNN GIN 92.7 3.8 84.8 0.7 77.1 1.5 GNN WEGL - Concat. 100.0 0.0 79.2 2.2 75.5 1.5 WEGL - Average 100.0 0.0 78.0 2.0 76.5 1.2 WEGL - Final 100.0 0.0 77.7 1.9 76.5 1.7 WEGL - Concat. 100.0 0.0 81.9 1.3 76.5 1.8 WEGL - Average 100.0 0.0 81.0 1.9 77.1 0.9 Ours WEGL - Final 100.0 0.0 81.0 1.0 77.6 1.1

Table 3: ROC-AUC (%) results on ogbg-molhiv dataset. The best validation and test results are shown in bold.

5.3 Computation Time

We compared the algorithmic efficiency of WEGL with GIN and the Wasserstein Weisfeiler-Lehman (WWL) graph kernel. For this experiment, we used the IMDB-BINARY dataset and measured the wall-clock training and inference times for the three methods (to achieve results reported in Table 1). For WEGL and WWL, we used the exact linear programming solver (as opposed to the entropy-regularized version). We carried out our experiments for WEGL and WWL on a GHz Intel® CoreTM i7-4980HQ CPU, while we used a GB NVIDIA® Tesla® P100 GPU for GIN. Figure 3 shows the wall-clock time for training and testing the three algorithms. As the figure illustrates, training WEGL is several orders of magnitude faster than WWL and GIN. During inference, WEGL is slightly slower than GIN (CPU vs. GPU) and significantly faster than WWL. Using GPU-accelerated implementations of the diffusion process (9) and the entropy-regularized transport problem could potentially further enhance the computational efficiency of WEGL during inference.

Figure 3: Average wall-clock time comparison of our proposed method (WEGL - Concat.) with WWL [togninalli2019wasserstein] and GIN [xu2018how] on the IMDB-BINARY dataset, where the shaded areas represent the standard deviation of the run times. WEGL and WWL were implemented on a GHz Intel® CoreTM i7-4980HQ CPU, while GIN was implemented using a GB NVIDIA® Tesla® P100 GPU, with epochs and batch size of .

6 Conclusion

We considered the problem of graph property prediction and introduced the linear Wasserstein Embedding for Graph Learning, which we denoted as WEGL. Similar to [togninalli2019wasserstein], our approach also relies on measuring the Wasserstein distances between the node embeddings of graphs. Unlike [togninalli2019wasserstein], however, we further embed the node embeddings of graphs into a Hilbert space, in which their Euclidean distance approximates their 2-Wasserstein distance. WEGL provides two significant benefits: 1) it has linear complexity in the number of graphs (as opposed to the quadratic complexity of [togninalli2019wasserstein]), and 2) it enables the application of any ML algorithm of choice, such as random forest, which was the downstream classifier we used in this work. Finally, we demonstrated WEGL’s superior performance on various benchmark datasets. The current formulation of WEGL assumes a fixed node embedding as input to the linear Wasserstein embedding process. Therefore, it does not allow for end-to-end training of a parametric node embedding method. We leave the extension of the proposed method to enable end-to-end training for future work.


This material is supported by the United States Air Force under Contract No. FA8750-19-C-0098, and by the National Institute of Health (NIH) under Contract No. GM130825. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the United States Air Force, DARPA, and NIH.


Appendix A Appendix

Here we provide further details on the theoretical aspect of WEGL, our implementation details, and the sensitivity of the results to the choice of reference distribution.

a.1 Detailed Discussion on Linear Wasserstein Embedding

Figure 4: Illustration of (a) the meaning behind used in the LOT distance in (13), and (b) the idea of the barycenteric projection, which provides a fixed-size representation (i.e., of size ).

The linear Wasserstein embedding used in WEGL is based on the Linear Optimal Transport (LOT) framework introduced in [wang2013linear]. The main idea is to compute the “projection” of the manifold of probability measures to the tangent space at a fixed reference measure. In particular, the tangent space at measure is the set of vector fields such that the inner product is the weighted :


We can then define , where is the optimal transport map from to . Note that , , and


In the paper, we use to turn the weighted- into .

The discussion above assumes an absolutely continuous reference measure . A more interesting treatment of the problem is via the generalized geodesics defined in [ambrosio2008gradient], connecting and and enabling us to use discrete reference measures. Following the notation in [ambrosio2008gradient], given the reference measure , let be the set of transport plans between and , and let be the set of all measures on the product space such that the marginals over and are and , respectively. Then the linearized optimal transport distance is defined as


In a discrete setting, where , , and , we have


See Figure 4a for a depiction of Equation (13)’s meaning. Finally, the idea of barycenteric projection used to approximate Monge couplings and provide a fixed-size representation is shown in Figure 4b.

Next, to demonstrate the capability of the linear Wasserstein embedding, we present the following experiment. Consider a set of distributions , where each is a translated and dilated ring distribution in , and samples are observed from , where and could be different for . We then consider a normal distribution as the reference distribution and calculate the linear Wasserstein embedding with respect to the reference (See Figure 5a). Given the pseudo-invertible nature of the embedding, to demonstrate the modeling capability of the framework, we calculate the mean in the embedding space (i.e., on the vector fields), and invert it to obtain the mean distribution . Figure 5b shows the calculated mean, indicating that the linear Wasserstein embedding framework has successfully retrieved a ring distribution as the mean. Finally, we calculate Euclidean geodesics in the embedding space (i.e., the convex combination of the vector fields) between and , as well as between and , and show the inverted geodesics in Figures 5c and 5d, respectively. As the figures demosntrate, the calculated geodesics follow the Wasserstein geodesics.

Figure 5: An experiment demonstrating the capability of the linear Wasserstein embedding. (a) A simple dataset consisting of shifted and scaled noisy ring distributions , where we only observe samples from each distribution, together with the process of obtaining the linear Wasserstein embedding with respect to a reference distribution. In short, for each distribution , the embedding approximates the Monge-map (i.e., a vector field) from the reference samples to the target samples by a barycentric projection of the optimal transport plan. Adding samples in the embedding space corresponds to adding their vector fields, which can be used to calculate (b) the mean distribution in the embedding space, i.e., and (c)-(d) the Euclidean geodesics in this space, i.e., for . As can be seen, the calculated mean is the Wasserstein barycenter of the dataset, and the Euclidean geodesics in the embedding space follow the Wasserstein geodesics in the original space.

a.2 Implementation Details

To derive the node embeddings, we use the diffusion process in (9) for the datasets without edge features/labels, i.e., all the social network datasets (IMDB-BINARY, IMDB-MULTI, COLLAB, REDDIT-BINARY, REDDIT-MULTI-5K, and REDDIT-MULTI-12K) and four of the molecule and bioinformatics datasets (ENZYMES, PROTEINS, D&D, and NCI1). We specifically set for any and also for all self-connections, i.e., .

The remaining datasets contain edge labels that cannot be directly used with (9). Specifically, each edge in the MUTAG and PTC-MR datasets has a categorical label, encoded as a one-hot vector of dimension four. Moreover, in the ogbg-molhiv dataset, each edge has three categorical features indicating bond type (five categories), bond stereochemistry (six categories) and whether the bond is conjugated (two categories). We first convert each categorical feature to its one-hot representation, and then concatenate them together, resulting in a binary 13-dimensional feature vector for each edge.

In each of the three aforementioned datasets, for any edge , let us denote its binary feature vector by , where is equal to 4, 4, and 13 for MUTAG, PTC-MR, and ogbg-molhiv, respectively. We then use the following extension of the diffusion process in (9),


where for any , denotes the th element of , and for any node , we define as its degree over the th elements of the edge features; i.e., . We assign vectors of all-one features to the self-connections in the graph; i.e., . Note that the formulation of the diffusion process in (14) can be seen as an extension of (9), where the underlying graph with multi-dimensional edge features is broken into parallel graphs with non-negative single-dimensional edge features, and the parallel graphs perform message passing at each round/layer of the diffusion process.

For the ogbg-molhiv experiments in which virtual nodes were appended to the original molecule graphs, we set the initial feature vectors of all virtual nodes to all-zero vectors. Moreover, for any graph in the dataset with nodes, we set the edge features for the edge between the virtual node and each of the original graph nodes as . The normalization by the number of graph nodes is included so as to regulate the degree of the virtual node used in (14). We also include the resultant embedding of the virtual node at the end of the diffusion process in the calculation of the graph embedding .

In the experiments conducted on each dataset, once the node embeddings are derived from the diffusion process, we standardize them by subtracting the mean embedding and dividing by the standard deviation of the embeddings, where the statistics are calculated based on all the graphs in the dataset. Moreover, to reduce the computational complexity of estimating the graph embeddings for the ogbg-molhiv dataset, we further apply a 20-dimensional PCA on the node embeddings.


We use the following set of hyperparameters to perform a grid search over in each of the experiments:

  • Random Forest: , , and .

  • Number of Diffusion Layers in (9) and (14): .

  • Initial Node Feature Dimensionality (for ogbg-molhiv only): .

Figure 6: ROC-AUC (%) results on ogbg-molhiv dataset, when the reference distribution is calculated by -means (Section 4.2) on the training dataset (denoted as -means), compared to when it is fixed to be a normal distribution (denoted as Normal). With a -value, the choice of the template is statistically insignificant.

a.3 Sensitivity to the Choice of Reference Distribution

To measure the dependency of WEGL on the reference distribution choice, we changed the reference to a normal distribution (i.e., data-independent). We compared the results of WEGL using the new reference distribution to that using a reference distribution calculated via -means on the training set. We used the ogbg-molhiv dataset with initial node embedding of size and diffusion layers. We ran the experiment with 100 different random seeds, and measured the test ROC-AUC of WEGL calculated with the two aforementioned reference distributions. Figure 6 shows the results of this experiment, indicating that the choice of reference distribution is statistically insignificant.