Memory-Based Graph Networks

Graph neural networks (GNNs) are a class of deep models that operate on data with arbitrary topology represented as graphs. We introduce an efficient memory layer for GNNs that can jointly learn node representations and coarsen the graph. We also introduce two new networks based on this layer: memory-based GNN (MemGNN) and graph memory network (GMN) that can learn hierarchical graph representations. The experimental results shows that the proposed models achieve state-of-the-art results in eight out of nine graph classification and regression benchmarks. We also show that the learned representations could correspond to chemical features in the molecule data. Code and reference implementations are released at:



There are no comments yet.


page 1

page 2

page 3

page 4


Hierarchical Graph Representation Learning withDifferentiable Pooling

Recently, graph neural networks (GNNs) have revolutionized the field of ...

Position-based Hash Embeddings For Scaling Graph Neural Networks

Graph Neural Networks (GNNs) bring the power of deep representation lear...

GraphGallery: A Platform for Fast Benchmarking and Easy Development of Graph Neural Networks Based Intelligent Software

Graph Neural Networks (GNNs) have recently shown to be powerful tools fo...

Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions

Training and deploying graph neural networks (GNNs) remains difficult du...

Learning Graph Augmentations to Learn Graph Representations

Devising augmentations for graph contrastive learning is challenging due...

Graph Neural Networks in IoT: A Survey

The Internet of Things (IoT) boom has revolutionized almost every corner...

Learning Execution through Neural Code Fusion

As the performance of computer systems stagnates due to the end of Moore...

Code Repositories


Memory-Based Graph Networks

view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Graph neural networks (GNNs) (Wu et al., 2019; Zhou et al., 2018; Zhang et al., 2018) are a class of deep models that operate on data with arbitrary topology represented as graphs such as social networks (Kipf and Welling, 2016)

, knowledge graphs

(Vivona and Hassani, 2019), molecules (Duvenaud et al., 2015), point clouds (Hassani and Haley, 2019), and robots (wang_2018_iclr). Unlike regular-structured inputs such as grids (e.g., images and volumetric data) and sequences (e.g., speech and text), GNN inputs are permutation-invariant variable-size graphs consisting of nodes and their interactions. GNNs such as gated GNN (GGNN) (Li et al., 2015), message passing neural network (MPNN) (Gilmer et al., 2017), graph convolutional network (GCN) (Kipf and Welling, 2016), and graph attention network (GAT) (Veličković et al., 2018) learn node representations through an iterative process of transferring, transforming, and aggregating the node representations from topological neighbors. Each iteration expands the receptive field by one hop and after iterations the nodes within hops influence the node representations of one another. GNNs are shown to learn better representations compared to random walks (Grover and Leskovec, 2016; Perozzi et al., 2014), matrix factorization (Belkin and Niyogi, 2002; Ou et al., 2016), kernel methods (Shervashidze et al., 2011; Kriege et al., 2016), and probabilistic graphical models (Dai et al., 2016).

These models, however, cannot learn hierarchical representations as they do not exploit the compositional nature of graphs. Recent work such as differentiable pooling (DiffPool) (Ying et al., 2018), TopKPool (Gao and Ji, 2019), and self-attention graph pooling (SAGPool) (Lee et al., 2019) introduce parametric graph pooling layers that allow GNNs to learn hierarchical graph representations by stacking interleaved layers of GNN and pooling layers. These layers cluster nodes in the latent space. The clusters may correspond to communities in a social network or potent functional groups within a chemical dataset. Nevertheless, these models are not efficient as they require an iterative process of message passing after each pooling layer.

In this paper, we introduce a memory layer for joint graph representation learning and graph coarsening that consists of a multi-head array of memory keys and a convolution operator to aggregate the soft cluster assignments from different heads. The queries to a memory layer are node representations from the previous layer and the outputs are the node representations of the coarsened graph. The memory layer does not explicitly require connectivity information and unlike GNNs relies on the global information rather than local topology. Hence, it does not struggle with over-smoothing problem (Xu et al., 2018; Li et al., 2018). These properties make memory layers more efficient and improves their performance. We also introduce two networks based on the proposed layer: memory-based GNN (MemGNN) and graph memory network (GMN). MemGNN consists of a GNN that learns the initial node representations, and a stack of memory layers that learn hierarchical representations up to the global graph representation. GMN, on the other hand, learns the hierarchical representations purely based on memory layers and hence does not require message passing.

2 Related Work

Memory augmented neural networks (MANNs)

utilize external memory with differentiable read-write operators allowing them to explicitly access the past experiences and are shown to enhance reinforcement learning

(Pritzel et al., 2017), meta learning (Santoro et al., 2016), few-shot learning (Vinyals et al., 2016), and multi-hop reasoning (Weston et al., 2015). Unlike RNNs, in which the memory is represented within their hidden states, the decoupled memory in MANNs allows them to store and retrieve longer term memories with less parameters. The memory can be implemented as: key-value memory such as neural episodic control (Pritzel et al., 2017) and product-key memory layers (Lample et al., 2019), or array-structured memory

such as neural Turing machine (NTM)

(Graves et al., 2014), prototypical networks (Snell et al., 2017), memory networks (Weston et al., 2015), and sparse access memory (SAM) (Rae et al., 2016). Our memory layer consists of a multi-head array of memory keys.

Graph neural networks (GNNs) mostly use message passing to learn node representations over graphs. GraphSAGE (Hamilton et al., 2017) learns representations by sampling and aggregating neighbor nodes whereas GAT (Veličković et al., 2018) uses attention mechanism to aggregate representations from all neighbors. GCN extend the convolution operator to arbitrary topology. Spectral GCNs (Bruna et al., 2014; Defferrard et al., 2016; Kipf and Welling, 2016) use spectral filters over graph Laplacian to define the convolution in the Fourier domain. These models are less efficient compared to spatial GCNs (Schlichtkrull et al., 2018; Ma et al., 2019) which directly define the convolution on graph patches centered on nodes. Our memory layer uses a feed-forward network to learn the node representations.

Graph pooling can be defined in global or hierarchical manners. In former, node representations are aggregated into a graph representation by a readout layer implemented using arithmetic operators such as summation or averaging (Hamilton et al., 2017; Kipf and Welling, 2016) or set neural networks such as Set2Set (Vinyals et al., 2015) and SortPool (Morris et al., 2019). In latter, graphs are coarsened in each layer to capture the hierarchical structure. Efficient non-parametric methods such as clique pooling (Luzhnica et al., 2019)

, kNN pooling

(Wang et al., 2018), and Graclus (Dhillon et al., 2007)

rely on topological information but are outperformed by parametric models such as edge contraction pooling

(Diehl, 2019).

DiffPool (Ying et al., 2018) trains two parallel GNNs to compute node representations and cluster assignments using a multi-term loss including classification, link prediction, and entropy losses, whereas Mincut pool (Bianchi et al., 2019) trains a sequence of a GNN and an MLP using classification loss and the minimum cut objective. TopKPool (Cangea et al., 2018; Gao and Ji, 2019) computes node scores by learning projection vectors and dropping all the nodes except the top scoring nodes. SAGPool (Lee et al., 2019) extends the TopKPool by using graph convolutions to take neighbor node features into account. We use a clustering-friendly distribution to compute the attention scores between nodes and clusters.

3 Method

3.1 Memory Layer

We define a memory layer in layer as a parametric function that takes in query vectors of size and generates query vectors of size such that . The input and output queries represent the node representations of the input graph and the coarsened graph, respectively. The memory layer learns to jointly coarsen the input nodes, i.e., pooling, and transform their features, i.e., representation learning. As shown in Figure 1, a memory layer consists of arrays of memory keys, i.e., multi-head memory, and a convolutional layer. Assuming memory heads, a shared input query is compared against all the keys in each head resulting in attention matrices which are then aggregated into a single attention matrix using the convolution layer.

In a content addressable memory (Graves et al., 2014; Sukhbaatar et al., 2015; Weston et al., 2015), the task of attending to memory, i.e., addressing scheme, is formulated as computing the similarity between memory keys to a given query . Specifically, the attention weight of key for query is defined as where

is a similarity measure, typically Euclidean distance or cosine similarity

(Rae et al., 2016). The soft read operation on memory is defined as a weighted average over the memory keys: .

In this work, we treat the input queries as the node representations of an input graph and treat the keys as the cluster centroids of the queries. To satisfy this assumption, we impose a clustering-friendly distribution as the distance metric between keys and a query. Following (Xie et al., 2016; Maaten and Hinton, 2008), we use the Student’s t-distribution as a kernel to measure the normalized similarity between query and key as follows:


where is the normalized score between query and key

, i.e., probability of assigning node

to cluster or attention score between query and memory key , and

is the degree of freedom of the Student’s t-distribution, i.e., temperature.

Figure 1: The proposed architecture for hierarchical graph representation learning using the proposed memory layer. The query network projects the initial node features into a latent query space and each memory layer jointly coarsens the input queries and transforms them into a new query space.

To increase the capacity, we model the memory keys as a multi-head array. Applying a shared input query against the memory keys produces a tensor of cluster assignments

where denotes the number of heads. To aggregate the heads into a single assignment matrix, we treat the heads and the assignments matrices as depth, height, and width channels in standard convolution analogy and apply a convolution operator over them. Because there is no spatial structure, we use convolution to aggregate the information across heads and therefore the convolution behaves as a weighted pooling that reduces the heads into a single matrix. The aggregated assignment matrix is computed as follows:


where is a convolutional operator parametrized by , is the concatenation operator, and is the aggregated soft assignment matrix.

A memory read generates a value matrix that represents the coarsened node representations in the same space as the input queries and is defined as the product of the soft assignment scores and the original queries as follows:


The value matrix is fed to a single-layer feed-forward neural network to project the coarsened embeddings from

into as follows:


where is the output queries, is the network parameters, and is the non-linearity implemented using LeakyReLU.

Thanks to these parametrized transformations, a memory layer can jointly learn the node representations and coarsens the graph end-to-end. The computed queries are the input queries to the subsequent memory layer . For graph classification tasks, one can simply stack layers of memory up to the level where the input graph is coarsened into a single node representing the global graph representation and then feed it to a fully-connected layer to predict the graph class as follows:


where is the initial query representation111We use initial node representation and initial query representation interchangeably throughout the paper. generated by the query network over graph . We introduce two architectures based on the memory layer: GMN and MemGNN. These two architectures are different in the way that the query network is implemented. More specifically, GMN uses a feed-forward network for initializing the query: , whereas MemGNN implements the query network as a message passing GNN: .

3.2 GMN Architecture

A GMN is a stack of memory layers on top of a query network that generates the initial query representations without any message passing. Similar to set neural networks (Vinyals et al., 2015) and transformers (Vaswani et al., 2017), nodes in GMN are treated as a permutation-invariant set of representations. The query network projects the initial node features into a latent space that represents the initial query space.

Assume a training set of graphs where each graph is represented as and denotes the adjacency matrix, is the initial node features, and is the graph label. Considering that GMN treats a graph as a set of order-invariant nodes and does not use message passing, and also considering that the memory layers do not rely on connectivity information, the topological information of each node should be somehow encoded into its initial representation. To define the topological embedding, we use an instantiation of general graph diffusion matrix . More specifically, we use random walk with restart (RWR) (Pan et al., 2004) to compute the topological embeddings and then sort them row-wise to force the embedding to be order-invariant. For further details please see section A.4. Inspired by transformers (Vaswani et al., 2017), we then fuse the topological embeddings with the initial node features into the initial query representations using a query network implemented as a two-layer feed-forward neural network:


where and are the parameters of the query networks, is the concatenation operator, and is the non-linearity implemented using LeakyReLU.

3.2.1 Permutation Invariance

Considering the inherent permutation-invariant property of graph-structured data, a model designed to address graph classification tasks, should also enforce this property. This implies that the model should generate same outputs for isomorphic input graphs. We impose this on GMN architecture by sorting the topological embedding row-wise as a pre-processing step.

Proposition 1. Given a sorting function o, is permutation-invariant.

Proof. Let be an arbitrary permutation matrix. For each node in graph G with adjacency matrix A, the corresponding node in graph with permuted adjacency matrix has the permuted version of the topological embedding of the node in graph G. Sorting the embeddings cancels out the effect of permutation and makes the corresponding embeddings in graph G and identical.

3.3 MemGNN Architecture

Unlike the GMN architecture, the query network in MemGNN relies on message passing to compute the initial query as follows:


where query network is an arbitrary parameterized message passing GNN (Gilmer et al., 2017; Li et al., 2015; Kipf and Welling, 2016; Veličković et al., 2018).

In our implementation, we use a modified variant of GAT (Veličković et al., 2018). Specifically, we introduce an extension to the original GAT model called edge-based GAT (e-GAT) and use it as the query network. Unlike GAT, e-GAT learns attention weights not only from the neighbor nodes but also from the input edge features. This is especially important for data containing edge information (e.g., various bonds among atoms represented as edges in molecule datasets). In an e-GAT layer, attention score between two neighbor nodes is computed as follows:


where and denote the representation of node and the representation of the edge connecting node to its one-hop neighbor node in layer , respectively. and are learnable node and edge weights and W is the parameter of a single-layer feed-forward network that computes the attention score. is the non-linearity implemented using LeakyReLU.

3.4 Training

We jointly train the model using two loss functions: a supervised loss and an unsupervised clustering loss. The supervised loss denoted as

is defined as cross-entropy loss and root mean square error (RMSE) for graph classification and regression tasks, respectively. The unsupervised clustering loss is inspired by deep clustering methods (Razavi et al., 2019; Xie et al., 2016; Aljalbout et al., 2018). It encourages the model to learn clustering-friendly embeddings in the latent space by learning from high confidence assignments with the help of an auxiliary target distribution. The unsupervised loss is defined as the Kullback-Leibler (KL) divergence between the soft assignments and the auxiliary distribution as follows:


For the target distributions , we use the distribution proposed in (Xie et al., 2016) which normalizes the loss contributions and improves the cluster purity while emphasizing on the samples with higher confidence. This distribution is defined as follows:


We define the total loss as follows where is the number of memory layers and is a scalar weight.


We initialize the model parameters, the keys, and the queries randomly and optimize them jointly with respect to

using mini-batch stochastic gradient descent. To stabilize the training, the gradients of

are back-propagated batch-wise while the gradients of

are applied epoch-wise by periodically switching

between and . Updating the centroids, i.e., memory keys, with the same frequency as the network parameters can destabilize the training. To address this, we optimize all model parameters and the queries in each batch with respect to and in each epoch with respect to . Memory keys, on the other hand, are only updated at the end of each epoch by the gradients of . This technique has also been applied in (Hassani and Haley, 2019; Caron et al., 2018) to avoid trivial solutions.

4 Experiments

4.1 Datasets

We use nine benchmarks including seven graph classification and two graph regression datasets to evaluate the proposed method. These datasets are commonly used in both graph kernel (Borgwardt and Kriegel, 2005; Yanardag and Vishwanathan, 2015; Shervashidze et al., 2009; Ying et al., 2018; Shervashidze et al., 2011; Kriege et al., 2016) and GNN (Cangea et al., 2018; Ying et al., 2018; Lee et al., 2019; Gao and Ji, 2019) literature. The summary of the datasets is as follows where the first two benchmarks are regression tasks and the rest are classification tasks.

ESOL (Delaney, 2004) contains water solubility data for compounds.
Lipophilicity (Gaulton et al., 2016) contains experimental results of octanol/water distribution of compounds.
Bace (Subramanian et al., 2016) provides quantitative binding results for a set of inhibitors of human -secretase 1 (BACE-1).
DD (Dobson and Doig, 2003) is used to distinguish enzyme structures from non-enzymes.
Enzymes (Schomburg et al., 2004) is for predicting functional classes of enzymes.
Proteins (Dobson and Doig, 2003) is used to predict the protein function from structure.
Collab (Yanardag and Vishwanathan, 2015) is for predicting the field of a researcher given one’s ego-collaboration graph.
Reddit-Binary (Yanardag and Vishwanathan, 2015) is for predicting the type of community given a graph of online discussion threads.
Tox21 (Challenge, 2014) is for predicting toxicity on 12 different targets.

For more information about the detailed statistics of the datasets refer to Appendix A.2.

4.2 Graph Classification Results

To evaluate the performance of our models on DD, Enzymes, Proteins, Collab, and Reddit-Binary datasets, we follow the experimental protocol in (Ying et al., 2018) and perform 10-fold cross-validation and report the mean accuracy over all folds. We also report the performance of four graph kernel methods including Graphlet (Shervashidze et al., 2009), shortest path (Borgwardt and Kriegel, 2005), Weisfeiler-Lehman (WL) (Shervashidze et al., 2011), and WL Optimal Assignment (Kriege et al., 2016), and ten GNN models.

The results shown in Table 1 suggest that: (i) our models achieve state-of-the-art results w.r.t. GNN models and significantly improve the performance on Enzymes, Proteins, DD, Collab, and Reddit-Binary datasets by absolute margins of 14.49%, 6.0%, 3.76%, 2.62%, and 8.98% accuracy, respectively, (ii) our models outperform graph kernels on all datasets except Collab where our models are competitive with the best kernel, i.e., absolute margin of 0.56%, (iii) both proposed models achieve better performance or are competitive compared to the baseline GNNs, (iv) GMN achieves better results compared to MemGNN which suggests that replacing local adjacency information with global topological embeddings provides the model with more useful information, and (v) On Collab, our models are outperformed by a variant of DiffPpool (i.e., diffpool-det) (Ying et al., 2018) and WL Optimal Assignment (Kriege et al., 2016). The former is a GNN augmented with deterministic clustering algorithm222In diffpool-det assignment matrices are generated using a deterministic graph clustering algorithm., whereas the latter is a graph kernel method. We speculate this is because of the high edge-to-node ratio in this dataset and the augmentations used in these two methods help them with extracting near-optimal cliques.

Method Dataset
Enzymes Proteins DD Collab Reddit-B


Graphlet (Shervashidze et al., 2009) 41.03 72.91 64.66 64.66 78.04
ShortestPath (Borgwardt and Kriegel, 2005) 42.32 76.43 78.86 59.10 64.11
WL (Shervashidze et al., 2011) 53.43 73.76 74.02 78.61 68.20
WL Optimal (Kriege et al., 2016) 60.13 75.26 79.04 80.74 89.30


PatchySan (Niepert et al., 2016) 75.00 76.27 72.60 86.30
GraphSage (Hamilton et al., 2017) 54.25 70.48 75.42 68.25
ECC (Simonovsky and Komodakis, 2017) 53.50 72.65 74.10 67.79
Set2Set (Vinyals et al., 2015) 60.15 74.29 78.12 71.75
SortPool (Morris et al., 2019) 57.12 75.54 79.37 73.76
DiffPool (Ying et al., 2018) 60.53 76.25 80.64 75.48 85.95
CliquePool (Luzhnica et al., 2019) 60.71 72.59 77.33 74.50
Sparse HGC (Cangea et al., 2018) 64.17 75.46 78.59 75.46 79.20
TopKPool (Gao and Ji, 2019) 77.68 82.43 77.56 74.70
SAGPool (Lee et al., 2019) 71.86 76.45 73.90
GMN (ours) 78.66 82.25 84.40 80.18 95.28
MemGNN (ours) 75.50 81.35 82.92 77.0 85.55
Table 1: Mean validation accuracy over 10-folds.
Method Dataset
BACE Tox21
validation test validation test
Logistic Regression 0.719 0.003 0.781 0.010 0.772 0.011 0.794 0.015
KernelSVM 0.739 0.000 0.862 0.000 0.818 0.010 0.822 0.006
XGBoost 0.756 0.000 0.850 0.008 0.775 0.018 0.794 0.014
Random Forest 0.728 0.004 0.867 0.008 0.763 0.002 0.769 0.015
IRV 0.715 0.001 0.838 0.000 0.807 0.006 0.799 0.006
Multitask 0.696 0.037 0.824 0.0006 0.795 0.017 0.803 0.012
Bypass 0.745 0.017 0.829 0.006 0.800 0.008 0.810 0.013
GCN 0.627 0.015 0.783 0.014 0.825 0.013 0.829 0.006
Weave 0.638 0.014 0.806 0.002 0.828 0.008 0.820 0.010
MemGNN (ours) 0.859 0.000 0.907 0.000 0.862 0.009 0.828 0.004
Table 2: AUC-ROC on BACE and Tox21 datasets.

To evaluate the performance on the BACE and Tox21 datasets, we follow the evaluation protocol in (Wu et al., 2018) and report the area under the curve receiver operating characteristics (AUC-ROC) measure. Considering that the BACE and Tox21 datasets contain initial edge features, we train the MemGNN model and compare its performance to the baselines reported in (Wu et al., 2018). The results shown in Table 2 suggest that our model achieves state-of-the-art results by absolute margin of 4.0 AUC-ROC on the BACE benchmark and is competitive with the state-of-the-art GCN model on the Tox21 dataset, i.e., absolute margin of 0.001.

4.3 Graph Regression Results

For the ESOL and Lipophilicity datasets, we follow the evaluation protocol in (Wu et al., 2018) and report their RMSEs. Considering that these datasets contain initial edge features (refer to Appendix A.2 for further details), we train the MemGNN model and compare the results to the baseline models reported in (Wu et al., 2018)

including graph based methods such as GCN, MPNN, directed acyclic graph (DAG) model, and Weave as well as other conventional methods such as kernel ridge regression (KRR) and influence relevance voting (IRV). Results shown in Table

3 suggest that our MemGNN model achieves state-of-the-art results by absolute margin of 0.07 and 0.1 RMSE on ESOL and Lipophilicity benchmarks, respectively. For further details on regression datasets and baselines please refer to (Wu et al., 2018).

Method Dataset
ESOL Lipophilicity
validation test validation test
Multitask 1.17 0.13 1.12 0.19 0.852 0.048 0.859 0.013
Random Forest 1.16 0.15 1.07 0.19 0.835 0.036 0.876 0.040
XGBoost 1.05 0.10 0.99 0.14 0.783 0.021 0.799 0.054
GCN 1.05 0.15 0.97 0.01 0.678 0.040 0.655 0.036
MPNN 0.55 0.02 0.58 0.03 0.757 0.030 0.715 0.035
KRR 1.65 0.19 1.53 0.06 0.889 0.009 0.899 0.043
DAG 0.74 0.04 0.82 0.08 0.857 0.050 0.835 0.039
Weave 0.57 0.04 0.61 0.07 0.734 0.011 0.715 0.035
MemGNN (ours) 0.53 0.03 0.54 0.01 0.555 0.039 0.556 0.023
Table 3: RMSE on ESOL and Lipophilicity datasets.

4.4 Ablation Study

4.4.1 Effect of edge features

To investigate the effect of the proposed e-GAT model, we train the MemGNN model using both GAT and e-GAT layers as the query network. Considering that the ESOL, Lipophilicity, and BACE datasets contain edge features, we use them as the benchmarks. Since nodes have richer features compared to edges, we set the node and edge feature dimensions to 16 and 4, respectively. The performance of the two layers on the ESOL dataset shown in AppendixA.3 suggesting that e-GAT achieves better results on the validation set in each epoch compared to the standard GAT model. We observed the same effect on Lipophilicity and BACE datasets.

4.4.2 Effect of Topological Embedding

To investigate the effect of topological embeddings on the GMN model, we evaluated three initial topological features including adjacency matrix, normalized adjacency matrix, and RWR. For further details on RWR, see section A.4. The results suggest that using the RWR as the initial positional embedding achieves the best performance. For instance, 10-fold cross validation accuracy of a GMN model trained on Enzymes dataset with adjacency matrix, normalized adjacency matrix, and RWR are 78.66%, 77.16%, and 77.33%, respectively. Furthermore, sorting the topological embeddings to guarantee invariance to permutations improves the performance. For example, it increases the accuracy on the DD dataset from 82.24% to 84.40%.

4.4.3 Down-sampling Neighbors with Random Walks

We investigate two methods to down-sample the neighbors in dense datasets such as Collab (i.e., average of 66 neighbors per node) to enhance the memory and computation. The first method randomly selects 10% of the edges whereas the second method ranks the neighbors based on their RWR scores with respect to the center node and then keeps the top 10% of the edges. We trained the MemGNN model on Collab using both sampling methods which resulted in 73.9% and 73.1% 10-fold cross validation accuracy for random and RWR-based sampling methods respectively, suggesting that random sampling performs slightly better than RWR based sampling.

4.4.4 Effect of Number of Keys and Heads

We speculate that although keys represent the clusters, the number of keys is not necessarily proportional to the number of the nodes in the input graphs. In fact, datasets with smaller graphs might have more meaningful clusters to capture. For example, molecules are comprised of many functional groups and yet the average number of nodes in the ESOL dataset is 13.3. Moreover, our experiments show that for Enzymes with average number of 32.69 nodes, the best performance is achieved with 10 keys whereas for the ESOL dataset 64 keys results in the best performance. In ESOL 8, 64, and 160 keys result in RMSE of 0.56, 0.52, and 0.54, respectively. We also observed that with a fixed number of parameters, increasing the number of memory heads improves the performance. For instance, when the model is trained on ESOL with 160 keys and 1 head, it achieves RMSE of 0.54, whereas when trained with 32 keys and 5 heads, the same model achieves RMSE of 0.53.

Figure 2: Visualization of the learned clusters of two molecule instances from (a) ESOL and (b) Lipophilicity datasets. The visualizations show that the learned clusters correspond to known chemical groups. Note that a node without label represents a carbon atom. For more visualizations and discussion see section A.5

4.4.5 What Do the Keys Represent?

Intuitively, the memory keys represent the cluster centroids and enhance the model performance by capturing meaningful structures. To investigate this, we used the learned keys to interpret the knowledge learned by models through visualizations. Figure 2 visualizes the learned clusters over atoms (i.e., atoms with the same color are within the same cluster) indicating that the clusters mainly consist of meaningful chemical substructures such as a carbon chain and a Hydroxyl group (OH) (i.e., Figure 1(a)), as well as a Carboxyl group (COOH) and a benzene ring (i.e., Figure 1(b)

). From a chemical perspective, Hydroxyl and Carboxyl groups, and carbon chains have a significant impact on the solubility of the molecule in water or lipid. This confirms that the network has learned chemical features that are essential for determining the molecule solubility. It is noteworthy that we tried initializing the memory keys using K-Means algorithm over the initial node representations to warm-start them but did not observe any significant improvement over the randomly selected keys.

5 Conclusion

We proposed an efficient memory layer and two models for hierarchical graph representation learning. We evaluated the proposed models on nine graph classification and regression tasks and achieved state-of-the-art results on eight of them. We also experimentally showed that the learned representations can capture well-known chemical features of the molecules. Furthermore, we showed that concatenating node features with topological embeddings and passing them through a few memory layers achieves notable results without using message passing. Moreover, we showed that defining the topological embeddings using graph diffusion achieves best performance. Finally, we showed that although connectivity information is not explicitly imposed on the model, the memory layer can process node representations and properly cluster and aggregate the learned representations.

Limitations: In section 4.3, we discussed that a graph kernel and a GNN augmented with deterministic clustering achieve better performance compared to our models on the Collab dataset. Analyzing samples in this dataset suggests that in graphs with dense communities, such as cliques, our model struggles to properly detect the dense sub-graphs.

Future Directions: We plan to extend our models to also perform node classification by attending to the node representations and centroids of the clusters from different layers of hierarchy that the nodes belongs to. Moreover, we are planning to evaluate other graph diffusion models, e.g., personalized PageRank and heat kernel, to initialize the topological embeddings. We are also planing to investigate the representation learning capabilities of the proposed models in self-supervised setting.


  • E. Aljalbout, V. Golkov, Y. Siddiqui, M. Strobel, and D. Cremers (2018)

    Clustering with deep learning: taxonomy and new methods

    arXiv preprint arXiv:1801.07648. Cited by: §3.4.
  • M. Belkin and P. Niyogi (2002) Laplacian eigenmaps and spectral techniques for embedding and clustering. In Advances in Neural Information Processing Systems, pp. 585–591. Cited by: §1.
  • F. M. Bianchi, D. Grattarola, and C. Alippi (2019) Mincut pooling in graph neural networks. arXiv preprint arXiv:1907.00481. Cited by: §2.
  • K. M. Borgwardt and H. Kriegel (2005) Shortest-path kernels on graphs. In International Conference on Data Mining, pp. 8–pp. Cited by: §4.1, §4.2, Table 1.
  • J. Bruna, W. Zaremba, A. Szlam, and Y. LeCun (2014) Spectral networks and locally connected networks on graphs. In International Conference on Learning Representation, Cited by: §2.
  • C. Cangea, P. Veličković, N. Jovanović, T. Kipf, and P. Liò (2018)

    Towards sparse hierarchical graph classifiers

    In Advances in Neural Information Processing Systems, Workshop on Relational Representation Learning, Cited by: §2, §4.1, Table 1.
  • M. Caron, P. Bojanowski, A. Joulin, and M. Douze (2018)

    Deep clustering for unsupervised learning of visual features


    European Conference on Computer Vision

    pp. 132–149. Cited by: §3.4.
  • T. D. Challenge (2014) Tox21 data challenge 2014. Cited by: §4.1.
  • H. Dai, B. Dai, and L. Song (2016) Discriminative embeddings of latent variable models for structured data. In

    International Conference on Machine Learning

    pp. 2702–2711. Cited by: §1.
  • M. Defferrard, X. Bresson, and P. Vandergheynst (2016) Convolutional neural networks on graphs with fast localized spectral filtering. In Advances in Neural Information Processing Systems, pp. 3844–3852. Cited by: §2.
  • J. S. Delaney (2004)

    ESOL: estimating aqueous solubility directly from molecular structure

    Journal of Chemical Information and Computer Sciences 44 (3), pp. 1000–1005. Cited by: §4.1.
  • I. S. Dhillon, Y. Guan, and B. Kulis (2007)

    Weighted graph cuts without eigenvectors a multilevel approach

    IEEE Transactions on Pattern Analysis and Machine Intelligence 29 (11), pp. 1944–1957. Cited by: §2.
  • F. Diehl (2019) Edge contraction pooling for graph neural networks. arXiv preprint arXiv:1905.10990. Cited by: §2.
  • P. D. Dobson and A. J. Doig (2003) Distinguishing enzyme structures from non-enzymes without alignments. Journal of Molecular Biology 330 (4), pp. 771–783. Cited by: §4.1.
  • D. K. Duvenaud, D. Maclaurin, J. Iparraguirre, R. Bombarell, T. Hirzel, A. Aspuru-Guzik, and R. P. Adams (2015) Convolutional networks on graphs for learning molecular fingerprints. In Advances in Neural Information Processing Systems, pp. 2224–2232. Cited by: §1.
  • H. Gao and S. Ji (2019) Graph u-nets. In International Conference on Machine Learning, pp. 2083–2092. Cited by: §1, §2, §4.1, Table 1.
  • A. Gaulton, A. Hersey, M. Nowotka, A. P. Bento, J. Chambers, D. Mendez, P. Mutowo, F. Atkinson, L. J. Bellis, E. Cibrián-Uhalte, et al. (2016) The chembl database in 2017. Nucleic Acids Research 45, pp. D945–D954. Cited by: §4.1.
  • J. Gilmer, S. S. Schoenholz, P. F. Riley, O. Vinyals, and G. E. Dahl (2017) Neural message passing for quantum chemistry. In International Conference on Machine Learning, pp. 1263–1272. Cited by: §1, §3.3.
  • A. Graves, G. Wayne, and I. Danihelka (2014) Neural turing machines. arXiv preprint arXiv:1410.5401. Cited by: §2, §3.1.
  • A. Grover and J. Leskovec (2016) Node2Vec: scalable feature learning for networks. In International Conference on Knowledge Discovery and Data Mining, pp. 855–864. Cited by: §1.
  • W. Hamilton, Z. Ying, and J. Leskovec (2017) Inductive representation learning on large graphs. In Advances in Neural Information Processing Systems, pp. 1024–1034. Cited by: §2, §2, Table 1.
  • K. Hassani and M. Haley (2019) Unsupervised multi-task feature learning on point clouds. In International Conference on Computer Vision, pp. 8160–8171. Cited by: §1, §3.4.
  • S. Ioffe and C. Szegedy (2015) Batch normalization: accelerating deep network training by reducing internal covariate shift. In International Conference on Machine Learning, pp. 448–456. Cited by: §A.1.
  • D. P. Kingma and J. L. Ba (2014) Adam: amethod for stochastic optimization. In International Conference on Learning Representation, Cited by: §A.1.
  • T. N. Kipf and M. Welling (2016) Semi-supervised classification with graph convolutional networks. In International Conference on Learning Representations, Cited by: §1, §2, §2, §3.3.
  • N. M. Kriege, P. Giscard, and R. Wilson (2016) On valid optimal assignment kernels and applications to graph classification. In Advances in Neural Information Processing Systems, pp. 1623–1631. Cited by: §1, §4.1, §4.2, §4.2, Table 1.
  • G. Lample, A. Sablayrolles, M. Ranzato, L. Denoyer, and H. Jégou (2019) Large memory layers with product keys. arXiv preprint arXiv:1907.05242. Cited by: §2.
  • J. Lee, I. Lee, and J. Kang (2019) Self-attention graph pooling. In International Conference on Machine Learning, pp. 3734–3743. Cited by: §1, §2, §4.1, Table 1.
  • Q. Li, Z. Han, and X. Wu (2018)

    Deeper insights into graph convolutional networks for semi-supervised learning


    AAAI Conference on Artificial Intelligence

    pp. 3538–3545. Cited by: §1.
  • Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel (2015) Gated graph sequence neural networks. In International Conference on Learning Representations, Cited by: §1, §3.3.
  • E. Luzhnica, B. Day, and P. Lio (2019) Clique pooling for graph classification. In International Conference on Learning Representations, Workshop on Representation Learning on Graphs and Manifolds, Cited by: §2, Table 1.
  • J. Ma, P. Cui, K. Kuang, X. Wang, and W. Zhu (2019) Disentangled graph convolutional networks. In International Conference on Machine Learning, pp. 4212–4221. Cited by: §2.
  • L. v. d. Maaten and G. Hinton (2008) Visualizing data using t-sne. Journal of Machine Learning Research 9, pp. 2579–2605. Cited by: §3.1.
  • C. Morris, M. Ritzert, M. Fey, W. L. Hamilton, J. E. Lenssen, G. Rattan, and M. Grohe (2019) Weisfeiler and leman go neural: higher-order graph neural networks. In AAAI Conference on Artificial Intelligence, Vol. 33, pp. 4602–4609. Cited by: §2, Table 1.
  • M. Niepert, M. Ahmed, and K. Kutzkov (2016) Learning convolutional neural networks for graphs. In International Conference on Machine Learning, pp. 2014–2023. Cited by: Table 1.
  • M. Ou, P. Cui, J. Pei, Z. Zhang, and W. Zhu (2016) Asymmetric transitivity preserving graph embedding. In International Conference on Knowledge Discovery and Data Mining, pp. 1105–1114. Cited by: §1.
  • J. Pan, H. Yang, C. Faloutsos, and P. Duygulu (2004) Automatic multimedia cross-modal correlation discovery. In International Conference on Knowledge Discovery and Data Mining, pp. 653–658. Cited by: §3.2.
  • A. Paszke, S. Gross, S. Chintala, G. Chanan, E. Yang, Z. DeVito, Z. Lin, A. Desmaison, L. Antiga, and A. Lerer (2017)

    Automatic differentiation in pytorch

    Cited by: §A.1.
  • B. Perozzi, R. Al-Rfou, and S. Skiena (2014) DeepWalk: online learning of social representations. In International Conference on Knowledge Discovery and Data Mining, pp. 701–710. Cited by: §1.
  • A. Pritzel, B. Uria, S. Srinivasan, A. P. Badia, O. Vinyals, D. Hassabis, D. Wierstra, and C. Blundell (2017) Neural episodic control. In International Conference on Machine Learning, pp. 2827–2836. Cited by: §2.
  • J. Rae, J. J. Hunt, I. Danihelka, T. Harley, A. W. Senior, G. Wayne, A. Graves, and T. Lillicrap (2016) Scaling memory-augmented neural networks with sparse reads and writes. In Advances in Neural Information Processing Systems, pp. 3621–3629. Cited by: §2, §3.1.
  • A. Razavi, A. v. d. Oord, and O. Vinyals (2019) Generating diverse high-fidelity images with vq-vae-2. arXiv preprint arXiv:1906.00446. Cited by: §3.4.
  • A. Santoro, S. Bartunov, M. Botvinick, D. Wierstra, and T. Lillicrap (2016) Meta-learning with memory-augmented neural networks. In International Conference on Machine Learning, pp. 1842–1850. Cited by: §2.
  • M. Schlichtkrull, T. N. Kipf, P. Bloem, R. Van Den Berg, I. Titov, and M. Welling (2018) Modeling relational data with graph convolutional networks. In European Semantic Web Conference, pp. 593–607. Cited by: §2.
  • I. Schomburg, A. Chang, C. Ebeling, M. Gremse, C. Heldt, G. Huhn, and D. Schomburg (2004) BRENDA, the enzyme database: updates and major new developments. Nucleic Acids Research 32, pp. D431–D433. Cited by: §4.1.
  • N. Shervashidze, P. Schweitzer, E. J. v. Leeuwen, K. Mehlhorn, and K. M. Borgwardt (2011) Weisfeiler-lehman graph kernels. Journal of Machine Learning Research 12 (Sep), pp. 2539–2561. Cited by: §1, §4.1, §4.2, Table 1.
  • N. Shervashidze, S. Vishwanathan, T. Petri, K. Mehlhorn, and K. Borgwardt (2009) Efficient graphlet kernels for large graph comparison. In Artificial Intelligence and Statistics, pp. 488–495. Cited by: §4.1, §4.2, Table 1.
  • M. Simonovsky and N. Komodakis (2017) Dynamic edge-conditioned filters in convolutional neural networks on graphs. In

    International Conference on Computer Vision and Pattern Recognition

    pp. 3693–3702. Cited by: Table 1.
  • J. Snell, K. Swersky, and R. Zemel (2017) Prototypical networks for few-shot learning. In Advances in Neural Information Processing Systems, pp. 4077–4087. Cited by: §2.
  • N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, and R. Salakhutdinov (2014) Dropout: a simple way to prevent neural networks from overfitting. Journal of Machine Learning Research 15, pp. 1929–1958. Cited by: §A.1.
  • G. Subramanian, B. Ramsundar, V. Pande, and R. A. Denny (2016) Computational modeling of -secretase 1 (bace-1) inhibitors using ligand based approaches. Journal of Chemical Information and Modeling 56 (10), pp. 1936–1949. Cited by: §4.1.
  • S. Sukhbaatar, J. Weston, R. Fergus, et al. (2015) End-to-end memory networks. In Advances in Neural Information Processing Systems, pp. 2440–2448. Cited by: §3.1.
  • H. Tong, C. Faloutsos, and J. Pan (2006) Fast random walk with restart and its applications. In International Conference on Data Mining, pp. 613–622. Cited by: §A.4.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin (2017) Attention is all you need. In Advances in Neural Information Processing Systems, pp. 5998–6008. Cited by: §3.2, §3.2.
  • P. Veličković, G. Cucurull, A. Casanova, A. Romero, P. Liò, and Y. Bengio (2018) Graph attention networks. In International Conference on Learning Representations, Cited by: §1, §2, §3.3, §3.3.
  • O. Vinyals, S. Bengio, and M. Kudlur (2015) Order matters: sequence to sequence for sets. In International Conference on Learning Representations, Cited by: §2, §3.2, Table 1.
  • O. Vinyals, C. Blundell, T. Lillicrap, D. Wierstra, et al. (2016) Matching networks for one shot learning. In Advances in Neural Information Processing Systems, pp. 3630–3638. Cited by: §2.
  • S. Vivona and K. Hassani (2019) Relational graph representation learning for open-domain question answering. arXiv preprint arXiv:1910.08249. Cited by: §1.
  • C. Wang, B. Samari, and K. Siddiqi (2018) Local spectral graph convolution for point set feature learning. In European Conference on Computer Vision, pp. 52–66. Cited by: §2.
  • J. Weston, S. Chopra, and A. Bordes (2015) Memory networks. In International Conference on Learning Representation, Cited by: §2, §3.1.
  • Z. Wu, B. Ramsundar, E. N. Feinberg, J. Gomes, C. Geniesse, A. S. Pappu, K. Leswing, and V. Pande (2018) MoleculeNet: a benchmark for molecular machine learning. Chemical Science 9 (2), pp. 513–530. Cited by: §4.2, §4.3.
  • Z. Wu, S. Pan, F. Chen, G. Long, C. Zhang, and P. S. Yu (2019) A comprehensive survey on graph neural networks. arXiv preprint arXiv:1901.00596. Cited by: §1.
  • J. Xie, R. Girshick, and A. Farhadi (2016)

    Unsupervised deep embedding for clustering analysis

    In International Conference on Machine Learning, pp. 478–487. Cited by: §3.1, §3.4, §3.4.
  • K. Xu, C. Li, Y. Tian, T. Sonobe, K. Kawarabayashi, and S. Jegelka (2018) Representation learning on graphs with jumping knowledge networks. In International Conference on Machine Learning, pp. 5453–5462. Cited by: §1.
  • P. Yanardag and S.V.N. Vishwanathan (2015) Deep graph kernels. In International Conference on Knowledge Discovery and Data Mining, pp. 1365–1374. Cited by: §4.1, §4.1.
  • Z. Ying, J. You, C. Morris, X. Ren, W. Hamilton, and J. Leskovec (2018) Hierarchical graph representation learning with differentiable pooling. In Advances in Neural Information Processing Systems, pp. 4800–4810. Cited by: §1, §2, §4.1, §4.2, §4.2, Table 1.
  • Z. Zhang, P. Cui, and W. Zhu (2018) Deep learning on graphs: a survey. arXiv preprint arXiv:1812.04202. Cited by: §1.
  • J. Zhou, G. Cui, Z. Zhang, C. Yang, Z. Liu, and M. Sun (2018) Graph neural networks: a review of methods and applications. arXiv preprint arXiv:1812.08434. Cited by: §1.

Appendix A Appendix

a.1 Implementation Details

We implemented the model with PyTorch (Paszke et al., 2017) and optimized it using Adam (Kingma and Ba, 2014) optimizer. We trained the model for a maximum number of 2000 epochs and decayed the learning rate every 500 epochs by 0.5. The model uses batch-normalization (Ioffe and Szegedy, 2015)

, skip-connections, LeakyRelu activation functions, and dropout

(Srivastava et al., 2014) for regularization. We also set the temperature in Student’s t-distribution to 1.0 and the restart probability in RWR to 0.1. We decided the hidden dimension and the number of model parameters using random hyper-parameter search strategy. The best performing hyper-parameters for the datasets are shown in Table 4.

a.2 Dataset Statistics

Table 5 summarizes the statistics of the datasets used for graph classification and regression tasks.

a.3 effect of e-gat

In section 4.4.1, we introduced e-GAT. Figures 2(a) and 2(b) illustrate the RMSE and R score on the validation set of the ESOL dataset achieved by a MemGNN model using both GAT and e-GAT as the query network, respectively. As shown, e-GAT performs better compared to GAT on both metrics.

a.4 Random Walk with Restart

Suppose an agent randomly traverses a graph starting from node and iteratively walks towards its neighbors with a probability proportional to the edge weight that connects them. The agent also can randomly restart the traverse with probability . Eventually, the agent will stop at node with a probability called relevance score of node with respect to node (Tong et al., 2006). The relevance score of node with every other node of the graph is defined as follows:


where is the RWR score corresponding to node , is the restart probability, is the normalized adjacency matrix, and is one-hot vector representation of node .

Note that the restart probability defines how far the agent can walk from the source node and therefore represents the trad-off between local and global information around node .

a.5 Learned Clusters

Figure 4 shows how unsupervised loss helps the model to push the nodes into distinct clusters. Figures 3(a) and 3(c) illustrates clusters with unsupervised loss and Figures 3(b) and 3(d) show computed clusters without unsupervised loss. The visualizations suggest that unsupervised loss helps the model to avoid trivial solutions by collapsing the latent node representations into meaningless clusters. Also, Figure 5 represents meaningful chemical groups extracted by MemGNN. Figures 4(b) and 4(d) are from LIPO and figure 4(a) and 4(c) are from ESOL dataset respectively.

Dataset #Keys #Heads #Layers Hidden Dimension Batch
Enzymes [10,1] 5 2 100 20
Proteins [10,1] 5 2 80 20
DD [16, 8, 1] 5 3 120 64
Collab [32, 8, 1] 5 3 100 64
Reddit-B [32,1] 1 2 16 32
ESOL [64,1] 5 2 16 32
Lipophilicity [32,1] 5 2 16 32
BACE [32,1] 5 2 8 32
Table 4: Hyper-parameters selected for the models.
Name Task Graphs Classes Avg. Nodes Avg. Edges Node Attr. Edge Attr.
Enzymes classification 600 6 32.63 62.14 18 0
Proteins classification 1113 2 39.06 72.82 29 0
DD classification 1178 2 284.32 715.66 0 0
Collab classification 5000 3 74.49 2475.78 0 0
Reddit-B classification 2000 2 429.63 497.75 0 0
Bace classification 1513 2 34.09 36.86 32 7
Tox21 classification 8014 2 17.87 18.50 32 7
ESOL regression 1144 - 13.29 13.68 32 7
Lipophilicity regression 4200 - 27.04 29.50 32 7
Table 5: Summary of statistics of the benchmark dataset.
Figure 3: Validation (a) R score, and (b) RMSE achieved by MemGNN model on ESOL with GAT and e-GAT based query networks.
Figure 4: Figures (b) and (d) show computed clusters without using unsupervised clustering loss, whereas Figures (a) and (c) show the clusters learned using the unsupervised clustering loss. The visualizations suggest that the unsupervised loss helps the model in learning distinct and meaningful clusters.
Figure 5: Clusters learned by a MeMGNN for ESOL and LIPO dataset. Chemical groups like OH (hydroxyl group), CCl3, COOH (carboxyl group), CO (ketone group) as well as benzene rings have been recognized during the learning procedure. These chemical groups are highly active and have a great impact on the solubility of molecules.