Recently, there has been considerable interest in developing learning algorithms for structured data such as graphs. For example, molecular property prediction has many applications in chemistry and drug discovery (yang2019analyzing; vamathevan2019applications). Historically, graphs were systematically decomposed into features such as molecular fingerprints, turned into non-parametric graph kernels (vishwanathan2010graph; shervashidze2011weisfeiler), or, more recently, learned representations via graph neural networks (GNNs) (duvenaud2015convolutional; NIPS2016_6081; kipf2016semi).
Despite successes, graph neural networks are often underutilized in whole graph prediction tasks such as molecule property prediction. Specifically, while GNNs produce node embeddings for each atom in the molecule, these are typically aggregated via simple operations such as a sum or average, turning the molecule into a single vector prior to classification or regression. As a result, some of the information naturally extracted by node embeddings may be lost.
Recent work by Togninalli et al. togninalli2019wasserstein proposed to dispense with the aggregation step altogether and instead derive a kernel function over graphs by directly comparing node embeddings as point clouds through optimal transport (Wasserstein distance). Their non-parametric model yields better empirical performance over popular graph kernels, but haven’t been so far extended to the more challenging parametric case.
Motivated by this observation and drawing inspiration from prototypical networks (snell2017prototypical), we introduce a new class of graph neural networks where the key representational step consists of comparing each input graph to a set of abstract prototypes (fig. 1). These prototypes play the role of dictionary items or basis functions in the comparison; they are also stored as point clouds as if they were encoded from actual real graphs. Each input graph is first encoded into a set of node embeddings using a GNN. We then compare this resulting embedding point cloud to those corresponding to the prototypes. Formally, the distance between two point clouds is measured by appeal to optimal transport Wasserstein distances. The prototypes as abstract basis functions can be understood as keys that highlight property values associated with different structural features. In contrast to kernel methods, the prototypes are learned together with the GNN parameters in an end-to-end manner.
Our model improves upon traditional aggregation by explicitly tapping into the full set of node embeddings without collapsing them first to a single vector. We theoretically prove that, unlike standard GNN aggregation, our model defines a class of set functions that is universal approximator.
Introducing points clouds as free parameters creates a challenging optimization problem. Indeed, as the models are trained end-to-end, the primary signal is initially available only in aggregate form. If trained as is, the prototypes often collapse to single points, reducing the Wasserstein distance between point clouds into Euclidean comparisons of their means. To counter this effect, we introduce a contrastive regularizer which effectively prevents the model from collapsing (Section 3.2). We demonstrate empirically that it both improves model performance and generates richer prototypes.
Our contributions are summarized as follows. First, we introduce an efficiently trainable class of graph neural networks enhanced with optimal transport (OT) primitives for computing graph representations. Second, we devise a principled noise contrastive regularizer to prevent the model from collapsing back to standard aggregation, thus fully exploiting the OT geometry. Third, we provide a mathematical justification of the increased representational power compared to standard aggregation methods used in popular GNNs. Finally, our model shows consistent empirical improvements over previous state-of-the-art on molecular datasets, yielding also smoother graph embedding spaces.
2.1 Directed Message Passing Neural Networks
We briefly remind here of the simplified D-MPNN (dai2016discriminative) architecture which was successfully used for molecular property prediction by Yang et al. yang2019analyzing.
This model takes as input a directed graph , with node and edge features denoted by and respectively, for , in the vertex set and when is an edge in . The parameters of D-MPNN are the below weight matrices .
It keeps track of messages and hidden states for each step , defined as follows. An initial hidden state is set to where “” denotes concatenation. Then, the message passing operates as
where denotes ’s incoming neighbors. After steps of message passing, node embeddings are obtained by summing edge embeddings:
A final graph embedding is then obtained as
, which is usually fed to a multilayer perceptron (MLP) for classification or regression.
2.2 Optimal Transport Geometry
Optimal Transport (OT) is a mathematical framework that defines distances or similarities between objects such as probability distributions, either discrete or continuous, as the cost of an optimal transport plan from one to the other.
Wasserstein for point clouds.
Let a point cloud of size be a set of points . Given point clouds of respective sizes , a transport plan (or coupling) is a matrix of size with entries in , satisfying the two following marginal constraints: and . Intuitively, the marginal constraints mean that preserves the mass from to . We denote the set of such couplings as .
Given a cost function on , its associated Wasserstein discrepancy is defined as
We further describe the shape of optimal transports for point clouds of same sizes in Appendix B.3.
3 Model & Practice
3.1 Architecture Enhancement
Reformulating standard architectures.
As mentioned at the end of Section 2.1, the final graph embedding obtained by aggregating node embeddings is usually fed to a MLP performing a matrix-multiplication . Replacing by a distance/kernel allows the processing of more general graph representations than just vectors in
, such as point clouds or adjacency tensors.
From a single point to a point cloud.
We propose to replace the aggregated graph embedding by the point cloud (of unaggregated node embeddings) , and the inner-products by the below written Wasserstein discrepancy:
where the are point clouds and free parameters, and the cost is chosen as or . Note that both options yield identical optimal transport plans.
Greater representational power.
We formulate mathematically in Section 4 to what extent this kernel has a strictly greater representational power than the kernel corresponding to standard inner-product on top of a sum aggregation, to distinguish between different point clouds. In practice, we would also like our model to exploit its additional representational power. This practical concern is discussed in the next subsection.
3.2 Contrastive Regularization
What would happen to if all points belonging to point cloud would collapse to the same point ? All transport plans would yield the same cost, giving for :
In this scenario, our proposition would simply over-parametrize the standard Euclidean model.
A first obstacle. Our first empirical trials with OT-enhanced GNNs showed that a model trained with only the Wasserstein component would sometimes perform similarly to the Euclidean baseline, in spite of its greater representational power. Since these two models achieved both similar test and train performance, the absence of improvement in generalization was most likely not due to overfitting.
Further investigation revealed that the Wasserstein model would naturally displace the points in each of its free point clouds in such a way that the optimal transport plan obtained by maximizing was not discriminative, i.e. many other transports would yield a similar Wasserstein cost. Indeed, as shown in Eq. (5), if each point cloud collapses to its mean, then the Wasserstein geometry collaspses to Euclidean geometry. In this scenario, any transport plan yields the same Wasserstein cost. Further explanations are provided in Appendix A.1.
This observation has lead us to consider the use of a regularizer which would encourage the model to displace its free point clouds such that the optimal transport plans it computes would be discriminative against chosen contrastive transport plans. Namely, consider a point cloud of node embeddings and let be an optimal transport plan obtained in the computation of ; for each , we then build a set of noisy/contrastive transports. If we denote by the Wasserstein cost obtained for the particular transport , then our contrastive regularization consists in maximizing the term:
which can be interpreted as the log-likelihood that the correct transport be (as it should) a better minimizer of than its negative samples. This can be considered as an approximation of , where the partition function is approximated by our selection of negative examples, as done e.g. by Nickel & Kiela nickel2017poincare. Its effect of is shown in Figure 3.
The selection of negative examples must reflect the following trade-off: (i) to not be too large, for computational efficiency while (ii) containing sufficiently meaningful and challenging contrastive samples. Details about choice of contrastive samples are exposed in Appendix A.2. Note that replacing the set with a singleton
for a contrastive random variablewould let us rewrite Eq. (6) as222where
is the sigmoid function.
, reminiscent of noise contrastive estimation(gutmann2010noise).
3.3 Optimization & Complexity
Backpropagating gradients through optimal transport (OT) has been the subject of recent research investigations: Genevay et al. genevay2017learning explain how to unroll and differentiate through the Sinkhorn procedure solving OT, which was extended by Schmitz et al. schmitz2018wasserstein to Wasserstein barycenters. However, more recently, Xu xu2019gromovfactor proposed to simply invoke the envelop theorem (afriat1971theory) to support the idea of keeping the optimal transport plan fixed during the back-propagation of gradients through Wasserstein distances. For the sake of simplicity and training stability, we resort to the latter procedure: keeping fixed during back-propagation. We discuss complexity in appendix C.
4 Theoretical Analysis
In this section we show that the standard architecture lacks a fundamental property of universal approximation of functions defined on point clouds, and that our proposed architecture recovers this property. We will denote by the set of point clouds of size in .
As seen in Section 3.1, we have replaced the sum aggregation followed by the Euclidean inner-product by Wasserstein discrepancies. How does this affect the function class and representations?
A common framework used to analyze the geometry inherited from similarities and discrepancies is that of kernel theory. A kernel on a set is a symmetric function , which can either measure similarities or discrepancies. An important property of a given kernel on a space is whether simple functions defined on top of this kernel can approximate any continuous function on the same space. This is called universality: a crucial property to regress unknown target functions.
Universal kernels. A kernel defined on is said to be universal if the following holds: for any compact subset , the set of functions in the form333For , and . is dense in the set of continuous functions from to , w.r.t the sup norm , denoting the sigmoid. Although the notion of universality does not indicate how easy it is in practice to learn the correct function, it at least guarantees the absence of a fundamental bottleneck of the model using this kernel.
We have that:
The aggregation kernel is not universal.
The Wasserstein kernel defined in Theorem 2 is universal.
Proof: See Appendix B.1.
Universality of the Wasserstein kernel essentially comes from the fact that its square-root defines a metric, and in particular from the axiom of separation of distances: if then .
For the sake of simplified mathematical analysis, similarity kernels are often required to be positive definite (p.d.), which corresponds to discrepancy kernels being conditionally negative definite (c.n.d.). Although such a property has the benefit of yielding the mathematical framework of Reproducing Kernel Hilbert Spaces, it essentially implies linearity, i.e. the possibility to embed the geometry defined by that kernel in a linear vector space.
We now show that, interestingly, the Wasserstein kernel we used does not satisfy this property, and hence constitutes an interesting instance of a universal, non p.d. kernel. Let us remind these notions.
Kernel definiteness. A kernel is positive definite (p.d.) on if for , and , we have . It is conditionally negative definite (c.n.d.) on if for , and such that , we have .
These two notions relate to each other via the below result boughorbel2005conditionally:
Let be a symmetric kernel on , let and define the kernel:
Then is p.d. if and only if is c.n.d. Example: and yield .
The aggregating kernel against which we wish to compare the Wasserstein kernel is the inner-product over a summation of the points in the point clouds: .
One can easily show that this also defines a p.d. kernel, and that . However, the Wasserstein kernel is not p.d., as shown by the below theorem which we prove in Appendix B.2.
We have that:
The (similarity) Wasserstein kernel is not positive definite;
The (discrepancy) Wasserstein kernel is not conditionally negative definite, where:
5.1 Experimental Setup
We test our model on 4 benchmark molecular property prediction datasets (yang2019analyzing) including both regression (ESOL, Lipophilicity) and classification (BACE, BBBP) tasks. These datasets cover a variety of different complex chemical properties (e.g. ESOL - water solubility, LIPO - octanol/water distribution coefficient, BACE - inhibition of human -secretase, BBBP - blood-brain barrier penetration). We show that our models improves over state-of-the-art baselines.
is the state-of-the-art graph neural network that we use as our primary baseline, as well as the underlying graph model for our prototype models. Its architecture is described in section 2.1.
is the model that treats point clouds as point sets, and computes the Wasserstein distances to each point cloud (using either L2 distance or (minus) dot product cost functions) as the molecular embedding.
is a special case of , in which the point clouds have a single point. Instead of using Wasserstein distances, we instead just compute simple Euclidean distances between the aggregated graph embedding and point clouds. Here, we omit using dot product distances, as that model is mathematically equivalent to the GNN model.
We use the the POT library (flamary2017pot) to compute Wasserstein distances using the network simplex algorithm (ot.emd), which we find empirically to be faster than using the Sinkhorn algorithm, due to the small size of the graphs present in our datasets. We define the cost matrix by taking the pairwise L2 or negative dot product distances. As mentioned in Section 3.3, we fix the transport plan, and only backprop through the cost matrix for computational efficiency. Additionally, since the sum aggregation operator easily accounts for the sizes of input graphs, we multiply the OT distance between two point clouds by their respective sizes. To avoid the problem of point clouds collapsing, we employ the contrastive regularizer defined in Section 3.2. More details about experimental setup in Appendix D.1. We also tried extensions to our prototype models using Gromov-Wasserstein geometry. However, we found that these models proved much more difficult to optimize in practice.
|ESOL (RMSE)||Lipo (RMSE)||BACE (AUC)||BBBP (AUC)|
|.635 .027||.646 .041||.865 .013||.915 .010|
|.611 .034||.580 .016||.865 .010||.918 .009|
|(no reg.)||.608 .029||.637 .018||.867 .014||.919 .009|
|.594 .031||.629 .015||.871 .014||.919 .009|
|(no reg.)||.616 .028||.615 .025||.870 .012||.920 .010|
|.605 .029||.604 .014||.873 .015||.920 .010|
5.2 Experimental Results
5.2.1 Regression and Classification
The results on the property prediction datasets are shown in Table 1. We find that the prototype models outperform the GNN on all 4 property prediction tasks, showing that this model paradigm can be more powerful than conventional GNN models. Moreover, the prototype models using Wasserstein distance () achieves better performance on 3 out of 4 of the datasets compared to the prototype model using only Euclidean distances (). This confirms our hypothesis that Wasserstein distance confers greater discriminative power compared to traditional aggregation methods (summation).
5.2.2 Noise Contrastive Regularizer
Without any constraints, the Wasserstein prototype model will often collapse the set of points in a point cloud into a single point. As mentioned in Section 3.2, we use a contrastive regularizer to force the model to meaningfully distribute point clouds in the embedding space. We show 2D embeddings in Fig. 3, illustrating that without contrastive regularization, prototype point clouds are often displaced close to their mean, while regularization forces them to nicely scatter.
5.2.3 Learned Embedding Space: Qualitative and Quantitative Results
To further support our claim that Wasserstein distance provides more powerful representations, we also examine the embedding space of the GNN baseline and our Wasserstein model. Using the best performing models, we compute the pairwise difference in embedding vectors and the labels for each test data point on the ESOL dataset. Then, we compute two measures of rank correlation, Spearman correlation coefficient () and Pearson correlation coefficient (). This procedure is reminiscent of evaluation tasks for word embeddings w.r.t how semantic similarity in embedding space correlates with human labels (luong2013better).
Our achieves better and scores compared to the GNN model (Table 5), that indicating our Wasserstein model constructs more meaningful embeddings with respect to the label distribution. Indeed, Figure 5 plots the pairwise scores for the GNN model (left) and the model (right). Our model, trained to optimize distances in the embedding space, produces more meaningful representations with respect to the label of interest.
Moreover, as can be seen in Figure 6, our model also provides more robust molecular embeddings compared to the baseline, in the following sense: we observe that a small perturbation of a molecular embedding corresponds to a small change in predicted property value – a desirable phenomenon that holds rarely for the baseline GNN model. Qualitatively, this is shown in Figure 6. Our Wasserstein prototype models yields smoother heatmaps, which is desirable for molecular optimization in the latent space via gradient methods.
projections of molecular embeddings (before the last linear layer) w.r.t. their associated predicted labels. Heat colors are interpolations based only on the test molecules from each dataset. Comparing (a) vs (b) and (c) vs (d), we can observe a smoother space of our model compared to the GNN baseline as explained in the main text.
6 Related Work
Graph Neural Networks were introduced by Gori et al. gori2005new and Scarselli et al. scarselli2008graph
as a form of recurrent neural network. Graph convolutional networks (GCN) made their first appearance later on in various forms. Duvenaud et al.duvenaud2015convolutional and Atwood et al. atwood2016diffusion proposed a propagation rule inspired from convolution and diffusion, although these methods do not scale to graphs with either large degree distribution or node cardinality, respectively. Niepert et al. niepert2016learning defined a GCN as a 1D convolution on a chosen node ordering. Kearnes et al. kearnes2016molecular also used graph convolutions with great success to generate high quality molecular fingerprints. Efficient spectral methods were also proposed bruna2013spectral; NIPS2016_6081. Kipf & Welling kipf2016semi simplified their propagation rule, motivated as well from spectral graph theory (hammond2011wavelets), achieving impressive empirical results. Most of these different architectures were later unified into a message passing neural networks framework by Gilmer et al. gilmer2017neural, which applies them to molecular property prediction. A directed variant of message passing was motivated by Dai et al. dai2016discriminative, which was later used to improve state-of-the-art in molecular property prediction on a wide variety of datasets by ChemProp (yang2019analyzing). Another notable application includes recommender systems (ying2018graph). Ying et al. ying2018hierarchical proposed DiffPool, which performs a pooling operation for GNN in a hierarchical fashion. Inspired by DeepSets zaheer2017deep, Xu et al. xu2018powerful suggest both a simplification and generalization of certain GNN architectures, which should theoretically be powerful enough to discriminate between any different local neighborhoods, provided that hidden dimensions grow as much as the input size. Other recent approaches suggest to modify the sum-aggregation of node embeddings in the GCN architecture with the aim to preserve more information kondor2018covariant; hongbin2020geomgcn. On the other hand, Hongbin et al. hongbin2020geomgcn propose to preserve more semantic information by performing a bi-level aggregation which depends on the local geometry of the neighborhood of the given node in the graph. Other recent geometry-inspired GNN include adaptations to embeddings lying in hyperbolic spaces (liu2019hyperbolic; chami2019hyperbolic) or spaces of constant sectional curvature (bachmann2019constant).
We propose : an enhancement of GNN architectures replacing sum-aggregation of node embeddings via Optimal Transport geometry. We introduce an efficient regularizer which prevents the enhanced model from collapsing back to standard aggregation. Empirically, our models show strong performances in different molecular property prediction tasks. The induced geometry of their latent representations also exhibits stronger correlation with target labels.
We thank Louis Abraham for a counter example of positive-definiteness of the Wasserstein Kernel, Andreas Bloch for help with figures and Wengong Jin & Rachel Wu for detailed comments and feedback.
Appendix A Further Details on Contrastive Regularization
One may speculate that it was locally easier for the model to extract valuable information if it would behave like the Euclidean component, preventing it from exploring other roads of the optimization landscape. To better understand this situation, consider the scenario in which a subset of points in a free point cloud “collapses", i.e. become close to each other (see Figure 3), thus sharing similar distances to all the node embeddings of real input graphs. The submatrix of the optimal transport matrix corresponding to these collapsed points can be equally replaced by any other submatrix with the same marginals (i.e. same two vectors obtained by summing rows or columns), meaning that the optimal transport matrix is not discriminative. In general, we want to avoid any two rows or columns in the Wasserstein cost matrix being proportional. An additional problem of point collapsing is that it is a non-escaping situation when using gradient-based learning methods. The reason is that gradients of these collapsed points would become and remain identical, thus nothing will encourage them to “separate" in the future.
a.2 On the Choice of Contrastive Samples
Our experiments were conducted with ten negative samples for each correct transport plan. Five of them were obtained by initializing a matrix with uniform i.i.d entries from and performing around five Sinkhorn iterations [cuturi2013sinkhorn] in order to make the matrix satisfy the marginal constraints. The other five were obtained by randomly permuting the columns of the correct transport plan. The latter choice has the desirable effect of penalizing the points of a free point cloud to collapse onto the same point. Indeed, the rows of index points in , while its columns index points in .
Appendix B Theoretical Results
b.1 Proof of Theorem 1
Let us first justify why is not universal. Consider a function such that there exists satisfying both and . Clearly, any function of the form would take equal values on and and hence would not approximate arbitrarily well.
To justify that is universal, we take inspiration from the proof of universality of neural networks [cybenko1989approximation].
Denote by the space of finite, signed regular Borel measures on .
We say that is discriminatory w.r.t a kernel if for a measure ,
for all and implies that .
We start by reminding a lemma coming from the original paper on the universality of neural networks by Cybenko [cybenko1989approximation].
If is discriminatory w.r.t. then is universal.
Proof: Let be the subset of functions of the form for any , and and denote by the closure444W.r.t the topology defined by the sup norm . of in . Assume by contradiction that . By the Hahn-Banach theorem, there exists a bounded linear functional on such that for all , and such that there exists s.t. . By the Riesz representation theorem, this bounded linear functional is of the form:
for all , for some . Since is in , we have
for all and . Since is discriminatory w.r.t. , this implies that and hence , which is a contradiction with . Hence , i.e. is dense in and is universal.
Now let us look at the part of the proof that is new.
is discriminatory w.r.t. .
Proof: Note that for any , when we have that goes to 1 if , to 0 if and to if .
Denote by and for and for . By the Lebesgue Bounded Convergence Theorem we have:
Since this is true for any , it implies that . From (because for ), we also have . Hence is zero on all balls defined by the metric .
From the Hahn decomposition theorem, there exist disjoint Borel sets such that and where , for any Borel set with being positive measures. Since and coincide on all balls on a finite dimensional metric space, they coincide everywhere [hoffmann1976measures] and hence .
Combining the previous lemmas with concludes the proof.
b.2 Proof of Theorem 2
We build a counter example. We consider point clouds of size and dimension . First, define for . Then take , , and . On the one hand, if , then all vectors in the two point clouds are orthogonal, which can only happen for . On the other hand, if , then either or . This yields the following Gram matrix
whose determinant is
, which implies that this matrix has a negative eigenvalue.
This comes from proposition 1. Choosing and to be the trivial point cloud made of times the zero vector yields . Since is not positive definite from the previous point of the theorem, is not conditionally negative definite from proposition 1.
b.3 Shape of the optimal transport plan for point clouds of same size
The below result describes the shape of optimal transport plans for point clouds of same size. For the sake of curiosity, we also illustrate in Figure 2 the optimal transport for point clouds of different sizes. We note that non-square transports seem to remain relatively sparse as well. This is in line with our empirical observations.
For there exists a rescaled permutation matrix which is an optimal transport plan, i.e.
It is well known from Birkhoff’s theorem that every squared doubly-stochastic matrix is a convex combination of permutation matrices. Since the Wasserstein cost for a given transportis a linear function, it is also a convex/concave function, and hence it is maximized/minimized over the convex compact set of couplings at one of its extremal points, namely one of the rescaled permutations, yielding the desired result. ∎
Appendix C Complexity
Computing the Wasserstein optimal transport plan between two point clouds consists in the minimization of a linear function under linear constraints. It can either be performed exactly by using network simplex methods or interior point methods as done by [pele2009fast] in time , or approximately up to via the Sinkhorn algorithm [cuturi2013sinkhorn] in time . More recently, [dvurechensky2018computational] proposed an algorithm solving OT up to with time complexity via a primal-dual method inspired from accelerated gradient descent.
In our experiments, we used the Python Optimal Transport (POT) library [flamary2017pot]. We noticed empirically that the EMD solver yielded faster and more accurate solutions than Sinkhorn for our datasets, because the graphs and point clouds were small enough ( elements). However, Sinkhorn may take the lead for larger graphs.
c.2 General remarks
Significant speed up could potentially be obtained by rewritting the POT library for it to solve OT in batches over GPUs. In our experiments, we ran all jobs on CPUs. A slow-down in speed by a factor 4 was observed from a purely Euclidean to purely Wasserstein models.
Appendix D Further Experimental Details
d.1 Setup of Experiments
Each dataset is split randomly 5 times into 80%:10%:10% train, validation and test sets. For each of the 5 splits, we run each model 5 times to reduce the variance in particular data splits (resulting in each model being run 25 times). We search hyperparameters for each split of the data, and then take the average performance over all the splits. The hyperparameters are separately searched for each data split, so that the model performance is based on a completely unseen test set, and that there is no data leakage across data splits. The models are trained for 150 epochs with early stopping if the validation error has not improved in 50 epochs and a batch size of 16. We train the models using the Adam optimizer with a learning rate of 5e-4. For the prototype models, we use different learning rates for the GNN and the point clouds (5e-4 and 5e-3 respectively), because empirically we find that the gradients are much smaller for the point clouds. The molecular datasets used for experiments here are small in size (varying from 1-4k data points), so this is a fair method of comparison, and is indeed what is done in other works on molecular property prediction[yang2019analyzing].
|Parameter Name||Search Values||Description|
|nepochs||Number of epochs trained|
|batchsize||Size of each batch|
|lr||5e-4||Overall learning rate for model|
|lrpc||5e-3||Learning rate for the free parameter point clouds|
|nlayers||Number of layers in the GNN|
|nhidden||Size of hidden dimension in GNN|
|nffnhidden||Size of the output feed forward layer|
|dropoutgnn||Dropout probability for GNN|
|dropoutfnn||Dropout probability for feed forward layer|
|npc||Number of free parameter point clouds in prototype models|
|pcsize||Number of points in free parameter point clouds|
|pchidden||Size of hidden dimension in point clouds|
|nccoef||Coefficient for noise contrastive regularization|