1 Introduction
Explaining the underlying roots of neurological disorders (i.e., what brain regions are associated with the disorder) has been a main goal in the field of neuroscience and medicine [12, 7, 1, 19]. Functional magnetic resonance imaging (fMRI), a noninvasive neuroimaging technique that measures neural activation, has been paramount in advancing our understanding of the functional organization of the brain [26, 20, 25]. The functional network of the brain can be modeled as a graph in which each node is a brain region and the edges represent the strength of the connection between those regions.
The past few years have seen the growing prevalence of using graph neural networks (GNN) for graph classification [10]
. Like pooling layers in convolutional neural networks (CNNs)
[23, 18], the pooling layer in GNNs is an important design to compress a large graph to a smaller one for lower dimensional feature extraction. Many node pooling strategies have been studied and can be divided into the following categories: 1) clusteringbased pooling, which clusters nodes to a super node based on graph topology
[3, 5, 29] and 2) rankingbased pooling, which assigns each node a score and keeps the top ranked nodes [6, 15]. Clusteringbased pooling methods do not preserve node assignment mapping in the input graph domain, hence they are not inherently interpretable at the node level. For our purpose of interpreting node importance, we focus on rankingbased pooling methods. Currently, existing methods of this type [6, 15] have the following key limitations when applying them to salient brain ROI analysis: 1) ranking scores for the discarded nodes and the remaining nodes may not be significantly distinguishable, which is not suitable for identifying salient and representative regional biomarkers, and 2) the nodes in different graphs in the same group may be ranked totally differently (usually caused by overfitting), which is problematic when our objective is to find grouplevel biomarkers. To reach grouplevel analysis, such approaches typically require additional steps to summarize statistics (such as averaging). For these twostage methods, if the results from the first stage are not reliable, significant errors can be induced in the second stage.To utilize GNN for fMRI learning and meet the need of grouplevel biomarker finding, we propose a pooling regularized GNN framework (PRGNN) for classifying neurodisorder patients vs healthy control subjects and discovering disorder related biomarkers jointly. The overview of our methods is depicted in Fig.
1. Our key contributions are:We formulate an endtoend framework for fMRI prediction and biomarker (salient brain ROIs) interpretation.
We propose novel regularization terms for rankingbased pooling methods to encourage more reasonable node selection and provide flexibility between individuallevel and grouplevel interpretation in GNN.
2 Graph Neural Network for Brain Network Analysis
The architecture of our PRGNN is shown in Fig. 2. Below, we introduce the notation and the layers in PRGNN. For simplicity, we focus on Graph Attention Convolution (GATConv) [24, 28] as the node convolutional layer. For node pooling layers, we test two existing ranking based pooling methods: TopK pooling [6] and SAGE pooling [15].
2.1 Notation and Problem Definition
We first parcellate the brain into ROIs based on its T1 structural MRI. We define ROIs as graph nodes . We define an undirected weighted graph as , where is the edge set, i.e., a collection of linking vertices and . has an associated node feature matrix , where
is the feature vector associated with node
. For every edge connecting two nodes, , we have its strength . We also define for and therefore the adjacency matrix is well defined.2.2 Graph Convolutional Block
2.2.1 Node Convolutional Layer
To improve GATConv [10], we incorporate edge features in the brain graph as suggested by Gong Cheng [8] and Yang et. al [28]. We define as the feature for the node in the layer and , where is the number of nodes at the layer (the same for ). The propagation model for the forwardpass update of node representation is calculated as:
(1) 
where the attention coefficients are computed as
(2) 
where denotes the set of indices of neighboring nodes of , denotes concatenation, and are model parameters.
2.2.2 Node Pooling Layer
The choices of keeping which nodes in TopK pooling and SAGE pooling are determined based on the node importance score , which is calculated in two ways as follows:
(3) 
where is calculated in Eq. (1) and and are model parameters. Note that is different from in Eq. (1) such that the output of is a scalar.
Given the following equation roughly describes the pooling procedure:
(4) 
The notation above finds the indices corresponding to the largest elements in score vector , and is an indexing operation which takes elements at row indices specified by and column indices specified by . The nodes receiving lower scores will experience less feature retention.
Lastly, we seek a “flattening” operation to translate graph information to a vector. Suppose the last layer is , we use , where mean operates elementwisely. Then
is sent to a multilayer perceptron (MLP) to give the final prediction.
3 Proposed Regularizations
3.1 Distance Loss
To overcome the limitation of existing methods that ranking scores for the discarded nodes and the remaining nodes may not be distinguishable, we propose two distance losses to encourage the difference. Before introducing them, we first rank the elements of the instance scores, , in a descending order, denote it as , and denote its top elements as , and the remaining elements as . We apply two types of constraint to all the training instances.
3.1.1 MMD Loss
3.1.2 BCE Loss
Ideally, the scores for the selected nodes should be close to 1 and the scores for the unselected nodes should be close to 0. Binary cross entropy (BCE) loss is calculated as:
(5) 
The effect of this constraint will be shown in Section 4.3.
3.2 Grouplevel Consistency Loss
Note that in Eq. (4) is computed from the input . Therefore, for from different instances, the ranking of the entries of can be very different. For our application, we want to find the common patterns/biomarkers for a certain neuroprediction task. Thus, we add regularization to force the vectors to be similar for different input instances in the first pooling layer, where the grouplevel biomarkers are extracted. We call the novel regularization grouplevel consistency (GLC) and only apply it to the first pooling layer, as the nodes in the following layers from different instances might be different. Suppose there are instances for class in a batch, where and is the number of classes. We form the scoring matrix . The GLC loss can be expressed as:
(6) 
where , is a matrix with all 1s, is a diagonal matrix with as diagonal elements. We propose to use Euclidean distance for and due to the benefits of convexity and computational efficiency.
Cross entropy loss
is used for the final prediction. Then, the final loss function is formed as:
(7) 
where ’s are tunable hyperparameters, indicates the GNN block and is the total number of GNN blocks, is either MMD or BCE.
4 Experiments and Results
4.1 Data and Preprocessing
We collected fMRI data from a group of 75 ASD children and 43 age and IQmatched healthy controls (HC), acquired under the ”biopoint” task [13]. The fMRI data was preprocessed following the pipeline in Yang et al. [27]. The DesikanKilliany [4] atlas was used to parcellate brain images into 84 ROIs. The mean time series for each node was extracted from a random of voxels in the ROI by bootstrapping. In this way, we augmented the data 10 times. Edges were defined by top positive partial correlations to achieve sparse connections. If this led to isolated nodes, we added back the largest edge to each of them. For node attributes, we used Pearson correlation coefficient to node . Pearson correlation and partial correlation are different measures of fMRI connectivity. We aggregate them by using one to build edge connections and the other to build node features.
4.2 Implementation Details
The model architecture was implemented with 2 conv layers and 2 pooling layers as shown in Fig. 2, with parameter
. We designed a 3layer MLP (with 16, 8 and 2 neurons in each layer) that takes the flattened graph
as input and predicts ASD vs. HC. The pooling layer kept the top important nodes (). We will discuss the variation of and in Section 4.3. We randomly split the data into five folds based on subjects, which means that the graphs from a single subject can only appear in either the training or test set. Four folds were used as training data, and the leftout fold was used for testing. Adam was used as the optimizer. We trained the model for 100 epochs with an initial learning rate of 0.001, annealed to half every 20 epochs. We set
in the MMD loss to match the same scale as BCE loss.4.3 Hyperparameter Discussion and Ablation Study
We tuned the parameters and in the loss function Eq. (7) and showed the results in Table 1. encouraged more separable node importance scores for selected and unselected nodes after pooling. controlled the similarity of the selected nodes for instances within the same class. A larger moves toward grouplevel interpretation of biomarkers. We first performed an ablation study by comparing setting (00) and (0.10). Mean accuracies increased at least in TopK (1 in SAGE) with MMD or BCE loss. To demonstrate the effectiveness of , we showed the distribution of node pooling scores of the two pooling layers in Fig. 3 over epochs for different combination of pooling functions and distance losses, with and . In the early epochs, the scores centered around 0.5. Then the scores of the top important nodes moved to 1 and scores of unimportant nodes moved to 0 (less obvious for the second pooling layer using SAGE, which may explain why SAGE got lower accuracies than TopK). Hence, significantly higher scores were attributed to the selected important nodes in the pooling layer. Then, we investigated the effects of on the accuracy by varying it from 0 to 1, with fixed at 0.1. Without , the model was easier to overfit to the training set, while larger may result in underfitting to the training set. As the results in Table 1 show, the accuracy increased when increased from 0 to 0.1 and the accuracy dropped if we increased to 1 (except for TopK+MMD). For the following baseline comparison experiments, we set  to be .
Loss  Pool  00  0.10  0.10.1  0.10.5  0.11 

MMD  TopK  0.753(0.042)  0.784(0.062)  0.781(0.038)  0.780(0.059)  0.744(0.060) 
SAGE  0.751(0.022)  0.770(0.039)  0.771(0.051)  0.773(0.047)  0.751(0.050)  
BCE  TopK  0.750(0.046)  0.779(0.053)  0.797(0.051)  0.789(0.066)  0.762(0.044) 
SAGE  0.755(0.041)  0.767(0.033)  0.773(0.047)  0.764(0.050)  0.755(0.041) 
[width=6em]MetricModel  SVM  Random Forest  MLP  BrainNetCNN [14]  Li et al. [17]  PRGNN 

Acc  0.686(0.111)  0.723(0.020)  0.727(0.047)  0.781(0.044)  0.753(0.033)  
Par  3k  3k  137k  1438k  16k 
Acc: Accuracy; Par: The number of trainable parameters; PRGNN: TopK+BCE.
4.4 Comparison with Existing Models
We compared our method with several brain connectomebased methods, including Random Forest (1000 trees), SVM (RBF kernel), and MLP (one 20 nodes hidden layer), a stateoftheart CNNbased method, BrainNetCNN [14] and a recent GNN method on fMRI [17], in terms of accuracy and number of parameters. We used the parameter settings indicated in the original paper [14]. The inputs and the architecture parameter setting (node conv, pooling and MLP layers) of the alternative GNN method were the same as PRGNN. The inputs of BrainNetCNN were Pearson correlation matrices. The inputs of the other alternative methods were the flattened uptriangle of Pearson correlation matrices. Note that the inputs of GNN models contained both Pearson and partial correlations. For a fair comparison with the nonGNN models, we used Pearson correlations (node features) as their inputs, because Pearson correlations were the embedded features, while partial correlations (edge weights) only served as message passing filters in GNN models. The results are shown in Table 2. Our PRGNN outperformed alternative models. With regularization terms on the pooling function, PRGNN achieved better accuracy than the recent GNN [17]. Also, PRGNN needs only parameters compared to the MLP and less than parameters compared to BrainNetCNN.
4.5 Biomarker Interpretation
Without losing generalizability, we investigated the selected salient ROIs using the model TopK+BCE () with different levels of interpretation by tuning . As we discussed in Section 3.2, large led to grouplevel interpretation and small led to individuallevel interpretation. We varied from 00.5. Without losing generalizability, we show the salient ROI detection results of four randomly selected ASD instances in Fig. 4. We show the remaining 21 ROIs after the 2nd pooling layer (with pooling ratio = 0.5, nodes left) and corresponding node pooling scores. As shown in Fig. 4(a), when , we could rarely find any overlapped area among the instances. In Fig. 4(bc), we circled the large overlapped areas across the instances. By visually examining the salient ROIs, we found two overlapped areas in Fig. 4(b) and four overlapped areas in Fig. 4(c). By averaging the node importance scores (1st pooling layer) over all the instances, dorsal striatum, thalamus and frontal gyrus were the most salient ROIs associated with identifying ASD. These ROIs are related to the neurological functions of social communication, perception and execution [22, 11, 2, 21], which are clearly deficient in ASD.
5 Conclusion
In this paper, we propose PRGNN, an interpretable graph neural network for fMRI analysis. PRGNN takes graphs built from fMRI as inputs, then outputs prediction results together with interpretation results. With the builtin interpretability, PRGNN not only performs better on classification than alternative methods, but also detects salient brain regions for classification. The novel loss term gives us the flexibility to use this same method for individuallevel biomarker analysis (small ) and grouplevel biomarker analysis (large ). We believe that this is the first work using a single model in fMRI study that fills the critical interpretation gap between individual and grouplevel analysis. Our interpretation results reveal the salient ROIs to identify autistic disorders from healthy controls. Our method has a potential for understanding neurological disorders, and ultimately benefiting neuroimaging research. We will extend and validate our methods on larger benchmark datasets in future work.
Acknowledgements
This research was supported in part by NIH grants [R01NS035193, R01MH100028].
References
 [1] (2014) Disruption of cortical association networks in schizophrenia and psychotic bipolar disorder. JAMA psychiatry 71 (2), pp. 109–118. Cited by: §1.
 [2] (2014) The social brain and reward: social information processing in the human striatum. Wiley Interdisciplinary Reviews: Cognitive Science 5 (1), pp. 61–73. Cited by: §4.5.
 [3] (2016) Convolutional neural networks on graphs with fast localized spectral filtering. In Advances in neural information processing systems, pp. 3844–3852. Cited by: §1.
 [4] (2006) An automated labeling system for subdividing the human cerebral cortex on mri scans into gyral based regions of interest. Neuroimage 31 (3), pp. 968–980. Cited by: §4.1.

[5]
(2007)
Weighted graph cuts without eigenvectors a multilevel approach
. IEEE transactions on pattern analysis and machine intelligence 29 (11), pp. 1944–1957. Cited by: §1.  [6] (2019) Graph unets. arXiv preprint arXiv:1905.05178. Cited by: §1, §2.
 [7] (2014) Biomarkers in autism. Frontiers in psychiatry 5, pp. 100. Cited by: §1.

[8]
(2019)
Exploiting edge features for graph neural networks.
In
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition
, pp. 9211–9219. Cited by: §2.2.1. 
[9]
(2012)
A kernel twosample test.
Journal of Machine Learning Research
13 (Mar), pp. 723–773. Cited by: §3.1.1.  [10] (2017) Inductive representation learning on large graphs. In Advances in neural information processing systems, pp. 1024–1034. Cited by: §1, §2.2.1.
 [11] (2006) Abnormal brain size effect on the thalamus in autism. Psychiatry Research: Neuroimaging 147 (23), pp. 145–151. Cited by: §4.5.
 [12] (2010) Neural signatures of autism. Proceedings of the National Academy of Sciences 107 (49), pp. 21223–21228. Cited by: §1.
 [13] (2010) Neural signatures of autism. PNAS. Cited by: §4.1.
 [14] (2017) BrainNetCNN: convolutional neural networks for brain networks; towards predicting neurodevelopment. NeuroImage 146, pp. 1038–1049. Cited by: §4.4, Table 2.
 [15] (2019) Selfattention graph pooling. arXiv preprint arXiv:1904.08082. Cited by: §1, §2.

[16]
(2017)
Mmd gan: towards deeper understanding of moment matching network
. In Advances in Neural Information Processing Systems, pp. 2203–2213. Cited by: §3.1.1.  [17] (2019) Graph neural network for interpreting taskfmri biomarkers. In International Conference on Medical Image Computing and ComputerAssisted Intervention, pp. 485–493. Cited by: §4.4, Table 2.
 [18] (2016) Unsupervised domain adaptation with residual transfer networks. In Advances in neural information processing systems, pp. 136–144. Cited by: §1.
 [19] (2018) Longitudinal cognitive and biomarker changes in dominantly inherited alzheimer disease. Neurology 91 (14), pp. e1295–e1306. Cited by: §1.
 [20] (2009) Decoding the largescale structure of brain function by classifying mental states across individuals. Psychological science 20 (11), pp. 1364–1372. Cited by: §1.
 [21] (2012) Dissociable roles of human inferior frontal gyrus during action execution and observation. Neuroimage 60 (3), pp. 1671–1677. Cited by: §4.5.
 [22] (2016) Morphological alterations in the thalamus, striatum, and pallidum in autism spectrum disorder. Neuropsychopharmacology 41 (11), pp. 2627–2637. Cited by: §4.5.
 [23] (2014) Very deep convolutional networks for largescale image recognition. arXiv preprint arXiv:1409.1556. Cited by: §1.
 [24] (2018) Graph attention networks. In ICLR, Cited by: §2.

[25]
(2019)
Decoding and mapping task states of the human brain via deep learning
. Human Brain Mapping. Cited by: §1.  [26] (2002) A general statistical analysis for fmri data. Neuroimage 15 (1), pp. 1–15. Cited by: §1.
 [27] (2016) Brain responses to biological motion predict treatment outcome in young children with autism. Translational psychiatry 6 (11), pp. e948. Cited by: §4.1.
 [28] (2019) Interpretable multimodality embedding of cerebral cortex using attention graph network for identifying bipolar disorder. In International Conference on Medical Image Computing and ComputerAssisted Intervention, pp. 799–807. Cited by: §2.2.1, §2.
 [29] (2018) Hierarchical graph representation learning with differentiable pooling. In Advances in neural information processing systems, pp. 4800–4810. Cited by: §1.
Comments
There are no comments yet.