Log In Sign Up

Prediction of Progression to Alzheimer`s disease with Deep InfoMax

Arguably, unsupervised learning plays a crucial role in the majority of algorithms for processing brain imaging. A recently introduced unsupervised approach Deep InfoMax (DIM) is a promising tool for exploring brain structure in a flexible non-linear way. In this paper, we investigate the use of variants of DIM in a setting of progression to Alzheimer`s disease in comparison with supervised AlexNet and ResNet inspired convolutional neural networks. As a benchmark, we use a classification task between four groups: patients with stable, and progressive mild cognitive impairment (MCI), with Alzheimer`s disease, and healthy controls. Our dataset is comprised of 828 subjects from the Alzheimer`s Disease Neuroimaging Initiative (ADNI) database. Our experiments highlight encouraging evidence of the high potential utility of DIM in future neuroimaging studies.


page 1

page 2

page 3

page 4


Predicting Cognitive Decline with Deep Learning of Brain Metabolism and Amyloid Imaging

For effective treatment of Alzheimer disease (AD), it is important to id...

An Ensemble of Deep Convolutional Neural Networks for Alzheimer's Disease Detection and Classification

Alzheimer's Disease destroys brain cells causing people to lose their me...

Multi-View Graph Convolutional Network and Its Applications on Neuroimage Analysis for Parkinson's Disease

Parkinson's Disease (PD) is one of the most prevalent neurodegenerative ...

Deep Learning Based Unsupervised and Semi-supervised Classification for Keratoconus

The transparent cornea is the window of the eye, facilitating the entry ...

Investigating Conversion from Mild Cognitive Impairment to Alzheimer's Disease using Latent Space Manipulation

Alzheimer's disease is the most common cause of dementia that affects mi...

I Introduction

According to [1]

, the economic costs of mental disorders have the highest impact on economic growth, direct and indirect costs and the statistical value of life. One essential tool for better understanding mental illness is to use noninvasive neuroimaging (e.g., structural magnetic resonance imaging (MRI) images) along with machine learning to learn brain structure.

Deep Learning has been integral to the successes of machine learning for numerous demanding real-world applications, e.g., state-of-the-art image classification [2] and self-driving cars [3]

. While many of Deep Learning’s successes involve supervised learning, supervised approaches can fail when data annotation (e.g., labels) is limited or unavailable. When there is sufficient data, supervised models can not only perform well on holdout sets but provide representations that generalize well to other supervised settings 

[4]. However, when there is insufficient data, a supervised learner tends to discriminate on low-level (e.g., pixel-level, trivial) information, which hurts generalization performance. A model that generalizes well needs to extract meaningful high-level information (e.g., a collection of important features at the input level). In order to address this, many successful applications of machine learning to neuroscience rely on unsupervised learning [5, 6, 7, 8]

to extract representations of brain imaging data. These representations are then used as input to an off-the-shelf classifier (i.e., semi-supervised learning).

However, prior work on unsupervised learning of brain imaging data is either linear or weakly nonlinear [5, 6] or are highly restrictive in parameterization [7], and do not represent flexible methodology for learning representations.

In this work, we explore using DIM [9]

to learn deep non-linear representations of neuroimaging data as an output of a convolutional neural network. DIM works by maximizing the mutual information between a high-level feature vector and low-level feature maps of a highly flexible convolutional


network by training a second neural network that maximizes a lower bound on a divergence (probabilistic measure of difference) between the joint or the product of marginals of the encoder input and output. The estimates provided by this second network can be used to maximize the mutual information of the features in the encoder with the input. Unlike other popular unsupervised auto-encoding approaches such as VAE 

[10], DIM doesn’t require a decoder. Hence it significantly reduces memory requirements of the model for volumetric data.

We evaluate DIM by performing a downstream classification task between four groups: patients with stable and progressive MCI, with Alzheimer’s disease and healthy controls, using only the resulting representation from DIM as input to the classifier. We compare DIM to two convolutional networks with AlexNet [11] and ResNet [12] inspired architectures trained with supervised learning. On strict evaluation, we show comparable performance to supervised methods and to previously reported [13, 14, 15, 16] classification performance.

Ii Materials and Methods

Ii-a Deep InfoMax

Let and be the input and output variables of a neural network encoder, with parameters , where , and are its domain and range. We wish to find the parameters that maximize the following objective:


where is the mutual information estimate provided by a different network with parameters , and is the output of the encoder.

A parametric estimator for the mutual information can be found by training a statistics network to maximize a lower bound based on the Fenchel-dual [17] or the Donsker-Varadhan representation [18, 19]

of the Kullback–Leibler divergence

. The Donsker-Varadhan-based estimator is a consistent, asymptotically unbiased estimator has been shown to outperform nonparametric estimators, and can also be used to improve deep generative models 

[19]. However, is unbounded, which can be problematic if the above estimators are used for training deterministic neural network encoders. [9] showed that using an estimator based on the Jensen-Shannon divergence (JSD) (i.e., simple binary cross-entropy) is more stable and works well in practice, and it has been shown that this estimator also yields a good estimator for mutual information [9, 20]:


where is a statistics network with parameters , (softplus function) and is another input sampled from the data distribution independently from . In addition, the Noise-Contrastive variant of the estimator (NCE) [21] was shown to work well in practice [9]:


Here, are a set of samples where are a set of negative samples drawn from the data distribution, such that there is exactly one positive example in ( occurs exactly once).

[9] showed that maximizing the mutual information between the complete input and output of an encoder are insufficient for learning good representations for downstream classification tasks, as this approach can still focus on lower-level “trivial” localized details. Instead, they show that maximizing the mutual information between the high-level representation, and patches of an input image can achieve highly competitive results. The intuition is that this approach encourages the high-level representation to learn information that is shared across the input. It is suitable for many classification tasks, as we expect that class-discriminative features should be evident across many spatial locations of the input. For a convolutional encoder , the local DIM objective can be written in a compact form:


where is a feature map location from encoder (with a limited receptive field corresponding to an input patch with size ) at some intermediate layer of the network.

Due to stronger performance of AlexNet architecture (Section (II-B)) in our experiments (see Section (IV)) we used it as an encoder for DIM method. Last linear layer of AlexNet we changed with a layer for -dimensional output representation.

To estimate mutual information using eq. (4) we used the encode-and-dot-product architecture (Fig.  from [9]). First, patches taken from third convolutional layer of AlexNet were mapped using convolutional encoder-and-dot architecture (Tab.  from [9]) with units and their representation — linear encoder-and-dot architecture (Tab.  from [9]). Then flattened encoded mappings of patches and representations were combined using the dot product to create real and fake samples efficiently. The real sample is a dot product of a “local” patch and its “global” representation mappings, while fake — between mapping of some “local” patch with global representation coming from an unrelated input. Eventually we estimated JSD based loss eq. (2) and NCE — eq. (3) using these samples. Since NCE needs to have more negative samples to be competitive with JSD [9], all possible combinations between the patch and representation mappings were used a similar way to create negative samples.

To evaluate the performance of the learned representation by DIM, we trained three additional neural networks using as input features output from last convolutional layer with size , the first fully connected layer with units, and final fully connected layer with -dimensional representation, which we call as Conv, FC, and . The classifiers are composed of one fully-connected layer with hidden units, dropout [22] with

, batch normalization 


and a ReLU 

[24] activation.

3D Conv - BN 3D - ReLU - MP 3D
3D Conv - BN 3D - ReLU - MP 3D
3D Conv - BN 3D - ReLU
3D Conv - BN 3D - ReLU
3D Conv - BN 3D - ReLU - MP 3D
Linear - BN 1D - ReLU
Linear - SoftMax - ArgMax
3D Conv - BN 3D - ReLU - MP 3D
Residual Layer 1
          BB0 - 2 x (3D Conv - BN 3D - ReLU)
          BB1 - 2 x (3D Conv - BN 3D - ReLU)
Residual Layer 2
          BB0 - 3D Conv - BN 3D - ReLU
          BB0 - 3D Conv - BN 3D - ReLU
          BB0 downsample - 3D Conv - BN 3D
          BB1 - 2 x (3D Conv - BN 3D - ReLU)
Residual Layer 3
          BB0 - 3D Conv - BN 3D - ReLU
          BB0 - 3D Conv - BN 3D - ReLU
          BB0 downsample - 3D Conv - BN 3D
          BB1 - 2 x (3D Conv - BN 3D - ReLU)
MaxPool 3D
Linear - BN 1D - ReLU
Linear - SoftMax - ArgMax
TABLE I: AlexNet and ResNet architectures

Ii-B Supervised baselines

As baselines we have considered supervised methods — two convolutional networks, one based on a simplified AlexNet [11] architecture and the other a ResNet [12]

architecture. Both networks use convolutions and max pooling with volumetric kernels, batch normalization, ReLU and two fully connected layers in the end (see Tab. (

I) for details). The notations in Tab. (I) denotes: BN for batch normalization, BB — a basic block, MP — max pooling with kernel size

and stride

, for convolutions

— a number of input and output channels, a kernel size, a stride and a padding respectively). Cross-entropy loss used as a training objective.

Ii-C Regularization

For small datasets, it is common to penalize the number of the model parameters by driving most of them to zero using regularization. Formally, this penalty is defined as:


where is parameter vector of the model and — coefficient. regularization imposes a sparse solution. This penalty is added to JSD, NCE and cross-entropy losses in different setting. For our experiments we used .

Iii Experiments

Iii-a Datasets and preprocessing

For the downstream classification task, the data was obtained from the ADNI database (for up-to-date information, see We use T1w MRI images of subjects with four different groups: patients with stable, and progressive MCI, Alzheimer’s disease and healthy controls.

Structural MRI (sMRI) data was pre-processed to grey matter volume (modulated) maps using SPM12 toolbox. To segment grey matter, the MRI images were spatially normalized and smoothed by 6 mm full width at half maximum (FWHM) 3D Gaussian kernel. After quality control, two subjects from ADNI dataset were excluded. The final dataset consisted of subjects with a volume size of .

Iii-B Experimental setup

Iii-B1 Data

The dataset was divided in approximately and subjects for cross-validation and hold-out test sets using a stratified split. Then, subjects were split into five stratified folds.

For AlexNet and ResNet architectures, we used simple data augmentation of the training dataset to reduce overfitting to the small number of annotated samples available. Our augmentation consisted of zero padding and random cropping to size

along all dimensions along with randomly flipping the input with probability

for each axis. The whole brain was included in the crop.

For DIM, we didn’t use data augmentation, but we used zero padding to make sure that input size is equal to along all dimensions.

Iii-B2 Training

The models were trained using the AMSGrad [25] optimizer with learning rate for CNN models and for DIM using a batch size of but dropping the incomplete last batch. The training of the supervised architectures was performed for epochs, DIM — for epochs as pre-training and for epochs for training the classifiers on top of frozen features from the encoder.

Iii-B3 Evaluation

Since the dataset is not completely balanced, the evaluation was performed using balanced accuracy [26], defined as the average of recall of each class (implementation in scikit-learn [27]).

Iii-B4 Implementation and hardware

The implementation was written using Deep Learning frameworks PyTorch 

[28] and Cortex [29]. The DIM code is based on openly available DIM implementation [30]. The experiments were performed on NVIDIA GeForce Titan X Pascal and 1080 Ti and 8 CPU threads.

Fig. 1: Performance of the models

width=center Model Balanced Accuracy Stratified 5-Fold Balanced Accuracy Hold-out Mean gap Wilcoxon test Stat -value AlexNet 6.5 AlexNet Aug 7.0 Sparse AlexNet Aug N/A N/A ResNet 10.0 0.034 ResNet Aug 14.0 0.039 Sparse ResNet Aug 10.0 0.033 JSD Conv 13.5 JSD Conv SS 13.5 Sparse JSD Conv 15.0 0.022 Sparse JSD Sparse Conv 6.0 NCE Conv 15.0 0.022 NCE Conv SS 10.0 0.034 Sparse NCE Conv 15.0 0.022 Sparse NCE Sparse Conv 15.0 0.022 JSD FC 15.0 0.022 JSD FC SS 14.0 0.040 Sparse JSD FC 15.0 0.022 Sparse JSD Sparse FC 15.0 0.022 NCE FC 15.0 0.022 NCE FC SS 15.0 0.020 Sparse NCE FC 15.0 0.022 Sparse NCE Sparse FC 15.0 0.022 JSD Z 15.0 0.021 JSD Z SS 15.0 0.022 Sparse JSD Sparse Z 15.0 0.022 Sparse JSD Z 15.0 0.022 NCE Z 15.0 0.022 NCE Z SS 14.0 0.040 Sparse NCE Sparse Z 15.0 0.022 Sparse NCE Z 15.0 0.021

TABLE II: Performance

Iv Results

The final trained models used further to evaluate the performance were selected based on the best-balanced accuracy but from a checkpoint where the validation score was lower than the training score. We gave the model a burn-in period before applying this rule to deal with initial stochasticity. The models notations are as follows: Aug denotes augmentation of the training dataset, the first sparse — a model trained with regularization, the second — a classifier on top of the frozen features from encoders trained using regularization, SS — stands for training an unsupervised model with an additional supervised loss from -classifier.

Table II

reports the balanced accuracy rates including mean, standard deviation values, and the gap between mean values on cross-validation and hold-out. The bold text distinguishes the best scores and the name of the models. The last column shows

-value and statistic for the one-sided Wilcoxon test. The bold

-values indicate acceptance of the null hypothesis. The test was performed to compare each method with the best model (

Sparse AlexNet Aug) based on the five values of balanced accuracy on hold-out. An alternative hypothesis is that the model Sparse AlexNet Aug is better. Fig. 1 highlights the distributions of the performance.

With all modifications, ResNet shows a lower performance on hold-out (at most ) than AlexNet. It is reasonable since the capacity of the ResNet architecture is larger and the dataset is small. For Wilcoxon test also rejects H0 supporting the worse performance of ResNet. Performance of JSD Conv, JSD Conv SS, Sparse JSD Sparse Conv, AlexNet, AlexNet Aug is statistically indistinguishable from that of Sparse AlexNet Aug. Follows that unsupervised DIM has comparable performance to supervised methods.

Among DIM variants, JSD has higher scores than NCE. Lower scores of NCE can be explained by its requirement of a large number of negative samples during training to be competitive with JSD. Our dataset is not large enough to support the needed level of negative sampling.

The best score with convolutional features——was obtained by an encoder and classifier trained with regularization which is the Sparse JSD Sparse Conv model. For features from the fully-connected layer — JSD FC SS model with using semi-supervised loss was the best. However, Sparse JSD Sparse FC has similar results and a smaller mean gap but it has a lower mean cross-validation score by . For the smallest -dimensional representation, semi-supervised model JSD Z SS gives the best performance , but similar result were obtained by Sparse NCE Z model. Semi-supervised loss and regularization improved models’ generalization by reducing the gap between cross-validation and hold-out scores. The observed degradation in performance between Conv, FC, and can be explained by the reduced capacity of the features. regularization and dropout could also be adjusted. However, a more compact input representation can be of independent use, for example, for dimensionality reduction.

In previous studies, the best reported accuracy for the ResNet architecture in a 4-class sMRI classification task was  [13]

, while stacked autoencoders (SAE) 

[15] reached for sMRI only and for sMRI+PET , and DW-SMTL [16] — for sMRI or for sMRI+PET+CSF . Our values can’t be completely comparable since the evaluation is different. Reproduced ResNet can be used as a proxy to estimate performance relative to this prior work. Note, however, it is not one of the best-performing methods in our study.

V Conclusions

This work proposes an unsupervised method DIM for learning representations from structural neuroimaging data. The evaluation of the prediction of progression to Alzheimer’s disease demonstrates results comparable to supervised methods. In the future, we will scale up our experiments with increased sample size and address the cases of other diseases. Our future efforts will also be focused on the multi-modal fusion of brain imaging data [31] to increase the predictive strength of the model.


This study is supported by NIH grants R01EB020407, R01EB006841, P20GM103472, P30GM122734.

Data collection and sharing for this project was funded by the Alzheimer’s Disease Neuroimaging Initiative (ADNI) (National Institutes of Health Grant U01 AG024904) and DOD ADNI (Department of Defense award number W81XWH-12-2-0012). ADNI is funded by the National Institute on Aging, the National Institute of Biomedical Imaging and Bioengineering, and through generous contributions from the following: AbbVie, Alzheimer’s Association; Alzheimer’s Drug Discovery Foundation; Araclon Biotech; BioClinica, Inc.; Biogen; Bristol-Myers Squibb Company; CereSpir, Inc.; Cogstate; Eisai Inc.; Elan Pharmaceuticals, Inc.; Eli Lilly and Company; EuroImmun; F. Hoffmann-La Roche Ltd and its affiliated company Genentech, Inc.; Fujirebio; GE Healthcare; IXICO Ltd.;Janssen Alzheimer Immunotherapy Research & Development, LLC.; Johnson & Johnson Pharmaceutical Research & Development LLC.; Lumosity; Lundbeck; Merck & Co., Inc.;Meso Scale Diagnostics, LLC.; NeuroRx Research; Neurotrack Technologies; Novartis Pharmaceuticals Corporation; Pfizer Inc.; Piramal Imaging; Servier; Takeda Pharmaceutical Company; and Transition Therapeutics. The Canadian Institutes of Health Research is providing funds to support ADNI clinical sites in Canada. Private sector contributions are facilitated by the Foundation for the National Institutes of Health ( The grantee organization is the Northern California Institute for Research and Education, and the study is coordinated by the Alzheimer’s Therapeutic Research Institute at the University of Southern California. ADNI data are disseminated by the Laboratory for Neuro Imaging at the University of Southern California.