With the pervasiveness of graph-structured data, graph representation learning has become an increasingly important task. Its goal is to learn meaningful embeddings (i.e., vector representations) of nodes and/or (sub)graphs. These embeddings can be used in various downstream tasks, such as node classification, link prediction, and graph classification.
Metric learning is about learning distance between objects in a metric space. While it remains a difficult task to properly define an efficient metric measure directly based on graph topology, a common approach is to map the graphs into vectors in the Euclidean space and measure the distance among those vectors. In addition to satisfying the basic properties of metric, this mapping is also expected to separate graphs of different classes to distinguishable clusters.
Graph neural networks (GNNs) have received a lot of attention in the graph mining literature. Despite the challenge of applying the message-passing mechanism of neural networks to the graph structure, GNNs have proved successful in dealing with graph learning problems, including node classification [velickovic2018graph, kipf2016semi], link prediction [schlichtkrull2018modeling] and graph classification [zhang2018end, dai2016discriminative, duvenaud2015convolutional]. The common approach is to start from node features, allow information to flow among neighboring nodes and finalize the meaningful node embeddings. GNN models differ by the information-passing method and the objectives of the final embeddings.
|Accuracy improvement||5 out of 5 GNNs||3 out of 4 GNNs|
|All datasets||Only within domains|
|Required datasets||No additional set||Large ( 400K graphs)|
|Total training time||Short ( 1 hour)||Long ( 1 day)|
Graph classification involves separating graph instances of different classes and predicting the label of an unknown graph. This task requires a graph representation vector distinctive enough to distinguish graphs of different classes. The subtlety is how to combine the node embeddings into an expressive graph representation vector, and a number of approaches have been proposed.
Although GNNs are shown to achieve high accuracy of graph classification, we observe that, with usual end-to-end training methods, they cannot realize their full potential. Thus, we propose 2STG+, a new training method with two stages. The first stage is metric learning with triplet loss, and the second stage is training a classifier. We observed that 2STG+ significantly improves the accuracy of five different GNN models, compared to their original training methods.
Our training method is, to some extent, similar to [hlugzlplstrategies] in the sense that GNNs are pre-trained on a task before being used for graph classification. However, [hlugzlplstrategies]
does transfer learning by pre-training GNNs on a different massive graph, either in chemistry or biology domain, with numerous tasks, on both node and graph levels. On the other hand,2STG+ pre-trains GNNs on the same training dataset with only one graph-level task as the first stage. As highlighted in Table 1, 2STG+ is faster without requiring pre-training on rich and massive datasets, and it consistently achieves improved accuracy of more GNN models in more datasets than [hlugzlplstrategies].
In short, the contributions of our paper are three-fold.
Observation: In the graph classification task, GNNs often fail to exhibit their full power. Using a proper training method, their expressiveness can be further utilized.
Method Design: We propose a two-stage learning method with pre-training based on triplet loss. With this method, up to 5.4% points in accuracy can be increased.
Extensive Experiments: We conducted comprehensive experiments with 5 different GNN models and 12 datasets to illustrate the consistent accuracy improvement by our two-stage training method. We also compare our method with a strong graph transfer-learning framework to highlight its competency of our method.
2 Related Works
2.1 Graph neural networks
Graph neural networks (GNNs) attempt to learn embeddings (i.e, vector representations) of nodes and/or graphs, utilizing the mechanisms of neural networks adapted to the topology of graphs. The core idea of GNNs is to allow messages to pass among neighbors so that the representation of each node can incorporate the information from its neighborhood and thus to enable the GNNs to indirectly learn the graph structures. Numerous novel architectures for GNNs have been proposed and tested, which differ by the information-passing mechanisms. Among the most recent architectures are graph convolutions [kipf2016semi], attention mechanisms [velickovic2018graph]
, and those inspired by convolutional neural networks[gao2019graph, niepert2016learning, defferrard2016convolutional]. The final embeddings obtained from GNNs can be utilized for various graph mining tasks, such as node classification [kipf2016semi], link prediction [schlichtkrull2018modeling], and graph classification [zhang2018end, dai2016discriminative].
2.2 Graph classification by GNN
In graph classification, GNNs are tasked with predicting the label of an unseen graph. While node embeddings can be updated within a graph, the elusive step here is how to combine them into a vector representation of the entire graph that can distinguish among different labels. Two of the most common approaches are global pooling [duvenaud2015convolutional] and hierarchical pooling [zhitao2018hierarchical, yao2019graph, lee2019self]. The simplest ways for global pooling are global mean and global max of the final node embeddings. In contrast, hierarchical pooling iteratively reduces the number of nodes either by merging similar nodes into supernodes [zhitao2018hierarchical, yao2019graph] or selecting most significant nodes [lee2019self] until reaching a final (super)node whose embedding is used to represent the whole graph.
2.3 Transfer learning for graphs
While most existing methods attempt to train GNNs as an end-to-end classification system, some studies considered transfer learning in which the GNN is trained on a large dataset before being applied to the task of interest, often in a much smaller dataset. [hlugzlplstrategies] succeeded in improving 3 (out of 4 attempted) existing GNNs by transfer learning from other tasks. Rather than training a GNN to classify a dataset right away, the authors pre-trained that GNN on another massive dataset (up to 456K graphs); then they added a classifier and trained the whole on the graph classification task. However, transfer learning for graph remains a major challenge, as [ching2018opportunities, wang2019data] pointed out that considerable domain knowledge is required to design the appropriate pre-training procedure.
2.4 Metric learning
Metric learning aims to approximate a real-valued distance between two objects. Some work has focused on metric learning on graphs [ktena2018metric, liu2019community, ling2020hierarchical]. [ktena2018metric, liu2019community] employ a Siamese network structure, in which a twin network sharing the same weights is applied on a pair of graphs, and the two output vectors acting as representation for the two graphs are passed through a distance measure.
In computer vision,[schroff2015facenet] learns metric on triplet of images, where two (anchor and positive) share the same label and one (negative) has another label. The model aims to minimize the distance between the anchor and the positive, while maximizing the distance between the anchor and the negative. This inspired our interest in learning graph metrics in a triplet fashion.
3 Proposed Method
In this section, we first define our task of interest: graph classification. Then, we describe each component of our proposed training method of GNNs for graph classification.
3.0.1 Problem definition
We tackle the task of graph classification. Given where is the class of the graph , the goal of graph classification is to learn a mapping that maps graphs to the set of classes and predicts the class of unknown graphs.
3.0.2 Outline of our method
Our method incorporates the advantages of both GNN and metric learning. Specifically, to facilitate a better accuracy of the classifier, our method first maps input graphs into the Euclidean space such that their corresponding embeddings are well-separated based on classes. Below, we first briefly introduce GNNs and a learning scheme on triplet loss. Then, we describe the two stages of our method: pre-training a GNN and training a classifier.
3.1 Graph neural networks
Various GNNs have been proposed and proven effective in approximating such a function . Starting from a graph with node features , GNNs obtain final embeddings of nodes and a final embedding of the graph after layers of aggregation. Specifically, at each -th layer, the embedding of each node incorporates the embeddings of itself and its neighboring nodes at the -th layer as follows:
The embedding of the graph is then obtained by pooling all node embeddings into a single vector as follows:
Different GNNs differ by how the incorporating function , the aggregating function , and the final pooling function are implemented.
3.2 Metric learning based on triplet loss
Triplet loss was first introduced in [schroff2015facenet]. The core idea is to enforce a margin between classes of samples. This results in embeddings of the same class mapped to a cluster distant apart from that of other classes. Specifically, given a mapping , we wish for a graph (anchor) to be closer to another graph (positive) of the same class than to a graph (negative) of another class by at least a margin
, which is a hyperparameter, i.e.,
The triplet loss for the whole dataset becomes:
with the summation over all considered triplets.
Our two-stage method combines the power of both GNNs and the metric learning method, as described below.
3.3 First training stage (pre-training a GNN)
In the first training stage (depicted in Fig. 1), given a GNN architecture , its weights are shared among a triplet network , which consists of three identical GNN architectures having the same weights as . The parameters of are trained on each triplet of graphs
(anchor, positive, negative), in which the anchor and the positive graphs are of the same class while the negative graph is of another class. Instead of estimating the class probabilities,maps each graph to a real-valued vector in the Euclidean space: . Ideally, and should be close while is far from them both. The triplet loss for is defined as:
3.4 Second training stage (training a classifier)
At the second stage, a classifier is either trained independently, or added on top of the trained GNN and trained together on the graph classification task (see Fig. 2).
In summary, we propose two training methods for GNNs: 2STG and 2STG+, both consist of two stages.
2STG (Pre-training Setting): In the first stage, the GNN maps each triplet of graphs into a corresponding triplet of Euclidean-space vectors, and in turns the GNN is trained on triplet loss. In the second stage, a classifier is trained independently to classify the graph embeddings.
2STG+ (Fine-tuning setting): It has the same structure as 2STG except that in the second stage, the classifier is plugged on top of the trained GNN, and then the whole architecture is trained together in an end-to-end manner.
Note that our methods are compatible to any GNN model that maps each graph to a representation vector. As shown in the next section, when applied to this method, each GNN model outperformed itself in the original setting.
In this section, we describe the details of our experiments.
4.1 GNN architectures
In order to demonstrate that our two-stage method helps realize better accuracy of GNNs, for each of the following GNN model, we compare the accuracy obtained in the original setting versus that from our method:
: After obtaining node embeddings, global mean/max pooling is applied to combine them into one graph embedding.
GAT [velickovic2018graph]: Similarly, global mean/max pooling is employed to combine all the node embeddings.
Diff-pool [zhitao2018hierarchical]: A hierarchical approach for pooling the node embeddings.
Eigen-GCN [yao2019graph]: A different design for hierarchical pooling.
SAG-Pool [lee2019self]: A hierarchical graph pooling with self-attention mechanisms.
In previous studies, these models were trained end-to-end, mapping each graph to a prediction of class probabilities. To further illustrate the competency of our method, we also compare it with a transfer-learning method [hlugzlplstrategies].
To validate the claims, we apply our methods on 12 datasets. They include some commonly tested binary-class datasets [morris2020tudataset]
for graph classification: DD, MUTAG, Mutagenicity, PROTEINS, PTC-FM, and IMDB-BINARY. In addition, we also test our method on New York City Taxi datasets.111https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page More details can be found in the Appendix.
4.3 Experimental procedure
We tested the ability of each GNN architecture to classify graphs in the following three settings:
Original setting: The GNN with a final classifier outputs the estimated class probabilities, and the weights are updated by the cross-entropy loss with respect to the true class. We use the implementation provided by the authors. To enhance the capacity of the final classifier, we tune the classifier by using up to three fully-connected layers and select the model based on validation sets.
2STG and 2STG+: See the Proposed Method section.
Additionally, we also compare our two-stage method with the transfer-learning method in [hlugzlplstrategies], which also claims the effectiveness of a pre-training strategy. Out of 5 GNN models investigated in our work, GraphSage and GAT are provided with trained weights by [hlugzlplstrategies], and they are compared with GraphSage/GAT trained in 2STG+.
Each dataset is randomly split into three sets: training (80%), validation (10%) and test (10%). Details about hyperparameter search can be found in the appendix. The reported results are average and standard deviation of test accuracy of five splits.
We initialize node features as learnable features that are also optimized alongside GNN parameters during training. Even though input features are provided in some datasets, we empirically observe that in most cases, using learnable features leads to better accuracy.
In 2STG and 2STG+
, each graph can be anchor once, while the respective positive and negative graphs are chosen randomly. The classifier is a multi-layer perceptron (MLP).
4.4 Results and discussion
4.4.1 Improvement by our methods
Pre-training using triplet loss (i.e., the first stage of 2STG and 2STG+) consistently enhances the graph classification accuracy of each GNN model by 0.9-5.4% points, compared to its original setting.
Fine-tuning the weights of GNNs (i.e, the second stage of 2STG+) further improves the accuracy from 2STG in some cases by up to 1.3% points.
|Method||Data Set||Average Gain|
|D&D||MUTAG||Mutagenicity||PTC-FM||PROTEINS||IMDB-BINARY||(in % points)|
|GraphSage||69.24 0.52||65.13 0.87||75.44 0.50||61.77 1.11||71.25 1.38||65.52 0.96||-|
|GraphSage (2STG)||75.13 0.82||80.86 1.19||76.84 0.54||62.75 1.20||71.29 0.41||68.37 0.63||4.47|
|GraphSage (2STG+)||76.52 1.47||81.14 0.68||77.71 0.41||62.65 0.72||72.34 0.56||68.24 0.83||5.01|
|GAT||66.50 1.24||65.18 1.03||76.23 0.67||60.65 0.42||66.92 0.75||67.13 0.88||-|
|GAT (2STG)||72.95 0.91||77.84 0.63||76.34 0.52||62.04 1.16||70.17 0.72||69.15 0.87||4.11|
|GAT (2STG+)||74.13 1.47||78.17 1.41||76.49 1.23||61.61 0.53||72.64 0.58||67.25 0.89||5.37|
|DiffPool||72.11 0.42||86.32 0.83||77.21 1.16||61.15 0.35||72.24 0.67||64.93 0.74||-|
|DiffPool (2STG)||74.93 0.53||86.14 0.77||77.94 1.28||62.03 0.32||73.87 0.64||65.22 0.83||1.03|
|DiffPool (2STG+)||78.84 0.54||87.38 0.62||77.08 1.23||62.15 0.68||73.07 1.17||64.90 0.81||1.07|
|EigenGcn||75.62 0.63||79.87 0.66||76.65 1.14||63.34 1.23||75.63 0.82||71.86 0.55||-|
|EigenGcn (2STG)||77.56 0.48||80.21 0.71||77.98 0.62||64.13 0.95||75.93 0.56||72.66 0.42||0.91|
|EigenGcn (2STG+)||78.13 0.51||81.42 0.86||77.02 1.72||63.52 1.43||77.31 1.46||72.04 0.53||1.07|
|Sag-gPool||76.12 0.79||78.34 0.65||76.83 1.27||63.27 0.78||74.34 1.25||71.23 1.12||-|
|Sag-Pool (2STG)||78.32 1.26||79.63 0.95||78.03 0.68||63.83 0.83||77.52 0.54||71.73 0.81||1.48|
|Sag-Pool (2STG+)||78.22 0.70||79.03 0.89||77.03 0.63||64.34 0.86||76.23 1.12||72.36 0.73||1.24|
|Method||Data Set||Average Gain|
|Jan. G.||Feb. G.||Mar. G.||Jan. Y.||Feb. Y.||Mar. Y.||(in % points)|
|GraphSage||73.14 0.62||66.35 1.25||64.63 0.83||72.86 0.92||64.37 0.87||68.12 0.76||-|
|GraphSage (2STG)||76.14 0.93||66.67 1.31||67.13 0.85||75.24 1.16||65.43 0.68||70.15 0.64||1.88|
|GraphSage (2STG+)||76.63 0.82||67.74 0.88||68.95 1.41||75.21 1.70||67.64 0.73||70.23 1.25||2.82|
|GAT||71.26 1.51||67.82 0.77||66.13 0.72||72.64 0.54||64.76 0.73||67.51 1.69||-|
|GAT (2STG)||75.23 0.82||67.24 0.56||67.34 0.71||76.82 1.23||66.45 0.85||70.66 0.78||2.27|
|GAT (2STG+)||74.65 0.98||68.11 0.69||69.15 1.37||74.79 1.27||68.75 0.66||70.44 0.93||3.04|
|DiffPool||78.43 0.74||73.12 0.42||71.39 1.56||72.52 1.23||67.43 0.87||74.34 0.77||-|
|DiffPool (2STG)||80.28 1.16||75.69 1.21||73.79 0.81||75.09 0.72||68.19 0.50||74.87 0.83||1.78|
|DiffPool (2STG+)||79.63 0.82||74.56 1.32||72.92 0.65||75.95 1.21||69.31 0.97||75.76 0.86||1.81|
|EigenGcn||75.45 0.44||69.32 1.82||72.21 0.83||73.21 1.35||69.64 0.76||69.52 1.54||-|
|EigenGcn (2STG)||77.14 0.81||70.03 0.62||74.12 1.34||74.36 1.65||69.72 0.97||70.03 0.86||1.02|
|EigenGcn (2STG+)||76.73 1.21||71.27 1.33||73.37 1.85||75.33 1.14||71.65 1.67||71.84 0.62||1.80|
|Sag-Pool||73.23 0.59||67.46 0.73||72.78 1.34||72.65 0.72||68.83 1.25||69.68 1.35||-|
|Sag-Pool (2STG)||76.36 1.37||69.07 1.48||74.34 1.52||71.11 0.73||70.02 0.64||70.04 1.48||1.05|
|Sag-Pool (2STG+)||75.38 0.86||69.27 1.12||73.19 1.34||72.51 0.85||69.16 0.79||70.59 0.52||0.91|
|Method||Data Set||Average Gain|
|D&D||Imdb-Binary||MUTAG||Proteins||Mutagenicity||PTC-FM||(in % points)|
|GAT (2STG+)||74.13 1.47||67.25 0.89||78.17 1.41||72.64 0.58||76.49 1.23||61.61 0.53||1.56|
|GAT [hlugzlplstrategies]||72.24 0.83||65.16 1.47||76.86 1.35||71.76 0.77||75.59 1.48||61.28 0.97||-|
|GraphSage (2STG+)||76.52 1.47||68.24 0.83||81.14 0.68||72.34 0.56||77.71 0.41||62.65 0.72||0.31|
|GraphSage [hlugzlplstrategies]||75.26 1.36||67.14 0.52||82.43 1.49||72.15 0.83||76.83 0.95||62.61 0.78||-|
|Method||Data Set||Average Gain|
|Jan.G||Feb. G||Mar. G||Jan. Y.||Feb. Y.||Mar. Y.||(in % points)|
|GAT (2STG+)||74.65 0.98||68.11 0.69||69.15 1.37||74.79 1.27||68.75 0.66||70.44 0.93||1.97|
|GAT [hlugzlplstrategies]||73.87 1.13||65.89 1.04||67.25 1.27||71.87 1.35||66.24 0.92||68.95 1.25||-|
|GraphSage (2STG+)||76.63 0.82||67.74 0.88||68.95 1.41||75.21 1.70||67.64 0.73||70.23 1.25||0.53|
|GraphSage [hlugzlplstrategies]||75.19 0.98||67.82 0.43||68.03 0.66||73.66 1.27||68.24 0.79||70.25 0.73||-|
According to the results, 2STG and 2STG+ yield better accuracy than the original setting of each GNN. This observation suggests two possible explanations:
The end-to-end training methods fail to realize the full potential of the GNN models. Even if the final classifier of is upgraded from a fully-connected layer to an MLP, the accuracy is not as high as in 2STG and 2STG+.
Learning meaningful embeddings in between that are fairly separated based on classes (see Figure 3), for example through metric learning as in our methods, facilitates a better accuracy of the final classifier.
4.4.2 Comparison with a transfer learning method
We compare the results of the pre-trained models of GraphSage/GAT in [hlugzlplstrategies] with GraphSage/GAT with the same architecture (5 layers, 300 dimensional hidden units and global mean pooling) trained in 2STG+. Note that the pre-trained models are fine-tuned on the considered datasets. Results of comparison are in Tables 4 and 5, where the average and standard deviation of test accuracy of 5 splits are reported. Despite being pre-trained on a much smaller dataset with only one task, our method achieves better accuracy in of the considered cases: up to 2% points in the benchmark datasets and up to 3% points in the Taxi datasets. This validates our claims in Table 1.
4.4.3 Running time
For each hyperparameter setting in a dataset, the original setting takes up to half an hour to train a model. Due to having two stages, 2STG and 2STG+ take up to an hour for both stages. In [hlugzlplstrategies], the transfer-learning method was reported to take up to one day to pre-train on a rich dataset.
Graph Neural Networks are powerful tools in dealing with many graph mining tasks, including graph classification, which our work focuses on. However, training them end-to-end to predict class probabilities often fails to realize their full capability. Thus, we apply GNN models into a triplet framework to learn discriminative embeddings first, and then train a classifier on those embeddings. Extensive experiments in 12 datasets lead to following observations:
End-to-end training often fails to realize the full potential of GNN models. Applying GNN models in our method enhances their accuracy by up to 5.4% points.
Our two-stage training method leads to better accuracy than a state-of-the-art pre-training method based on transfer-learning in 83% of the considered cases.
Despite not requiring any additional massive rich datasets or long training time, our training method consistently improves the accuracy of (out of tested) GNN models in (out of considered) datasets.
Appendix A Datasets
We tested our training method using 12 datasets:
a.1 Benchmark datasets
These are the commonly tested binary-class datasets [morris2020tudataset] for the graph classification task: DD, MUTAG, Mutagenicity, PROTEINS, PTC-FM, and IMDB-BINARY.
a.2 New York City Taxi
We extracted the taxi ridership data in 2019 from New York City (NYC) Taxi Commission. The areas in New York are represented as nodes, and each taxi trip is an edge connecting the source and destination nodes. All taxi trips in an 1-hour interval form a graph, and each dataset spans a month of taxi operations. We augmented the binary label for each graph as taxi trips in weekdays (Mon-Thu) vs. weekend (Fri-Sun). We considered two taxi operators (Yellow and Green) and processed data in January, February and March of 2019, making 6 datasets in total.
Appendix B GNN architectures
In order to demonstrate that our two-stage training method helps realize a better performance of GNNS, for each GNN architecture, we compare the accuracy obtained in the original setting versus that from our method. The GNN architectures we considered in this work are:
GraphSage [hamiltoninductive]: This is often used as a strong baseline in graph classification. After obtaining node embeddings, global mean/max pooling is applied to combine all node embeddings into one graph embedding.
GAT [velickovic2018graph]: Instead of uniformly passing neighbor information into a node embedding, [velickovic2018graph] employs an attention mechanism for the importance of each neighboring node.
Diff-pool [zhitao2018hierarchical]: While using the same aggregation mechanism as [hamiltoninductive], [zhitao2018hierarchical] proposes a hierarchical approach to pool the node embeddings. Rather than a “flat-pooling” step at the end, diff-pool repeatedly merges nodes into “supernodes” until there is only 1 supernode whose embedding is treated as the graph embedding.
Eigen-GCN [yao2019graph]: Attempting to implement hierarchical pooling like [zhitao2018hierarchical], [yao2019graph]
formulates a different way to combine nodes and their respective embeddings making use of the eigenvectors of the Laplacian matrix.
SAG-Pool [lee2019self]: Hierarchical graph pooling employing self-attention mechanisms.
Appendix C Hyperparameter search
For each GNN, the hyperparameters regarding the network architecture were tuned in the same search space for the three settings: original, 2STG, 2STG+. The search space for the dimensions of the input vector, hidden vector and output vector for all GNNs was . For Diff-pool, we used three layers of graph convolution and one DIFFPOOL layer as described in the original paper. For Eigen-GCN, we used three pooling operators as it was shown to achieve the best performance in the original paper. For SAG-Pool, we used three pooling layers as explained in the original paper. Other hyperparameters that are exclusive to each GNN architecture were set to default values provided in each paper’s original code of each architecture’s authors.
In all three settings, the architecture of the final classifier was also tuned. The number of fully-connected layers was up to 3 while the search space for the hidden dimension was where is the dimension of the output vector.
The two settings 2STG and 2STG+ require an additional hyperparameter . While [schroff2015facenet] found to be effective, we empirically found that this value is too small to separate instances of different classes. Instead, the search space for we used is .