PRGNN_fMRI
Pytorch implementation of pooling-regularized GNN (PRGNN) for fMRI analysis. https://arxiv.org/pdf/2007.14589.pdf
view repo
Understanding how certain brain regions relate to a specific neurological disorder has been an important area of neuroimaging research. A promising approach to identify the salient regions is using Graph Neural Networks (GNNs), which can be used to analyze graph structured data, e.g. brain networks constructed by functional magnetic resonance imaging (fMRI). We propose an interpretable GNN framework with a novel salient region selection mechanism to determine neurological brain biomarkers associated with disorders. Specifically, we design novel regularized pooling layers that highlight salient regions of interests (ROIs) so that we can infer which ROIs are important to identify a certain disease based on the node pooling scores calculated by the pooling layers. Our proposed framework, Pooling Regularized-GNN (PR-GNN), encourages reasonable ROI-selection and provides flexibility to preserve either individual- or group-level patterns. We apply the PR-GNN framework on a Biopoint Autism Spectral Disorder (ASD) fMRI dataset. We investigate different choices of the hyperparameters and show that PR-GNN outperforms baseline methods in terms of classification accuracy. The salient ROI detection results show high correspondence with the previous neuroimaging-derived biomarkers for ASD.
READ FULL TEXT VIEW PDFPytorch implementation of pooling-regularized GNN (PRGNN) for fMRI analysis. https://arxiv.org/pdf/2007.14589.pdf
Explaining the underlying roots of neurological disorders (i.e., what brain regions are associated with the disorder) has been a main goal in the field of neuroscience and medicine [12, 7, 1, 19]. Functional magnetic resonance imaging (fMRI), a non-invasive neuroimaging technique that measures neural activation, has been paramount in advancing our understanding of the functional organization of the brain [26, 20, 25]. The functional network of the brain can be modeled as a graph in which each node is a brain region and the edges represent the strength of the connection between those regions.
The past few years have seen the growing prevalence of using graph neural networks (GNN) for graph classification [10]
. Like pooling layers in convolutional neural networks (CNNs)
[23, 18], the pooling layer in GNNs is an important design to compress a large graph to a smaller one for lower dimensional feature extraction. Many node pooling strategies have been studied and can be divided into the following categories: 1) clustering-based pooling, which clusters nodes to a super node based on graph topology
[3, 5, 29] and 2) ranking-based pooling, which assigns each node a score and keeps the top ranked nodes [6, 15]. Clustering-based pooling methods do not preserve node assignment mapping in the input graph domain, hence they are not inherently interpretable at the node level. For our purpose of interpreting node importance, we focus on ranking-based pooling methods. Currently, existing methods of this type [6, 15] have the following key limitations when applying them to salient brain ROI analysis: 1) ranking scores for the discarded nodes and the remaining nodes may not be significantly distinguishable, which is not suitable for identifying salient and representative regional biomarkers, and 2) the nodes in different graphs in the same group may be ranked totally differently (usually caused by overfitting), which is problematic when our objective is to find group-level biomarkers. To reach group-level analysis, such approaches typically require additional steps to summarize statistics (such as averaging). For these two-stage methods, if the results from the first stage are not reliable, significant errors can be induced in the second stage.To utilize GNN for fMRI learning and meet the need of group-level biomarker finding, we propose a pooling regularized GNN framework (PR-GNN) for classifying neurodisorder patients vs healthy control subjects and discovering disorder related biomarkers jointly. The overview of our methods is depicted in Fig.
1. Our key contributions are:The architecture of our PR-GNN is shown in Fig. 2. Below, we introduce the notation and the layers in PR-GNN. For simplicity, we focus on Graph Attention Convolution (GATConv) [24, 28] as the node convolutional layer. For node pooling layers, we test two existing ranking based pooling methods: TopK pooling [6] and SAGE pooling [15].
We first parcellate the brain into ROIs based on its T1 structural MRI. We define ROIs as graph nodes . We define an undirected weighted graph as , where is the edge set, i.e., a collection of linking vertices and . has an associated node feature matrix , where
is the feature vector associated with node
. For every edge connecting two nodes, , we have its strength . We also define for and therefore the adjacency matrix is well defined.To improve GATConv [10], we incorporate edge features in the brain graph as suggested by Gong Cheng [8] and Yang et. al [28]. We define as the feature for the node in the layer and , where is the number of nodes at the layer (the same for ). The propagation model for the forward-pass update of node representation is calculated as:
(1) |
where the attention coefficients are computed as
(2) |
where denotes the set of indices of neighboring nodes of , denotes concatenation, and are model parameters.
The choices of keeping which nodes in TopK pooling and SAGE pooling are determined based on the node importance score , which is calculated in two ways as follows:
(3) |
where is calculated in Eq. (1) and and are model parameters. Note that is different from in Eq. (1) such that the output of is a scalar.
Given the following equation roughly describes the pooling procedure:
(4) |
The notation above finds the indices corresponding to the largest elements in score vector , and is an indexing operation which takes elements at row indices specified by and column indices specified by . The nodes receiving lower scores will experience less feature retention.
Lastly, we seek a “flattening” operation to translate graph information to a vector. Suppose the last layer is , we use , where mean operates elementwisely. Then
is sent to a multilayer perceptron (MLP) to give the final prediction.
To overcome the limitation of existing methods that ranking scores for the discarded nodes and the remaining nodes may not be distinguishable, we propose two distance losses to encourage the difference. Before introducing them, we first rank the elements of the instance scores, , in a descending order, denote it as , and denote its top elements as , and the remaining elements as . We apply two types of constraint to all the training instances.
Ideally, the scores for the selected nodes should be close to 1 and the scores for the unselected nodes should be close to 0. Binary cross entropy (BCE) loss is calculated as:
(5) |
The effect of this constraint will be shown in Section 4.3.
Note that in Eq. (4) is computed from the input . Therefore, for from different instances, the ranking of the entries of can be very different. For our application, we want to find the common patterns/biomarkers for a certain neuro-prediction task. Thus, we add regularization to force the vectors to be similar for different input instances in the first pooling layer, where the group-level biomarkers are extracted. We call the novel regularization group-level consistency (GLC) and only apply it to the first pooling layer, as the nodes in the following layers from different instances might be different. Suppose there are instances for class in a batch, where and is the number of classes. We form the scoring matrix . The GLC loss can be expressed as:
(6) |
where , is a matrix with all 1s, is a diagonal matrix with as diagonal elements. We propose to use Euclidean distance for and due to the benefits of convexity and computational efficiency.
Cross entropy loss
is used for the final prediction. Then, the final loss function is formed as:
(7) |
where ’s are tunable hyper-parameters, indicates the GNN block and is the total number of GNN blocks, is either MMD or BCE.
We collected fMRI data from a group of 75 ASD children and 43 age and IQ-matched healthy controls (HC), acquired under the ”biopoint” task [13]. The fMRI data was preprocessed following the pipeline in Yang et al. [27]. The Desikan-Killiany [4] atlas was used to parcellate brain images into 84 ROIs. The mean time series for each node was extracted from a random of voxels in the ROI by bootstrapping. In this way, we augmented the data 10 times. Edges were defined by top positive partial correlations to achieve sparse connections. If this led to isolated nodes, we added back the largest edge to each of them. For node attributes, we used Pearson correlation coefficient to node . Pearson correlation and partial correlation are different measures of fMRI connectivity. We aggregate them by using one to build edge connections and the other to build node features.
The model architecture was implemented with 2 conv layers and 2 pooling layers as shown in Fig. 2, with parameter
. We designed a 3-layer MLP (with 16, 8 and 2 neurons in each layer) that takes the flattened graph
as input and predicts ASD vs. HC. The pooling layer kept the top important nodes (). We will discuss the variation of and in Section 4.3. We randomly split the data into five folds based on subjects, which means that the graphs from a single subject can only appear in either the training or test set. Four folds were used as training data, and the left-out fold was used for testing. Adam was used as the optimizer. We trained the model for 100 epochs with an initial learning rate of 0.001, annealed to half every 20 epochs. We set
in the MMD loss to match the same scale as BCE loss.We tuned the parameters and in the loss function Eq. (7) and showed the results in Table 1. encouraged more separable node importance scores for selected and unselected nodes after pooling. controlled the similarity of the selected nodes for instances within the same class. A larger moves toward group-level interpretation of biomarkers. We first performed an ablation study by comparing setting (0-0) and (0.1-0). Mean accuracies increased at least in TopK (1- in SAGE) with MMD or BCE loss. To demonstrate the effectiveness of , we showed the distribution of node pooling scores of the two pooling layers in Fig. 3 over epochs for different combination of pooling functions and distance losses, with and . In the early epochs, the scores centered around 0.5. Then the scores of the top important nodes moved to 1 and scores of unimportant nodes moved to 0 (less obvious for the second pooling layer using SAGE, which may explain why SAGE got lower accuracies than TopK). Hence, significantly higher scores were attributed to the selected important nodes in the pooling layer. Then, we investigated the effects of on the accuracy by varying it from 0 to 1, with fixed at 0.1. Without , the model was easier to overfit to the training set, while larger may result in underfitting to the training set. As the results in Table 1 show, the accuracy increased when increased from 0 to 0.1 and the accuracy dropped if we increased to 1 (except for TopK+MMD). For the following baseline comparison experiments, we set - to be -.
Loss | Pool | 0-0 | 0.1-0 | 0.1-0.1 | 0.1-0.5 | 0.1-1 |
---|---|---|---|---|---|---|
MMD | TopK | 0.753(0.042) | 0.784(0.062) | 0.781(0.038) | 0.780(0.059) | 0.744(0.060) |
SAGE | 0.751(0.022) | 0.770(0.039) | 0.771(0.051) | 0.773(0.047) | 0.751(0.050) | |
BCE | TopK | 0.750(0.046) | 0.779(0.053) | 0.797(0.051) | 0.789(0.066) | 0.762(0.044) |
SAGE | 0.755(0.041) | 0.767(0.033) | 0.773(0.047) | 0.764(0.050) | 0.755(0.041) |
[width=6em]MetricModel | SVM | Random Forest | MLP | BrainNetCNN [14] | Li et al. [17] | PR-GNN |
---|---|---|---|---|---|---|
Acc | 0.686(0.111) | 0.723(0.020) | 0.727(0.047) | 0.781(0.044) | 0.753(0.033) | |
Par | 3k | 3k | 137k | 1438k | 16k |
Acc: Accuracy; Par: The number of trainable parameters; PR-GNN: TopK+BCE.
We compared our method with several brain connectome-based methods, including Random Forest (1000 trees), SVM (RBF kernel), and MLP (one 20 nodes hidden layer), a state-of-the-art CNN-based method, BrainNetCNN [14] and a recent GNN method on fMRI [17], in terms of accuracy and number of parameters. We used the parameter settings indicated in the original paper [14]. The inputs and the architecture parameter setting (node conv, pooling and MLP layers) of the alternative GNN method were the same as PR-GNN. The inputs of BrainNetCNN were Pearson correlation matrices. The inputs of the other alternative methods were the flattened up-triangle of Pearson correlation matrices. Note that the inputs of GNN models contained both Pearson and partial correlations. For a fair comparison with the non-GNN models, we used Pearson correlations (node features) as their inputs, because Pearson correlations were the embedded features, while partial correlations (edge weights) only served as message passing filters in GNN models. The results are shown in Table 2. Our PR-GNN outperformed alternative models. With regularization terms on the pooling function, PR-GNN achieved better accuracy than the recent GNN [17]. Also, PR-GNN needs only parameters compared to the MLP and less than parameters compared to BrainNetCNN.
Without losing generalizability, we investigated the selected salient ROIs using the model TopK+BCE () with different levels of interpretation by tuning . As we discussed in Section 3.2, large led to group-level interpretation and small led to individual-level interpretation. We varied from 0-0.5. Without losing generalizability, we show the salient ROI detection results of four randomly selected ASD instances in Fig. 4. We show the remaining 21 ROIs after the 2nd pooling layer (with pooling ratio = 0.5, nodes left) and corresponding node pooling scores. As shown in Fig. 4(a), when , we could rarely find any overlapped area among the instances. In Fig. 4(b-c), we circled the large overlapped areas across the instances. By visually examining the salient ROIs, we found two overlapped areas in Fig. 4(b) and four overlapped areas in Fig. 4(c). By averaging the node importance scores (1st pooling layer) over all the instances, dorsal striatum, thalamus and frontal gyrus were the most salient ROIs associated with identifying ASD. These ROIs are related to the neurological functions of social communication, perception and execution [22, 11, 2, 21], which are clearly deficient in ASD.
In this paper, we propose PR-GNN, an interpretable graph neural network for fMRI analysis. PR-GNN takes graphs built from fMRI as inputs, then outputs prediction results together with interpretation results. With the built-in interpretability, PR-GNN not only performs better on classification than alternative methods, but also detects salient brain regions for classification. The novel loss term gives us the flexibility to use this same method for individual-level biomarker analysis (small ) and group-level biomarker analysis (large ). We believe that this is the first work using a single model in fMRI study that fills the critical interpretation gap between individual- and group-level analysis. Our interpretation results reveal the salient ROIs to identify autistic disorders from healthy controls. Our method has a potential for understanding neurological disorders, and ultimately benefiting neuroimaging research. We will extend and validate our methods on larger benchmark datasets in future work.
This research was supported in part by NIH grants [R01NS035193, R01MH100028].
Weighted graph cuts without eigenvectors a multilevel approach
. IEEE transactions on pattern analysis and machine intelligence 29 (11), pp. 1944–1957. Cited by: §1.Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition
, pp. 9211–9219. Cited by: §2.2.1.Journal of Machine Learning Research
13 (Mar), pp. 723–773. Cited by: §3.1.1.Mmd gan: towards deeper understanding of moment matching network
. In Advances in Neural Information Processing Systems, pp. 2203–2213. Cited by: §3.1.1.Decoding and mapping task states of the human brain via deep learning
. Human Brain Mapping. Cited by: §1.
Comments
There are no comments yet.