Taxonomy of multimodal self-supervised representation learning

12/25/2020 ∙ by Alex Fedorov, et al. ∙ 12

Sensory input from multiple sources is crucial for robust and coherent human perception. Different sources contribute complementary explanatory factors and get combined based on factors they share. This system motivated the design of powerful unsupervised representation-learning algorithms. In this paper, we unify recent work on multimodal self-supervised learning under a single framework. Observing that most self-supervised methods optimize similarity metrics between a set of model components, we propose a taxonomy of all reasonable ways to organize this process. We empirically show on two versions of multimodal MNIST and a multimodal brain imaging dataset that (1) multimodal contrastive learning has significant benefits over its unimodal counterpart, (2) the specific composition of multiple contrastive objectives is critical to performance on a downstream task, (3) maximization of the similarity between representations has a regularizing effect on a neural network, which sometimes can lead to reduced downstream performance but still can reveal multimodal relations. Consequently, we outperform previous unsupervised encoder-decoder methods based on CCA or variational mixtures MMVAE on various datasets on linear evaluation protocol.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 4

page 5

Code Repositories

fusion

Fusion is a self-supervised framework for data with multiple sources — specifically, this framework aims to support neuroimaging applications.


view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The idealized tasks on which machine learning models are benchmarked commonly involve a single data-source and readily available labels. By contrast, real-life data is often composed of multiple sources: different MRI modalities 

[6] in medical imaging, LiDAR and video for self-driving cars [35], data influenced by confounders [25]

. In addition to this, labels can be poor or scarce, which leads to the need for unsupervised or at least semi-supervised learning. In this work, we will be addressing both these constraints simultaneously, by proposing an analysis of unsupervised learning approaches to multi-source data, while contrasting these to more classical methods. The unsupervised learning methods we will consider will be reliant on contrastive learning.

Classical approaches to multi-source data include canonical correlation analysis (CCA) [11], which finds maximally correlated linear projections of two data sources. More recently, CCA has been extended to allow for representations obtained from neural networks in works such as DCCA [3] and DCCAE [34].

In addition to this family of methods, there are generative variational approaches. Specifically, a variational multi-modal mixture of experts (MMVAE) [26] has resulted in large performance improvements.

Figure 1:

The sample images from Two-View MNIST, MNIST-SVHN, OASIS-3, and a general scheme of contrastive self-supervised methods. The arrows represent all possible combinations as coordinated connections between vectors of locations (along convolutional channels) or whole latent representation. The red arrow is Local (L), the pink —Cross-Spatial (CS), the burgundy — Cross-Local (CL), the purple — Cross-Latent (S).

Contrastive objectives have in recent years become essential components of a large number of unsupervised learning methods. Mutual information estimation 

[5] has inspired a number of successful uses to a single (DIM [9], CPC [17]) and multi-view (AMDIM [4], CMC [31], SimCLR [7]

) image classification, reinforcement learning (ST-DIM 

[2]) and zero-shot learning (CM-DIM [29, 30]). Such systems have resulted in large improvements to representation learning by considering different views of a single input. These and other related methods mostly operate in an unsupervised fashion, where the goal is to encourage similarity between transformed representations of a single input. These objectives can also be readily applied to the present context whereby different sources can be understood as different views of the same data point. In addition to the relative paucity of literature on this topic, we also note that there are no studies that consider explicit combinations of these objectives. This work aims to solve both issues. Our contributions are as follows:

  • We show empirically that multimodal contrastive learning has significant benefits over its unimodal counterpart.

  • By analyzing the effect of adding different contrastive objectives, we show that correctly combining such objectives has a critical effect on performance.

  • Even in cases where the similarity metric has a detrimental effect, we propose a means by which it can be used for similarity analysis.

2 Methods

2.1 Problem setting

Let be set of datasets with samples in each. For each th dataset define a sampled image , an CNN encoder , a convolutional feature from a fixed layer of the encoder as and a latent representation defined as . In case of AE-based approaches we also need a reconstruction version of the sample .

To learn the set of encoders we want to maximize the objective defined as:

where is a loss and is a weight coefficient between datasets and

. There are multiple choices of the loss functions one can choose from. In this study, we are specifically exploring the self-supervised contrastive objectives based on the maximization of mutual information as a choice for

.

Figure 2: Possible variants of connected pairs to maximize mutual information. L is a local unimodal objective. CL and CS are cross-local and cross-spatial multi-source objectives. S is a cross-latent multi-source objective. Others are just combinations of 2, 3 and 4 objectives constructed from first 4 objectives.
Figure 3: Scheme for Supervised, AE, DCCAE, L-CCA, S-AE and MMVAE.

2.2 Contrastive mutual information maximization

While there are other choices as JSD and NWJ estimators [32] we utilize the most common estimator InfoNCE [17] for maximizing a lower bound of the mutual information which we define as:

where is a critic function, are some embeddings. The embeddings are obtained through additional projections of a corresponding location of convolutional feature or a latent (e.g. and ) (also known as projection heads [7]) parametrized by separate neural networks.

To describe a critic function, we define positive and negative pairs. A pair

called positive if it is sampled from a joint distribution

and negative — from product of marginals . For example, a single entity can be represented differently in the dataset and . More specifically, the digit ”1” can be represented by an image in the multiple domains as a handwritten digit in MNIST and house number in the SVHN dataset. Then a pair constructed from MNIST and SVHN by sampling a digit ”1” will be positive and negative — if we choose different digit from one of the datasets (such as ).

The idea behind a critic function to assign higher values to positive pairs and less to negative pairs. Our choice of the critic function for this study is a separate critic as in AMDIM [4] implementation (e.g. there are other possible choices such as billinear, concatenated critics [32]). Such critic is equivalent to scaled-dot product used in transformers [33].

Figure 4: Test dowsntream perfomance with linear evaluation on Two-View MNIST.

2.3 Objectives

All recent contrastive self-supervised methods incorporate in some way the estimator of mutual information which we schematically show in Figure 1. The origin of all these methods is based on the idea of Local DIM [10] which we show it as L. This method maximizes the mutual information between the location (where representation considered along channels) of convolutional feature and . Further, AMDIM [4], ST-DIM [2] and CM-DIM [29] incorporate Cross-Local (CL) and Cross-Spatial (CS) objectives. In CL setting, researcher have to pair latent with , , and in CS setting — locations in and , . The last connections were introduced by CMC [31] and SimCLR [7]. These two methods connect and , which we call cross-latent similarity (S). The first two objectives L and CL induce Deep InfoMax principle while S and CS — similarity which is closely connected to CCA idea.

Given these basic contrastive objective researchers can combine them as in Figure 2. We treat each edge as a type of objective. For a full picture, one can combine previous approaches as AE and DCCA [3] with contrastive objectives as S-AE and L-CCA. We show them schematically in Figure 3 with other baselines: uni-source Supervised and AE, and multi-source DCCAE [34] and MMVAE [26] with loose IWAE estimator. We trained the Supervised model specifically to get a discriminative bound on a multi-modal dataset.

Figure 5: Dowsntream test perfomance with linear evaluation on MNIST-SVHN. We can see that overall, cross-source contrastive losses fare better than mono-source contrastive losses.

3 Experiments

3.1 Datasets

For our experiments, we incorporate diverse datasets with multiple sources (shown in Figure 1).

3.1.1 Multi-view dataset

Two-View MNIST is inspired by [34] where each view represents a corrupted version of an original MNIST digit. First, the intensities of the original image are rescaled to the unit interval. Then we resized the image to to specifically fit DCGAN architecture. Lastly, to generate the first view we rotate the image by a random angle from interval. Then for the second view — we add unit uniform noise and rescale intensity again to a unit interval.

3.1.2 Multi-domain dataset

Multi-domain dataset MNIST-SVHN is used by the authors [26] where the first view is grayscale MNIST digit and the second view — RGB street view house number sampled from the SVHN dataset. We only modified MNIST digits by resizing an image to to use with DCGAN encoder. This dataset represents the more complicated case when the digit is represented by different underlying domains. All intensities are scaled to a unit interval.

3.1.3 Multi-modal dataset

For experiments with a multi-modal dataset, we utilize data provided by OASIS-3 [15] to evaluate representation on Alzheimer’s disease (AD) classification. We preprocessed fMRI into fALFF (in 0.01 to 0.1 Hz power band) and T1w to a structural MRI (sMRI) by brain masking using REST [28] and FSL [12] (v 6.0.2). All images are linearly converted to MNI space and resampled to 3mm resolution. The final input volume is . After careful selection (removing bad images, selecting most represented non-Hispanic Caucasian subset) we left 826 subjects. For each subject, we combined sMRI and fALFF into 4021 pairs. We left 100 () subjects for hold-out and used others in stratified (about ) 5-folds for training and validation. We defined 3 groups: healthy cohort (HC), AD, and others (subjects with other brain problems). During pretraining, we employ all groups and pairs, while during linear evaluation we take only one pair for each subject and use only HC or AD subjects. As additional preprocessing, we applied histogram standardization and z-normalization. During pretraining, we also use simple data-augmentation as random crops and flips. For the last two steps, we used the TorchIO library [20]. Since this dataset is highly unbalanced we utilize a class balanced data sampler [8].

3.2 Evaluation

3.2.1 Linear evaluation on downstream task

To evaluate representation on natural images we employ the linear evaluation protocol which is common for self-supervised approaches [4, 7]

. The methodology implies training a linear mapping from a latent representation to a number of classes. During this process, the encoder kept frozen. For our task, we evaluate each encoder separately. For OASIS-3 dataset we use Logistic Regression instead. The hyperparameters (inverse regularization, penalty) of Logistic Regression 

[19] have been optimized using Optuna [1].

3.2.2 Measuring similarity between representations

To better understand the underlying inductive bias of the specific objective we compare similarity between representation of the sources using SVCCA [23], which is mean correlation of aligned directions, and CKA [14], which is shown to reliably identify the relationship between representations of networks.

3.3 Implementation details

The architecture and hyperparameters for encoders for each source are completely based on DCGAN [22]. For decoder-based methods, we also utilized DCGAN decoder. However, we removed one layer for experiments with natural images to use an input size of . Encoders project input to a -dimensional vector . All the layers are initialized with a uniform Xavier.

The projection layer for latent representation is identity. For convolutional features

we use a 2-layer convolutional neural network with kernel-size

and a number of hidden channels equal to the dimension of the latent representation . The convolutional features are taken from the layer with a feature side size . The critic function eventually calculates the score on dimensional space. The parameters of the projection layers are shared for each loss, but not between sources.

We use RAdam [16](), OneCycleLR scheduler [27](). The pretraining task lasts over and epochs for natural images and OASIS-3, respectively. The linear for natural images is trained for epochs, and for logistic regression, we run Optuna for iterations. All experiments performed with a batch size of . In some runs, we noticed that CCA-objective and MMVAE are not stable. The MMVAE method was only able to train 2 folds out of 5 and with a batch size of due to memory constraints on OASIS-3 dataset. For contrastive objectives, we additionally practice a penalty (except for OASIS) and clipping as in AMDIM [4].

3.4 Implementation, Code and Computational Resources

The code was written using PyTorch 

[18] and Catalyst framework [13]. For data transforms of the brain images we utilized TorchIO [20], for CKA analysis of the representation — code by anatome [24], SVCCA [23], for AMDIM [4], for DCCAE [21], for MMVAE and MNIST-SVHN [26]. The experiments were performed with NVIDIA DGX-1 V100.

Figure 6: Test downstream performance with Logistic Regression, SVCCA and CKA similarity on OASIS-3.

4 Results

As we can see in Figures 4 and 5, the presence of cross-modal contrastive losses has a strong positive impact on downstream test performance, across different architecture and model choices. We also note that the formulations have different performance across settings and datasets, leading to the conclusion that applying them in practice requires careful adaptation to a given problem.

While in simple multi-view case contrastive method are absolute leaders, in multi-domain experiments reconstruction based models such as MMVAE and S-AE stand out. However, the performance for most of the models within from S-AE. Thus one can choose decoder-free self-supervised approaches to reduce computational cost. Uni-source AE, L, and multi-source DCCAE, L-CCA are clearly not able to learn the SVHN.

Figure 6 shows the results on the OASIS-3. Multi-source approaches exhibit strong performance, but the difference with other methods is less noticeable than for the previous tasks. As we discussed in the multi-domain experiment, here AE is also important to learn the modality. While similarity-based method S has the highest metric on T1, however it significantly bad on fALFF. It might indicate that modality T1 dominated fALFF during training. We think that method S is related to CCA, but it is a more stable objective. Thus S can be a good candidate to substitute CCA. Adding reconstruction to contrastive objective (as CCA and AE in DCCAE) we achieve S-AE. It helps to regularize the model. S-AE is equivalent to AE by downstream performance. However, it has a higher CKA similarity between representation, which can be used in multi-modal fusion [6] and clearly substitutes DCCAE [34].

On the fALFF modality, the absolute leader is L-CL, while it is better than the Supervised model by and comparable to AE and S-AE. Thus multi-source objective can help learn a representation of a struggling modality.

By SVCCA metric most self-supervised method and Supervised model are lower than MMVAE, DCCAE, S, S-AE, and AE. However, it does not contrast the relationship between the representations in different modalities as CKA.

5 Conclusion

In this work, we proposed a unifying view on contrastive methods and benchmarked them along baselines when applied to multi-source data. We believe that this unifying view will boost further understanding of how to learn powerful representations from multiple sources. Hopefully, instead of combining the similarities in various ways and publishing the winning combinations as individual methods the field moves to taking a broader perspective on the problem.

We empirically demonstrated that multi-modal contrastive approaches result in performance improvements over methods that rely on a single modality for contrastive learning. We also showed that downstream performance is highly dependent on how such objectives are composed. We argue that the similarity might not guarantee higher downstream performance. In some cases, it may weaken the representation or have a regularization effect on the objective. The highest similarity between representations can be important for other applications, i.e. multimodal analysis [6].

Not having to train the decoder, self-supervised models significantly reduce computation costs. While keeping comparably high downstream performance they can democratize medical imaging by lowering the hardware requirements.

For future work, we are interested in considering how the conclusions we draw here hold in different learning settings with scarcer data or annotations such as few-shot or zero-shot learning cases.

6 Acknowledgments

This work is supported by NIH R01 EB006841.

Data were provided in part by OASIS-3: Principal Investigators: T. Benzinger, D. Marcus, J. Morris; NIH P50 AG00561, P30 NS09857781, P01 AG026276, P01 AG003991, R01 AG043434, UL1 TR000448, R01 EB009352. AV-45 doses were provided by Avid Radiopharmaceuticals, a wholly-owned subsidiary of Eli Lilly.

References

  • [1] T. Akiba, S. Sano, T. Yanase, T. Ohta, and M. Koyama (2019) Optuna: a next-generation hyperparameter optimization framework. In ICKDM, Cited by: §3.2.1.
  • [2] A. Anand, E. Racah, S. Ozair, Y. Bengio, M.A. Côté, and D.R. Hjelm (2019) Unsupervised state representation learning in atari. In NeurIPS, Cited by: §1, §2.3.
  • [3] G. Andrew, R. Arora, J. Bilmes, and K. Livescu (2013) Deep canonical correlation analysis. In ICML, pp. 1247–1255. Cited by: §1, §2.3.
  • [4] P. Bachman, D.R. Hjelm, and W. Buchwalter (2019) Learning representations by maximizing mutual information across views. arXiv preprint arXiv:1906.00910. Cited by: §1, §2.2, §2.3, §3.2.1, §3.3, §3.4.
  • [5] M.I. Belghazi, A. Baratin, S. Rajeswar, S. Ozair, Y. Bengio, A. Courville, and D.R. Hjelm (2018) Mine: mutual information neural estimation. arXiv preprint arXiv:1801.04062. Cited by: §1.
  • [6] V.D. Calhoun and J. Sui (2016) Multimodal fusion of brain imaging data: a key to finding the missing link (s) in complex mental illness. Biological psychiatry: cognitive neuroscience and neuroimaging 1 (3), pp. 230–244. Cited by: §1, §4, §5.
  • [7] T. Chen, S. Kornblith, M. Norouzi, and G. Hinton (2020) A simple framework for contrastive learning of visual representations. arXiv preprint arXiv:2002.05709. Cited by: §1, §2.2, §2.3, §3.2.1.
  • [8] A. Hermans, L. Beyer, and B. Leibe (2017) In defense of the triplet loss for person re-identification. arXiv preprint arXiv:1703.07737. Cited by: §3.1.3.
  • [9] D.R. Hjelm, A. Fedorov, S. Lavoie-Marchildon, K. Grewal, P. Bachman, A. Trischler, and Y. Bengio (2018) Learning deep representations by mutual information estimation and maximization. arXiv preprint arXiv:1808.06670. Cited by: §1.
  • [10] D.R. Hjelm, A. Fedorov, S. Lavoie-Marchildon, K. Grewal, P. Bachman, A. Trischler, and Y. Bengio (2019) Learning deep representations by mutual information estimation and maximization. ICLR. Cited by: §2.3.
  • [11] H. Hotelling (1992) Relations between two sets of variates. In Breakthroughs in statistics, pp. 162–190. Cited by: §1.
  • [12] M. Jenkinson, P. Bannister, M. Brady, and S. Smith (2002) Improved optimization for the robust and accurate linear registration and motion correction of brain images. Neuroimage 17 (2), pp. 825–841. Cited by: §3.1.3.
  • [13] S. Kolesnikov (2018) Accelerated deep learning r&d. GitHub. Note: https://github.com/catalyst-team/catalyst Cited by: §3.4.
  • [14] S. Kornblith, M. Norouzi, H. Lee, and G. Hinton (2019) Similarity of neural network representations revisited. arXiv preprint arXiv:1905.00414. Cited by: §3.2.2.
  • [15] P. J. LaMontagne, T.L.S. Benzinger, J.C. Morris, S. Keefe, R. Hornbeck, C. Xiong, E. Grant, J. Hassenstab, K. Moulder, A. Vlassenko, M.E. Raichle, C. Cruchaga, and D. Marcus (2019) OASIS-3: longitudinal neuroimaging, clinical, and cognitive dataset for normal aging and alzheimer disease. medRxiv. External Links: Document, Link, https://www.medrxiv.org/content/early/2019/12/15/2019.12.13.19014902.full.pdf Cited by: §3.1.3.
  • [16] L. Liu, H. Jiang, P. He, W. Chen, X. Liu, J. Gao, and J. Han (2019)

    On the variance of the adaptive learning rate and beyond

    .
    arXiv preprint arXiv:1908.03265. Cited by: §3.3.
  • [17] A. Oord, Y. Li, and O. Vinyals (2018) Representation learning with contrastive predictive coding. arXiv preprint arXiv:1807.03748. Cited by: §1, §2.2.
  • [18] A. Paszke, S. Gross, F. Massa, A. Lerer, J. Bradbury, G. Chanan, T. Killeen, Z. Lin, N. Gimelshein, L. Antiga, A. Desmaison, A. Kopf, E. Yang, Z. DeVito, M. Raison, A. Tejani, S. Chilamkurthy, B. Steiner, L. Fang, J. Bai, and S. Chintala (2019) PyTorch: an imperative style, high-performance deep learning library. In NeurIPS, Cited by: §3.4.
  • [19] F. Pedregosa, G. Varoquaux, A. Gramfort, V. Michel, B. Thirion, O. Grisel, M. Blondel, P. Prettenhofer, R. Weiss, V. Dubourg, J. Vanderplas, A. Passos, D. Cournapeau, M. Brucher, M. Perrot, and E. Duchesnay (2011) Scikit-learn: machine learning in Python. JMLR 12, pp. 2825–2830. Cited by: §3.2.1.
  • [20] F. Pérez-García, R. Sparks, and S. Ourselin TorchIO: a Python library for efficient loading, preprocessing, augmentation and patch-based sampling of medical images in deep learning. arXiv:2003.04696. External Links: Link Cited by: §3.1.3, §3.4.
  • [21] R. Perry, G. Mischler, R. Guo, T. Lee, A. Chang, A. Koul, C. Franz, and J.T. Vogelstein (2020) Mvlearn: multiview machine learning in python. arXiv preprint arXiv:2005.11890. Cited by: §3.4.
  • [22] A. Radford, L. Metz, and S. Chintala (2015)

    Unsupervised representation learning with deep convolutional generative adversarial networks

    .
    arXiv preprint arXiv:1511.06434. Cited by: §3.3.
  • [23] M. Raghu, J. Gilmer, J. Yosinski, and J. Sohl-Dickstein (2017) SVCCA: singular vector canonical correlation analysis for deep learning dynamics and interpretability. In NeurIPS, Cited by: §3.2.2, §3.4.
  • [24] Anatome, a pytorch library to analyze internal representation of neural networks External Links: Link Cited by: §3.4.
  • [25] B. Schölkopf, D.W. Hogg, D. Wang, D. Foreman-Mackey, D. Janzing, C.J. Simon-Gabriel, and J. Peters (2016) Modeling confounding by half-sibling regression. PNAS 113 (27), pp. 7391–7398. Cited by: §1.
  • [26] Y. Shi, N. Siddharth, B. Paige, and P. Torr (2019)

    Variational mixture-of-experts autoencoders for multi-modal deep generative models

    .
    In NeurIPS, Cited by: §1, §2.3, §3.1.2, §3.4.
  • [27] L.N. Smith and N. Topin (2019) Super-convergence: very fast training of neural networks using large learning rates. In AI/ML MDO, Cited by: §3.3.
  • [28] X.W. Song, Z.Y. Dong, X.Y. Long, S.F. Li, X.N. Zuo, C.Z. Zhu, Y. He, C.G. Yan, and Y.F. Zang (2011) REST: a toolkit for resting-state functional magnetic resonance imaging data processing. PloS one 6 (9), pp. e25031. Cited by: §3.1.3.
  • [29] T. Sylvain, L. Petrini, and D. Hjelm (2020) Locality and compositionality in zero-shot learning. In ICLR, Cited by: §1, §2.3.
  • [30] T. Sylvain, L. Petrini, and R. D. Hjelm (2020) Zero-shot learning from scratch (zfs): leveraging local compositional representations. External Links: 2010.13320 Cited by: §1.
  • [31] Y. Tian, D. Krishnan, and P. Isola (2019) Contrastive multiview coding. arXiv preprint arXiv:1906.05849. Cited by: §1, §2.3.
  • [32] M. Tschannen, J. Djolonga, P.K. Rubenstein, S. Gelly, and M. Lucic (2019) On mutual information maximization for representation learning. arXiv preprint arXiv:1907.13625. Cited by: §2.2, §2.2.
  • [33] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A.N. Gomez, Ł. Kaiser, and I. Polosukhin (2017) Attention is all you need. In NeurIPS, Cited by: §2.2.
  • [34] W. Wang, R. Arora, K. Livescu, and J. Bilmes (2015) On deep multi-view representation learning. In ICML, Cited by: §1, §2.3, §3.1.1, §4.
  • [35] Y. Xiao, F. Codevilla, A. Gurram, O. Urfalioglu, and A. M. López (2020) Multimodal end-to-end autonomous driving. IEEE. Cited by: §1.