MHATC: Autism Spectrum Disorder identification utilizing multi-head attention encoder along with temporal consolidation modules

by   Ranjeet Ranjan Jha, et al.

Resting-state fMRI is commonly used for diagnosing Autism Spectrum Disorder (ASD) by using network-based functional connectivity. It has been shown that ASD is associated with brain regions and their inter-connections. However, discriminating based on connectivity patterns among imaging data of the control population and that of ASD patients' brains is a non-trivial task. In order to tackle said classification task, we propose a novel deep learning architecture (MHATC) consisting of multi-head attention and temporal consolidation modules for classifying an individual as a patient of ASD. The devised architecture results from an in-depth analysis of the limitations of current deep neural network solutions for similar applications. Our approach is not only robust but computationally efficient, which can allow its adoption in a variety of other research and clinical settings.



There are no comments yet.


page 1

page 2

page 3

page 4


Functional connectivity patterns of autism spectrum disorder identified by deep feature learning

Autism spectrum disorder (ASD) is regarded as a brain disease with globa...

Autism Classification Using Brain Functional Connectivity Dynamics and Machine Learning

The goal of the present study is to identify autism using machine learni...

ADHD Identification using Convolutional Neural Network with Seed-based Approach for fMRI Data

Attention Deficit Hyperactivity Disorder (ADHD) is a highly prevalent ps...

Identifying Autism Spectrum Disorder Based on Individual-Aware Down-Sampling and Multi-Modal Learning

Autism Spectrum Disorder(ASD) is a set of neurodevelopmental conditions ...

Head and Tail Localization of C. elegans

C. elegans is commonly used in neuroscience for behaviour analysis becau...

A Multi-Task Deep Learning Framework to Localize the Eloquent Cortex in Brain Tumor Patients Using Dynamic Functional Connectivity

We present a novel deep learning framework that uses dynamic functional ...

Identification of Effective Connectivity Subregions

Standard fMRI connectivity analyses depend on aggregating the time serie...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Autism spectrum disorder (ASD) is considered as one of the most complex neuro-developmental disorders, in which patients suffer from several difficulties, including repetitive behavior, deficits in social communication, and restricted interest. Over the years, functional Magnetic Resonance Imaging (fMRI) has been established as a well-known sensing technique and proved to be useful in assessing different mental disorders. Utilizing fMRI, it has been observed that the functional connectivity (FC) among the different regions of the brain can yield discriminative representation between control and a patient of ASD. In the literature, several works have been reported, which primarily operate on the FC matrix.

Early detection of ASD is very crucial for better treatment of patients. However, this is a very tedious and time-consuming task for medical practitioners as several clinical observations need to be made about a patient’s behaviour to diagnose this mental disorder. In this regard, fMRI, along with sophisticated signal processing techniques, can aid in the early identification of ASD. In addition, several machine learning-based techniques have been reported in the literature to perform the underlying classification task (Control vs ASD); but, those techniques depend on hand-crafted feature extraction methods, which limit the performance.

Recently, encouraged by the popularity of deep learning methods in the computer vision domain and sub-domains, few researchers have deployed these techniques also, in classification tasks involving the fMRI data including that of ASD identification. As deep learning techniques are end to end, we do not perform the feature extraction separately, unlike traditional machine learning, where this is the backbone. In


, the authors have utilized a Siamese graph convolutional neural network (GCNN) to perform the ASD classification (as opposed to the vanilla 3D-CNN approach taken by Khosla

et al. in [khosla20183d]). They worked with graph convolution in the spectral domain and achieved comparable results to state-of-the-art techniques. In [parisot2018disease], authors have considered a GCNN for the identification of ASD and Alzheimer’s disease. Heinsfeld et al. [heinsfeld2018identification] have used an encoder-decoder based architecture for getting robust and discriminative low-dimensional features from the FC matrix, which consequently help in better classification. Taking a temporal approach, Dvornek et al. [dvornek2017identifying] have performed the classification of individuals with ASD and typical controls by directly employing LSTM modules on resting-state fMRI time series data. This led to a significant increase of 9% in the classification accuracy at the time, as reported by them.

Figure 1: Proposed Architecture

Authors in [zhao2018diagnosis]

have employed multi-level high order functional connectivity, to get a more comprehensive characterization of the complex interactions between the different regions of the brain. Feature selection to identify discriminative FC features in their approach led to improved ASD identification results. In

[niu2020multichannel], an attention-based deep learning model has been proposed which achieved good accuracy. In [jiao2020improving] and [huang2020identifying]

, a deep learning-based capsule network and a Deep Belief Network, respectively, have been presented for performing the ASD classification task. Even the combination of sMRI and fMRI has been explored for autism spectrum disorder diagnosis in


where the decisions made by both the modalities have been fused using a deep autoencoder network. In

[ingalhalikar2021functional],ASD prediction has been performed on site harmonized autism dataset. [jung2021inter] have tried to learn inter regional high level relation from functional connectivity for the classification of ASD disease.

However, the problem is still challenging and even after many such works there is still considerable scope to improve the overall classification accuracy. Most of the state-of-the-art deep learning techniques, proposed for ASD classification, directly using the FC matrix as input which approximates the data at an initial stage with second-order statistics. On the other hand, our model works on the original input signal, and learns robust features by employing a multi-head self attention and temporal consolidation mechanism.

Our contribution in this paper is of several folds:

(1) Classification of ASD utilizing a novel architecture (as shown in Fig. 1), consisting of a multi-head attention encoder () and a temporal consolidation module (TCM). (2) The (Fig. 1(b)

) which includes a positional encoding component and leverages several blocks of multi-head attention and feed-forward networks with residual connections for flexibility. (3) The TCM (Fig.

1(a)) which includes several context consolidation modules that aggregate the features across regions at any given time step, which results in a better overall feature representation. (4) Achieving state-of-the-art performance on the publicly available ABIDE data-set.

2 Method

We have proposed a cascaded network (as shown in Fig. 1), which consists of two major modules - the Multi-Head Attention Encoder (MHAE) module and the Temporal Consolidation Module (TCM). As the name suggests, the MHAE module is employed to perform attention based encoding of the features required for classification (inspired by [vaswani2017attention]). On the other hand, the TCM is utilised to step-by-step consolidate the concerned features across the time points and get a condensed feature map at the end. Finally, a global pooling layer, followed by the classification layer, provides the final classification output.

2.1 Multi-Head Attention Encoding Module

In general, a learns position invariant representation of the input data as discussed in [vaswani2017attention]. However, for temporal data, positional information within the given input sequential data is important. This gives rise to the need of positional information to be embedded in the input. For this purpose, we have employed sinusoids like in [vaswani2017attention] to associate positional information with the input data, where phase of the sinusoids varies for different dimensions of the encoding. The input to this module is the matrix of size (200,200) where we have data for 200 time-points for every region and the brain is divided into 200 regions. The module functions by first calculating the positional embedding of the input block which helps attach a temporal meaning to the input fMRI data. Post this, the embedding is added to the input and passed to an encoder which employs multi-head self attention followed by a feed-forward network to capture the relative importance of features in the input. Multiple such layers, each including multi-head attention and feed-forward components, have been used which has shown to improve classification performance. Note that there is no reduction in size as the input passes through this module.

(a) A block in the Temporal Consolidation Module.
(b) A layer in the Multi-Head Attention Encoder module. Multiple such layers have been employed in a cascaded manner for improved encoding of inter-region dependence.
Figure 2: Note that in (a), a TCM block causes a (200,200) input to be consolidated to a (128,200) output owing to 1D convolutions with filter size () of 3. Multiple such blocks employed in a cascaded manner cause the consolidation to become further dense with the size reducing to (64,200) and then finally to (32,200). In (b), the MHAE module is shown.

Self-attention: This module’s core idea lies in multi-head self-attention

, which corresponds to the ability to attend to different positions of the input sequence to compute a representation of that sequence. We have used a stack of multi-head self-attention layers along with feed-forward neural networks. For a self-attention layer, there are three input vectors, viz. query (

), key (), and value () [vaswani2017attention] which indicate the role that each element of the input sequence plays. Here, three representations , and have been obtained from transformation of input (after positional embedding). The degree of correlation between two elements in a piece of sequential data determines the extent to which these elements influence the output at a given time-step. The attention () (similar to [vaswani2017attention]) is given by the following relation:


In our case, the query and key have the same dimension, i.e., which facilitates the dot product in the above equations. This dot-product is followed by scaling i.e. division by square root of m, resulting in an -dimensional vector (eq. 1). This has a normalizing effect and prevents explosion of values. Subsequently, softmax function has been applied on to get the attention vector as the upshot i.e. which holds different weightage for different entries in the value vector. Finally, this attention is multiplied with the value vector to get the feature vector .

Multi-head attention: To jointly attend to information at different positions from different representational spaces, multi-head attention has been employed. Multi-head attention, to put it simply, involves the concatenation of outputs by multiple simple scaled dot-product attention units as shown in Fig. 1(b). This concatenated output is finally passed through a dense layer. We have used five such ‘layers’ in the module, where the multi-head attention component of each layer consists of two attention heads.

2.2 Temporal Consolidation Module

This module aims to extract the more discriminative features by aggregating the brain’s regions and capturing the multi-context information. One can observe in Fig. 1(a) that we have concatenated three types of feature maps taken at three contexts by applying filters of size , (two sequential ), and (three sequential ), respectively. This module’s input dimension has been reduced to by including filters in the first 1D convolution layer; however, the number of filters remains the same for other layers. In this way, the final output after the multi-context layer would be of size , which contains the inherent multi-context representation. Further, there is an additional residual connection branch that contains only one 1D convolution layer of filter size . This layer’s task is to forward the input directly to the multi-context output; nevertheless, the dimension would be changed to for addition with

. This technique is used as it is known to speed up gradient flow during backpropagation. An important thing that has been exploited here is the ability of the network to decide what layers it requires for a well-enough representation of the crucial features. This adds to the flexibility of the network in not being rigid about the types of filters to be used, but allows for automatic selection due to the end-to-end nature. Subsequently, two more TCM blocks have been included in sequential order with the same block structures and the different number of filters (64 and 32). Thus, the aggregated feature map only contains the most crucial relationship information among the different regions, and hence helps in better classification.

3 Experimental Setup

3.1 Dataset description

For the evaluation of our proposed network, the complete ABIDE I [nielsen2013multisite] fMRI dataset has been utilized. This dataset has contributions from multiple different international brain imaging laboratories. It consists of 530 control (normal) individuals and 505 ASD patients. We have performed 10-fold cross validation like in [huang2020identifying], for checking the robustness of the model.

3.2 Evaluation metrics and implementation details

The accuracy, sensitivity and specificity metrics similar to [jiao2020improving], have been used for which evaluations have been discussed in this section. In addition, ROC curve has been plotted, where the area under the curve (AUC) is another important measure.

4 Results and Discussion

We provide a comparison of the accuracy of our approach with the present state-of-the-art methods. Table 1(a) provides the average performance analysis for cross validation. It can be observed that our proposed model is giving good mean accuracy, i.e., . Another point to note is that, mean sensitivity is high (), and such a high sensitivity enables correct identification of patient individuals to a great extent. However, mean specificity is comparatively less (. The reason behind this can be that rs-fMRI signals for ASD individuals have more discriminative representation than that of controls. In addition, we have shown the best ROC curve among all the validation data (considered for each fold) for ASD and control classification in Fig. 3, and the mentioned value of AUC also indicates a high-quality performance of our method. In addition to the overall accuracy, an ROC curve helps in providing more insight into the performance of the model. It helps figure out the trade-off between false positives and false negatives according to the use-case of the system under scrutiny.

Figure 3: Receiver Operating Characteristic (ROC) curves (class 0 for control and class 1 for the patient)
Model Accuracy Sensitivity Specificity
Ghiassian et al. [ghiassian2016using] 59.2% 72.2% 45.4%
Chen et al. [chen2015diagnostic] 66.0% 60% 72%
Heinsfeld et al. [heinsfeld2018identification] 70.0% 74% 63%
Jiao et al. [jiao2020improving] 71% 73% 66%
Huang et al. [huang2020identifying] 76.4% 77.8% 75.0%
Ingalhalikar et al. [ingalhalikar2021functional] 71.35% 59.5% 80.6%
Jung1 et al. [jung2021inter] 69.19% 64.79% 73.46%
Proposed method 77.4% 78.6% 74%
(a) Performance comparison between our proposed network and state-of-the-art-works
# Encoders # Attention heads Accuracy
1 2 73.56%
2 2 75.48%
3 2 76.62%
4 2 77.00%
5 2 77.40%
6 2 76.80%
(b) Ablation Study: Comparison for multi-head attention encoder
having different combinations of attention heads and encoder layers
Table 1: Comparative Analysis and Ablation Study

For a visual assessment, we have shown the Craddock atlas [craddock2012whole] in Fig. 4, where there are regions with color encoding for coronal, sagittal and axial view. As we know, that the module has been used to get refined features for each region; we have considered the mean features of the module for ASD and control classes separately. Further, including only the edges which have value higher than (this value being an indicator of the extent to which two regions in the brain are connected functionally) have been plotted in Fig. 4. One can observe that there is a difference in the regions which are ‘connected’ i.e. are engaged in some functionality together in case of ASD and control. This shows the discriminability of the features. The result depicts the involvement of several regions, including left fusiform gyrus, right precentral gyrus, left superior frontal gyrus, left precuneus cortex, for distinguishing ASD from control. Thus, we can say that the module aids in extracting discriminative features.

Figure 4: Atlas Visualization with 200 regions and Selected few highly correlated edges in both the ASD and Control (in coronal, sagittal and axial view)

Considering the Multi-Head Attention Encoding module, we have analyzed the effect of varying the number of encoders on the classification between control and ASD. During the experiment, we have observed that if we increase more than two attention heads, there is no improvement in accuracy. Moreover, going with more than five encoders reducing the performance. Hence, five encoders, each containing two attention heads, have given the best performance among all the combinations.

5 Conclusion

We have proposed an end to end deep learning architecture that leverages several modules, including multi-head attention and temporal consolidation, to perform the classification task between control and autism spectrum disorder fMRI scans. Our approach yields state-of-the-art accuracy and is also quite robust and computationally efficient. Thus, one can explore its application to various other research and clinical settings. In the future, we would explore variations of the proposed network to improve the performance further.

6 Compliance with Ethical Standards

The above work has been conducted using human subject data made available in open access by the Autism Brain Imaging Data Exchange. Ethical approval was not required as confirmed by the license attached with the open access data.