Hierarchical Graph Representation Learning withDifferentiable Pooling

06/22/2018 ∙ by Rex Ying, et al. ∙ University of Southern California TU Dortmund Stanford University 0

Recently, graph neural networks (GNNs) have revolutionized the field of graph representation learning through effectively learned node embeddings, and achieved state-of-the-art results in tasks such as node classification and link prediction. However, current GNN methods are inherently flat and do not learn hierarchical representations of graphs---a limitation that is especially problematic for the task of graph classification, where the goal is to predict the label associated with an entire graph. Here we propose DiffPool, a differentiable graph pooling module that can generate hierarchical representations of graphs and can be combined with various graph neural network architectures in an end-to-end fashion. DiffPool learns a differentiable soft cluster assignment for nodes at each layer of a deep GNN, mapping nodes to a set of clusters, which then form the coarsened input for the next GNN layer. Our experimental results show that combining existing GNN methods with DiffPool yields an average improvement of 5-10 benchmarks, compared to all existing pooling approaches, achieving a new state-of-the-art on four out of five benchmark data sets.



There are no comments yet.


page 2

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In recent years there has been a surge of interest in developing graph neural networks (GNNs)—general deep learning architectures that can operate over graph structured data, such as social network data

hamilton2017inductive ; kipf2017semi ; Vel+2018 or graph-based representations of molecules dai2016discriminative ; Duv+2015 ; Gil+2017 . The general approach with GNNs is to view the underlying graph as a computation graph and learn neural network primitives that generate individual node embeddings by passing, transforming, and aggregating node feature information across the graph Gil+2017 ; hamilton2017inductive . The generated node embeddings can then be used as input to any differentiable prediction layer, e.g., for node classification hamilton2017inductive or link prediction Sch+2017 , and the whole model can be trained in an end-to-end fashion.

However, a major limitation of current GNN architectures is that they are inherently flat as they only propagate information across the edges of the graph and are unable to infer and aggregate the information in a hierarchical way. For example, in order to successfully encode the graph structure of organic molecules, one would ideally want to encode the local molecular structure (e.g., individual atoms and their direct bonds) as well as the coarse-grained structure of the molecular graph (e.g., groups of atoms and bonds representing functional units in a molecule). This lack of hierarchical structure is especially problematic for the task of graph classification, where the goal is to predict the label associated with an entire graph. When applying GNNs to graph classification, the standard approach is to generate embeddings for all the nodes in the graph and then to globally pool all these node embeddings together, e.g., using a simple summation or neural network that operates over sets dai2016discriminative ; Duv+2015 ; Gil+2017 ; Li+2016 . This global pooling approach ignores any hierarchical structure that might be present in the graph, and it prevents researchers from building effective GNN models for predictive tasks over entire graphs.

Here we propose DiffPool, a differentiable graph pooling module that can be adapted to various graph neural network architectures in an hierarchical and end-to-end fashion (Figure 1). DiffPool allows for developing deeper GNN models that can learn to operate on hierarchical representations of a graph. We develop a graph analogue of the spatial pooling operation in CNNs krizhevsky2012imagenet , which allows for deep CNN architectures to iteratively operate on coarser and coarser representations of an image. The challenge in the GNN setting—compared to standard CNNs—is that graphs contain no natural notion of spatial locality, i.e., one cannot simply pool together all nodes in a “ patch” on a graph, because the complex topological structure of graphs precludes any straightforward, deterministic definition of a “patch”. Moreover, unlike image data, graph data sets often contain graphs with varying numbers of nodes and edges, which makes defining a general graph pooling operator even more challenging.

In order to solve the above challenges, we require a model that learns how to cluster together nodes to build a hierarchical multi-layer scaffold on top of the underlying graph. Our approach DiffPool learns a differentiable soft assignment at each layer of a deep GNN, mapping nodes to a set of clusters based on their learned embeddings. In this framework, we generate deep GNNs by “stacking” GNN layers in a hierarchical fashion (Figure 1): the input nodes at the layer GNN module correspond to the clusters learned at the layer GNN module. Thus, each layer of DiffPool coarsens the input graph more and more, and DiffPool is able to generate a hierarchical representation of any input graph after training. We show that DiffPool can be combined with various GNN approaches, resulting in an average 7% gain in accuracy and a new state of the art on four out of five benchmark graph classification tasks. Finally, we show that DiffPool

 can learn interpretable hierarchical clusters that correspond to well-defined communities in the input graphs.

Figure 1: High-level illustration of our proposed method DiffPool. At each hierarchical layer, we run a GNN model to obtain embeddings of nodes. We then use these learned embeddings to cluster nodes together and run another GNN layer on this coarsened graph. This whole process is repeated for

layers and we use the final output representation to classify the graph.

2 Related Work

Our work builds upon a rich line of recent research on graph neural networks and graph classification.

General graph neural networks

. A wide variety of graph neural network (GNN) models have been proposed in recent years, including methods inspired by convolutional neural networks

Bru+2014 ; Def+2015 ; Duv+2015 ; hamilton2017inductive ; kipf2017semi ; Lei+2017 ; niepert2016learning ; Vel+2018

, recurrent neural networks

Li+2016 , recursive neural networks bianchini2001 ; Sca+2009 and loopy belief propagation dai2016discriminative . Most of these approaches fit within the framework of “neural message passing” proposed by Gilmer et al.  Gil+2017 . In the message passing framework, a GNN is viewed as a message passing algorithm where node representations are iteratively computed from the features of their neighbor nodes using a differentiable aggregation function. Hamilton et al.  Ham+2017a provides a conceptual review of recent advancements in this area, and Bronstein et albronstein2017geometric outlines connections to spectral graph convolutions.

Graph classification with graph neural networks. GNNs have been applied to a wide variety of tasks, including node classification hamilton2017inductive ; kipf2017semi , link prediction kipf2018 , graph classification dai2016discriminative ; Duv+2015 ; zhang2018end , and chemoinformatics Mer+2005 ; Lus+2013 ; Fou+2017 ; Jin+2018 ; Sch+2017 . In the context of graph classification—the task that we study here—a major challenge in applying GNNs is going from node embeddings, which are the output of GNNs, to a representation of the entire graph. Common approaches to this problem include simply summing up or averaging all the node embeddings in a final layer Duv+2015 , introducing a “virtual node” that is connected to all the nodes in the graph Li+2016 , or aggregating the node embeddings using a deep learning architecture that operates over sets Gil+2017 . However, all of these methods have the limitation that they do not learn hierarchical representations (i.e., all the node embeddings are globally pooled together in a single layer), and thus are unable to capture the natural structures of many real-world graphs. Some recent approaches have also proposed applying CNN architectures to the concatenation of all the node embeddings niepert2016learning ; zhang2018end , but this requires a specifying (or learning) a canonical ordering over nodes, which is in general very difficult and equivalent to solving graph isomorphism.

Lastly, there are some recent works that learn hierarchical graph representations by combining GNNs with deterministic graph clustering algorithms Def+2015 ; simonovsky2017dynamic ; Fey+2018 , following a two-stage approach. However, unlike these previous approaches, we seek to learn the hierarchical structure in an end-to-end fashion, rather than relying on a deterministic graph clustering subroutine.

3 Proposed Method

The key idea of DiffPool is that it enables the construction of deep, multi-layer GNN models by providing a differentiable module to hierarchically pool graph nodes. In this section, we outline the DiffPool module and show how it is applied in a deep GNN architecture.

3.1 Preliminaries

We represent a graph as , where is the adjacency matrix, and is the node feature matrix assuming each node has features.111We do not consider edge features, although one can easily extend the algorithm to support edge features using techniques introduced in simonovsky2017dynamic . Given a set of labeled graphs where is the label corresponding to graph , the goal of graph classification is to learn a mapping

that maps graphs to the set of labels. The challenge—compared to standard supervised machine learning setup—is that we need a way to extract useful feature vectors from these input graphs. That is, in order to apply standard machine learning methods for classification, e.g., neural networks, we need a procedure to convert each graph to an finite dimensional vector in


Graph neural networks. In this work, we build upon graph neural networks in order to learn useful representations for graph classification in an end-to-end fashion. In particular, we consider GNNs that employ the following general “message-passing” architecture:


where are the node embeddings (i.e., “messages”) computed after steps of the GNN and is the message propagation function, which depends on the adjacency matrix, trainable parameters , and the node embeddings generated from the previous message-passing step.222For notational convenience, we assume that the embedding dimension is for all ; however, in general this restriction is not necessary. The input node embeddings at the initial message-passing iteration , are initialized using the node features on the graph, .

There are many possible implementations of the propagation function Gil+2017 ; hamilton2017inductive . For example, one popular variant of GNNs—Kipf’s et alkipf2017semi Graph Convolutional Networks (GCNs)—implements

using a combination of linear transformations and ReLU non-linearities:


where , and is a trainable weight matrix. The differentiable pooling model we propose can be applied to any GNN model implementing Equation (1), and is agnostic with regards to the specifics of how is implemented.

A full GNN module will run iterations of Equation (1) to generate the final output node embeddings , where is usually in the range 2-6. For simplicity, in the following sections we will abstract away the internal structure of the GNNs and use to denote an arbitrary GNN module implementing iterations of message passing according to some adjacency matrix and initial input node features .

Stacking GNNs and pooling layers. GNNs implementing Equation (1) are inherently flat, as they only propagate information across edges of a graph. The goal of this work is to define a general, end-to-end differentiable strategy that allows one to stack multiple GNN modules in a hierarchical fashion. Formally, given , the output of a GNN module, and a graph adjacency matrix , we seek to define a strategy to output a new coarsened graph containing nodes, with weighted adjacency matrix and node embeddings . This new coarsened graph can then be used as input to another GNN layer, and this whole process can be repeated times, generating a model with GNN layers that operate on a series of coarser and coarser versions of the input graph (Figure 1). Thus, our goal is to learn how to cluster or pool together nodes using the output of a GNN, so that we can use this coarsened graph as input to another GNN layer. What makes designing such a pooling layer for GNNs especially challenging—compared to the usual graph coarsening task—is that our goal is not to simply cluster the nodes in one graph, but to provide a general recipe to hierarchically pool nodes across a broad set of input graphs. That is, we need our model to learn a pooling strategy that will generalize across graphs with different nodes, edges, and that can adapt to the various graph structures during inference.

3.2 Differentiable Pooling via Learned Assignments

Our proposed approach, DiffPool, addresses the above challenges by learning a cluster assignment matrix over the nodes using the output of a GNN model. The key intuition is that we stack GNN modules and learn to assign nodes to clusters at layer in an end-to-end fashion, using embeddings generated from a GNN at layer . Thus, we are using GNNs to both extract node embeddings that are useful for graph classification, as well to extract node embeddings that are useful for hierarchical pooling. Using this construction, the GNNs in DiffPool learn to encode a general pooling strategy that is useful for a large set of training graphs. We first describe how the DiffPool module pools nodes at each layer given an assignment matrix; following this, we discuss how we generate the assignment matrix using a GNN architecture.

Pooling with an assignment matrix. We denote the learned cluster assignment matrix at layer as . Each row of corresponds to one of the nodes (or clusters) at layer , and each column of corresponds to one of the clusters at the next layer . Intuitively, provides a soft assignment of each node at layer to a cluster in the next coarsened layer .

Suppose that has already been computed, i.e., that we have computed the assignment matrix at the -th layer of our model. We denote the input adjacency matrix at this layer as and denote the input node embedding matrix at this layer as . Given these inputs, the DiffPool layer coarsens the input graph, generating a new coarsened adjacency matrix and a new matrix of embeddings for each of the nodes/clusters in this coarsened graph. In particular, we apply the two following equations:


Equation (3) takes the node embeddings and aggregates these embeddings according to the cluster assignments , generating embeddings for each of the clusters. Similarly, Equation (4) takes the adjacency matrix and generates a coarsened adjacency matrix denoting the connectivity strength between each pair of clusters.

Through Equations (3) and (4), the DiffPool layer coarsens the graph: the next layer adjacency matrix represents a coarsened graph with nodes or cluster nodes, where each individual cluster node in the new coarsened graph corresponds to a cluster of nodes in the graph at layer . Note that is a real matrix and represents a fully connected edge-weighted graph; each entry can be viewed as the connectivity strength between cluster and cluster . Similarly, the -th row of corresponds to the embedding of cluster . Together, the coarsened adjacency matrix and cluster embeddings can be used as input to another GNN layer, a process which we describe in detail below.

Learning the assignment matrix. In the following we describe the architecture of DiffPool, i.e., how DiffPool generates the assignment matrix and embedding matrices that are used in Equations (3) and (4). We generate these two matrices using two separate GNNs that are both applied to the input cluster node features and coarsened adjacency matrix . The embedding GNN at layer is a standard GNN module applied to these inputs:


i.e., we take the adjacency matrix between the cluster nodes at layer (from Equation 4) and the pooled features for the clusters (from Equation 3) and pass these matrices through a standard GNN to get new embeddings for the cluster nodes. In contrast, the pooling GNN at layer , uses the input cluster features and cluster adjacency matrix to generate an assignment matrix:


where the softmax function is applied in a row-wise fashion. The output dimension of corresponds to a pre-defined maximum number of clusters in layer

, and is a hyperparameter of the model.

Note that these two GNNs consume the same input data but have distinct parameterizations and play separate roles: The embedding GNN generates new embeddings for the input nodes at this layer, while the pooling GNN generates a probabilistic assignment of the input nodes to clusters.

In the base case, the inputs to Equations (5) and Equations (6) at layer are simply the input adjacency matrix and the node features for the original graph. At the penultimate layer of a deep GNN model using DiffPool, we set the assignment matrix be a vector of ’s, i.e., all nodes at the final layer

are assigned to a single cluster, generating a final embedding vector corresponding to the entire graph. This final output embedding can then be used as feature input to a differentiable classifier (e.g., a softmax layer), and the entire system can be trained end-to-end using stochastic gradient descent.

Permutation invariance. Note that in order to be useful for graph classification, the pooling layer should be invariant under node permutations. For DiffPool we get the following positive result, which shows that any deep GNN model based on DiffPool is permutation invariant, as long as the component GNNs are permutation invariant.

Proposition 1.

Let be any permutation matrix, then as long as (i.e., as long as the GNN method used is permutation invariant).


Equations (5) and (6) are permutation invariant by the assumption that the GNN module is permutation invariant. And since any permutation matrix is orthogonal, applying to Equation (3) and (4) finishes the proof. ∎

3.3 Auxiliary Link Prediction Objective and Entropy Regularization

In practice, it can be difficult to train the pooling GNN (Equation 4) using only gradient signal from the graph classification task. Intuitively, we have a non-convex optimization problem and it can be difficult to push the pooling GNN away from spurious local minima early in training. To alleviate this issue, we train the pooling GNN with an auxiliary link prediction objective, which encodes the intuition that nearby nodes should be pooled together. In particular, at each layer , we minimize , where denotes the Frobenius norm. Note that the adjacency matrix at deeper layers is a function of lower level assignment matrices, and changes during training.

Another important characteristic of the pooling GNN (Equation 4) is that the output cluster assignment for each node should generally be close to a one-hot vector, so that the membership for each cluster or subgraph is clearly defined. We therefore regularize the entropy of the cluster assignment by minimizing , where denotes the entropy function, and is the -th row of .

During training, and from each layer are added to the classification loss. In practice we observe that training with the side objective takes longer to converge, but nevertheless achieves better performance and more interpretable cluster assignments.

4 Experiments

We evaluate the benefits of DiffPool against a number of state-of-the-art graph classification approaches, with the goal of answering the following questions:

  1. [leftmargin=20pt]

  2. How does DiffPool compare to other pooling methods proposed for GNNs (e.g., using sort pooling zhang2018end or the Set2Set method Gil+2017 )?

  3. How does DiffPool combined with GNNs compare to the state-of-the-art for graph classification task, including both GNNs and kernel-based methods?

  4. Does DiffPool compute meaningful and interpretable clusters on the input graphs?

Data sets. To probe the ability of DiffPool to learn complex hierarchical structures from graphs in different domains, we evaluate on a variety of relatively large graph data sets chosen from benchmarks commonly used in graph classification KKMMN2016 . We use protein data sets including Enzymes, Proteins Borgwardt2005a ; Fer+2013 , D&D Dob+2003 , the social network data set Reddit-Multi-12k Yan+2015a , and the scientific collaboration data set Collab Yan+2015a . See Appendix A for statistics and properties. For all these data sets, we perform 10-fold cross-validation to evaluate model performance, and report the accuracy averaged over 10 folds.

Model configurations. In our experiments, the GNN model used for DiffPool is built on top of the GraphSage architecture, as we found this architecture to have superior performance compared to the standard GCN approach as introduced in kipf2017semi . We use the “mean” variant of GraphSage hamilton2017inductive and apply a DiffPool layer after every two GraphSage layers in our architecture. A total of 2 DiffPool layers are used for the datasets. For small datasets such as Enzymes and Collab, 1 DiffPool layer can achieve similar performance. After each DiffPool layer, 3 layers of graph convolutions are performed, before the next DiffPool layer, or the readout layer. The embedding matrix and the assignment matrix are computed by two separate GraphSage models respectively. In the 2 DiffPool layer architecture, the number of clusters is set as of the number of nodes before applying DiffPool, while in the 1 DiffPool layer architecture, the number of clusters is set as

. Batch normalization

ioffe2015batch is applied after every layer of GraphSage. We also found that adding an normalization to the node embeddings at each layer made the training more stable. In Section 1, we also test an analogous variant of DiffPool on the Structure2Vec dai2016discriminative architecture, in order to demonstrate how DiffPool

 can be applied on top of other GNN models. All models are trained for 3 000 epochs with early stopping applied when the validation loss starts to drop. We also evaluate two simplified versions of


  • [leftmargin=15pt, topsep=-5pt, parsep=0pt]

  • DiffPool-Det, is a DiffPool model where assignment matrices are generated using a deterministic graph clustering algorithm dhillon2007weighted .

  • DiffPool-NoLP is a variant of DiffPool where the link prediction side objective is turned off.

4.1 Baseline Methods

In the performance comparison on graph classification, we consider baselines based upon GNNs (combined with different pooling methods) as well as state-of-the-art kernel-based approaches.

GNN-based methods.

  • [leftmargin=15pt, topsep=-5pt, parsep=0pt]

  • GraphSage with global mean-pooling hamilton2017inductive . Other GNN variants such as those proposed in kipf2017semi are omitted as empirically GraphSAGE obtained higher performance in the task.

  • Structure2Vec (S2Vdai2016discriminative is a state-of-the-art graph representation learning algorithm that combines a latent variable model with GNNs. It uses global mean pooling.

  • Edge-conditioned filters in CNN for graphs (ECCsimonovsky2017dynamic incorporates edge information into the GCN model and performs pooling using a graph coarsening algorithm.

  • PatchySan niepert2016learning defines a receptive field (neighborhood) for each node, and using a canonical node ordering, applies convolutions on linear sequences of node embeddings.

  • Set2Set replaces the global mean-pooling in the traditional GNN architectures by the aggregation used in Set2Set vinyals2015order . Set2Set aggregation has been shown to perform better than mean pooling in previous work Gil+2017 . We use GraphSage as the base GNN model.

  • SortPool zhang2018end applies a GNN architecture and then performs a single layer of soft pooling followed by 1D convolution on sorted node embeddings.

For all the GNN baselines, we use 10-fold cross validation numbers reported by the original authors when possible. For the GraphSage and Set2Set baselines, we use the base implementation and hyperparameter sweeps as in our DiffPool approach. When baseline approaches did not have the necessary published numbers, we contacted the original authors and used their code (if available) to run the model, performing a hyperparameter search based on the original author’s guidelines.

Kernel-based algorithms. We use the Graphlet She+2009 , the Shortest-Path Borgwardt2005 , Weisfeiler-Lehman kernel (WLShe+2011 , and Weisfeiler-Lehman Optimal Assignment kernel (WL-OAkriege2016valid as kernel baselines. For each kernel, we computed the normalized gram matrix. We computed the classification accuracies using the -SVM implementation of LibSvm Cha+11 , using 10-fold cross validation. The parameter was selected from by 10-fold cross validation on the training folds. Moreover, for WL and WL-OA we additionally selected the number of iteration from .

4.2 Results for Graph Classification

Table 1 compares the performance of DiffPool to these state-of-the-art graph classification baselines. These results provide positive answers to our motivating questions Q1 and Q2: We observe that our DiffPool approach obtains the highest average performance among all pooling approaches for GNNs, improves upon the base GraphSage architecture by an average of , and achieves state-of-the-art results on 4 out of 5 benchmarks. Interestingly, our simplified model variant, DiffPool-Det, achieves state-of-the-art performance on the Collab benchmark. This is because many collaboration graphs in Collab show only single-layer community structures, which can be captured well with pre-computed graph clustering algorithm dhillon2007weighted . One observation is that despite significant performance improvement, DiffPool

could be unstable to train, and there is significant variation in accuracy across different runs, even with the same hyperparameter setting. It is observed that adding the link predictioin objective makes training more stable, and reduces the standard deviation of accuracy across different runs.

Method Data Set
Enzymes D&D Reddit-Multi-12k Collab Proteins Gain


Graphlet 41.03 74.85 21.73 64.66 72.91
Shortest-path 42.32 78.86 36.93 59.10 76.43
1-WL 53.43 74.02 39.03 78.61 73.76
WL-OA 60.13 79.04 44.38 80.74 75.26
PatchySan 76.27 41.32 72.60 75.00 4.17


GraphSage 54.25 75.42 42.24 68.25 70.48
ECC 53.50 74.10 41.73 67.79 72.65 0.11
Set2set 60.15 78.12 43.49 71.75 74.29 3.32
SortPool 57.12 79.37 41.82 73.76 75.54 3.39
DiffPool-Det 58.33 75.47 46.18 82.13 75.62 5.42
DiffPool-NoLP 61.95 79.98 46.65 75.58 76.22 5.95
DiffPool 62.53 80.64 47.08 75.48 76.25 6.27
Table 1: Classification accuracies in percent. The far-right column gives the relative increase in accuracy compared to the baseline GraphSage approach.

Differentiable Pooling on Structure2Vec. DiffPool can be applied to other GNN architectures besides GraphSage to capture hierarchical structure in the graph data. To further support answering Q1, we also applied DiffPool on Structure2Vec (S2V). We ran experiments using S2V with three layer architecture, as reported in dai2016discriminative . In the first variant, one DiffPool layer is applied after the first layer of S2V, and two more S2V layers are stacked on top of the output of DiffPool. The second variant applies one DiffPool layer after the first and second layer of S2V respectively. In both variants, S2V model is used to compute the embedding matrix, while GraphSage model is used to compute the assignment matrix.

Data Set Method
S2V S2V with 1 DiffPool S2V with 2 DiffPool
Enzymes 61.10 62.86 63.33
D&D 78.92 80.75 82.07
Table 2: Accuracy results of applying DiffPool to S2V.

The results in terms of classification accuracy are summarized in Table 2. We observe that DiffPool significantly improves the performance of S2V on both Enzymes and D&D data sets. Similar performance trends are also observed on other data sets. The results demonstrate that DiffPool is a general strategy to pool over hierarchical structure that can benefit different GNN architectures.

Running time. Although applying DiffPool requires additional computation of an assignment matrix, we observed that DiffPool did not incur substantial additional running time in practice. This is because each DiffPool layer reduces the size of graphs by extracting a coarser representation of the graph, which speeds up the graph convolution operation in the next layer. Concretely, we found that GraphSage with DiffPool was 12 faster than the GraphSage model with Set2Set pooling, while still achieving significantly higher accuracy on all benchmarks.

4.3 Analysis of Cluster Assignment in DiffPool

Hierarchical cluster structure. To address Q3, we investigated the extent to which DiffPool learns meaningful node clusters by visualizing the cluster assignments in different layers. Figure 2 shows such a visualization of node assignments in the first and second layers on a graph from Collab data set, where node color indicates its cluster membership. Node cluster membership is determined by taking the

of its cluster assignment probabilities. We observe that even when learning cluster assignment based solely on the graph classification objective,

DiffPool can still capture the hierarchical community structure. We also observe significant improvement in membership assignment quality with link prediction auxiliary objectives.

Dense vs. sparse subgraph structure. In addition, we observe that DiffPool learns to collapse nodes into soft clusters in a non-uniform way, with a tendency to collapse densely-connected subgraphs into clusters. Since GNNs can efficiently perform message-passing on dense, clique-like subgraphs (due to their small diameters) liao2018graph , pooling together nodes in such a dense subgraph is not likely to lead to any loss of structural information. This intuitively explains why collapsing dense subgraphs is a useful pooling strategy for DiffPool. In contrast, sparse subgraphs may contain many interesting structures, including path-, cycle- and tree-like structures, and given the high-diameter induced by sparsity, GNN message-passing may fail to capture these structures. Thus, by separately pooling distinct parts of a sparse subgraph, DiffPool can learn to capture the meaningful structures present in sparse graph regions (e.g., as in Figure 2).

Assignment for nodes with similar representations. Since the assignment network computes the soft cluster assignment based on features of input nodes and their neighbors, nodes with both similar input features and neighborhood structure will have similar cluster assignment. In fact, one can construct synthetic cases where 2 nodes, although far away, have exactly the same neighborhood structure and features for self and all neighbors. In this case the pooling network is forced to assign them into the same cluster, which is different from the concept of pooling in other architectures such as image ConvNets. In some cases we do observe that disconnected nodes are pooled together.

In practice we rely on the identifiability assumption similar to Theorem 1 in GraphSAGE hamilton2017inductive , where nodes are identifiable via their features. This holds in many real datasets 333However, some chemistry molecular graph datasets contain many nodes that are structurally similar, and assignment network is observed to pool together nodes that are far away.

. The auxiliary link prediction objective is observed to also help discouraging nodes that are far away to be pooled together. Furthermore, it is possible to use more sophisticated GNN aggregation function such as high-order moments

verma2018graph to distinguish nodes that are similar in structure and feature space. The overall framework remains unchanged.

Sensitivity of the Pre-defined Maximum Number of Clusters. We found that the assignment varies according to the depth of the network and , the maximum number of clusters. With larger , the pooling GNN can model more complex hierarchical structure. The trade-off is that very large results in more noise and less efficiency. Although the value of is a pre-defined parameter, the pooling net learns to use the appropriate number of clusters by end-to-end training. In particular, some clusters might not be used by the assignment matrix. Column corresponding to unused cluster has low values for all nodes. This is observed in Figure 2(c), where nodes are assigned predominantly into 3 clusters.

Figure 2: Visualization of hierarchical cluster assignment in DiffPool, using example graphs from Collab. The left figure (a) shows hierarchical clustering over two layers, where nodes in the second layer correspond to clusters in the first layer. (Colors are used to connect the nodes/clusters across the layers, and dotted lines are used to indicate clusters.) The right two plots (b and c) show two more examples first-layer clusters in different graphs. Note that although we globally set the number of clusters to be of the nodes, the assignment GNN automatically learns the appropriate number of meaningful clusters to assign for these different graphs.

5 Conclusion

We introduced a differentiable pooling method for GNNs that is able to extract the complex hierarchical structure of real-world graphs. By using the proposed pooling layer in conjunction with existing GNN models, we achieved new state-of-the-art results on several graph classification benchmarks. Interesting future directions include learning hard cluster assignments to further reduce computational cost in higher layers while also ensuring differentiability, and applying the hierarchical pooling method to other downstream tasks that require modeling of the entire graph structure.


This research has been supported in part by DARPA SIMPLEX, Stanford Data Science Initiative, Huawei, JD and Chan Zuckerberg Biohub. Christopher Morris is funded by the German Science Foundation (DFG) within the Collaborative Research Center SFB 876 “Providing Information by Resource-Constrained Data Analysis”, project A6 “Resource-efficient Graph Mining”. The authors also thank Marinka Zitnik for help in visualizing the high-level illustration of the proposed methods.


  • (1) M. Bianchini, M. Gori, and F. Scarselli. Processing directed acyclic graphs with recursive neural networks. IEEE Transactions on Neural Networks, 12(6):1464–1470, 2001.
  • (2) K. M. Borgwardt and H.-P. Kriegel. Shortest-path kernels on graphs. In IEEE International Conference on Data Mining, pages 74–81, 2005.
  • (3) K. M. Borgwardt, C. S. Ong, S. Schönauer, S. V. N. Vishwanathan, A. J. Smola, and H.-P. Kriegel. Protein function prediction via graph kernels. Bioinformatics, 21(Supplement 1):i47–i56, 2005.
  • (4) M. M. Bronstein, J. Bruna, Y. LeCun, A. Szlam, and P. Vandergheynst. Geometric deep learning: Going beyond euclidean data. IEEE Signal Processing Magazine, 34(4):18–42, 2017.
  • (5) J. Bruna, W. Zaremba, A. Szlam, and Y. LeCun. Spectral networks and deep locally connected networks on graphs. In International Conference on Learning Representations, 2014.
  • (6) C.-C. Chang and C.-J. Lin.

    LIBSVM: A library for support vector machines.

    ACM Transactions on Intelligent Systems and Technology, 2:27:1–27:27, 2011. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm.
  • (7) H. Dai, B. Dai, and L. Song. Discriminative embeddings of latent variable models for structured data. In International Conference on Machine Learning, pages 2702–2711, 2016.
  • (8) M. Defferrard, X. Bresson, and P. Vandergheynst. Convolutional neural networks on graphs with fast localized spectral filtering. In Advances in Neural Information Processing Systems, pages 3844–3852, 2016.
  • (9) I. S. Dhillon, Y. Guan, and B. Kulis.

    Weighted graph cuts without eigenvectors a multilevel approach.

    IEEE Transactions on Pattern Analysis and Machine Intelligence, 29(11):1944–1957, 2007.
  • (10) P. D. Dobson and A. J. Doig. Distinguishing enzyme structures from non-enzymes without alignments. Journal of Molecular Biology, 330(4):771 – 783, 2003.
  • (11) D. K. Duvenaud, D. Maclaurin, J. Iparraguirre, R. Bombarell, T. Hirzel, A. Aspuru-Guzik, and R. P. Adams. Convolutional networks on graphs for learning molecular fingerprints. In Advances in Neural Information Processing Systems, pages 2224–2232, 2015.
  • (12) A. Feragen, N. Kasenburg, J. Petersen, M. D. Bruijne, and K. M. Borgwardt. Scalable kernels for graphs with continuous attributes. In Advances in Neural Information Processing Systems, pages 216–224, 2013. Erratum available at http://image.diku.dk/aasa/papers/graphkernels_nips_erratum.pdf.
  • (13) M. Fey, J. E. Lenssen, F. Weichert, and H. Müller. SplineCNN: Fast geometric deep learning with continuous B-spline kernels. In

    IEEE Conference on Computer Vision and Pattern Recognition

    , 2018.
  • (14) A. Fout, J. Byrd, B. Shariat, and A. Ben-Hur. Protein interface prediction using graph convolutional networks. In Advances in Neural Information Processing Systems, pages 6533–6542, 2017.
  • (15) J. Gilmer, S. S. Schoenholz, P. F. Riley, O. Vinyals, and G. E. Dahl. Neural message passing for quantum chemistry. In International Conference on Machine Learning, pages 1263–1272, 2017.
  • (16) W. L. Hamilton, R. Ying, and J. Leskovec. Inductive representation learning on large graphs. In Advances in Neural Information Processing Systems, pages 1025–1035, 2017.
  • (17) W. L. Hamilton, R. Ying, and J. Leskovec. Representation learning on graphs: Methods and applications. IEEE Data Engineering Bulletin, 40(3):52–74, 2017.
  • (18) S. Ioffe and C. Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International Conference on Machine Learning, pages 448–456, 2015.
  • (19) W. Jin, C. W. Coley, R. Barzilay, and T. S. Jaakkola. Predicting organic reaction outcomes with Weisfeiler-Lehman network. In Advances in Neural Information Processing Systems, pages 2604–2613, 2017.
  • (20) K. Kersting, N. M. Kriege, C. Morris, P. Mutzel, and M. Neumann. Benchmark data sets for graph kernels, 2016.
  • (21) T. N. Kipf, E. Fetaya, K. C. Wang, M. Welling, and R. Zemel. Neural relational inference for interacting systems. International Conference on Machine Learning, 2018.
  • (22) T. N. Kipf and M. Welling. Semi-supervised classification with graph convolutional networks. In International Conference on Learning Representations, 2017.
  • (23) N. M. Kriege, P.-L. Giscard, and R. Wilson. On valid optimal assignment kernels and applications to graph classification. In Advances in Neural Information Processing Systems, pages 1623–1631, 2016.
  • (24) A. Krizhevsky, I. Sutskever, and G. E. Hinton. ImageNet classification with deep convolutional neural networks. In Advances in Neural Information Processing Systems, pages 1097–1105, 2012.
  • (25) T. Lei, W. Jin, R. Barzilay, and T. S. Jaakkola. Deriving neural architectures from sequence and graph kernels. In International Conference on Machine Learning, pages 2024–2033, 2017.
  • (26) Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel. Gated graph sequence neural networks. In International Conference on Learning Representations, 2016.
  • (27) R. Liao, M. Brockschmidt, D. Tarlow, A. L. Gaunt, R. Urtasun, and R. Zemel. Graph partition neural networks for semi-supervised classification. In International Conference on Learning Representations (Workshop Track), 2018.
  • (28) A. Lusci, G. Pollastri, and P. Baldi. Deep architectures and deep learning in chemoinformatics: The prediction of aqueous solubility for drug-like molecules. Journal of Chemical Information and Modeling, 53(7):1563–1575, 2013.
  • (29) C. Merkwirth and T. Lengauer. Automatic generation of complementary descriptors with molecular graph networks. Journal of Chemical Information and Modeling, 45(5):1159–1168, 2005.
  • (30) M. Niepert, M. Ahmed, and K. Kutzkov. Learning convolutional neural networks for graphs. In International Conference on Machine Learning, pages 2014–2023, 2016.
  • (31) F. Scarselli, M. Gori, A. C. Tsoi, M. Hagenbuchner, and G. Monfardini. The graph neural network model. Transactions on Neural Networks, 20(1):61–80, 2009.
  • (32) M. Schlichtkrull, T. N. Kipf, P. Bloem, R. van den Berg, I. Titov, and M. Welling. Modeling relational data with graph convolutional networks. In Extended Semantic Web Conference, 2018.
  • (33) K. Schütt, P. J. Kindermans, H. E. Sauceda, S. Chmiela, A. Tkatchenko, and K. R. Müller. SchNet: A continuous-filter convolutional neural network for modeling quantum interactions. In Advances in Neural Information Processing Systems, pages 992–1002, 2017.
  • (34) N. Shervashidze, P. Schweitzer, E. J. van Leeuwen, K. Mehlhorn, and K. M. Borgwardt. Weisfeiler-Lehman graph kernels. Journal of Machine Learning Research, 12:2539–2561, 2011.
  • (35) N. Shervashidze, S. V. N. Vishwanathan, T. H. Petri, K. Mehlhorn, and K. M. Borgwardt. Efficient graphlet kernels for large graph comparison. In

    International Conference on Artificial Intelligence and Statistics

    , pages 488–495, 2009.
  • (36) M. Simonovsky and N. Komodakis. Dynamic edge-conditioned filters in convolutional neural networks on graphs. In IEEE Conference on Computer Vision and Pattern Recognition, pages 29–38, 2017.
  • (37) P. Veličković, G. Cucurull, A. Casanova, A. Romero, P. Liò, and Y. Bengio. Graph attention networks. In International Conference on Learning Representations, 2018.
  • (38) S. Verma and Z.-L. Zhang. Graph capsule convolutional neural networks. arXiv preprint arXiv:1805.08090, 2018.
  • (39) O. Vinyals, S. Bengio, and M. Kudlur. Order matters: Sequence to sequence for sets. In International Conference on Learning Representations, 2015.
  • (40) P. Yanardag and S. V. N. Vishwanathan. A structural smoothing framework for robust graph comparison. In Advances in Neural Information Processing Systems, pages 2134–2142, 2015.
  • (41) M. Zhang, Z. Cui, M. Neumann, and Y. Chen. An end-to-end deep learning architecture for graph classification. In AAAI Conference on Artificial Intelligence, 2018.