Two-stage Training of Graph Neural Networks for Graph Classification

11/10/2020 ∙ by Manh Tuan Do, et al. ∙ KAIST 수리과학과 Yonsei University 0

Graph Neural Networks (GNNs) have received massive attention in the field of machine learning on graphs. Inspired by the success of neural networks, a line of research has been conducted to train GNNs to deal with various tasks, such as node classification, graph classification, and link prediction. In this work, our task of interest is graph classification. Several GNN models have been proposed and shown great accuracy in this task. However, the question is whether usual training methods fully realizes the power of the GNN models. In this work, we propose a two-stage training framework based on triplet loss. In the first stage, GNN is trained to map each graph to a Euclidean-space vector so that graphs of the same class are close while those of different classes are mapped far apart. Once graphs are well-separated based on labels, a classifier can be trained to distinguish between different classes. This method is generic in the sense that it is compatible with any GNN model. By adapting five GNN models to our method, we demonstrate the consistent improvement in accuracy over the original training method of each model up to 5.4 12 datasets.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

With the pervasiveness of graph-structured data, graph representation learning has become an increasingly important task. Its goal is to learn meaningful embeddings (i.e., vector representations) of nodes and/or (sub)graphs. These embeddings can be used in various downstream tasks, such as node classification, link prediction, and graph classification.

Metric learning is about learning distance between objects in a metric space. While it remains a difficult task to properly define an efficient metric measure directly based on graph topology, a common approach is to map the graphs into vectors in the Euclidean space and measure the distance among those vectors. In addition to satisfying the basic properties of metric, this mapping is also expected to separate graphs of different classes to distinguishable clusters.

Graph neural networks (GNNs) have received a lot of attention in the graph mining literature. Despite the challenge of applying the message-passing mechanism of neural networks to the graph structure, GNNs have proved successful in dealing with graph learning problems, including node classification  [velickovic2018graph, kipf2016semi], link prediction [schlichtkrull2018modeling] and graph classification [zhang2018end, dai2016discriminative, duvenaud2015convolutional]. The common approach is to start from node features, allow information to flow among neighboring nodes and finalize the meaningful node embeddings. GNN models differ by the information-passing method and the objectives of the final embeddings.

2STG+ (Proposed) [hlugzlplstrategies]
Accuracy improvement 5 out of 5 GNNs 3 out of 4 GNNs
All datasets Only within domains
Required datasets No additional set Large ( 400K graphs)
Total training time Short ( 1 hour) Long ( 1 day)
Table 1: Comparison of our method 2STG+ with [hlugzlplstrategies]. While both need a pre-training step before graph classification, our method consistently improves accuracy of several GNNs even in datasets of a far domain while requiring shorter training time and no additional rich dataset.

Graph classification involves separating graph instances of different classes and predicting the label of an unknown graph. This task requires a graph representation vector distinctive enough to distinguish graphs of different classes. The subtlety is how to combine the node embeddings into an expressive graph representation vector, and a number of approaches have been proposed.

Although GNNs are shown to achieve high accuracy of graph classification, we observe that, with usual end-to-end training methods, they cannot realize their full potential. Thus, we propose 2STG+, a new training method with two stages. The first stage is metric learning with triplet loss, and the second stage is training a classifier. We observed that 2STG+ significantly improves the accuracy of five different GNN models, compared to their original training methods.

Our training method is, to some extent, similar to [hlugzlplstrategies] in the sense that GNNs are pre-trained on a task before being used for graph classification. However, [hlugzlplstrategies]

does transfer learning by pre-training GNNs on a different massive graph, either in chemistry or biology domain, with numerous tasks, on both node and graph levels. On the other hand,

2STG+ pre-trains GNNs on the same training dataset with only one graph-level task as the first stage. As highlighted in Table 1, 2STG+ is faster without requiring pre-training on rich and massive datasets, and it consistently achieves improved accuracy of more GNN models in more datasets than [hlugzlplstrategies].

In short, the contributions of our paper are three-fold.

  • Observation: In the graph classification task, GNNs often fail to exhibit their full power. Using a proper training method, their expressiveness can be further utilized.

  • Method Design: We propose a two-stage learning method with pre-training based on triplet loss. With this method, up to 5.4% points in accuracy can be increased.

  • Extensive Experiments: We conducted comprehensive experiments with 5 different GNN models and 12 datasets to illustrate the consistent accuracy improvement by our two-stage training method. We also compare our method with a strong graph transfer-learning framework to highlight its competency of our method.

2 Related Works

2.1 Graph neural networks

Graph neural networks (GNNs) attempt to learn embeddings (i.e, vector representations) of nodes and/or graphs, utilizing the mechanisms of neural networks adapted to the topology of graphs. The core idea of GNNs is to allow messages to pass among neighbors so that the representation of each node can incorporate the information from its neighborhood and thus to enable the GNNs to indirectly learn the graph structures. Numerous novel architectures for GNNs have been proposed and tested, which differ by the information-passing mechanisms. Among the most recent architectures are graph convolutions [kipf2016semi], attention mechanisms [velickovic2018graph]

, and those inspired by convolutional neural networks 

[gao2019graph, niepert2016learning, defferrard2016convolutional]. The final embeddings obtained from GNNs can be utilized for various graph mining tasks, such as node classification [kipf2016semi], link prediction [schlichtkrull2018modeling], and graph classification [zhang2018end, dai2016discriminative].

2.2 Graph classification by GNN

In graph classification, GNNs are tasked with predicting the label of an unseen graph. While node embeddings can be updated within a graph, the elusive step here is how to combine them into a vector representation of the entire graph that can distinguish among different labels. Two of the most common approaches are global pooling [duvenaud2015convolutional] and hierarchical pooling [zhitao2018hierarchical, yao2019graph, lee2019self]. The simplest ways for global pooling are global mean and global max of the final node embeddings. In contrast, hierarchical pooling iteratively reduces the number of nodes either by merging similar nodes into supernodes [zhitao2018hierarchical, yao2019graph] or selecting most significant nodes [lee2019self] until reaching a final (super)node whose embedding is used to represent the whole graph.

2.3 Transfer learning for graphs

While most existing methods attempt to train GNNs as an end-to-end classification system, some studies considered transfer learning in which the GNN is trained on a large dataset before being applied to the task of interest, often in a much smaller dataset. [hlugzlplstrategies] succeeded in improving 3 (out of 4 attempted) existing GNNs by transfer learning from other tasks. Rather than training a GNN to classify a dataset right away, the authors pre-trained that GNN on another massive dataset (up to 456K graphs); then they added a classifier and trained the whole on the graph classification task. However, transfer learning for graph remains a major challenge, as [ching2018opportunities, wang2019data] pointed out that considerable domain knowledge is required to design the appropriate pre-training procedure.

2.4 Metric learning

Metric learning aims to approximate a real-valued distance between two objects. Some work has focused on metric learning on graphs [ktena2018metric, liu2019community, ling2020hierarchical][ktena2018metric, liu2019community] employ a Siamese network structure, in which a twin network sharing the same weights is applied on a pair of graphs, and the two output vectors acting as representation for the two graphs are passed through a distance measure.

In computer vision

[schroff2015facenet] learns metric on triplet of images, where two (anchor and positive) share the same label and one (negative) has another label. The model aims to minimize the distance between the anchor and the positive, while maximizing the distance between the anchor and the negative. This inspired our interest in learning graph metrics in a triplet fashion.

3 Proposed Method

In this section, we first define our task of interest: graph classification. Then, we describe each component of our proposed training method of GNNs for graph classification.

3.0.1 Problem definition

We tackle the task of graph classification. Given where is the class of the graph , the goal of graph classification is to learn a mapping that maps graphs to the set of classes and predicts the class of unknown graphs.

3.0.2 Outline of our method

Our method incorporates the advantages of both GNN and metric learning. Specifically, to facilitate a better accuracy of the classifier, our method first maps input graphs into the Euclidean space such that their corresponding embeddings are well-separated based on classes. Below, we first briefly introduce GNNs and a learning scheme on triplet loss. Then, we describe the two stages of our method: pre-training a GNN and training a classifier.

3.1 Graph neural networks

Various GNNs have been proposed and proven effective in approximating such a function . Starting from a graph with node features , GNNs obtain final embeddings of nodes and a final embedding of the graph after layers of aggregation. Specifically, at each -th layer, the embedding of each node incorporates the embeddings of itself and its neighboring nodes at the -th layer as follows:

The embedding of the graph is then obtained by pooling all node embeddings into a single vector as follows:

Different GNNs differ by how the incorporating function , the aggregating function , and the final pooling function are implemented.

Figure 1: First training stage for a GNN. is trained to differentiate graphs of different classes via the triplet loss.

Figure 2: Second training stage for a classifier. After maps to , a classifier is added to map

to the class probability prediction. At this step, either the classifier is trained independently (

2STG) or the whole architecture is trained together (2STG+).

3.2 Metric learning based on triplet loss

Triplet loss was first introduced in  [schroff2015facenet]. The core idea is to enforce a margin between classes of samples. This results in embeddings of the same class mapped to a cluster distant apart from that of other classes. Specifically, given a mapping , we wish for a graph (anchor) to be closer to another graph (positive) of the same class than to a graph (negative) of another class by at least a margin

, which is a hyperparameter, i.e.,

The triplet loss for the whole dataset becomes:

with the summation over all considered triplets.

Our two-stage method combines the power of both GNNs and the metric learning method, as described below.

3.3 First training stage (pre-training a GNN)

In the first training stage (depicted in Fig. 1), given a GNN architecture , its weights are shared among a triplet network , which consists of three identical GNN architectures having the same weights as . The parameters of are trained on each triplet of graphs

(anchor, positive, negative), in which the anchor and the positive graphs are of the same class while the negative graph is of another class. Instead of estimating the class probabilities,

maps each graph to a real-valued vector in the Euclidean space: . Ideally, and should be close while is far from them both. The triplet loss for is defined as:

3.4 Second training stage (training a classifier)

At the second stage, a classifier is either trained independently, or added on top of the trained GNN and trained together on the graph classification task (see Fig. 2).

In summary, we propose two training methods for GNNs: 2STG and 2STG+, both consist of two stages.

  • 2STG (Pre-training Setting): In the first stage, the GNN maps each triplet of graphs into a corresponding triplet of Euclidean-space vectors, and in turns the GNN is trained on triplet loss. In the second stage, a classifier is trained independently to classify the graph embeddings.

  • 2STG+ (Fine-tuning setting): It has the same structure as 2STG except that in the second stage, the classifier is plugged on top of the trained GNN, and then the whole architecture is trained together in an end-to-end manner.

Note that our methods are compatible to any GNN model that maps each graph to a representation vector. As shown in the next section, when applied to this method, each GNN model outperformed itself in the original setting.

4 Experiments

In this section, we describe the details of our experiments.

4.1 GNN architectures

In order to demonstrate that our two-stage method helps realize better accuracy of GNNs, for each of the following GNN model, we compare the accuracy obtained in the original setting versus that from our method:

  • GraphSage [hamiltoninductive]

    : After obtaining node embeddings, global mean/max pooling is applied to combine them into one graph embedding.

  • GAT [velickovic2018graph]: Similarly, global mean/max pooling is employed to combine all the node embeddings.

  • Diff-pool [zhitao2018hierarchical]: A hierarchical approach for pooling the node embeddings.

  • Eigen-GCN [yao2019graph]: A different design for hierarchical pooling.

  • SAG-Pool [lee2019self]: A hierarchical graph pooling with self-attention mechanisms.

In previous studies, these models were trained end-to-end, mapping each graph to a prediction of class probabilities. To further illustrate the competency of our method, we also compare it with a transfer-learning method [hlugzlplstrategies].

4.2 Datasets

To validate the claims, we apply our methods on 12 datasets. They include some commonly tested binary-class datasets  [morris2020tudataset]

for graph classification: DD, MUTAG, Mutagenicity, PROTEINS, PTC-FM, and IMDB-BINARY. In addition, we also test our method on New York City Taxi datasets.

111 More details can be found in the Appendix.

4.3 Experimental procedure

We tested the ability of each GNN architecture to classify graphs in the following three settings:

  • Original setting: The GNN with a final classifier outputs the estimated class probabilities, and the weights are updated by the cross-entropy loss with respect to the true class. We use the implementation provided by the authors. To enhance the capacity of the final classifier, we tune the classifier by using up to three fully-connected layers and select the model based on validation sets.

  • 2STG and 2STG+: See the Proposed Method section.

Additionally, we also compare our two-stage method with the transfer-learning method in [hlugzlplstrategies], which also claims the effectiveness of a pre-training strategy. Out of 5 GNN models investigated in our work, GraphSage and GAT are provided with trained weights by [hlugzlplstrategies], and they are compared with GraphSage/GAT trained in 2STG+.

Each dataset is randomly split into three sets: training (80%), validation (10%) and test (10%). Details about hyperparameter search can be found in the appendix. The reported results are average and standard deviation of test accuracy of five splits.

We initialize node features as learnable features that are also optimized alongside GNN parameters during training. Even though input features are provided in some datasets, we empirically observe that in most cases, using learnable features leads to better accuracy.

In 2STG and 2STG+

, each graph can be anchor once, while the respective positive and negative graphs are chosen randomly. The classifier is a multi-layer perceptron (MLP).

4.4 Results and discussion

4.4.1 Improvement by our methods

The results of comparing 2STG and 2STG+ with the original setting are summarized in Tables 2 and 3. Several observations can be drawn:

  • Pre-training using triplet loss (i.e., the first stage of 2STG and 2STG+) consistently enhances the graph classification accuracy of each GNN model by 0.9-5.4% points, compared to its original setting.

  • Fine-tuning the weights of GNNs (i.e, the second stage of 2STG+) further improves the accuracy from 2STG in some cases by up to 1.3% points.

Method Data Set Average Gain
D&D MUTAG Mutagenicity PTC-FM PROTEINS IMDB-BINARY (in % points)
GraphSage 69.24 0.52 65.13 0.87 75.44 0.50 61.77 1.11 71.25 1.38 65.52 0.96 -
GraphSage (2STG) 75.13 0.82 80.86 1.19 76.84 0.54 62.75 1.20 71.29 0.41 68.37 0.63 4.47
GraphSage (2STG+) 76.52 1.47 81.14 0.68 77.71 0.41 62.65 0.72 72.34 0.56 68.24 0.83 5.01
GAT 66.50 1.24 65.18 1.03 76.23 0.67 60.65 0.42 66.92 0.75 67.13 0.88 -
GAT (2STG) 72.95 0.91 77.84 0.63 76.34 0.52 62.04 1.16 70.17 0.72 69.15 0.87 4.11
GAT (2STG+) 74.13 1.47 78.17 1.41 76.49 1.23 61.61 0.53 72.64 0.58 67.25 0.89 5.37
DiffPool 72.11 0.42 86.32 0.83 77.21 1.16 61.15 0.35 72.24 0.67 64.93 0.74 -
DiffPool (2STG) 74.93 0.53 86.14 0.77 77.94 1.28 62.03 0.32 73.87 0.64 65.22 0.83 1.03
DiffPool (2STG+) 78.84 0.54 87.38 0.62 77.08 1.23 62.15 0.68 73.07 1.17 64.90 0.81 1.07
EigenGcn 75.62 0.63 79.87 0.66 76.65 1.14 63.34 1.23 75.63 0.82 71.86 0.55 -
EigenGcn (2STG) 77.56 0.48 80.21 0.71 77.98 0.62 64.13 0.95 75.93 0.56 72.66 0.42 0.91
EigenGcn (2STG+) 78.13 0.51 81.42 0.86 77.02 1.72 63.52 1.43 77.31 1.46 72.04 0.53 1.07
Sag-gPool 76.12 0.79 78.34 0.65 76.83 1.27 63.27 0.78 74.34 1.25 71.23 1.12 -
Sag-Pool (2STG) 78.32 1.26 79.63 0.95 78.03 0.68 63.83 0.83 77.52 0.54 71.73 0.81 1.48
Sag-Pool (2STG+) 78.22 0.70 79.03 0.89 77.03 0.63 64.34 0.86 76.23 1.12 72.36 0.73 1.24
Table 2: Average and standard deviation (in %) of graph classification accuracies in benchmark datasets in three settings: original, 2STG, and 2STG+. Pre-training GNNs in 2STG and 2STG+ improves the classification accuracy compared to the original setting, andfine-tuning GNNs in 2STG+ further improves the accuracy. The average gain is in % points.
Method Data Set Average Gain
Jan. G. Feb. G. Mar. G. Jan. Y. Feb. Y. Mar. Y. (in % points)
GraphSage 73.14 0.62 66.35 1.25 64.63 0.83 72.86 0.92 64.37 0.87 68.12 0.76 -
GraphSage (2STG) 76.14 0.93 66.67 1.31 67.13 0.85 75.24 1.16 65.43 0.68 70.15 0.64 1.88
GraphSage (2STG+) 76.63 0.82 67.74 0.88 68.95 1.41 75.21 1.70 67.64 0.73 70.23 1.25 2.82
GAT 71.26 1.51 67.82 0.77 66.13 0.72 72.64 0.54 64.76 0.73 67.51 1.69 -
GAT (2STG) 75.23 0.82 67.24 0.56 67.34 0.71 76.82 1.23 66.45 0.85 70.66 0.78 2.27
GAT (2STG+) 74.65 0.98 68.11 0.69 69.15 1.37 74.79 1.27 68.75 0.66 70.44 0.93 3.04
DiffPool 78.43 0.74 73.12 0.42 71.39 1.56 72.52 1.23 67.43 0.87 74.34 0.77 -
DiffPool (2STG) 80.28 1.16 75.69 1.21 73.79 0.81 75.09 0.72 68.19 0.50 74.87 0.83 1.78
DiffPool (2STG+) 79.63 0.82 74.56 1.32 72.92 0.65 75.95 1.21 69.31 0.97 75.76 0.86 1.81
EigenGcn 75.45 0.44 69.32 1.82 72.21 0.83 73.21 1.35 69.64 0.76 69.52 1.54 -
EigenGcn (2STG) 77.14 0.81 70.03 0.62 74.12 1.34 74.36 1.65 69.72 0.97 70.03 0.86 1.02
EigenGcn (2STG+) 76.73 1.21 71.27 1.33 73.37 1.85 75.33 1.14 71.65 1.67 71.84 0.62 1.80
Sag-Pool 73.23 0.59 67.46 0.73 72.78 1.34 72.65 0.72 68.83 1.25 69.68 1.35 -
Sag-Pool (2STG) 76.36 1.37 69.07 1.48 74.34 1.52 71.11 0.73 70.02 0.64 70.04 1.48 1.05
Sag-Pool (2STG+) 75.38 0.86 69.27 1.12 73.19 1.34 72.51 0.85 69.16 0.79 70.59 0.52 0.91
Table 3: Average and standard deviation (in %) of graph classification accuracies in NYC Taxi datasets in three settings: original, 2STG, and 2STG+. Similarly to the case of benchmark datasets, 2STG and 2STG+ significantly outperforms the original setting. G., Y. stand for Green, Yellow. The average gain is in % points.
Method Data Set Average Gain
D&D Imdb-Binary MUTAG Proteins Mutagenicity PTC-FM (in % points)
GAT (2STG+) 74.13 1.47 67.25 0.89 78.17 1.41 72.64 0.58 76.49 1.23 61.61 0.53 1.56
GAT [hlugzlplstrategies] 72.24 0.83 65.16 1.47 76.86 1.35 71.76 0.77 75.59 1.48 61.28 0.97 -
GraphSage (2STG+) 76.52 1.47 68.24 0.83 81.14 0.68 72.34 0.56 77.71 0.41 62.65 0.72 0.31
GraphSage [hlugzlplstrategies] 75.26 1.36 67.14 0.52 82.43 1.49 72.15 0.83 76.83 0.95 62.61 0.78 -
Table 4: 2STG+ outperforms transfer learning [hlugzlplstrategies], in terms of average classification accuracy in most cases when the benchmark datasets are used. Note that 2STG+ has several advantages over [hlugzlplstrategies], as highlighted in Table 1.
Method Data Set Average Gain
Jan.G Feb. G Mar. G Jan. Y. Feb. Y. Mar. Y. (in % points)
GAT (2STG+) 74.65 0.98 68.11 0.69 69.15 1.37 74.79 1.27 68.75 0.66 70.44 0.93 1.97
GAT [hlugzlplstrategies] 73.87 1.13 65.89 1.04 67.25 1.27 71.87 1.35 66.24 0.92 68.95 1.25 -
GraphSage (2STG+) 76.63 0.82 67.74 0.88 68.95 1.41 75.21 1.70 67.64 0.73 70.23 1.25 0.53
GraphSage [hlugzlplstrategies] 75.19 0.98 67.82 0.43 68.03 0.66 73.66 1.27 68.24 0.79 70.25 0.73 -
Table 5: The accuracy gap between 2STG+ and transfer learning [hlugzlplstrategies] is larger in some Taxi datasets, as the networks for transfer learning were pre-trained on either a biology or chemistry dataset, which is of a far domain.

According to the results, 2STG and 2STG+ yield better accuracy than the original setting of each GNN. This observation suggests two possible explanations:

  • The end-to-end training methods fail to realize the full potential of the GNN models. Even if the final classifier of is upgraded from a fully-connected layer to an MLP, the accuracy is not as high as in 2STG and 2STG+.

  • Learning meaningful embeddings in between that are fairly separated based on classes (see Figure 3), for example through metric learning as in our methods, facilitates a better accuracy of the final classifier.

Figure 3: Visualization of the final embeddings in Original (left), 2STG (middle) and 2STG+ (right) settings (DD dataset). Instances of the two classes are separated better in 2STG and 2STG+ than in the original setting.

4.4.2 Comparison with a transfer learning method

We compare the results of the pre-trained models of GraphSage/GAT in [hlugzlplstrategies] with GraphSage/GAT with the same architecture (5 layers, 300 dimensional hidden units and global mean pooling) trained in 2STG+. Note that the pre-trained models are fine-tuned on the considered datasets. Results of comparison are in Tables 4 and 5, where the average and standard deviation of test accuracy of 5 splits are reported. Despite being pre-trained on a much smaller dataset with only one task, our method achieves better accuracy in of the considered cases: up to 2% points in the benchmark datasets and up to 3% points in the Taxi datasets. This validates our claims in Table 1.

4.4.3 Running time

For each hyperparameter setting in a dataset, the original setting takes up to half an hour to train a model. Due to having two stages, 2STG and 2STG+ take up to an hour for both stages. In [hlugzlplstrategies], the transfer-learning method was reported to take up to one day to pre-train on a rich dataset.

5 Conclusion

Graph Neural Networks are powerful tools in dealing with many graph mining tasks, including graph classification, which our work focuses on. However, training them end-to-end to predict class probabilities often fails to realize their full capability. Thus, we apply GNN models into a triplet framework to learn discriminative embeddings first, and then train a classifier on those embeddings. Extensive experiments in 12 datasets lead to following observations:

  • End-to-end training often fails to realize the full potential of GNN models. Applying GNN models in our method enhances their accuracy by up to 5.4% points.

  • Our two-stage training method leads to better accuracy than a state-of-the-art pre-training method based on transfer-learning in 83% of the considered cases.

  • Despite not requiring any additional massive rich datasets or long training time, our training method consistently improves the accuracy of (out of tested) GNN models in (out of considered) datasets.


Appendix A Datasets

We tested our training method using 12 datasets:

a.1 Benchmark datasets

These are the commonly tested binary-class datasets  [morris2020tudataset] for the graph classification task: DD, MUTAG, Mutagenicity, PROTEINS, PTC-FM, and IMDB-BINARY.

a.2 New York City Taxi

We extracted the taxi ridership data in 2019 from New York City (NYC) Taxi Commission. The areas in New York are represented as nodes, and each taxi trip is an edge connecting the source and destination nodes. All taxi trips in an 1-hour interval form a graph, and each dataset spans a month of taxi operations. We augmented the binary label for each graph as taxi trips in weekdays (Mon-Thu) vs. weekend (Fri-Sun). We considered two taxi operators (Yellow and Green) and processed data in January, February and March of 2019, making 6 datasets in total.

Appendix B GNN architectures

In order to demonstrate that our two-stage training method helps realize a better performance of GNNS, for each GNN architecture, we compare the accuracy obtained in the original setting versus that from our method. The GNN architectures we considered in this work are:

  • GraphSage [hamiltoninductive]: This is often used as a strong baseline in graph classification. After obtaining node embeddings, global mean/max pooling is applied to combine all node embeddings into one graph embedding.

  • GAT [velickovic2018graph]: Instead of uniformly passing neighbor information into a node embedding, [velickovic2018graph] employs an attention mechanism for the importance of each neighboring node.

  • Diff-pool [zhitao2018hierarchical]: While using the same aggregation mechanism as [hamiltoninductive], [zhitao2018hierarchical] proposes a hierarchical approach to pool the node embeddings. Rather than a “flat-pooling” step at the end, diff-pool repeatedly merges nodes into “supernodes” until there is only 1 supernode whose embedding is treated as the graph embedding.

  • Eigen-GCN [yao2019graph]: Attempting to implement hierarchical pooling like [zhitao2018hierarchical], [yao2019graph]

    formulates a different way to combine nodes and their respective embeddings making use of the eigenvectors of the Laplacian matrix.

  • SAG-Pool [lee2019self]: Hierarchical graph pooling employing self-attention mechanisms.

Appendix C Hyperparameter search

For each GNN, the hyperparameters regarding the network architecture were tuned in the same search space for the three settings: original, 2STG, 2STG+. The search space for the dimensions of the input vector, hidden vector and output vector for all GNNs was . For Diff-pool, we used three layers of graph convolution and one DIFFPOOL layer as described in the original paper. For Eigen-GCN, we used three pooling operators as it was shown to achieve the best performance in the original paper. For SAG-Pool, we used three pooling layers as explained in the original paper. Other hyperparameters that are exclusive to each GNN architecture were set to default values provided in each paper’s original code of each architecture’s authors.

In all three settings, the architecture of the final classifier was also tuned. The number of fully-connected layers was up to 3 while the search space for the hidden dimension was where is the dimension of the output vector.

The two settings 2STG and 2STG+ require an additional hyperparameter . While [schroff2015facenet] found to be effective, we empirically found that this value is too small to separate instances of different classes. Instead, the search space for we used is .