Graph Neural Networks (GNNs) have become increasingly popular since many real-world relational data can be represented as graphs, such as social networks Bian et al. (2020), molecules Gilmer et al. (2017) and financial data Yang et al. (2020). Following a message passing paradigm to learn node representations, GNNs have achieved state-of-the-art performance in node classification, graph classification, and link prediction Kipf and Welling (2017); Veličković et al. (2017); Xu et al. (2019). Despite the remarkable effectiveness of GNNs, explaining predictions made by GNNs remains a challenging open problem. Without understanding the rationales behind the predictions, these black-box models cannot be fully trusted and widely applied in critical areas such as medical diagnosis. In addition, model explanations can facilitate model debugging and error analysis. These indicate the necessity of investigating the explainability of GNNs.
Recently, extensive efforts have been made to study explanation techniques for GNNs Yuan et al. (2020b). These methods can explain the predictions of node or graph classifications of trained GNNs with different strategies. For example, GNNExplainer Ying et al. (2019) and PGExplainer Luo et al. (2020) are proposed to select a compact subgraph structure that maximizes the mutual information with the GNN’s predictions as the explanation. PGM-Explainer Vu and Thai (2020)
firstly obtains a local dataset by random node feature perturbation. Then it employs an interpretable Bayesian network to fit the local dataset and to explain the predictions of the original GNN model. In addition, XGNNYuan et al. (2020a)
generates graph patterns to maximize the predicted probability for a certain class and provides model-level explanation. Despite the tremendous developments in the interpretation of GNNs, most existing approaches can be classified aspost-hoc explanations where another explanatory model is used to provide explanations for a trained GNN. Post-hoc explanations can be inaccurate or incomplete in revealing the actual reasoning process of the original model Rudin (2018). Therefore, it is more desirable to build models with inherent interpretability where the explanations are generated by the model themselves.
We leverage the concept of prototype learning to construct GNNs with built-in interpretability (i.e. self-explaining GNNs). In contrast to post-hoc explanation methods, the explanations generated by self-explaining GNNs are actually used during classification and are not generated post-hoc. Prototype learning is a form of case-based reasoning Kolodner (1992); Schmidt et al. (2001), which makes the predictions for new instances by comparing them with several learned exemplar cases (i.e. prototypes). It is a natural practice in solving problems with graph-structured data. For example, chemists identify potential drug candidates based on known functional groups (i.e. key subgraphs) in molecular graphs He et al. (2010); Zhang et al. (2021)
. Prototype learning provides better interpretability by imitating such a human problem-solving process. Recently the concept of the prototype has been incorporated in convolutional neural networks to build interpretable image classifiersChen et al. (2018); Rymarczyk et al. (2021). However, so far prototype learning is not yet explored for explaining GNNs.
Building self-explaining GNNs based on prototype learning brings unique challenges. First, the discreteness of the edges makes the projection and visualization of the graph prototypes difficult. Second, the combinatorial nature of graph structure makes it hard to build self-explaining models with both efficiency and high accuracy for graph modeling.
In this paper, we tackle the aforementioned challenges and propose Prototype Graph Neural Network (ProtGNN), which provides a new perspective on the explanations of GNNs. Specifically, various popular GNN architectures can be employed as the graph encoder in ProtGNN. Prediction on a new input graph is performed based on its similarity to the prototypes in the prototype layer. Furthermore, we propose to employ the Monte Carlo tree search algorithm Silver et al. (2017) to efficiently explore different subgraphs for prototype projection and visualization. In addition, in ProtGNN+, we design a conditional subgraph sampling module to identify which part of the input graph is most similar to each prototype for better interpretability and efficiency. Finally, extensive experiments on several real-world datasets show that ProtGNN/ProtGNN+ provides built-in interpretability while achieving comparable performance with the non-interpretable counterparts.
Graph Neural Networks
Graph neural networks have demonstrated their effectiveness on various graph tasks. Let denotes a graph with node attributes for and a set of edges
. GNNs leverage the graph connectivity as well as node and edge features to learn a representation vector (i.e., embedding)for each node or a vector for the entire graph . Generally, GNNs follows a message passing paradigm, in which the representation of node is iteratively updated by aggregating the representations of ’s neighboring nodes . Here we use Graph Convolutional Network (GCN) Kipf and Welling (2017) as an example to illustrate such message passing procedures:
where is the representation vector of node at the -th layer and is the normalized adjacency matrix. is the adjacency matrix of the graph with self connections added and is a diagonal matrix with . in Eq. (1
) is the ReLU function andis the trainable weight matrix of the -th layer.
Explainability in Graph Neural Networks
As the application of GNNs grows, understanding why GNNs make such predictions becomes increasingly critical. Recently, the study of the explainability in GNNs is experiencing rapid developments. As Suggested by a recent survey Yuan et al. (2020b), existing methods for explaining GNNs can be categorized into several classes: gradients/features-based methods Baldassarre and Azizpour (2019); Pope et al. (2019), perturbation-based methods Ying et al. (2019); Luo et al. (2020); Yuan et al. (2021); Schlichtkrull et al. (2020), decomposition methods Schwarzenberg et al. (2019); Schnake et al. (2020), and surrogate methods Vu and Thai (2020); Huang et al. (2020).
Specifically, the gradients/features-based methods employ the gradients or the feature values to indicate the importance of different input features. These methods simply adapt existing explanation techniques in the image domain to the graph domain without incorporating the properties of graph data. Perturbation-based methods monitor the changes in the predictions by perturbing different input features and identifies the most influential features. Decomposition methods explain GNNs by decomposing the original model predictions into several terms and associating these terms with graph nodes or edges. Given an input example, surrogate methods firstly sample a dataset from the neighborhood of the given example and then fit a simple and interpretable model, e.g., a decision tree to the sampled dataset. The surrogate models are usually easier to interpret, shedding light into the inner-workings of more complex models.
However, all the above methods are post-hoc explanation methods. Compared with post-hoc explanation methods, built-in interpretability Chen et al. (2018); Ming et al. (2019) is more desirable since post-hoc explanations usually do not fit the original model precisely Rudin (2018). Therefore, it is necessary to build models with inherent interpretability and high accuracy.
The Proposed ProtGNN
In this section, We introduce the architecture of ProtGNN/ProtGNN+, formulate the learning objective and describe the training procedures.
We let be a labeled training dataset, where is the input attributed graph and is the label of the graph. We aim to learn representative prototypical graph patterns that can be used for classification references and analogical explanations. For a new input graph, its similarities with each prototype are measured in the latent space. Then, the prediction of the new instance can be derived and explained by its similar prototype graph patterns.
In Figure 1, we show the overview of the architecture of our proposed ProtGNN. The network consists of three key components: a graph encoder , a prototype layer , and a fully connected layer appended by softmax to output the probabilities in multi-class classification tasks.
For a given input graph , the graph encoder maps the whole graph into a single graph embedding with a fixed length. The encoder could be any backbone GNN e.g., GCN, GAT or GIN. The graph embedding
could be obtained by taking a sum or max pooling of the last GNN layer.
In the prototype layer, we allocate a pre-determined number of prototypes for each class. In the final trained ProtGNN, each class can be represented by a set of learned prototypes. The prototypes should capture the most relevant graph patterns for identifying graphs of each class. For each input graph and its embedding , the prototype layer computes the similarity scores:
where is the -th prototype with the same dimension as the graph embedding . The similarity function is designed to be monotonically decreasing to and always greater than zero. In experiments, is set to a small value e.g., 1e-4. Finally, with the similarity scores, the fully connected layer with softmax computes the output probabilities for each class.
Our goal is to learn a ProtGNN with both accuracy and inherent interpretability. For accuracy, we minimize the cross-entropy loss on the training dataset: . For better interpretability, we consider several constraints in constructing prototypes for the explanation. Firstly, the cluster cost (Clst) encourages that each graph embedding should at least be close to one prototype of its own class. Secondly, the separation cost (Sep) encourages that each graph embedding should stay far away from prototypes not of its class. Finally, we found in experiments that some learned prototypes are very close to each other in the latent space. We encourage the diversity of the learned prototypes by adding the diversity loss (Div) which penalizes prototypes too close to each other.
To sum up, the objective function we aim to minimize is
where , , and are hyper-parameters controlling the weights of the losses. is the set of prototypes belonging to class .
is the threshold of the cosine similarity measured byin the diversity loss.
The learned prototypes are embedding vectors that are not directly interpretable. For better interpretation and visualization, we design a projection procedure performed in the training stage. Specifically, we project each prototype () onto the nearest latent training subgraph from the same class as that of (see Eq. (7
)). In this way, we can conceptually equate each prototype with a subgraph, which is more intuitive and human-intelligible. To reduce the computational cost, the projection step is only performed every few training epochs:
Unlike grid-like data such as images, the combinatorial characteristic of graph makes it unrealistic to find the nearest subgraph by enumeration Chen et al. (2018). In graph prototype projection, we employ the Monte Carlo tree search algorithm (MCTS) Silver et al. (2017) as the search algorithm to guide our subgraph explorations (see Figure 2). We build a search tree in which the root is associated with the input graph and each of other nodes corresponds to an explored subgraph. Formally, we define each node in the search tree as and denotes the root node. The edges in the search tree represent the pruning actions. In the search tree, the graph associated with a child node can be obtained by performing node-pruning from the graph corresponding to its parent node. To limit the search space, we have added two additional constraints: has to be a connected subgraph and the size of the projected subgraph should be small.
During the search process, the MCTS algorithm records the statistics of visiting counts and rewards to guide the exploration and reduce the search space. Specifically, for the node and pruning action pair (, ), we assume that the subgraph is obtained by action from . The MCTS algorithm records four variables for (, ):
C(, ) denotes the number of counts for selecting action for node .
W(, ) is the total reward for all (, ) visits.
Q(, ) is the averaged reward for multiple visits.
R(, ) is the immediate reward for selecting on , which is measured by the similarity between the prototype and the subgraph embedding in this paper. The subgraph embedding is obtained by encoding the subgraph with the GNN encoder .
Guided by these statistics, MCTS searches for the nearest subgraphs in multiple iterations. Each iteration consists of two phases. In the forward pass, MCTS selects a path starting from the root to a leaf node . To keep subgraphs connected, we select to prune peripheral nodes with minimum degrees. The leaf node can be defined based on the numbers of nodes in subgraphs such that . The action selection criteria at node is:
where is a hyper-parameter to control the trade-off between exploration and exploitation. The strategy initially prefers to select child nodes with low visit counts to explore different pruning actions, but asympotically prefers actions leading to higher similarity scores.
In the backward pass, the statistics of all node and action pairs selected in this path are updated:
where is the embedding of the subgraph associated to the leaf node . In the end, we select the subgraph with the highest similarity score from all the expanded nodes as the new projected prototype.
Conditional Subgraph Sampling module
We further propose ProtGNN+ with a novel conditional subgraph sampling module to provide better interpretation. In ProtGNN+, we not only show the similarity scores to prototypes, but also identify which part of the input graph is most similar to each prototype as part of the reasoning process. In Figure 1, the conditional subgraph sampling module outputs different subgraph embeddings for each prototype. While this task can also be accomplished by MCTS, the exponentially-growing time complexity to the graph size and the difficulty of parallelization and generalization make MCTS algorithm an undesirable choice. Instead, we adopt a parameterized method for efficient similar subgraph selection conditioned on given prototypes.
Formally, we let
be the binary variable indicating whether the edge between nodeand is selected. The matrix of is denoted as . The optimization objective of the conditional subgraph sampling module is:
where is the selected subgraph whose adjacency matrix is . is the maximum size of the subgraph.
The combinatorial and discrete nature of graph makes the direct optimization of the above objective function intractable. We first consider a relaxation by assuming that the explanatory graph is a Gilbert random graph Gilbert (1959) where the state of each edge is independent to each other. Furthermore, for ease of gradient computation and update, we relax into convex space . is the number of nodes in the input graph. For efficiency and generalizability, we adopt deep neural networks to learn :
here is the Sigmoid function. MLP is a multi-layer neural network parameterized withand [·; ·; ·] is the concatenation operation. and are node embedding obtained from the GNN Encoder, which encodes the feature and structure information of the nodes’ neighborhood. Then the objective in Eq. (12) becomes
where is the weight for the budget regularization
. In our experiments, we adopt stochastic gradient descent to optimize the objective function.
Comparison with MCTS: Our designed conditional subgraph sampling module is much more efficient than MCTS and easier for parallel computation. The parameters of our conditional subgraph sampling module are fixed and independent of the graph size. To sample from a graph with edges, the time complexity of our method is . One limitation of the conditional subgraph sampling module is that it requires additional training. Therefore, MCTS is still used in the prototype projection step of ProtGNN+ for the stability of optimization.
Theorem on Subgraph Sampling
To provide more understandable visualization, ProtGNN+ prunes the input graph to find the subgraphs most similar to prototypes and then calculates the similarity scores. Compared with ProtGNN, the subgraph sampling procedure may affect the classification accuracy. The following theorem provides some theoretical understanding of how input graph sampling affects classification accuracy.
Theorem 1: Let be a ProtoGNN. The embedding of the input graph is . We assume that the number of prototypes is the same for each class, and is denoted as . For each class , the weight connection in the last layer between a class prototype and the class logit is 1, and that between a non-class prototype and the class logit is 0. We denote as the -th prototype for class and the embedding of the pruned subgraph. ProtGNN and ProtGNN+ has the same graph encoder . We make the following assumptions: there exists some with ,
for the correct class, we have and ;
for the incorrect classes, , .
For one correctly classified input graph in ProtGNN, if the output logits between the top-2 classes are at least , then ProtGNN+ can classify the input graph correctly as well.
The intuition behind Theorem 1 is that if the subgraph sampling does not change the graph embedding too much, ProtGNN+ will have the same correct predictions as ProtGNN. The proof is included in the appendix.
In Algorithm 1, we show the training procedure of ProtGNN/ProtGNN+. Before training starts, we randomly initialize the model parameters. We let be the weight matrix of the fully connected layer and be the weight connection between the output of the -th prototype and the logit of class . In particular, for a class k, we set for all with and for all with . Intuitively, such initialization of encourages prototypes belonging to class to learn semantic concepts that are characteristic to class . After training begins, we employ gradient descents to optimize the objective function in Eq. (3). If the training epoch is larger than the projection epoch , we perform the prototype projection step every few training epochs. Furthermore, if we train ProtGNN+, the conditional subgraph sampling module and ProtGNN are iteratively optimized after the warm-up epoch when the optimization of GNN encoder and prototypes are stabilized.
ProtGNN for Generic Graph Tasks
In the above sections and illustrations, we have described ProtGNN/ProtGNN+ using graph classification as an example. It is worth noting that ProtGNN/ProtGNN+ can be easily generalized to other graph tasks, such as node classification and link prediction. For example, in the node classification task, the explanation target is to provide the reasoning process behind the prediction of node . Assuming the GNN encoder has layers, the prediction of node only relies on its -hop computation graph. Therefore, prototype projection and conditional subgraph sampling are all performed in the -hop computation graph.
Datasets and Experimental Settings
The classification accuracies and standard deviations () of ProtGNN, ProtGNN+, and the original GNNs.
Datasets: We conduct extensive experiments on different datasets and GNN models to demonstrate the effectiveness of our proposed model. These datasets are listed as below:
Graph-SST2 Socher et al. (2013) and Graph-Twitter Dong et al. (2014) are sentiment graph datasets for graph classification. They convert sentences to graphs with Biaffine parser Gardner et al. (2018) that nodes denote words and edges represent the relationships between words. The node embeddings are initialized with Bert word embeddings Devlin et al. (2018). The labels are determined by the sentiment of text sentences.
BA-Shape is a synthetic node classification dataset. Each graph contains a base graph obtained from the Barabási-Albert (BA) mode Albert and Barabási (2002) and a house-like five-node motif attached to the base graph. Each node is labeled based on whether it belongs to the base graph or the different spatial locations of the motif.
Experimental Settings: In our evaluation, we use three variants of GNNs, i.e. GCN, GAT, and GIN. The split for train/validation/test sets is . All models are trained for 500 epochs with an early stopping strategy based on accuracy on the validation set. We adopt the ADAM optimizer with a learning rate of 0.005. In Eq.(3), the hyper-parameters , , and are set to 0.10, 0.05, and 0.01 respectively. is set to 0.3 in Eq. (6). The number of prototypes per class is set to 5. In MCTS for prototype projection, we set in Eq. (9) to 5 and the number of iterations to 20. Each node in the Monte Carlo Tree can expand up to 10 child nodes and is set to 5. The prototype projection period is set to 50 and the projection epoch is set to 100. In the training of ProtGNN+, the warm-up epoch is set to 200. We employ a three-layer neural network to learn edge weights. In Eq. (14), is set to 0.01 and is set to 10. We select hyper-parameters based on related works or grid search, an analysis on hyper-parameters is included in the appendix. All our experiments are conducted with one Tesla V100 GPU.
Evaluations on ProtGNN/ProtGNN+
Comparison with Baselines
In Table 1, we compare the classification accuracy of ProtGNN/ProtGNN+ with the original GNNs. We apply 3 independent runs on random data splitting and report the means and standard deviations. In the following sections, we use GCN as the default backbone model. As we can see, ProtGNN and ProtGNN+ achieve comparable classification performance with the corresponding original GNN models, which also empirically verifies Theorem 1.
Reasoning Process of Our Network
In Figure 3, we perform case studies on MUTAG and Graph-SST2 to qualitatively evaluate the performance of our proposed method. We visualize the prototypes and show the reasoning process of our ProtGNN+ in reaching a classification decision on input graphs. In particular, given an input graph , the network finds the likelihood to be in each class by comparing it with prototypes from each class. The conditional subgraph sampling module finds the most similar subgraphs in . These similarity scores are calculated, weighted, and summed together to give a final score for belonging to each class. For example, Figure 3(a) shows the reasoning process of ProtGNN+ in deciding whether the input molecular graph is mutagenic. Based on chemical domain knowledge Debnath et al. (1991), carbon rings and groups tend to be mutagenic. In the Prototype column of the mutagenic class, we can observe that the prototypes can capture the structures of and carbon rings well. Moreover, in the column of Similar Subgraphs, the conditional subgraph sampling module can effectively identify the most similar subgraphs. For instance, in the first row of the mutagenic class, the group and part of the carbon ring can be identified, which is quite similar to the prototype on the right.
Compared with biochemistry datasets, examples on text data could be more understandable since no special domain knowledge is required. In Figure 3(b), the input graph “can take the grandkids or the grandparents and never worry about anyone being bored” is positive. Our method can effectively capture the key phrase/subgraph leading to positiveness, “never worry about bored”. Furthermore, we can observe that the similarity score between the input graph with the positive prototypes e.g., “kind of entertainment love to have” is much larger than negative prototypes e.g., “embarrassed by invention”.
Overall, our method provides interpretable evidence to support classifications. Such explanations participate in the actual model computation and is always faithful to the classification decisions. More examples and case studies are reported in appendix.
t-SNE Visualization of Prototypes
In Figure 4 we show the visualization on BBBP dataset of the graph and prototype embeddings using t-SNE method. We can observe that the prototypes can occupy the centers of graph embeddings, which verifies the effectiveness of prototype learning.
|Time||177.9 s||506.3 s||632.7 s||2 hrs|
Finally, we study the efficiency of our proposed methods. In Table 2, we show the time required to finish training for each model. Here ProtGNN+* denotes using MCTS for subgraph sampling in the training of ProtGNN+. The time complexity of ProtGNN+* is extremely high due to the complexity of MCTS. The proposed conditional subgraph sampling module can effectively reduce the time cost of ProtGNN+. Although ProtGNN and ProtGNN+ have a larger time cost compared to GCN (largely due to prototype projection with MCTS), the time cost is still acceptable considering the provided built-in interpretability.
While extensive efforts have been made to explain GNNs from different angles, none of existing methods can provide built-in explanations for GNNs. In this paper, we propose ProtGNN/ProtGNN+ which provides a new perspective on the explanations of GNNs. The prediction of ProtGNN is obtained by comparing the inputs to a few learned prototypes in the prototype layer. For better interpretability and higher efficiency, a novel conditional subgraph sampling module is proposed to indicate the subgraphs most similar to prototypes. Extensive experimental results show that our method can provide a human-intelligible reasoning process with acceptable classification accuracy and time-complexity.
- Statistical mechanics of complex networks. Reviews of modern physics 74 (1), pp. 47. Cited by: 3rd item.
- Explainability techniques for graph convolutional networks. ICML workshop. Cited by: Explainability in Graph Neural Networks.
- Rumor detection on social media with bi-directional graph convolutional networks. In AAAI, Vol. 34, pp. 549–556. Cited by: Introduction.
This looks like that: deep learning for interpretable image recognition. NeurIPS. Cited by: Introduction, Explainability in Graph Neural Networks, Prototype Projection.
- Structure-activity relationship of mutagenic aromatic and heteroaromatic nitro compounds. correlation with molecular orbital energies and hydrophobicity. Journal of medicinal chemistry 34 (2), pp. 786–797. Cited by: 1st item, Reasoning Process of Our Network.
- Bert: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805. Cited by: 2nd item.
- Adaptive recursive neural network for target-dependent twitter sentiment classification. In ACL, pp. 49–54. Cited by: 2nd item.
Allennlp: a deep semantic natural language processing platform. arXiv preprint arXiv:1803.07640. Cited by: 2nd item.
- Random graphs. The Annals of Mathematical Statistics 30 (4), pp. 1141–1144. Cited by: Conditional Subgraph Sampling module.
- Neural message passing for quantum chemistry. In ICML, pp. 1263–1272. Cited by: Introduction.
- Predicting drug-target interaction networks based on functional groups and biological features. PloS one 5 (3), pp. e9603. Cited by: Introduction.
- Graphlime: local interpretable model explanations for graph neural networks. arXiv preprint arXiv:2001.06216. Cited by: Explainability in Graph Neural Networks.
- Semi-supervised classification with graph convolutional networks. ICLR. Cited by: Introduction, Graph Neural Networks.
- An introduction to case-based reasoning. Artificial intelligence review 6 (1), pp. 3–34. Cited by: Introduction.
- Parameterized explainer for graph neural network. NeurIPS. Cited by: Introduction, Explainability in Graph Neural Networks.
- Interpretable and steerable sequence learning via prototypes. In SIGKDD, Cited by: Explainability in Graph Neural Networks.
- Explainability methods for graph convolutional neural networks. In CVPR, pp. 10772–10781. Cited by: Explainability in Graph Neural Networks.
- Please stop explaining black box models for high stakes decisions. stat 1050, pp. 26. Cited by: Introduction, Explainability in Graph Neural Networks.
- ProtoPShare: prototype sharing for interpretable image classification and similarity discovery. SIGKDD. Cited by: Introduction.
- Interpreting graph neural networks for nlp with differentiable edge masking. arXiv preprint arXiv:2010.00577. Cited by: Explainability in Graph Neural Networks.
- Cased-based reasoning for medical knowledge-based systems. International Journal of Medical Informatics 64 (2-3), pp. 355–367. Cited by: Introduction.
- XAI for graphs: explaining graph neural network predictions by identifying relevant walks. arXiv e-prints, pp. arXiv–2006. Cited by: Explainability in Graph Neural Networks.
- Layerwise relevance visualization in convolutional text graph classifiers. arXiv preprint arXiv:1909.10911. Cited by: Explainability in Graph Neural Networks.
- Mastering the game of go without human knowledge. nature 550 (7676), pp. 354–359. Cited by: Introduction, Prototype Projection.
- Recursive deep models for semantic compositionality over a sentiment treebank. In EMNLP, pp. 1631–1642. Cited by: 2nd item.
- Graph attention networks. arXiv preprint arXiv:1710.10903. Cited by: Introduction.
- PGM-explainer: probabilistic graphical model explanations for graph neural networks. In NeurIPS, H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin (Eds.), Vol. 33. External Links: Cited by: Introduction, Explainability in Graph Neural Networks.
MoleculeNet: a benchmark for molecular machine learning. Chemical science 9 (2), pp. 513–530. Cited by: 1st item.
- How powerful are graph neural networks?. In ICLR, Cited by: Introduction.
- Financial risk analysis for smes with graph-based supply chain mining.. In IJCAI, pp. 4661–4667. Cited by: Introduction.
- Gnnexplainer: generating explanations for graph neural networks. NeurIPS 32, pp. 9240. Cited by: Introduction, Explainability in Graph Neural Networks.
- Xgnn: towards model-level explanations of graph neural networks. In SIGKDD, pp. 430–438. Cited by: Introduction.
- Explainability in graph neural networks: a taxonomic survey. arXiv preprint arXiv:2012.15445. Cited by: Introduction, Explainability in Graph Neural Networks.
- On explainability of graph neural networks via subgraph explorations. ICML. Cited by: Explainability in Graph Neural Networks.
Motif-based graph self-supervised learning for molecular property prediction. NeurIPS 34. Cited by: Introduction.
Appendix A Dataset Statistics
In Table 3, we show the detailed statistics of five datasets. These datasets include biological data, text data, and synthetic data. The first four datasets are used for graph classification tasks while BA-Shape is used for node classfication.
Appendix B Proof of Theorem 1
In this section, we provide a proof for Theorem 1 in the main paper.
Theorem 1: Let be a ProtoGNN. The embedding of the input graph is . We assume that the number of prototypes is the same for each class, which is denoted as . For each class , the weight connection in the last layer between a class prototype and the class logit is 1, and that between a non-class prototype and the class logit is 0. Let denote the -th prototype for class and the embedding of the pruned subgraph. ProtGNN and ProtGNN+ has the same graph encoder .
We make the following assumptions: there exists some with ,
for the correct class, we have and ;
for the incorrect classes, , .
For one correctly classified input graph in ProtGNN, if the output logits between the top-2 classes are at least , then ProtGNN+ can classify the input graph correctly as well.
Proof: For any class , let denotes the summed contributed scores for graph belonging to class in ProtGNN. According to Eq. (2) and the assumption:
Let denotes the summed contributed scores in ProtGNN+:
Then, the gap between the summend contributed scores denoted by is:
Correct class: We first derive the lower bound of for the correct class. Firstly, we have
Then, by the triangle inequality , we have . As a result, we have:
Combining the above two inequalities, for the correct class is .
Wrong class: Now we begin to derive an upper bound of for incorrect classes. First,
According to the assumption for incorrect classes, we have:
For incorrect classes, .
Finally, suppose the summed contributed scores of the correct class is at least larger than any other classes in ProtGNN, the input graph will still be correctly classified by ProtGNN+.
Appendix C Architecture of the Conditional Subgraph Sampling Module
|Input||128 + 128 + 128|
|Fully Connected + ReLU||64|
|Fully Connected + ReLU||8|
In the conditional subgraph sampling module, we adopt deep neural networks to learn :
In Table 4, we show the details of architecture. In our experiments, the node embedding size and prototype size are set to 128. To make sure the selected adjacency matrix is symmetric, we set as in experiments.
Appendix D More Case Studies
In Figure 5 and Figure 6, we show more case studies on BBBP and Graph-Twitter. Note that Graph-Twitter is a 3-class dataset and we show the prototypes for negativeness, neutrality, and positiveness. The input graph in Figure 6 is positive. Our method can effectively capture the key phrase/subgraph leading to positiveness, “amazing lady gaga I love”.
Appendix E Hyper-parameters Analysis
In this section, we provide some analysis on hyper-parameters in ProtGNN/ProtGNN+.
Choosing the Number of Prototypes per Class
We first investigate how would the number of prototypes per class influence the performance of ProtGNN using BBBP and Graph-Twitter. With the default setting of hyper-parameters, we train ProtGNN with varying . In Figure 7, we observe that the accuracy of ProtGNN firstly increase dramatically as increases. Then the increasing slope flattens after exceeds 5.
Actually, there is one trade-off between accuracy and interpretability when choosing . The accuracy increases when increases. However, a large number of prototypes makes the model difficult to train and comprehend. In experiments, we choose since the increasing only brings marginal improvement to the performance.
Influence of Diversity Loss
We further show the effectiveness of the diversity loss (Eq. 6). In Figure 8, we plot the cosine similarity matrices of the learned prototypes on BBBP. The first row are similarity matrices with the diversity loss while the second row without the diversity loss. We can observe that the cosine similarities among prototypes without diversity regularization are much larger than those with diversity loss. In some extreme cases, the similarities are close to 1, which means the learned prototypes are nearly the same. Therefore, the diversity loss can help ProtGNN learn more diverse and evenly distributed prototypes.
Influence of the Cluster and Separation Loss
Here we want to show the influence of and which controls the weights of the cluster loss and separation loss respectively. In Figure 9, we can observe that the cluster and separation constraints play important roles in ProtGNN. When and , ProtGNN achieves the best performance on BBBP dataset.