On self-supervised multi-modal representation learning: An application to Alzheimer's disease

by   Alex Fedorov, et al.

Introspection of deep supervised predictive models trained on functional and structural brain imaging may uncover novel markers of Alzheimer's disease (AD). However, supervised training is prone to learning from spurious features (shortcut learning) impairing its value in the discovery process. Deep unsupervised and, recently, contrastive self-supervised approaches, not biased to classification, are better candidates for the task. Their multimodal options specifically offer additional regularization via modality interactions. In this paper, we introduce a way to exhaustively consider multimodal architectures for contrastive self-supervised fusion of fMRI and MRI of AD patients and controls. We show that this multimodal fusion results in representations that improve the results of the downstream classification for both modalities. We investigate the fused self-supervised features projected into the brain space and introduce a numerically stable way to do so.


page 3

page 4


Self-supervised multimodal neuroimaging yields predictive representations for a spectrum of Alzheimer's phenotypes

Recent neuroimaging studies that focus on predicting brain disorders via...

Taxonomy of multimodal self-supervised representation learning

Sensory input from multiple sources is crucial for robust and coherent h...

Multimodal Self-Supervised Learning of General Audio Representations

We present a multimodal framework to learn general audio representations...

Self-supervised Feature Learning via Exploiting Multi-modal Data for Retinal Disease Diagnosis

The automatic diagnosis of various retinal diseases from fundus images i...

Multimodal Representations Learning and Adversarial Hypergraph Fusion for Early Alzheimer's Disease Prediction

Multimodal neuroimage can provide complementary information about the de...

Brain-Aware Replacements for Supervised Contrastive Learning

We propose a novel framework for Alzheimer's disease (AD) detection usin...

BHN: A Brain-like Heterogeneous Network

The human brain works in an unsupervised way, and more than one brain re...

Code Repositories


Fusion is a self-supervised framework for data with multiple sources — specifically, this framework aims to support neuroimaging applications.

view repo

1 Introduction

Diagnosing pathologies from raw medical imaging outputs is often a more complex problem than the idealized problems faced in non-medical image classification. Single modalities often don’t contain enough information requiring multi-modal fusion of multiple distinct sources of data (as is commonly the case for MRI). Besides, multiple data sources contain a wealth of complementary information and not enough redundancy to easily align them. As a result, correctly utilizing the different sources can be key to designing robust diagnostic tools.

While supervised approaches might be prone to shortcut learning [1] and require more data we tend to explore unsupervised methods [2]. Some common approaches to tackle learning with multiple sources are inspired by Deep CCA [3], parallel ICA [4], and recently by variational direction such as MMVAE [5]. A new breed is recently emerging that relies on deep learning in achieving powerful representations using a self-supervised approach. Specifically, the algorithms are based on the maximization of mutual information (DIM) [6]. However, ML field is breeding a zoo of various seemingly unique methods exploiting this approach [7, 8, 9, 10, 11] while they can be unified under one paradigm. We investigate all of the existing and generate some yet unpublished methods under the same framework to apply to AD and healthy controls (HC) data on functional (f)MRI and structural (s)MRI modalities. The proposed approach shows great promise empirically.

Our contributions are as follows:

  • We compare and contrast all approaches by their effect on the downstream classification task.

  • We also show that representation similarity of the learned AD embeddings does not necessarily lead to better classification yet allows us to uncover links between modalities.

  • We report an improved and numerically stable method of investigating thus obtained multimodal features via model introspection as a statistical test contrasts.

  • We performed all experiments on a large AD/MCI dataset.

2 Methods

2.1 Problem definition

Let be the dataset of N paired images of different modalities: is T1 and — fALFF (see Section 3.1 for details). We want to learn compressed and semantically meaningful latent representations of each modality . The latent representation

is an vector

encoding the image through encoder

parametrized by a neural network with parameters

as .

To learn the set of parameters we want to optimize the objective defined as:

where is uni-modal () or multi-modal () objective. In this work we specifically explore the decoder-free objectives based on maximization of mutual information.

Figure 1: The sample pair of T1 and fALFF from OASIS-3, and a scheme of the method. The arrows represent objectives.

2.2 Mutual Information Maximization

Here we utilize the familiar InfoNCE [12]

based estimator for a lower bound of mutual information.

where is a separable critic function, where is a dimension of latent representation. The embeddings and are computed using additional projections or parametrized by neural networks for latent representation or features from th layer .

The idea behind this estimator to learn representations such that .

2.3 Constructions of the objectives

Using these definitions we can construct different ways of maximizing mutual information with multi-source data which are shown schematically in Figure 2. The edges represent pairs of features used in the objective. Specifically, L is known as local deep InfoMax (DIM) [6], where we train . The methods CL and CS are used by AMDIM [7], ST-DIM [8] and CM-DIM [9], where — and , respectively. The method S is originally shown by CMC [10] and then perfected by SimCLR [11], where — . Each variant implies own inductive bias on predictability between embeddings. As L and CL imply the InfoMax principle. The objectives of type CS and S maximize the similarity between convolutional features and latent variables on the same level respectively.

For completeness, we compare DIM-based methods to related DCCAE [13], MMVAE with looser IWAE estimator [5]. We also combine the CCA objective with L-objective (L-CCA) and AE with the Similarity objective (S-AE) to create new combinations. However, for DCCAE we do not layer-wise pre-train the encoder as in the original work before fine-tuning to a CCA objective – our focus here is on the pure end-to-end approaches. Additionally, we train a supervised model for an OASIS dataset to get an approximate bound of what is achievable to compare to all architectures. The schemes for baseline and additional combinations are shown in Figure 2.

3 Experimental setup

Figure 2:

Hold-out dowsntream perfomance on OASIS dataset with Logistic Regression trained on representations and cross-modal CKA similarity of the latent representation in different groups. We explore combinations of the 6 possible connections and more. The label letters are L–local, CL–cross-local, CS–cross-spatil, S–similarity. See main text for details.

3.1 Dataset

We choose neuroimaging dataset OASIS-3 [14] to study Alzheimer’s disease. As modalities, we selected T1 and fractional amplitude of low-frequency fluctuation (fALFF) prepared from T1w and computed from resting-state fMRI respectively.

Resting-state fMRI time series were registered to the first image in the series using mcflirt in FSL [15]

(v 6.0.2), using a 3-stage search level (8mm, 4mm, 4mm), 20mm field-of-view, 256 histogram bins (for matching), 6 degrees-of-freedom (dof) for transformation, a scaling factor of 6mm, and normalized correlation values across images as the cost function (smoothed to 1mm). Final transformations and outputs were interpolated using splines as opposed to the default trilinear interpolation. Then a fractional amplitude of low-frequency fluctuation (fALFF) map was then computed in 0.01 to 0.1 Hz power band using REST 

[16]. After visual check 15 T1w images were removed. T1w images were brainmasked using bet from FSL. Afterward, both fALFF and T1 images were linearly (7 dof) converted to MNI space and resampled to 3mm resolution. The final volume size for moth modalities is . The data preprocessing is minimized intentionally to reduce its impact on training with deep neural networks and transformed to simplify analysis.

After an analysis of demographic data, we only leave Non-Hispanic Caucasians (totaling 826 subjects) since other groups are underrepresented. For the group with Alzheimer’s disease (AD), we choose all subjects with confirmed AD records and the healthy cohort (HC) are cognitively normal subjects. The subjects with other conditions are used as an additional group only in unsupervised pretraining. For pretraining we combine all possible pairs (4021 pairs) which are closest by days of multi-modal images for each individual subject. During the final evaluation, we leave only 1 pair for each subject.

We split subjects on 5 stratified (70% healthy, 15% AD, 15% other) cross-validation folds (580-582 subjects (2828-2944 pairs), 144-146 subjects (653-769 pairs)) and hold-out (100 subjects (424 pairs)). Then we apply histogram standardization based on each training subset and z-normalization to images using TorchIO library [17]. For pretraining on OASIS-3 we use random flips, random crops as data augmentation. During optimization, we also utilize class balanced data sampler [18].

3.2 Architecture, Hyperparameters, and Optimization

In our experiments, we use DCGAN [19]. DCGAN is a convolutional architecture with encoder and decoder. The last layer maps input features to a -dimensional latent representation. The convolutional projection heads consist of 2-convolutional layers with kernel size , input equal to a number of features of the selected layer with feature side size in the encoder, and output — to . The latent projection heads are chosen to be identity. All the weights of projections are shared across all contrastive objectives.

We penalize each contrastive bound with squared matching scores of the critic with (except experiment with OASIS) and clip values of the critic by with . The projections are shared across different objectives. Thus, the optimization of the objective can be considered as multi-task learning.

To train the weights of the neural networks we used RAdam [20] with learning rate and OneCycleLR [21] scheduler with maximum learning rate for epochs with batch size . However, the model MMVAE we could train only with batch size due to memory constraints and 3 folds out of 5 did not converge.

4 Results

Downstream task

To evaluate the representation on a downstream task we trained the Logistic Regression on top of the representation produced by the pre-trained encoder. To choose hyperparameters of Logistic Regression we searched the space using Optuna 

[22] over fold cross-validation by computing the mean ROC AUC as a score for iterations. Inverse regularization strength is sampled log-uniformly from the interval, the penalty is chosen from L1, L2, or elastic net, the elastic net mixing parameter is sampled uniformly from unit interval. The solver is saga.

The results are shown in Figure 2. The models are sorted by the average AUC across modalities. Overall, most combinations of contrastive objectives outperform CCA-based DCCAE and variational MMVAE. The best-unsupervised result for T1 is by unimodal AE and by multimodal S , for fALFF — by multi-modal L-CL . Comparing fALFF results for S and AE methods we notice that the performance is lower by , thus similarity might degrade the performance. Interestingly, while for T1 the supervised model is still the leader by , the unsupervised method L-CL surpasses it by for fALFF. We argue that the multi-modal objective has a regularizing effect. Additionally, the method S-AE might be a good candidate for future analysis as it combines reconstruction error and maximization of mutual information from two perspectives while preserving high downstream performance and higher similarity of the representation (as we show it in the next subsection using similarity analysis).

Representational Insights using Similarity

To better understand the influence of different multi-source objectives on the latent representation we employ SVCCA [23] and CKA [24]. We compute the similarity of the representation between modalities in different groups. The metrics are shown in Figure 2.

Per SVCCA metric, models behave similarly across AD patients. While the models AE, S-AE, DCCAE, S, S-CS have a noticeable difference compared to other models on HC indicating higher correlation. The AD patients also have a higher correlation between modalities compared to HC subjects. This might indicate that healthy subjects have richer representation within modality thus fewer similarities between modalities. Given performance on the downstream task, we can conclude that unsupervised learning can capture group differences even without prior knowledge about them.

The CKA metric shows significant differences in representation between models, even though most of the methods are very close in their predictive performance. This empirically supports the hypothesis that similarity does not guarantee higher downstream performance or can have a regularizing effect.

Representational Insights using Saliency

Figure 3: Left column: the highest correlated (0.34) pair of mean saliency images for Supervised method. Right column: the highest correlated (0.85) pair of mean saliency images for S-AE method.
Figure 4:

Group differences on T1 and fALFF are shown using effect size RBC. Left column is for dimensions with highest positive beta in Logistic Regression, right column — with highest negative beta. Odd rows are for Supervised model, even rows — for S-AE. The coordinates are choosen by absolute peak value of the effect size.

To gain an additional understanding of how representations behave in uni-source and multi-source, supervised and unsupervised settings with respect to brain and groups we utilize sensitivity analysis based on SmoothGrad [25] with std and 5 iterations. We compute gradients for each dimension of the latent representation, instead of computing them based on a label. After computing sensitivity maps we apply brain masking, rescale gradient values to a unit-interval, smoothing them with Gaussian filter ().

Using computed saliencies, we study how dimensions of the latent representation correspond to the input image in unimodal and multimodal scenarios. Given one dimension in T1 and another in fALFF, we compute correlations across subjects for each dimension. Then we select the highest correlated pairs (0.85 for S-AE and 0.34 for Supervised) and show thresholded mean saliency on each modality in Figure 3. Unsupervised multimodal method S-AE shows highly spatially related saliencies between modalities and also related to the supervised method, which shows why this method has comparable performance. While the supervised method has no relation between modalities.

Using Logistic Regression we selected dimensions with the biggest positive and negative beta value. Then we study the group differences using a voxel-wise Mann-Whitney U Test on computed saliencies and report rank-biserial correlation (RBC) as effect size. The results for S-AE and Supervised model are shown in Figure 4. The main finding is that the Supervised method “looks” at ventricles (on both modalities) and Inferior frontal guys (on T1) in the brain. While the supervised method captures the ventricles, this trivial marker does not benefit our understanding of brain degeneration. Unsupervised methods learn more general representation. For example, S-AE captures non-trivial locations in the brain, which might be interesting and need to be analyzed much more closely.

5 Conclusions

We investigated previous and introduced new approaches for multi-modal representation learning using advances in self-supervised learning. Applying our approach to the OASIS dataset, we evaluated learned representation with multiple tools and obtained strong empirical insights for further development in data fusion. Our findings indicate the high potential of DIM based methods for addressing the shortcut learning problem.

6 Compliance with Ethical Standards

This research study was conducted retrospectively using human subject data made available in open access by OASIS-3 [14]. Ethical approval was not required as confirmed by the license attached with the open-access data.

7 Acknowledgments

This study is supported by NIH R01 EB006841.

Data were provided in part by OASIS-3: Principal Investigators: T. Benzinger, D. Marcus, J. Morris; NIH P50 AG00561, P30 NS09857781, P01 AG026276, P01 AG003991, R01 AG043434, UL1 TR000448, R01 EB009352. AV-45 doses were provided by Avid Radiopharmaceuticals, a wholly-owned subsidiary of Eli Lilly.