Graph Embedding Using Infomax for ASD Classification and Brain Functional Difference Detection

08/09/2019 ∙ by Xiaoxiao Li, et al. ∙ 0

Significant progress has been made using fMRI to characterize the brain changes that occur in ASD, a complex neuro-developmental disorder. However, due to the high dimensionality and low signal-to-noise ratio of fMRI, embedding informative and robust brain regional fMRI representations for both graph-level classification and region-level functional difference detection tasks between ASD and healthy control (HC) groups is difficult. Here, we model the whole brain fMRI as a graph, which preserves geometrical and temporal information and use a Graph Neural Network (GNN) to learn from the graph-structured fMRI data. We investigate the potential of including mutual information (MI) loss (Infomax), which is an unsupervised term encouraging large MI of each nodal representation and its corresponding graph-level summarized representation to learn a better graph embedding. Specifically, this work developed a pipeline including a GNN encoder, a classifier and a discriminator, which forces the encoded nodal representations to both benefit classification and reveal the common nodal patterns in a graph. We simultaneously optimize graph-level classification loss and Infomax. We demonstrated that Infomax graph embedding improves classification performance as a regularization term. Furthermore, we found separable nodal representations of ASD and HC groups in prefrontal cortex, cingulate cortex, visual regions, and other social, emotional and execution related brain regions. In contrast with GNN with classification loss only, the proposed pipeline can facilitate training more robust ASD classification models. Moreover, the separable nodal representations can detect the functional differences between the two groups and contribute to revealing new ASD biomarkers.



There are no comments yet.


page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Autism spectrum disorder (ASD) affects the structure and function of the brain. Functional magnetic resonance imaging (fMRI) produces 4D spatial-temporal data describing functional activation but with very low signal-noise ratio (SNR). It can be used to characterize neural pathways and brain changes that occur in ASD. However, due to high dimension and low SNR, it is difficult to analyze fMRI. Here, we address the problem of embedding good fMRI representations for identifying ASD and detecting brain functional differences between ASD and healthy control (HC). To utilize the spatial-temporal information of fMRI, we represent the whole brain fMRI as a graph, where each brain region (ROI) is a node, the underlying connection can be calculated by fMRI correlation matrix and node features can be predetermined, hence preserving both geometrical and temporal information. The Graph Neural Network (GNN), a deep learning architecture to analyze graph structured data, has been used in ASD classification

[7]. In addition to improving ASD classification, one core objective of our work is to discover useful representations to detect brain regional differences between ASD vs HC. The simple idea explored here is to train a representation-learning function related to the end-goal task, which maximizes the mutual information (MI) between nodal representation and graph-level representation and minimizes the loss of the end-goal task. MI is notoriously difficult to compute, particularly in continuous and high dimensional settings. Fortunately, the recently proposed MINE [1]

enables effective computation of MI between high dimensional input/output pairs of a deep neural network, by training a statistics network as a classifier of samples coming from the joint distribution of two random variables and the product of their marginals. During training of a GNN, we simultaneously optimize the classification loss and Infomax loss

[10], which maximizes the MI between local/global representation. In this way, we tune the suitability of learned representations for classification and detecting group-level regional functional differences. Results show the improvement of the classification task and reveal the functional differences between ASD and HC from the separable embedded brain regions encoded by the GNN.

2 Methodology

2.1 Data Definition and Notations

Suppose each brain is parcellated into ROIs based on its T1 structural MRI. We define an undirected graph on the brain regions , where and , and

is the attribute dimensions of nodes. For node attributes, we concatenate handcrafted features: degree of connectivity, General Linear Model (GLM) coefficients, mean, and standard deviation of task-fMRI, and ROI center coordinates.

is calculated by the correlation of the mean fMRI time series in each ROI. Graph convolutional kernel (Section 2.2) will encode the input graph to a feature map , that reflects useful structure locally. Next, we summarize the node representation into a global feature by pooling and reading out (Section 2.3). Given a , we will generate a negative graph , whose embedded node representation is . The corresponding positive pair and negative pair will be encouraged to be separated by a discriminator (Section 2.4).

Figure 1: The flowchart of our proposed ASD classification and graph embedding architecture. The top row of the flowchart is a Graph Neural Network architecture to classify ASD and HC. The bottom row is a graph infomax pipeline to encourage better graph embedding. Here, (a) and (b) are positive samples; (c) and (d) are negative samples. (a)(c) (or (b)(d)) is a paired graph. The inputs of discriminator D

is the summary vector generated from positive samples, paired with node embedded representation (

or ). () pair will have True (T) output from D, whereas () will have False (F) output. The encoder, classifier and discriminator are trained simultaneously.

2.2 Encoder: Graph Convolutional Layer

Our encoder node embedding network is a -layer supervised GraphSAGE architecture [3], which learns the embedding function mapping input nodes to output . The embedding function is based on the mean-pooling (MP) propagation rule as used in Hamilton et al.[3], where is the adjacency matrix with inserted self-loops and is its corresponding degree diagonal matrix with . Our encoder can be written as:


where is a learnable projection matrix and

is sigmoid function.

2.3 Classifier: Pooling and Readout Layer

To aggregate the information of each node for the graph level classification, we use Dense hierarchical pooling (DHP [13]) to cluster nodes together. After each DHP, the number of nodes in the graph decreases. At the last level , the pooling layer is performed by a filtering matrix .


produces pooled nodes and adjacency matrix , which generate readout vector . The final number of nodes is predefined. was learned by another GraphSAGE convolutional layer optimized by a regularization loss , where denotes the Frobenius norm. Readout vector will be submitted to a MLP for obtaining final classification outputs

, the probability of being an ASD subject.

2.4 Discriminator: Encouraging Good Representation

Following the intuition in Deep Graph Infomax [10], the good representation may not benefit from encoding counter information. In order to obtain a representation more suitable for classification, we maximize the average MI between the high-level representation and local aggregated embedding of each node, which favours encoding aspects of the data that are shared across the nodes and reduces noisy encoding [4]. The graph-level summary vector can be as the input of discriminator, here

is the logistic sigmoid nonlinearity. A discriminator

is used as a proxy for maximizing the MI representing the probability scores assigned to the local-global pairs. We randomly sample an instance from the opposite class as the negative sample . The discriminator scores summary-node representation pairs by applying a simple bi-linear scoring function [10]


where is a learnable scoring matrix and is the logistic sigmoid nonlinearity, used to convert scores into probabilities of being positive.

2.5 Loss function

In order to learn useful, predictive representations, the Infomax loss function

encourages nodes of the same graph to have similar representations, while enforcing that the representations of disparate nodes are highly distinct. In order to insure the performance of downstream classification, we use binary cross-entropy as the classification loss . Therefore, the loss function of our model is written as:


3 Experiment and Results

3.1 Data Acquisition and Preprocessing

We tested our method on a group of 75 ASD children and 43 age and IQ-matched healthy controls collected at Yale Child Study Center [7] under the ”biopoint” task [6]. The fMRI data was preprocessed following the pipeline in Yang et al.[11]. The graph data was augmented as described in our previous work [7], resulting in 750 ASD graphs and 860 HC graphs. We split the data into 5 folds based on subjects. Four folds were used as training data and the left out fold was used for testing. Based on the definition in Section 2.1, each node attribute . Specifically, the GLM parameters of the ”biopoint task” are: , coefficient of biological motion matrix; , coefficient of scrambled motion matrix; and , coefficients of the previous two matrices’ derivatives.

3.2 Experiment and Results

We tested classifier performance on the Destrieux atlas [2] (148 ROIs) using the proposed GNN with and separately, to examine the advantage of including graph infomax loss . In our GNN setting, and pooling ratios . We used the Adam optimizer with initial learning 0.001, then decreased it by a factor of

every 20 epochs. We trained the network 100 epochs for all of the splits and measured the instance classification by F-score (Table

1). We changed the architectures by tuning either two graph convolutional layers with kernel size or one graph convolutional layer with kernel size . was tested at 8 and 16. The regularization parameters are adjusted correspondingly to get the best performance.

For notation convenience, we use model and model to represent the model of certain GNN architecture and corresponding training loss. Under model , we could not find obvious advantage of using . However, if we increase the encoders’ complexity to , the model became easily overfitted while model kept similar performance. This may indicate can perform as regularization and restrain embedding from data noise. In case, the model was underfitted, while the model performed slightly better. It’s probably because encourages encoding common nodal signals over the graph hence ignoring data noise or just because model had more trainable parameters.

After training, we extracted the nodal embedded vectors after the last Graph Convolutional Layer and used t-SNE [8] to visualize the node presentations in 2D space. Only with did we find linearly separable nodal representations of ASD and HC for certain regions. We visually examined all the nodal representation embeddings by and verified they cannot be linearly separated into the two groups. We marked the regions whose Silhouette score [9] was greater than 0.1 (resulting in 31 regions using , shown in Fig. 2 (b)) as the brain ROIs with functional difference between ASD and HC. We compared the results with GLM z-stats analysis using FSL [5] (shown in Fig. 2 (c)). Our proposed method marked obvious prefrontal cortex, while GLM method did not highlight those regions. Both our method and GLM analysis highlighted cingulate cortex. These regions were indicated as ASD biomarkers in Yang et al.[11] and Kaiser et al.[6]. Also, we used Neurosynth [12] to decode the functional keywords associated with separable regions found by our methods, as shown in Fig. 2 (d). The decoded functional keywords of our detected regions showed evidence that these regions might have social-, mental-, visual-related and default mind functional differences between ASD and HC group. Potentially, our proposed method can be used as a tool to identify new brain biomarkers for better understanding the underlying roots of ASD.

Loss + (conv-layer) (16,16) (8,8) (16) (16,16) (8,8) (16)
Table 1: Performance of different loss functions and GNN architectures (mean std)
Figure 2: Analysis of functional differences between ASD and HC. (a) shows the embedded representations of 4 brain regions visualized by t-SNE. HC is colored in green and ASD is colored in red. The top 2 regions are not separable, while the bottom two region representations are separable. (b) shows two views of the separable regions detected by our methods. (c) is the z stats of two groups by GLM. (d) shows the functional keyword decoding results of the regions in (b).

4 Conclusion

We applied GNN to identify ASD and designed a loss function to encourage better node representation and detect separable brain regions of ASD and HC. By incorporating mutual information of local and global representations, the proposed loss function improved classification performance in certain cases. The added Infomax loss potentially regularizes the embedding of noisy fMRI and increases model robustness. By examining the embedded node representations, we found that ASD and HC had separable representations in regions related to default mode, social function, emotion regulation and visual function, etc. The finding is consistent with prior literature [7, 6] and our approach could potentially discover new functional differences between ASD and HC. Overall, the proposed method provides an efficient and objective way of embedding ASD and HC brain graphs.


  • [1] M. I. Belghazi et al. (2018)

    MINE: mutual information neural estimation

    ICML. Cited by: §1.
  • [2] C. Destrieux et al. (2010) Automatic parcellation of human cortical gyri and sulci using standard anatomical nomenclature. Neuroimage 53 (1), pp. 1–15. Cited by: §3.2.
  • [3] W. L. Hamilton, R. Ying, and J. Leskovec (2017) Inductive representation learning on large graphs. In NIPS, Cited by: §2.2.
  • [4] R. D. Hjelm et al. (2018) Learning deep representations by mutual information estimation and maximization. arXiv preprint arXiv:1808.06670. Cited by: §2.4.
  • [5] M. Jenkinson et al. (2012) FSL. NeuroImage. Cited by: §3.2.
  • [6] M. D. Kaiser et al. (2010) Neural signatures of autism. PNAS. Cited by: §3.1, §3.2, §4.
  • [7] X. Li et al. (2019) Graph neural network for interpreting task-fmri biomarkers. MICCAI. Cited by: §1, §3.1, §4.
  • [8] L. v. d. Maaten and G. Hinton (2008) Visualizing data using t-sne.

    Journal of machine learning research

    9 (Nov), pp. 2579–2605.
    Cited by: §3.2.
  • [9] P. J. Rousseeuw (1987)

    Silhouettes: a graphical aid to the interpretation and validation of cluster analysis

    Journal of computational and applied mathematics 20, pp. 53–65. Cited by: §3.2.
  • [10] P. Veličković et al. (2019) Deep graph infomax. ICLR. Cited by: §1, §2.4.
  • [11] D. Yang et al. (2016) Brain responses to biological motion predict treatment outcome in young children with autism. Translational psychiatry 6 (11), pp. e948. Cited by: §3.1, §3.2.
  • [12] T. Yarkoni et al. (2011) Large-scale automated synthesis of human functional neuroimaging data. Nature methods 8 (8), pp. 665. Cited by: §3.2.
  • [13] Z. Ying et al. (2018) Hierarchical graph representation learning with differentiable pooling. In NeurIPS, pp. 4805–4815. Cited by: §2.3.