A deep learning model for data-driven discovery of functional connectivity

by   Usman Mahmood, et al.
Georgia State University

Functional connectivity (FC) studies have demonstrated the overarching value of studying the brain and its disorders through the undirected weighted graph of fMRI correlation matrix. Most of the work with the FC, however, depends on the way the connectivity is computed, and further depends on the manual post-hoc analysis of the FC matrices. In this work we propose a deep learning architecture BrainGNN that learns the connectivity structure as part of learning to classify subjects. It simultaneously applies a graphical neural network to this learned graph and learns to select a sparse subset of brain regions important to the prediction task. We demonstrate the model's state-of-the-art classification performance on a schizophrenia fMRI dataset and demonstrate how introspection leads to disorder relevant findings. The graphs learned by the model exhibit strong class discrimination and the sparse subset of relevant regions are consistent with the schizophrenia literature.



There are no comments yet.


page 9

page 10


Functional connectivity patterns of autism spectrum disorder identified by deep feature learning

Autism spectrum disorder (ASD) is regarded as a brain disease with globa...

Brain Age Prediction Based on Resting-State Functional Connectivity Patterns Using Convolutional Neural Networks

Brain age prediction based on neuroimaging data could help characterize ...

Weighted Ensemble-model and Network Analysis: A method to predict fluid intelligence via naturalistic functional connectivity

Objectives: Functional connectivity triggered by naturalistic stimulus (...

Architectural configurations, atlas granularity and functional connectivity with diagnostic value in Autism Spectrum Disorder

Currently, the diagnosis of Autism Spectrum Disorder (ASD) is dependent ...

Ordinal Pattern Kernel for Brain Connectivity Network Classification

Brain connectivity networks, which characterize the functional or struct...

Encoding Multi-Resolution Brain Networks Using Unsupervised Deep Learning

The main goal of this study is to extract a set of brain networks in mul...

Improving the Level of Autism Discrimination through GraphRNN Link Prediction

Dataset is the key of deep learning in Autism disease research. However,...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Functional connectivity which is often computed using cross-correlation among brain regions of interest (ROIs) is a powerful approach which has been shown to be informative for classifying brain disorders and revealing putative bio-markers relevant to the underlying disorder 10.1093/brain/awn018; Lynall9477; https://doi.org/10.1002/hbm.23524; 10.1007/978-3-030-59728-3_52. Inferring and using functional connectivity through spatio-temporal data, e.g. functional magnetic resonance imaging (fMRI), has been an especially important area of research in recent times. Functional connectivity can improve our understanding of brain dynamics and improve classification accuracy for brain disorders such as schizophrenia. Recent work (yan2017discriminating) uses functional network connectivity (FNC) as features to predict schizophrenia related changes. Whereas, PARISOT2018117 uses functional connectivity obtained by a fixed formula with phenotypic and imaging data as inputs and to extract graphic features for the classification of AD and Autism. BrainNetCNN also uses connection strength between brain regions as edges, typically defined as the number of white-matter tracts connecting the regions. 10.1007/978-3-319-66182-7_54 employs spectral graph theory to learn similarity metrics among functional connectivity networks.

These papers, as well as many others, have shown the efficacy of functional connectivity and feature extraction based on neural network models. However, existing studies often heavily depend on the underlying method of functional connectivity estimation, in terms of classification accuracy, feature extraction, or learning brain dynamics. Studies like

RASHID2016645; Saha2020.06.24.161745; SALMAN2019101747

depend on hand-crafted features based on methods like ICA (Independent Component Analysis). These studies work very well on classification but do not learn a sparse graph and not helpful for identifying bio-markers in the brain.

Many functional connectivity studies 10.3389/fnins.2018.00525 on brain disorders utilize ROIs predefined based on anatomical or functional atlases, which are either fixed for all subjects or based are based on group differences.

These approaches ignore the possibility of inter-subject variations of ROIs, especially the variations due to the underlying disease conditions. They also rely on the complete set of these ROIs discounting the possibility that only a small subset may be important at a time. A disorder can have varying symptoms for different people, hence making it crucial to determine disorder and subject specific ROIs.

In this work, we address the problems of using a fixed method of learning functional connectivity and using it as a fixed graph to represent brain structure (the standard practices) by utilizing a novel attention based Graph Neural Network (GNN)  li2016gated, which we call BrainGNN. We apply it to fMRI data and 1) achieve comparable classification accuracy to existing algorithms, 2) learn dynamic graph functional connectivity, and 3) increase model interpretability by learning which regions from the set of ROIs are relevant for the classification, enabling additional insights into the health and disordered brain.

2 Materials and Methods

2.1 Materials

In this study, we worked with the data from Function Biomedical Informatics Research Network (FBIRN) keator2016function dataset including schizophrenia (SZ) patients and healthy controls (HC) for testing our model. Details of the dataset are in the following section.

2.1.1 Fbirn

Resting fMRI data from the phase III FBIRN were analyzed for this project. The dataset has total subjects out of which were selected based on the preprocessing method explained in 2.1.2.

2.1.2 Preprocessing

The fMRI data was preprocessed using statistical parametric mapping (SPM12, http://www.fil.ion.ucl.ac.uk/spm/) under the MATLAB 2019 environment. A rigid body motion correction was performed to correct subject head motion, followed by the slice-timing correction to account for timing difference in slice acquisition. The fMRI data were subsequently warped into the standard Montreal Neurological Institute (MNI) space using an echo planar imaging (EPI) template and were slightly resampled to mm isotropic voxels. The resampled fMRI images were then smoothed using a Gaussian kernel with a full width at half maximum (FWHM) =

mm. After the smoothing, the functional images were temporally filtered by a finite impulse response (FIR) bandpass filter (0.01 Hz-0.15 Hz). Then for each voxel, six rigid body head motion parameters, white matter (WM) signals, and cerebrospinal fluid (CSF) signals were regressed out using linear regression.

We selected subjects for further analysis FU2021117385 if the subjects have head motion and mm, and with functional data providing near full brain successful normalization fu2019altered.

This resulted in a total of subjects with healthy controls and subjects with schizophrenia. Each subject is represented by , where represent the number of voxels in each dimension and is the number of time points which are

. To reduce the affect of noise we zscore the time sequence of each voxel independently. Thus, time series of every voxel is replaced by the z-score of the time series. This does not have any affect on the data dimensions.

To partition the data into regions use automated anatomical labeling (AAL) TZOURIOMAZOYER2002273 which contains brain regions. Taking sum of the voxels inside a region is an easy and common method but this gives and unfair advantage to bigger regions. For this, we take the weighted average of the voxel intensities inside a region. Weight is the value of a voxel being inside a region, as these values are not binary. Averaging helps to negate the bias towards bigger regions. This results in a dataset where , , , .

2.2 Method

We have three distinct parts in our novel attention based GNN architecture: 1) a Convolutional Neural Network (CNN)

726791 that creates embeddings for each region, 2) a Self-Attention mechanism 10.5555/3295222.3295349 that assigns weights between regions for functional connectivity and 3) A GNN that uses regions (nodes) and edges for graph classification. In this section we explain the purpose and details of each part separately. Refer to Figure 1 for the complete architecture diagram of BrainGNN.

2.2.1 CNN Encoder

We use a CNN KIRANYAZ2021107398 encoder to obtain the representation of individual regions created in the preprocessing step outlined in 2.1.2

. Each region vector of dimension

is passed through multiple layers of one dimensional convolution, and a fully connected layer to get final embedding. The one dimensional CNN encoder used in our architecture consists of convolution layers with filter size

, stride

and output channels . This is followed by a fully connected layer resulting in a final embedding of size

. We use rectified linear unit (ReLU) as an activation layer between convolution layers. Each region is encoded individually to later on create connections between regions and interpret which regions are more important/informative for classification. Our one dimensional CNN layer embeds the temporal features of regions and the spatial connections are handled in the attention and GNN parts of the architecture.

Figure 1: BrainGNN architecture using a) Preprocessing: To preprocess the raw data with different steps (2.1.2). b) 1DCNN: To create embedding for regions (2.2.1). c) Self-attention: To create connectivity between regions (2.2.2) d) GNN: To obtain a single feature vector for the entire graph (2.2.3) and e) Linear classifier: To obtain the final classification.

2.2.2 Self Attention

Using the embeddings created by the CNN encoder, we estimate the connectivity between the regions of the brain using multi-head self-attention following 10.5555/3295222.3295349

. The self-attention model creates three embeddings namely (key, query, value) for each region, which in our architecture are created using three simple linear layers. Each linear layer

is of size . , , . To create weights between a region and every other region, the model takes dot product of a region’s query with every other region’s key embedding to get scores between them. Hence, . The scores are then converted to weights using softmax. where is a vector of scores between region and every other region. The weights are then multiplied with the embedding of each region and summed together to create new representation for a . Following equations show how to get new region embedding and weight values.


This process is carried out for all the regions, producing new representation of every region and the weights between regions. These weights are then used as the functional connectivity between different regions of brain for every subject. The self attention layer encodes the spatial axis for each subject and provides with the connection between regions. The weights are learned via end to end learning of our model performing classification. This frees us from using predefined models or functions to estimate the connectivity.

2.2.3 Gnn

Our graph network is based on a previously published model li2016gated. Each subject is represented by a graph having where is the matrix of vertices, where each vertex is represented by an embedding acquired by self-attention. are the adjacency and edge weight matrices. Since we do not use any existing method of computing edges, we construct a complete directed graph with backward edges, meaning every pair of vertices is joined by two directed edges with weights and . For each GNN layer, at every step

, each node, which is a region in our model sums feature vectors of every other region relative to the weight edge between the nodes and pass the resultant and it’s own feature vector through a gated recurrent unit (GRU) network

cho-etal-2014-properties, to obtain new embedding for itself.


where can be explained by following set of equations, with representing the result of sum in Equation 2:


The number of steps is a hyper-parameter which we have set it as based on our experiments. The graph neural network helps nodes to create new embeddings based on the embeddings of other regions in the graph weighted by the edge weights between them. In our architecture, we use GNN layers, as shown in experiments of bresson2017residual that it provides with the highest accuracy, with the first followed by a top-k pooling layer gao2019graph; knyazev2019understanding. On the input feature vectors which are the embeddings of the regions, the pooling operator learns a parameter () which is to assign weight to the features. Based on this parameter, top (k) layers are chosen in each pooling layer and the rest of the regions are discarded from further layers. The pooling method can be explained by the following equations.


and are the new features and adjacency matrix we get after selecting top (k) regions. Pooling is performed to help model focus on the important regions/nodes which are responsible for classification. The ratio of nodes to keep in the pooling layer is a hyper-parameter and we have used as the ratios. Since we represent each subject as graph , in the end we do graph classification by pooling all the feature vectors of the remaining

regions/nodes. To get one feature vector from the entire graph we concatenate the output of three different pooling layers. We pass the complete graph into three separate pooling layers. Each of the pooling layer gives us one feature factor. In the end, we concatenate the three vectors to get one final embedding for the entire graph which represents a subject. In our model we use graph max pool, graph average pool and attention based pool

vinyals2016order. The dimension of the resulting vector is . The feature vector is then passed through two linear layers of size and . As the name suggests, graph max pool and graph average pool just gets the max and average vector from the graph whereas attention based pooling multiplies each vector with a learned attention value before summing all the vectors.

2.2.4 Training and Testing

To train, validate and test our model we divide the total subjects into three groups of size , and , for training, validating and testing respectively. To conduct a fair experiment we use fold cross validation and for each fold we perform trials, resulting in a total of trials, and selecting

subjects per class for each trial. We calculate the area under the ROC (receiver operating characteristic) curve (AUC) for each trial. To optimize our model we train all of our architecture in an end to end fashion, using Cross Entropy to calculate our loss by giving true labels

as targets, Adam as our optimizer and reducing learning rate on plateau with patience 10. We early stop our model based on validation loss, with patience of 15. Let represent the parameters of the entire architecture.


3 Results

We show three different groups of results in our study. 1) The classification results, 2) Regions’ connectivity and 3) Key regions selection. We discuss these in the following sections. We test and compare our model against the classical machine learning algorithms and  


on the same data used in BrainGNN. The input for the machine learning model is sFNC matrices produced using Pearson product-moment correlation coefficients (PCC).

3.1 Classification

As mentioned, we use the AUC metric to quantify the classification results of our model. AUC is more informative than simple accuracy for binary classification as in our case. Figure 2 shows the results for our model. Figure 3 shows the ROC curves of the models for each fold. The performance is comparable to state of the art classical machine learning algorithms using hand crafted features and existing deep learning approaches such as  10.1007/978-3-030-59728-3_40, which performed test on independent component analysis (ICA) components with a hold out dataset. Comparison with other machine and deep learning approaches is shown in Figure 4

and prove our claim of BrainGNN providing state of the art results. BrainGNN gives almost the same mean AUC as the best performing model i.e. SVM (Support Vector Machine). To the best of our knowledge, these results are currently among the best on the unmodified FBIRN fMRI dataset

RASHID2016645; Saha2020.06.24.161745; SALMAN2019101747. Table 1

shows the mean AUC for each cross validation fold that was used for experimentation for BrainGNN. As it is shown in the table that AUC has high variance across the different test sets of cross validation. To make more sense out of the functional connectivity and region selection, both results are based on the second test fold which gives the highest (

) AUC score.

3.2 Functional Connectivity

The functional connectivity between regions of the brain is crucial for understanding how different parts of brain are interacting with each other. We use the weights assigned by the self-attention module of our architecture as the connection between regions. Figure 5 shows weight matrices for the second test set in cross validation. Weight matrices of subjects belonging to SZ class turn out to be much sparser than weights of healthy controls subjects. The result shows that the connectivity is limited to fewer regions, and functional connectivity differs across classes and fewer regions get higher weights in case of SZ subjects. We also perform statistical testing to confirm that the weight matrices of HC differ from those of SZ subjects. We create two sets each representing the concatenation of the weights of test subjects belonging to a class. We perform different testing, shown in Table 2. P-value of

shows that we can reject the null-hypothesis, hence making it highly likely that the difference between weights of HC and SZ subjects is not zero. FNC matrices produced using PCC method, do not provide such level of information and almost all regions get unit weight between other regions.

5 shows the usefulness of learning connectivity between regions in an end-to-end manner while training the model for classification.

Figure 2:

KDE plot of probability density of ROC-AUC score on FBIRN dataset. The 190 points on the x-axis signifies the 19 fold cross validation, 10 trials per cross validation. With average and median of ((

)), density peaks at () AUC.
CV Fold 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
AUC 0.695 0.955 0.644 0.752 0.908 0.917 0.894 0.803 0.649 0.805 0.922 0.699 0.625 0.780 0.794 0.766 0.914 0.750 0.777
Table 1: Showing mean AUC of 10 trials for each cv fold
Figure 3: Shows the ROC curves of the 19 models generated using each fold of cross validation. The graph is symmetrical and well balanced. It shows that the model did not learn one class over the other.
Figure 4: BrainGNN comparision with other popular methods. BrainGNN provides mean AUC as , which is just (

) less than the best performing model (SVM). Methods like WholeMILC (UFPT) and l1 logistic regression failed to learn on the input data. The l1 logistic regression model does perform better with a very weak regularization term.

3.3 Region Selection

The pooling layer added in our GNN module allows us to reduce the number of regions. Functionality across brain regions differ significantly and not all regions are affected by a disorder or have any noticeable affect on classification. This makes it very important to know which regions are more significantly informative of the underlying disorder and study how they get affected or affect the disorder. Figure (a)a shows the final regions selected after the last pooling layer in the GNN model which is just

percent of the total brain regions used. The relevance of these regions is further signified by the fact that the graph model has no residual connections and the final feature vector created after the last GNN layer is through the feature vectors of these regions. Figure

(b)b shows the location of the selected regions in the MNI brain space, regions are distinguished by color. Each region is assigned one unit from the color bar, used to represent signal variation in the fMRI data.

Test P Value
Mann-Whitney U Test 0.0

Welch’s t-test

Table 2: Statistical testing between weight matrices of HC and SZ. The test shows that weights of regions differ across HC and SZ subjects. Refer to Figure 4 for mean and deviation of these folds.
BrainGNN (Directed) sFNC (Undirected)
Figure 5: Connectivity between regions of subjects of both classes using BrainGNN and sFNC (PCC method). BrainGNN: The similarity of connection between a class and difference across class is compelling. Weights of SZ class are more sparse than HC, highlighting the fact that fewer regions receive higher weights for subjects with SZ. Refer to Table 2 for results of statistical testing between weights of HC and SZ subjects. sFNC: The matrices are symmetric but are less informative than those produced by BrainGNN. Most of the regions are assigend unit weight.
(a) Region frequency
(b) Region maps
Figure 8: (a)a: Histogram of regions selected after the last pooling layer of GNN. fold of the cross validation gives this figure. All regions are selection equal number of times (). It further signifies the important of these regions, showing that for all subjects across both classes, these regions are always selection. (b)b: Mapping the regions back on the brain across the three anatomical planes. time point is selected for these brain scans. X axis shows different slices of the plane.

4 Discussion

The richness of results in the three presented categories highlights the benefits of the proposed method. High classification performance shows that the model can accurately classify the subjects and hence can be trusted with the other two interpretative results of the paper. Functional connectivity between regions shown in the paper is of paramount importance as it highlights how brain regions are connected to each other and the variation between classes. Learning functional connectivity end-to-end through classification training frees the model from depending on an external method. The sparse weight matrix of subjects with SZ shows that connectivity remains significant between considerably fewer regions than for healthy controls. Notably, the attention based functional connectivity cannot be interpreted as the conventional correlation based symmetric connectivity. Due to the inherent asymmetry in keys and values the obtained graph is directed but is also prediction based rather than simply correlation. We expect that a further investigation into the obtained graph structure will bring more results and deeper interpretations. The sparsity is to be further explored and seen in context of the regions selected, shown in the last section of results. The final regions selected by the model strengthens our hypotheses that not all regions are equally important for identifying a particular brain disorder. Reducing the brain regions by almost helps in identifying the important regions for classification of SZ. The regions selected by our model such as (cerebellum, temporal lobe, caudate, SMA) etc have been linked to the disease by multiple previous studies, hence reassuring the correctness of our model article; https://doi.org/10.1002/hbm.25205; article2; article3. We see an immediate benefit of using GNNs to study functional connectivity and our BrainGNN model specifically. The data-driven model almost eliminates manual decisions transitioning graph construction and region selection into the data-driven realm. With this BrainGNN opens up a new direction to the existing studies of connectivity and we expect further model introspection to yield insight into the spatio-temporal biomarkers of schizophrenia. Further reducing the selected regions and how they different across subjects belonging to different class is also left for future work. We envision great benefits to interpretability and elimination of manual processing and decisions in a future extension of the model that would enable it to work directly from the voxel-level not only connecting and selecting ROIs, but also constructing them.

“Conceptualization, U.M., S.P.; methodology, U.M.; software, U.M.; validation, U.M.; formal analysis, U.M.; investigation, U.M.; resources, V.C.; data curation, Z.F, U.M.; writing–original draft preparation, U.M.; writing–review and editing, U.M., S.P., Z.F.; visualization, U.M.; supervision, S.P.; project administration, S.P., V.C.; funding acquisition, S.P., V.C. All authors have read and agreed to the published version of the manuscript.”, please turn to the CRediT taxonomy for the term explanation.

This study was in part supported by NIH grants 2RF1MH121885, 1R01AG063153, and 2R01EB006841.

Data for Schizophrenia classification was used in this study were downloaded from the Function BIRN Data Repository (http://bdr.birncommunity.org:8080/BDR/), supported by grants to the Function BIRN (U24-RR021992) Testbed funded by the National Center for Research Resources at the National Institutes of Health, U.S.A. and from the COllaborative Informatics and Neuroimaging Suite Data Exchange tool (COINS; http://coins.trendscenter.org) and data collection was performed at the Mind Research Network, and funded by a Center of Biomedical Research Excellence (COBRE) grant 5P20RR021938/P20GM103472 from the NIH to Dr.Vince Calhoun. The authors declare no conflict of interest. References yes