Pooling Regularized Graph Neural Network for fMRI Biomarker Analysis

by   Xiaoxiao Li, et al.

Understanding how certain brain regions relate to a specific neurological disorder has been an important area of neuroimaging research. A promising approach to identify the salient regions is using Graph Neural Networks (GNNs), which can be used to analyze graph structured data, e.g. brain networks constructed by functional magnetic resonance imaging (fMRI). We propose an interpretable GNN framework with a novel salient region selection mechanism to determine neurological brain biomarkers associated with disorders. Specifically, we design novel regularized pooling layers that highlight salient regions of interests (ROIs) so that we can infer which ROIs are important to identify a certain disease based on the node pooling scores calculated by the pooling layers. Our proposed framework, Pooling Regularized-GNN (PR-GNN), encourages reasonable ROI-selection and provides flexibility to preserve either individual- or group-level patterns. We apply the PR-GNN framework on a Biopoint Autism Spectral Disorder (ASD) fMRI dataset. We investigate different choices of the hyperparameters and show that PR-GNN outperforms baseline methods in terms of classification accuracy. The salient ROI detection results show high correspondence with the previous neuroimaging-derived biomarkers for ASD.



There are no comments yet.


page 1

page 2

page 3

page 4


Deep Reinforcement Learning Guided Graph Neural Networks for Brain Network Analysis

Modern neuroimaging techniques, such as diffusion tensor imaging (DTI) a...

Understanding Graph Isomorphism Network for Brain MR Functional Connectivity Analysis

Graph neural networks (GNN) rely on graph operations that include neural...

Graph Neural Network for Interpreting Task-fMRI Biomarkers

Finding the biomarkers associated with ASD is helpful for understanding ...

Aiding Medical Diagnosis Through the Application of Graph Neural Networks to Functional MRI Scans

Graph Neural Networks (GNNs) have been shown to be a powerful tool for g...

Graph Embedding Using Infomax for ASD Classification and Brain Functional Difference Detection

Significant progress has been made using fMRI to characterize the brain ...

A Graph Neural Network Framework for Causal Inference in Brain Networks

A central question in neuroscience is how self-organizing dynamic intera...

BrainIB: Interpretable Brain Network-based Psychiatric Diagnosis with Graph Information Bottleneck

Developing a new diagnostic models based on the underlying biological me...

Code Repositories


Pytorch implementation of pooling-regularized GNN (PRGNN) for fMRI analysis. https://arxiv.org/pdf/2007.14589.pdf

view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Figure 1: The overview of the pipeline. fMRI images are parcellated by atlas and transferred to graphs. Then, the graphs are sent to our proposed PR-GNN, which gives the prediction of specific tasks and jointly selects salient brain regions that are informative to the prediction task.
Figure 2: PR-GNN for brain graph classification and the details of its key component - Graph Convolutional Block. Each Graph Convolutional Block contains a node convolutional layer followed by a node pooling layer.

Explaining the underlying roots of neurological disorders (i.e., what brain regions are associated with the disorder) has been a main goal in the field of neuroscience and medicine [12, 7, 1, 19]. Functional magnetic resonance imaging (fMRI), a non-invasive neuroimaging technique that measures neural activation, has been paramount in advancing our understanding of the functional organization of the brain [26, 20, 25]. The functional network of the brain can be modeled as a graph in which each node is a brain region and the edges represent the strength of the connection between those regions.

The past few years have seen the growing prevalence of using graph neural networks (GNN) for graph classification [10]

. Like pooling layers in convolutional neural networks (CNNs)

[23, 18]

, the pooling layer in GNNs is an important design to compress a large graph to a smaller one for lower dimensional feature extraction. Many node pooling strategies have been studied and can be divided into the following categories: 1) clustering-based pooling, which clusters nodes to a super node based on graph topology

[3, 5, 29] and 2) ranking-based pooling, which assigns each node a score and keeps the top ranked nodes [6, 15]. Clustering-based pooling methods do not preserve node assignment mapping in the input graph domain, hence they are not inherently interpretable at the node level. For our purpose of interpreting node importance, we focus on ranking-based pooling methods. Currently, existing methods of this type [6, 15] have the following key limitations when applying them to salient brain ROI analysis: 1) ranking scores for the discarded nodes and the remaining nodes may not be significantly distinguishable, which is not suitable for identifying salient and representative regional biomarkers, and 2) the nodes in different graphs in the same group may be ranked totally differently (usually caused by overfitting), which is problematic when our objective is to find group-level biomarkers. To reach group-level analysis, such approaches typically require additional steps to summarize statistics (such as averaging). For these two-stage methods, if the results from the first stage are not reliable, significant errors can be induced in the second stage.

To utilize GNN for fMRI learning and meet the need of group-level biomarker finding, we propose a pooling regularized GNN framework (PR-GNN) for classifying neurodisorder patients vs healthy control subjects and discovering disorder related biomarkers jointly. The overview of our methods is depicted in Fig. 

1. Our key contributions are:
We formulate an end-to-end framework for fMRI prediction and biomarker (salient brain ROIs) interpretation.
We propose novel regularization terms for ranking-based pooling methods to encourage more reasonable node selection and provide flexibility between individual-level and group-level interpretation in GNN.

2 Graph Neural Network for Brain Network Analysis

The architecture of our PR-GNN is shown in Fig. 2. Below, we introduce the notation and the layers in PR-GNN. For simplicity, we focus on Graph Attention Convolution (GATConv) [24, 28] as the node convolutional layer. For node pooling layers, we test two existing ranking based pooling methods: TopK pooling [6] and SAGE pooling [15].

2.1 Notation and Problem Definition

We first parcellate the brain into ROIs based on its T1 structural MRI. We define ROIs as graph nodes . We define an undirected weighted graph as , where is the edge set, i.e., a collection of linking vertices and . has an associated node feature matrix , where

is the feature vector associated with node

. For every edge connecting two nodes, , we have its strength . We also define for and therefore the adjacency matrix is well defined.

2.2 Graph Convolutional Block

2.2.1 Node Convolutional Layer

To improve GATConv [10], we incorporate edge features in the brain graph as suggested by Gong Cheng [8] and Yang et. al [28]. We define as the feature for the node in the layer and , where is the number of nodes at the layer (the same for ). The propagation model for the forward-pass update of node representation is calculated as:


where the attention coefficients are computed as


where denotes the set of indices of neighboring nodes of , denotes concatenation, and are model parameters.

2.2.2 Node Pooling Layer

The choices of keeping which nodes in TopK pooling and SAGE pooling are determined based on the node importance score , which is calculated in two ways as follows:


where is calculated in Eq. (1) and and are model parameters. Note that is different from in Eq. (1) such that the output of is a scalar.

Given the following equation roughly describes the pooling procedure:


The notation above finds the indices corresponding to the largest elements in score vector , and is an indexing operation which takes elements at row indices specified by and column indices specified by . The nodes receiving lower scores will experience less feature retention.

Lastly, we seek a “flattening” operation to translate graph information to a vector. Suppose the last layer is , we use , where mean operates elementwisely. Then

is sent to a multilayer perceptron (MLP) to give the final prediction.

3 Proposed Regularizations

3.1 Distance Loss

To overcome the limitation of existing methods that ranking scores for the discarded nodes and the remaining nodes may not be distinguishable, we propose two distance losses to encourage the difference. Before introducing them, we first rank the elements of the instance scores, , in a descending order, denote it as , and denote its top elements as , and the remaining elements as . We apply two types of constraint to all the training instances.

3.1.1 MMD Loss

Maximum mean discrepancy (MMD) loss [9, 16] was originally proposed in Generative adversarial nets (GANs) to quantify the difference of the scores between real and generated samples. In our application, we define MMD loss for the pooling layer as:

where is a Gaussian kernel and is a scaling factor.

3.1.2 BCE Loss

Ideally, the scores for the selected nodes should be close to 1 and the scores for the unselected nodes should be close to 0. Binary cross entropy (BCE) loss is calculated as:


The effect of this constraint will be shown in Section 4.3.

3.2 Group-level Consistency Loss

Note that in Eq. (4) is computed from the input . Therefore, for from different instances, the ranking of the entries of can be very different. For our application, we want to find the common patterns/biomarkers for a certain neuro-prediction task. Thus, we add regularization to force the vectors to be similar for different input instances in the first pooling layer, where the group-level biomarkers are extracted. We call the novel regularization group-level consistency (GLC) and only apply it to the first pooling layer, as the nodes in the following layers from different instances might be different. Suppose there are instances for class in a batch, where and is the number of classes. We form the scoring matrix . The GLC loss can be expressed as:


where , is a matrix with all 1s, is a diagonal matrix with as diagonal elements. We propose to use Euclidean distance for and due to the benefits of convexity and computational efficiency.

Cross entropy loss

is used for the final prediction. Then, the final loss function is formed as:


where ’s are tunable hyper-parameters, indicates the GNN block and is the total number of GNN blocks, is either MMD or BCE.

4 Experiments and Results

4.1 Data and Preprocessing

We collected fMRI data from a group of 75 ASD children and 43 age and IQ-matched healthy controls (HC), acquired under the ”biopoint” task [13]. The fMRI data was preprocessed following the pipeline in Yang et al. [27]. The Desikan-Killiany [4] atlas was used to parcellate brain images into 84 ROIs. The mean time series for each node was extracted from a random of voxels in the ROI by bootstrapping. In this way, we augmented the data 10 times. Edges were defined by top positive partial correlations to achieve sparse connections. If this led to isolated nodes, we added back the largest edge to each of them. For node attributes, we used Pearson correlation coefficient to node . Pearson correlation and partial correlation are different measures of fMRI connectivity. We aggregate them by using one to build edge connections and the other to build node features.

4.2 Implementation Details

The model architecture was implemented with 2 conv layers and 2 pooling layers as shown in Fig. 2, with parameter

. We designed a 3-layer MLP (with 16, 8 and 2 neurons in each layer) that takes the flattened graph

as input and predicts ASD vs. HC. The pooling layer kept the top important nodes (). We will discuss the variation of and in Section 4.3

. We randomly split the data into five folds based on subjects, which means that the graphs from a single subject can only appear in either the training or test set. Four folds were used as training data, and the left-out fold was used for testing. Adam was used as the optimizer. We trained the model for 100 epochs with an initial learning rate of 0.001, annealed to half every 20 epochs. We set

in the MMD loss to match the same scale as BCE loss.

4.3 Hyperparameter Discussion and Ablation Study

We tuned the parameters and in the loss function Eq. (7) and showed the results in Table 1. encouraged more separable node importance scores for selected and unselected nodes after pooling. controlled the similarity of the selected nodes for instances within the same class. A larger moves toward group-level interpretation of biomarkers. We first performed an ablation study by comparing setting (0-0) and (0.1-0). Mean accuracies increased at least in TopK (1- in SAGE) with MMD or BCE loss. To demonstrate the effectiveness of , we showed the distribution of node pooling scores of the two pooling layers in Fig. 3 over epochs for different combination of pooling functions and distance losses, with and . In the early epochs, the scores centered around 0.5. Then the scores of the top important nodes moved to 1 and scores of unimportant nodes moved to 0 (less obvious for the second pooling layer using SAGE, which may explain why SAGE got lower accuracies than TopK). Hence, significantly higher scores were attributed to the selected important nodes in the pooling layer. Then, we investigated the effects of on the accuracy by varying it from 0 to 1, with fixed at 0.1. Without , the model was easier to overfit to the training set, while larger may result in underfitting to the training set. As the results in Table 1 show, the accuracy increased when increased from 0 to 0.1 and the accuracy dropped if we increased to 1 (except for TopK+MMD). For the following baseline comparison experiments, we set - to be -.

Loss Pool 0-0 0.1-0 0.1-0.1 0.1-0.5 0.1-1
MMD TopK 0.753(0.042) 0.784(0.062) 0.781(0.038) 0.780(0.059) 0.744(0.060)
SAGE 0.751(0.022) 0.770(0.039) 0.771(0.051) 0.773(0.047) 0.751(0.050)
BCE TopK 0.750(0.046) 0.779(0.053) 0.797(0.051) 0.789(0.066) 0.762(0.044)
SAGE 0.755(0.041) 0.767(0.033) 0.773(0.047) 0.764(0.050) 0.755(0.041)
Table 1: Model variations and hyperparameter (-) discussion.
[width=6em]MetricModel SVM Random Forest MLP BrainNetCNN [14] Li et al. [17] PR-GNN
Acc 0.686(0.111) 0.723(0.020) 0.727(0.047) 0.781(0.044) 0.753(0.033)
Par 3k 3k 137k 1438k 16k

Acc: Accuracy; Par: The number of trainable parameters; PR-GNN: TopK+BCE.

Table 2: Comparison with different baseline models.
Figure 3: Distributions of node pooling scores over epochs (offset from far to near).

4.4 Comparison with Existing Models

We compared our method with several brain connectome-based methods, including Random Forest (1000 trees), SVM (RBF kernel), and MLP (one 20 nodes hidden layer), a state-of-the-art CNN-based method, BrainNetCNN [14] and a recent GNN method on fMRI [17], in terms of accuracy and number of parameters. We used the parameter settings indicated in the original paper [14]. The inputs and the architecture parameter setting (node conv, pooling and MLP layers) of the alternative GNN method were the same as PR-GNN. The inputs of BrainNetCNN were Pearson correlation matrices. The inputs of the other alternative methods were the flattened up-triangle of Pearson correlation matrices. Note that the inputs of GNN models contained both Pearson and partial correlations. For a fair comparison with the non-GNN models, we used Pearson correlations (node features) as their inputs, because Pearson correlations were the embedded features, while partial correlations (edge weights) only served as message passing filters in GNN models. The results are shown in Table 2. Our PR-GNN outperformed alternative models. With regularization terms on the pooling function, PR-GNN achieved better accuracy than the recent GNN [17]. Also, PR-GNN needs only parameters compared to the MLP and less than parameters compared to BrainNetCNN.

4.5 Biomarker Interpretation

Figure 4: Selected salient ROIs (importance score indicated by yellow-red color) of four randomly selected ASD individuals with different weights on GLC. The commonly detected salient ROIs across different individuals are circled in green.

Without losing generalizability, we investigated the selected salient ROIs using the model TopK+BCE () with different levels of interpretation by tuning . As we discussed in Section 3.2, large led to group-level interpretation and small led to individual-level interpretation. We varied from 0-0.5. Without losing generalizability, we show the salient ROI detection results of four randomly selected ASD instances in Fig. 4. We show the remaining 21 ROIs after the 2nd pooling layer (with pooling ratio = 0.5, nodes left) and corresponding node pooling scores. As shown in Fig. 4(a), when , we could rarely find any overlapped area among the instances. In Fig. 4(b-c), we circled the large overlapped areas across the instances. By visually examining the salient ROIs, we found two overlapped areas in Fig. 4(b) and four overlapped areas in Fig. 4(c). By averaging the node importance scores (1st pooling layer) over all the instances, dorsal striatum, thalamus and frontal gyrus were the most salient ROIs associated with identifying ASD. These ROIs are related to the neurological functions of social communication, perception and execution [22, 11, 2, 21], which are clearly deficient in ASD.

5 Conclusion

In this paper, we propose PR-GNN, an interpretable graph neural network for fMRI analysis. PR-GNN takes graphs built from fMRI as inputs, then outputs prediction results together with interpretation results. With the built-in interpretability, PR-GNN not only performs better on classification than alternative methods, but also detects salient brain regions for classification. The novel loss term gives us the flexibility to use this same method for individual-level biomarker analysis (small ) and group-level biomarker analysis (large ). We believe that this is the first work using a single model in fMRI study that fills the critical interpretation gap between individual- and group-level analysis. Our interpretation results reveal the salient ROIs to identify autistic disorders from healthy controls. Our method has a potential for understanding neurological disorders, and ultimately benefiting neuroimaging research. We will extend and validate our methods on larger benchmark datasets in future work.


This research was supported in part by NIH grants [R01NS035193, R01MH100028].


  • [1] J. T. Baker, A. J. Holmes, G. A. Masters, B. T. Yeo, F. Krienen, R. L. Buckner, and D. Öngür (2014) Disruption of cortical association networks in schizophrenia and psychotic bipolar disorder. JAMA psychiatry 71 (2), pp. 109–118. Cited by: §1.
  • [2] J. P. Bhanji and M. R. Delgado (2014) The social brain and reward: social information processing in the human striatum. Wiley Interdisciplinary Reviews: Cognitive Science 5 (1), pp. 61–73. Cited by: §4.5.
  • [3] M. Defferrard, X. Bresson, and P. Vandergheynst (2016) Convolutional neural networks on graphs with fast localized spectral filtering. In Advances in neural information processing systems, pp. 3844–3852. Cited by: §1.
  • [4] R. S. Desikan, F. Ségonne, B. Fischl, B. T. Quinn, B. C. Dickerson, D. Blacker, R. L. Buckner, A. M. Dale, R. P. Maguire, B. T. Hyman, et al. (2006) An automated labeling system for subdividing the human cerebral cortex on mri scans into gyral based regions of interest. Neuroimage 31 (3), pp. 968–980. Cited by: §4.1.
  • [5] I. S. Dhillon, Y. Guan, and B. Kulis (2007)

    Weighted graph cuts without eigenvectors a multilevel approach

    IEEE transactions on pattern analysis and machine intelligence 29 (11), pp. 1944–1957. Cited by: §1.
  • [6] H. Gao and S. Ji (2019) Graph u-nets. arXiv preprint arXiv:1905.05178. Cited by: §1, §2.
  • [7] A. A. Goldani, S. R. Downs, F. Widjaja, B. Lawton, and R. L. Hendren (2014) Biomarkers in autism. Frontiers in psychiatry 5, pp. 100. Cited by: §1.
  • [8] L. Gong and Q. Cheng (2019) Exploiting edge features for graph neural networks. In

    Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition

    pp. 9211–9219. Cited by: §2.2.1.
  • [9] A. Gretton, K. M. Borgwardt, M. J. Rasch, B. Schölkopf, and A. Smola (2012) A kernel two-sample test.

    Journal of Machine Learning Research

    13 (Mar), pp. 723–773.
    Cited by: §3.1.1.
  • [10] W. Hamilton, Z. Ying, and J. Leskovec (2017) Inductive representation learning on large graphs. In Advances in neural information processing systems, pp. 1024–1034. Cited by: §1, §2.2.1.
  • [11] A. Y. Hardan, R. R. Girgis, J. Adams, A. R. Gilbert, M. S. Keshavan, and N. J. Minshew (2006) Abnormal brain size effect on the thalamus in autism. Psychiatry Research: Neuroimaging 147 (2-3), pp. 145–151. Cited by: §4.5.
  • [12] M. D. Kaiser, C. M. Hudac, S. Shultz, S. M. Lee, C. Cheung, A. M. Berken, B. Deen, N. B. Pitskel, D. R. Sugrue, A. C. Voos, et al. (2010) Neural signatures of autism. Proceedings of the National Academy of Sciences 107 (49), pp. 21223–21228. Cited by: §1.
  • [13] M. D. Kaiser et al. (2010) Neural signatures of autism. PNAS. Cited by: §4.1.
  • [14] J. Kawahara, C. J. Brown, S. P. Miller, B. G. Booth, V. Chau, R. E. Grunau, J. G. Zwicker, and G. Hamarneh (2017) BrainNetCNN: convolutional neural networks for brain networks; towards predicting neurodevelopment. NeuroImage 146, pp. 1038–1049. Cited by: §4.4, Table 2.
  • [15] J. Lee, I. Lee, and J. Kang (2019) Self-attention graph pooling. arXiv preprint arXiv:1904.08082. Cited by: §1, §2.
  • [16] C. Li, W. Chang, Y. Cheng, Y. Yang, and B. Póczos (2017)

    Mmd gan: towards deeper understanding of moment matching network

    In Advances in Neural Information Processing Systems, pp. 2203–2213. Cited by: §3.1.1.
  • [17] X. Li, N. C. Dvornek, Y. Zhou, J. Zhuang, P. Ventola, and J. S. Duncan (2019) Graph neural network for interpreting task-fmri biomarkers. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 485–493. Cited by: §4.4, Table 2.
  • [18] M. Long, H. Zhu, J. Wang, and M. I. Jordan (2016) Unsupervised domain adaptation with residual transfer networks. In Advances in neural information processing systems, pp. 136–144. Cited by: §1.
  • [19] E. McDade, G. Wang, B. A. Gordon, J. Hassenstab, T. L. Benzinger, V. Buckles, A. M. Fagan, D. M. Holtzman, N. J. Cairns, A. M. Goate, et al. (2018) Longitudinal cognitive and biomarker changes in dominantly inherited alzheimer disease. Neurology 91 (14), pp. e1295–e1306. Cited by: §1.
  • [20] R. A. Poldrack, Y. O. Halchenko, and S. J. Hanson (2009) Decoding the large-scale structure of brain function by classifying mental states across individuals. Psychological science 20 (11), pp. 1364–1372. Cited by: §1.
  • [21] C. Press, N. Weiskopf, and J. M. Kilner (2012) Dissociable roles of human inferior frontal gyrus during action execution and observation. Neuroimage 60 (3), pp. 1671–1677. Cited by: §4.5.
  • [22] M. Schuetze, M. T. M. Park, I. Y. Cho, F. P. MacMaster, M. M. Chakravarty, and S. L. Bray (2016) Morphological alterations in the thalamus, striatum, and pallidum in autism spectrum disorder. Neuropsychopharmacology 41 (11), pp. 2627–2637. Cited by: §4.5.
  • [23] K. Simonyan and A. Zisserman (2014) Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556. Cited by: §1.
  • [24] P. Veličković et al. (2018) Graph attention networks. In ICLR, Cited by: §2.
  • [25] X. Wang, X. Liang, Z. Jiang, B. A. Nguchu, Y. Zhou, Y. Wang, H. Wang, Y. Li, Y. Zhu, F. Wu, et al. (2019)

    Decoding and mapping task states of the human brain via deep learning

    Human Brain Mapping. Cited by: §1.
  • [26] K. J. Worsley, C. H. Liao, J. Aston, V. Petre, G. Duncan, F. Morales, and A. Evans (2002) A general statistical analysis for fmri data. Neuroimage 15 (1), pp. 1–15. Cited by: §1.
  • [27] D. Yang et al. (2016) Brain responses to biological motion predict treatment outcome in young children with autism. Translational psychiatry 6 (11), pp. e948. Cited by: §4.1.
  • [28] H. Yang, X. Li, Y. Wu, S. Li, S. Lu, J. S. Duncan, J. C. Gee, and S. Gu (2019) Interpretable multimodality embedding of cerebral cortex using attention graph network for identifying bipolar disorder. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 799–807. Cited by: §2.2.1, §2.
  • [29] Z. Ying, J. You, C. Morris, X. Ren, W. Hamilton, and J. Leskovec (2018) Hierarchical graph representation learning with differentiable pooling. In Advances in neural information processing systems, pp. 4800–4810. Cited by: §1.