Functional connectivity which is often computed using cross-correlation among brain regions of interest (ROIs) is a powerful approach which has been shown to be informative for classifying brain disorders and revealing putative bio-markers relevant to the underlying disorder 10.1093/brain/awn018; Lynall9477; https://doi.org/10.1002/hbm.23524; 10.1007/978-3-030-59728-3_52. Inferring and using functional connectivity through spatio-temporal data, e.g. functional magnetic resonance imaging (fMRI), has been an especially important area of research in recent times. Functional connectivity can improve our understanding of brain dynamics and improve classification accuracy for brain disorders such as schizophrenia. Recent work (yan2017discriminating) uses functional network connectivity (FNC) as features to predict schizophrenia related changes. Whereas, PARISOT2018117 uses functional connectivity obtained by a fixed formula with phenotypic and imaging data as inputs and to extract graphic features for the classification of AD and Autism. BrainNetCNN also uses connection strength between brain regions as edges, typically defined as the number of white-matter tracts connecting the regions. 10.1007/978-3-319-66182-7_54 employs spectral graph theory to learn similarity metrics among functional connectivity networks.
These papers, as well as many others, have shown the efficacy of functional connectivity and feature extraction based on neural network models. However, existing studies often heavily depend on the underlying method of functional connectivity estimation, in terms of classification accuracy, feature extraction, or learning brain dynamics. Studies likeRASHID2016645; Saha2020.06.24.161745; SALMAN2019101747
depend on hand-crafted features based on methods like ICA (Independent Component Analysis). These studies work very well on classification but do not learn a sparse graph and not helpful for identifying bio-markers in the brain.
Many functional connectivity studies 10.3389/fnins.2018.00525 on brain disorders utilize ROIs predefined based on anatomical or functional atlases, which are either fixed for all subjects or based are based on group differences.
These approaches ignore the possibility of inter-subject variations of ROIs, especially the variations due to the underlying disease conditions. They also rely on the complete set of these ROIs discounting the possibility that only a small subset may be important at a time. A disorder can have varying symptoms for different people, hence making it crucial to determine disorder and subject specific ROIs.
In this work, we address the problems of using a fixed method of learning functional connectivity and using it as a fixed graph to represent brain structure (the standard practices) by utilizing a novel attention based Graph Neural Network (GNN) li2016gated, which we call BrainGNN. We apply it to fMRI data and 1) achieve comparable classification accuracy to existing algorithms, 2) learn dynamic graph functional connectivity, and 3) increase model interpretability by learning which regions from the set of ROIs are relevant for the classification, enabling additional insights into the health and disordered brain.
2 Materials and Methods
In this study, we worked with the data from Function Biomedical Informatics Research Network (FBIRN) keator2016function dataset including schizophrenia (SZ) patients and healthy controls (HC) for testing our model. Details of the dataset are in the following section.
Resting fMRI data from the phase III FBIRN were analyzed for this project. The dataset has total subjects out of which were selected based on the preprocessing method explained in 2.1.2.
The fMRI data was preprocessed using statistical parametric mapping (SPM12, http://www.fil.ion.ucl.ac.uk/spm/) under the MATLAB 2019 environment. A rigid body motion correction was performed to correct subject head motion, followed by the slice-timing correction to account for timing difference in slice acquisition. The fMRI data were subsequently warped into the standard Montreal Neurological Institute (MNI) space using an echo planar imaging (EPI) template and were slightly resampled to mm isotropic voxels. The resampled fMRI images were then smoothed using a Gaussian kernel with a full width at half maximum (FWHM) =
mm. After the smoothing, the functional images were temporally filtered by a finite impulse response (FIR) bandpass filter (0.01 Hz-0.15 Hz). Then for each voxel, six rigid body head motion parameters, white matter (WM) signals, and cerebrospinal fluid (CSF) signals were regressed out using linear regression.
We selected subjects for further analysis FU2021117385 if the subjects have head motion and mm, and with functional data providing near full brain successful normalization fu2019altered.
This resulted in a total of subjects with healthy controls and subjects with schizophrenia. Each subject is represented by , where represent the number of voxels in each dimension and is the number of time points which are
. To reduce the affect of noise we zscore the time sequence of each voxel independently. Thus, time series of every voxel is replaced by the z-score of the time series. This does not have any affect on the data dimensions.
To partition the data into regions use automated anatomical labeling (AAL) TZOURIOMAZOYER2002273 which contains brain regions. Taking sum of the voxels inside a region is an easy and common method but this gives and unfair advantage to bigger regions. For this, we take the weighted average of the voxel intensities inside a region. Weight is the value of a voxel being inside a region, as these values are not binary. Averaging helps to negate the bias towards bigger regions. This results in a dataset where , , , .
We have three distinct parts in our novel attention based GNN architecture: 1) a Convolutional Neural Network (CNN)726791 that creates embeddings for each region, 2) a Self-Attention mechanism 10.5555/3295222.3295349 that assigns weights between regions for functional connectivity and 3) A GNN that uses regions (nodes) and edges for graph classification. In this section we explain the purpose and details of each part separately. Refer to Figure 1 for the complete architecture diagram of BrainGNN.
2.2.1 CNN Encoder
We use a CNN KIRANYAZ2021107398 encoder to obtain the representation of individual regions created in the preprocessing step outlined in 2.1.2
. Each region vector of dimensionis passed through multiple layers of one dimensional convolution, and a fully connected layer to get final embedding. The one dimensional CNN encoder used in our architecture consists of convolution layers with filter size
, strideand output channels . This is followed by a fully connected layer resulting in a final embedding of size
. We use rectified linear unit (ReLU) as an activation layer between convolution layers. Each region is encoded individually to later on create connections between regions and interpret which regions are more important/informative for classification. Our one dimensional CNN layer embeds the temporal features of regions and the spatial connections are handled in the attention and GNN parts of the architecture.
2.2.2 Self Attention
Using the embeddings created by the CNN encoder, we estimate the connectivity between the regions of the brain using multi-head self-attention following 10.5555/3295222.3295349
. The self-attention model creates three embeddings namely (key, query, value) for each region, which in our architecture are created using three simple linear layers. Each linear layeris of size . , , . To create weights between a region and every other region, the model takes dot product of a region’s query with every other region’s key embedding to get scores between them. Hence, . The scores are then converted to weights using softmax. where is a vector of scores between region and every other region. The weights are then multiplied with the embedding of each region and summed together to create new representation for a . Following equations show how to get new region embedding and weight values.
This process is carried out for all the regions, producing new representation of every region and the weights between regions. These weights are then used as the functional connectivity between different regions of brain for every subject. The self attention layer encodes the spatial axis for each subject and provides with the connection between regions. The weights are learned via end to end learning of our model performing classification. This frees us from using predefined models or functions to estimate the connectivity.
Our graph network is based on a previously published model li2016gated. Each subject is represented by a graph having where is the matrix of vertices, where each vertex is represented by an embedding acquired by self-attention. are the adjacency and edge weight matrices. Since we do not use any existing method of computing edges, we construct a complete directed graph with backward edges, meaning every pair of vertices is joined by two directed edges with weights and . For each GNN layer, at every step
, each node, which is a region in our model sums feature vectors of every other region relative to the weight edge between the nodes and pass the resultant and it’s own feature vector through a gated recurrent unit (GRU) networkcho-etal-2014-properties, to obtain new embedding for itself.
where can be explained by following set of equations, with representing the result of sum in Equation 2:
The number of steps is a hyper-parameter which we have set it as based on our experiments. The graph neural network helps nodes to create new embeddings based on the embeddings of other regions in the graph weighted by the edge weights between them. In our architecture, we use GNN layers, as shown in experiments of bresson2017residual that it provides with the highest accuracy, with the first followed by a top-k pooling layer gao2019graph; knyazev2019understanding. On the input feature vectors which are the embeddings of the regions, the pooling operator learns a parameter () which is to assign weight to the features. Based on this parameter, top (k) layers are chosen in each pooling layer and the rest of the regions are discarded from further layers. The pooling method can be explained by the following equations.
and are the new features and adjacency matrix we get after selecting top (k) regions. Pooling is performed to help model focus on the important regions/nodes which are responsible for classification. The ratio of nodes to keep in the pooling layer is a hyper-parameter and we have used as the ratios. Since we represent each subject as graph , in the end we do graph classification by pooling all the feature vectors of the remaining
regions/nodes. To get one feature vector from the entire graph we concatenate the output of three different pooling layers. We pass the complete graph into three separate pooling layers. Each of the pooling layer gives us one feature factor. In the end, we concatenate the three vectors to get one final embedding for the entire graph which represents a subject. In our model we use graph max pool, graph average pool and attention based poolvinyals2016order. The dimension of the resulting vector is . The feature vector is then passed through two linear layers of size and . As the name suggests, graph max pool and graph average pool just gets the max and average vector from the graph whereas attention based pooling multiplies each vector with a learned attention value before summing all the vectors.
2.2.4 Training and Testing
To train, validate and test our model we divide the total subjects into three groups of size , and , for training, validating and testing respectively. To conduct a fair experiment we use fold cross validation and for each fold we perform trials, resulting in a total of trials, and selecting
subjects per class for each trial. We calculate the area under the ROC (receiver operating characteristic) curve (AUC) for each trial. To optimize our model we train all of our architecture in an end to end fashion, using Cross Entropy to calculate our loss by giving true labelsas targets, Adam as our optimizer and reducing learning rate on plateau with patience 10. We early stop our model based on validation loss, with patience of 15. Let represent the parameters of the entire architecture.
We show three different groups of results in our study. 1) The classification results, 2) Regions’ connectivity and 3) Key regions selection. We discuss these in the following sections. We test and compare our model against the classical machine learning algorithms and10.1007/978-3-030-59728-3_40
on the same data used in BrainGNN. The input for the machine learning model is sFNC matrices produced using Pearson product-moment correlation coefficients (PCC).
As mentioned, we use the AUC metric to quantify the classification results of our model. AUC is more informative than simple accuracy for binary classification as in our case. Figure 2 shows the results for our model. Figure 3 shows the ROC curves of the models for each fold. The performance is comparable to state of the art classical machine learning algorithms using hand crafted features and existing deep learning approaches such as 10.1007/978-3-030-59728-3_40, which performed test on independent component analysis (ICA) components with a hold out dataset. Comparison with other machine and deep learning approaches is shown in Figure 4
and prove our claim of BrainGNN providing state of the art results. BrainGNN gives almost the same mean AUC as the best performing model i.e. SVM (Support Vector Machine). To the best of our knowledge, these results are currently among the best on the unmodified FBIRN fMRI datasetRASHID2016645; Saha2020.06.24.161745; SALMAN2019101747. Table 1
shows the mean AUC for each cross validation fold that was used for experimentation for BrainGNN. As it is shown in the table that AUC has high variance across the different test sets of cross validation. To make more sense out of the functional connectivity and region selection, both results are based on the second test fold which gives the highest () AUC score.
3.2 Functional Connectivity
The functional connectivity between regions of the brain is crucial for understanding how different parts of brain are interacting with each other. We use the weights assigned by the self-attention module of our architecture as the connection between regions. Figure 5 shows weight matrices for the second test set in cross validation. Weight matrices of subjects belonging to SZ class turn out to be much sparser than weights of healthy controls subjects. The result shows that the connectivity is limited to fewer regions, and functional connectivity differs across classes and fewer regions get higher weights in case of SZ subjects. We also perform statistical testing to confirm that the weight matrices of HC differ from those of SZ subjects. We create two sets each representing the concatenation of the weights of test subjects belonging to a class. We perform different testing, shown in Table 2. P-value of
shows that we can reject the null-hypothesis, hence making it highly likely that the difference between weights of HC and SZ subjects is not zero. FNC matrices produced using PCC method, do not provide such level of information and almost all regions get unit weight between other regions.5 shows the usefulness of learning connectivity between regions in an end-to-end manner while training the model for classification.
3.3 Region Selection
The pooling layer added in our GNN module allows us to reduce the number of regions. Functionality across brain regions differ significantly and not all regions are affected by a disorder or have any noticeable affect on classification. This makes it very important to know which regions are more significantly informative of the underlying disorder and study how they get affected or affect the disorder. Figure (a)a shows the final regions selected after the last pooling layer in the GNN model which is just
percent of the total brain regions used. The relevance of these regions is further signified by the fact that the graph model has no residual connections and the final feature vector created after the last GNN layer is through the feature vectors of these regions. Figure(b)b shows the location of the selected regions in the MNI brain space, regions are distinguished by color. Each region is assigned one unit from the color bar, used to represent signal variation in the fMRI data.
|Mann-Whitney U Test||0.0|
|BrainGNN (Directed)||sFNC (Undirected)|
The richness of results in the three presented categories highlights the benefits of the proposed method. High classification performance shows that the model can accurately classify the subjects and hence can be trusted with the other two interpretative results of the paper. Functional connectivity between regions shown in the paper is of paramount importance as it highlights how brain regions are connected to each other and the variation between classes. Learning functional connectivity end-to-end through classification training frees the model from depending on an external method. The sparse weight matrix of subjects with SZ shows that connectivity remains significant between considerably fewer regions than for healthy controls. Notably, the attention based functional connectivity cannot be interpreted as the conventional correlation based symmetric connectivity. Due to the inherent asymmetry in keys and values the obtained graph is directed but is also prediction based rather than simply correlation. We expect that a further investigation into the obtained graph structure will bring more results and deeper interpretations. The sparsity is to be further explored and seen in context of the regions selected, shown in the last section of results. The final regions selected by the model strengthens our hypotheses that not all regions are equally important for identifying a particular brain disorder. Reducing the brain regions by almost helps in identifying the important regions for classification of SZ. The regions selected by our model such as (cerebellum, temporal lobe, caudate, SMA) etc have been linked to the disease by multiple previous studies, hence reassuring the correctness of our model article; https://doi.org/10.1002/hbm.25205; article2; article3. We see an immediate benefit of using GNNs to study functional connectivity and our BrainGNN model specifically. The data-driven model almost eliminates manual decisions transitioning graph construction and region selection into the data-driven realm. With this BrainGNN opens up a new direction to the existing studies of connectivity and we expect further model introspection to yield insight into the spatio-temporal biomarkers of schizophrenia. Further reducing the selected regions and how they different across subjects belonging to different class is also left for future work. We envision great benefits to interpretability and elimination of manual processing and decisions in a future extension of the model that would enable it to work directly from the voxel-level not only connecting and selecting ROIs, but also constructing them.
“Conceptualization, U.M., S.P.; methodology, U.M.; software, U.M.; validation, U.M.; formal analysis, U.M.; investigation, U.M.; resources, V.C.; data curation, Z.F, U.M.; writing–original draft preparation, U.M.; writing–review and editing, U.M., S.P., Z.F.; visualization, U.M.; supervision, S.P.; project administration, S.P., V.C.; funding acquisition, S.P., V.C. All authors have read and agreed to the published version of the manuscript.”, please turn to the CRediT taxonomy for the term explanation.
This study was in part supported by NIH grants 2RF1MH121885, 1R01AG063153, and 2R01EB006841.