Mental disorders manifest in behavior that is driven by disruptions in brain dynamics [goldberg1992common, calhoun2014chronnectome]. Functional MRI captures the nuances of spatio-temporal dynamics that could potentially provide clues to the causes of mental disorders and enable early diagnosis. However, the obtained data for a single subject is of high dimensionality and to be useful for learning, and statistical analysis, one needs to collect datasets with a large number of subjects . Yet, for any kind of a disorder, demographics or other types of conditions, a single study is rarely able to amass datasets large enough to go out of the mode. Traditionally small data problem is approached by handcrafting features [Khazaee2016] of much smaller dimension, effectively reducing via dimensionality reduction. Often, the dynamics of brain function in these representations vanishes into proxy features such as correlation matrices of functional network connectivity (FNC) [yan2017discriminating].
Our goal is to enable the direct study of brain dynamics in the situation. In the case of brain data it, in turn, can enable an analysis of brain function via model introspection. In this paper, we show how one can achieve significant improvement in classification directly from dynamical data on small datasets by taking advantage of publicly available large but unrelated datasets. We demonstrate that it is possible to train a model in a self-supervised manner on dynamics of healthy control subjects from the Human Connectome Project (HCP) [van2013wu] and apply the pre-trained model to a completely different data collected across multiple sites from healthy controls and patients. We show that pre-training on dynamics allows the encoder to generalize across a number of datasets and a wide range of disorders: schizophrenia, autism, and Alzheimer’s disease. Importantly, we show that learnt dynamics generalizes across different data distributions, as our model pre-trained on healthy adults shows improvements in children and elderly.
2 Related Work
Unsupervised pre-training is a well-known technique to get a head start for the deep neural network[erhan2010does]
. It finds wide use across a number of fields such as computer vision[henaff2019data]
, natural language processing (NLP)[devlin2018bert]
and automatic speech recognition (ASR)[lugosch2019speech]. However, outside NLP unsupervised pre-training is not as popular as supervised.
Recent advances in self-supervised methods with mutual information objectives are approaching performance of supervised training [infonce, hjelm2018learning, bachman2019learning] and can scale pre-training to very deep convolutional networks (e.g., 50-layer ResNet). They were shown to benefit structural MRI analysis [fedorov2019prediction], learn useful representations from the frames in Atari games [anand2019unsupervised] and for speaker identification [ravanelli2018learning]. Pre-trained models can outperform supervised methods by a large margin in case of small data [henaff2019data].
Earlier work in brain imaging [khosla2019machine, frontiers2014] have been based on unsupervised methods to learn the dynamics and structure of the brain using approaches such as ICA [calhoun2001method] and HMM [eavani2013unsupervised]. Deep learning for capturing the brain dynamics has also been previously proposed [hjelm2014restricted, hjelm2018spatio, khosla2019detecting]. In some very small datasets, transfer learning was proposed for use in neuroimaging applications [mensch2017learning, 10.3389/fnins.2018.00491, thomas2019deep]. Yet another idea is the data generating approach [ulloa2018improving]. ST-DIM [anand2019unsupervised] has been used for pre-training on unrelated data with subsequent use for classification [mahmood2019transfer].
We present MILC as an unsupervised pre-training method. We use MILC to pre-train on large unrelated and unlabelled data to better learn data representation. The learnt representations are then used for classification on downstream tasks adding a simple linear network on top of the pre-training architecture. The fundamental idea of MILC is to establish relationship between windows (a time slice from the entire sequence) and their respective sequences through learning useful signal dynamics. In all of our experiments we use encoded rsfMRI ICA time courses as our sequences and a consecutive chunk of time points as windows. The model uses the idea to distinguish among sequences (subjects) which proves to be extremely useful in downstream tasks e.g classification of HC or SZ subjects. To realize the concept, we maximize the mutual information of the latent space of a window and the corresponding sequence as a whole.
Let be a dataset of pairs computed from ICA time courses. is the local embedding of -th window taken from sequence , is the global embedding for the entire sequence . is the number of windows in a sequence, and is the total number of sequences. Then is called a dataset of positive pairs and — of negative pairs. The dataset
refers to a joint distribution and
— a marginal distribution of the whole sequence and the window in the latent space. Eventually, the lower bound with InfoNCE estimator[infonce] is defined as:
where is a critic function. Specifically, we are using separable critic , where is some embedding function parameterized by neural networks. Such embedding function is used to calculate value of a critic function in same dimensional space from two dimensional inputs. Critic learns an embedding function such that critic assigns higher values for positive pairs compared to negative pairs: .
Our critic function takes the latent representation of a window and sequence as input. We define latent state of window as an output produced by the CNN part of MILC, given input from -th window of sequence . The latent state of sequence as is the global embedding obtained from MILC architecture. Thus the critic function for input pair —a window and a sequence—is . The loss is InfoNCE with as . The scheme of the MILC is shown in Figure 1.
3.1 Transfer and Supervised Learning
In the downstream task, we use the representation (output) of the attention model pre-trained using MILC as input to a simple binary classifier on top. Refer to section 4.1 for further details.
In this section we study the performance of our model on both, synthetic and real data. To compare and show the advantage of pre-training on large unrelated dataset we use three different kind of models — 1) FPT (Frozen Pre-Trained): The pre-trained model is not further trained on the dataset of downstream task, 2) UFPT (Unfrozen Pre-Trained): The pre-trained model is further trained on the dataset of downstream task and 3) NPT (Not Pre-trained): The model is not pre-trained at all and only trained on the dataset of downstream task. The models are shown in Figure 1. In each experiment, we compare all three models to demonstrate the effectiveness of unsupervised pre-training.
The CNN Encoder of MILC for simulation experiment consists of D convolutional layers with output features , kernel sizes
respectively, followed by ReLU after each layer followed by a linear layer withunits. For real data experiments, we use D convolutional layers with output features , kernel sizes respectively, followed by ReLU after each layer followed by a linear layer with
units. We use stride
for all of the convolution layers. We also test against autoencoder based pre-training for simulation experiment, for which we use the same CNN encoder as for MILC in the reduction phase. For the decoder, we use the reverse architecture of the encoder that result inwindows at the output.
In MILC based pre-training, for all possible pairs in the batch, we take feature
from the output layer of CNN encoder. The latent representation of the entire time series is then passed through biLSTM. The output of biLSTM is used as input to the attention model to get a single vector, which represents the entire time series. Scores are calculated using and as explained in 3. Using these scores, we compute the loss. The neural networks are trained using Adam optimizer.
In downstream tasks we are more interested in subjects for classification task, for each subject the output of attention model () is used as input to a feed forward network of two linear layers with and units to perform binary classification. For experiments, a hold out is selected for testing and is never used through the training/validation phase. For each experiment, 10 trials are performed to ensure random selection of training subjects and, in each case, the performance is evaluated on the hold out (test data). The code is available at: https://github.com/UsmanMahmood27/MILC
To generate synthetic data, we generate multiple -node graphs with stable transition matrices. Using these we generate multivariate time series with autoregressive (VAR) and structural vector autoregressive (SVAR) models [lutkepohl2005new].
VAR times series with size are split into three time slices respectively for training, validation and testing. Using these samples, We pre-train MILC to assign windows to respective time series.
In the final downstream task, we classify the whole time-series into VAR or SVAR (obtained by randomly dropping VAR samples) groups. We generate samples and split as for training, for validation and for hold-out test. For both pre-training and downstream task, we follow the same set up as described in section 4.1.
We compare the effectiveness of MILC with the model used in [mahmood2019transfer] and two variations of autoencoder based pre-training. The two variations of autoencoder are acquired by replacing the CNN encoder of [mahmood2019transfer] and MILC by the pre-trained or randomly initialized autoencoder during downstream classification, depending on the model as explained in section 4. We refer to these two variations as AE_STDIM and AE_STDIM+attention. Note that difference between the two is the added attention layer in the later during downstream classification.
It is observed that the MILC based pre-trained models can easily be fine-tuned only with small amount of downstream data. Note, with very few samples, models based on the pre-trained MILC (FPT and UFPT) outperform the un-pre-trained models (NPT), ST-DIM models, autoencoder based models. ST-DIM based pre-training model [mahmood2019transfer] performs reasonably well compared to autoencoder and NPT models, however, MILC steadily outperforms ST-DIM. Results show that autoencoder based self-supervised pre-training does not assist in VAR vs. SVAR classification. Refer to Figure 1 Left for the results of simulation experiments.
4.3 Brain Imaging
Next, we apply MILC to brain imagining data. We use rsfMRI data for all brain data experiments. Refer to Figure 1 for the details of the datasets used. We compare MILC with ST-DIM based pre-training shown in [mahmood2019transfer].
Four datasets used in this study are collected from FBIRN (Function Biomedical Informatics Research Network 222These data were downloaded from Function BIRN Data Repository, Project Accession Number 2007-BDR-6UHZ1.) [keator2016function] project, from COBRE (Center of Biomedical Research Excellence) [ccetin2014thalamus] project, from release 1.0 of ABIDE (Autism Brain Imaging Data Exchange 333http://fcon_1000.projects.nitrc.org/indi/abide/) [di2014autism] and from release 3.0 of OASIS (Open Access Series of Imaging Studies 444https://www.oasis-brains.org/) [rubin1998prospective].
We preprocess the fMRI data using statistical parametric mapping (SPM12, http://www.fil.ion.ucl.ac.uk/spm/) under MATLAB 2016 environment. After the preprocessing, subjects were included in the analysis if the subjects have head motion and mm, and with functional data providing near full brain successful normalization [fu2019altered].
For each dataset, ICA components are acquired using the same procedure described in [fu2019altered]. However, only non-noise components as determined per slice (time point) are used in all experiments. For all experiments, the fMRI sequence is divided into overlapping windows of time points with overlap along time dimension.
For schizophrenia classification, we conduct experiments on two different datasets, FBIRN [keator2016function] and COBRE [ccetin2014thalamus]. The datasets contain labeled Schizophrenia (SZ) and Healthy Control (HC) subjects.
The dataset has total subjects — a collection of HC and affected with SZ. We use two hold-out sets of size each for validation and test respectively. The remaining data is used for supervised training. The results in Figure 3 strengthen the efficiency of MILC. That is, with only training subjects, FPT and UFPT perform significantly better than NPT having difference in their median AUC scores.
With total subjects, are HC and are affected with autism. We use subjects each for validation and test purpose. The remaining data is used for downstream training i.e., autism vs. HC classification. Figure 3 shows, MILC pre-trained models perform reasonably better than NPT and thus reinforces our hypothesis that unsupervised pre-training learns signal dynamics useful for downstream tasks. We suspect that the reason why pre-trained models do not work well for subjects is that the dataset is much different than HCP. The big age gap between subjects of HCP and ABIDE is a major difference and subjects are not enough even for pre-trained models. Refer to Figure 1 for the demographic information of all the datasets.
4.3.5 Alzheimer’s disease
The dataset OASIS [rubin1998prospective] has total subjects with equal number () of HC and AZ patients. We use two hold-out sets each of size respectively for validation and test purpose. The remaining are used for supervised training. Refer to Figure 3 for results. The AUC scores of pre-trained models is higher than NPT starting from subjects, even with subjects NPTdoes not perform equally well.
Our experiments demonstrate that with the whole MILC pre-training we’re able to achieve reasonable prediction performance from complete dynamics even on small data. Importantly, we’re now able to investigate what in the dynamics was the most discriminative (see Figure 4).
5 Conclusions and Future Work
As we have demonstrated, self-supervised pre-training of a spatio-temporal encoder gives significant improvement on the downstream tasks in brain imaging datasets. Learning dynamics of fMRI helps to improve classification results for all three dieseases and speed up the convergence of the algorithm on small datasets, that otherwise do not provide reliable generalizations. Although the utility of these results is highly promising by itself, we conjecture that direct application to spatio-temporal data will warrant benefits beyond improved classification accuracy in the future work. Working with ICA components is a smaller and thus easier to handle space that exhibits all dynamics of the signal, in future we will move beyond ICA pre-processing and work with fMRI data directly. We expect further model introspection to yield insight into the spatio-temporal biomarkers of schizophrenia. It may indeed be learning crucial information about dynamics that might contain important clues into the nature of mental disorders.
This study was in part supported by NIH grants 1R01AG063153 and 2R01EB006841. We’d like to thank and acknowledge the open access data platforms and data sources that were used for this work, including: Human Connectome Project (HCP), Open Access Series of Imaging Studies (OASIS), Autism Brain Imaging Data Exchange (ABIDE I), Function Biomedical Informatics Research Network (FBIRN) and Centers of Biomedical Research Excellence (COBRE).