Multimodal Self-Supervised Learning for Medical Image Analysis

by   Aiham Taleb, et al.
Hasso Plattner Institute

In this paper, we propose a self-supervised learning approach that leverages multiple imaging modalities to increase data efficiency for medical image analysis. To this end, we introduce multimodal puzzle-solving proxy tasks, which facilitate neural network representation learning from multiple image modalities. These representations allow for subsequent fine-tuning on different downstream tasks. To achieve that, we employ the Sinkhorn operator to predict permutations of puzzle pieces in conjunction with a modality agnostic feature embedding. Together, they allow for a lean network architecture and increased computational efficiency. Under this framework, we propose different strategies for puzzle construction, integrating multiple medical imaging modalities, with varying levels of puzzle complexity. We benchmark these strategies in a range of experiments to assess the gains of our method in downstream performance and data-efficiency on different target tasks. Our experiments show that solving puzzles interleaved with multimodal content yields more powerful semantic representations. This allows us to solve downstream tasks more accurately and efficiently, compared to treating each modality independently. We demonstrate the effectiveness of the proposed approach on two multimodal medical imaging benchmarks: the BraTS and the Prostate semantic segmentation datasets, on which we achieve competitive results to state-of-the-art solutions, at a fraction of the computational expense. We also outperform many previous solutions on the chosen benchmarks.


page 1

page 5

page 7

page 12

page 13


3D Self-Supervised Methods for Medical Imaging

Self-supervised learning methods have witnessed a recent surge of intere...

Self-Supervised Learning for 3D Medical Image Analysis using 3D SimCLR and Monte Carlo Dropout

Self-supervised learning methods can be used to learn meaningful represe...

Stain-Adaptive Self-Supervised Learning for Histopathology Image Analysis

It is commonly recognized that color variations caused by differences in...

A unified representation network for segmentation with missing modalities

Over the last few years machine learning has demonstrated groundbreaking...

Self-supervised Representation Learning for Ultrasound Video

Recent advances in deep learning have achieved promising performance for...

ContIG: Self-supervised Multimodal Contrastive Learning for Medical Imaging with Genetics

High annotation costs are a substantial bottleneck in applying modern de...

Taxonomy of multimodal self-supervised representation learning

Sensory input from multiple sources is crucial for robust and coherent h...

1 Introduction

Modern medical diagnostics heavily rely on the analysis of multiple imaging modalities, particularly, for differential diagnosis [16]. However, to leverage the data for conventional supervised machine learning approaches, it requires annotation of large numbers of training examples. Generating expert annotations of patient data at scale is non-trivial, expensive, and time-consuming, especially for 3D scans. In fact, with the growing size of imaging datasets, expert annotation becomes nearly impossible without computerized assistance [7]

. Even current semi-automatic software tools fail to sufficiently reduce the time and effort required for annotation and measurement of these large data sets. Consequently, scarcity of data and, in particular, annotations are some of the main constraints for machine learning applications in medical imaging. At the same time, modern deep learning pipelines are drastically increasing in depth, complexity, and memory requirement, yielding an additional computational bottleneck.

Self-supervised learning provides a viable solution when labeled training data is scarce. In these approaches, the supervisory signals are derived from the data itself, typically by unsupervised learning of a proxy task. Subsequently, models obtained using self-supervision facilitate data-efficient supervised fine-tuning on the target task, significantly reducing the burden of manual annotation. Recently proposed self-supervised learning methods utilize spatial context as a supervisory signal to learn effective data representations. The earliest of these works is 

[5], in which a visual representation is learned by the task of predicting the position of an image patch relative to another. Here, the problem is posed such that the model has to understand the concepts in a training image to solve this problem.

The work of Noroozi and Favaro [19] extended the former patch-based approach to solve “Jigsaw Puzzles” on natural (non-medical) images as a proxy task. We view this work as the most related self-supervised work to ours. The intuition behind this idea is that in order to solve the puzzle at sufficient complexity, the model should understand the objects that appear in the images as well as those objects’ parts. In contrast to our approach, this method only relies on a single imaging modality. However, in a medical context, the inclusion of other modalities in puzzle-solving, e.g. by mixing MRI and CT scans, should yield more informative data representations that can be leveraged in downstream tasks. This is due to the fact that different physical properties of organ tissues are expressed in a complementary fashion in the different modalities. For instance, soft body tissue are better encoded in MRI, but CT scans capture bone structures better. Such complementary information is necessary for solving downstream tasks, e.g. semantic segmentation. Through solving our multimodal puzzles, we enforce the model to integrate this information. In other words, we encourage the model to learn modality-agnostic representations, by mixing these modalities at the data-level.

Nevertheless, the integration of multiple imaging modalities in our generated puzzles requires a more efficient learning method. Noroozi and Favaroo’s method [19] requires massive memory and compute resources, even for small puzzles of 3-by-3, as it integrates 9 replicas of AlexNet [11]. To achieve computational tractability, our approach builds upon Sinkhorn networks proposed by E. Mena et al. [17], utilizing the Sinkhorn operator [23, 1] as an analog of the Softmax operator, but for permutation-related tasks. We extend the method of E. Mena et al. [17] to work efficiently with modern architectures [21]. Consequently, as opposed to Noroozi and Favaro’s method [19], our approach can solve puzzles with more levels of complexity, e.g., .

Contributions. Our contributions are two-fold: First, a novel self-supervised method that performs multimodal puzzle-solving, by using explicitly confusing modalities at the data-level. This allows for combining the complementary information, which is encoded in the different modalities, about the concepts in the data. Second, we exploit the more efficient Sinkhorn operator for pretraining our self-supervised task of multimodal puzzle solving. This efficient puzzle solver allows us to solve our multimodal puzzles with varying levels of complexity. We show that such a multimodal self-supervised task results in rich data representations that can be utilized in solving multiple downstream tasks. Our results suggest that exploiting inexpensive solutions similar to ours can provide performance gains in medical image analysis tasks, particularly in low data regimes.

2 Related Work

Solving jigsaw puzzles as a self-supervised task has been utilized recently as a solution for domain adaptation across different datasets [3]. In a multi-task training scheme, Carlucci et al. exploited jigsaw puzzle-solving as a secondary task for object recognition, acting mainly as a regularizer. Similarly, other types of self-supervised tasks were used to improve domain generalization in [24]. Both of these works seek to align different source and target domains in the feature space by learning to perform self-supervised tasks on both domains. Their aim is to learn domain agnostic representations for different concepts. In our work, however, we utilize self-supervision in an attempt to learn cross-modal representations. Also, as opposed to our approach, both works use self-supervision in a multi-task fashion, solving jointly the same tasks on multiple domains. We create a multimodal task by fusing the data of multiple modalities and then solving that task. Their approach is likely to fail when the domain difference is high, i.e. the modality difference in our case. On the other hand, our approach can handle this issue as the integration of modalities occurs at the data level.

In the medical context, self-supervised learning has found use cases in diverse applications such as depth estimation in monocular endoscopy 

[15], medical image registration [13], body part recognition [26] and a body part regression for slice ordering [25]. As opposed to our approach, none of these methods propose an auxiliary task as a self-supervised learning stage for other downstream tasks. In addition, many of these works make multiple assumptions about input data, resulting in engineered solutions that hardly generalize to other target tasks. Our proposed approach do not make any assumptions about input data.

3 Method

Our method processes input samples from datasets that contain multiple imaging modalities, as it is the case in the majority of medical imaging datasets. The types of medical imaging modalities are numerous [6], and they vary in their characteristics and use-cases. We assume no prior knowledge about what modalities are being used in our models, i.e., the modalities can vary from one downstream task to another. In other words, our multimodal puzzles can stem from any combination of available imaging modalities.

3.1 Multimodal Puzzle Construction

Solving a jigsaw puzzle entails two main steps. First, the image is cut into puzzle pieces (patches or tiles) and shuffled randomly according to a certain permutation. Second, these shuffled image pieces are assembled such that the original image is restored. If is the number of puzzle pieces, then there exist of possible puzzle piece arrangements. It should be noted that when the puzzle complexity increases, the association of individual puzzle tiles might be ambiguous. For instance, puzzle tiles that originate from unicolored backgrounds can be tricky to place correctly. Nevertheless, the placement of different puzzle tiles is mutually exclusive. Thus, when all tiles are observed at the same time, the positional ambiguities are alleviated. In a conventional jigsaw puzzle, the puzzle pieces originate from only one image at a time, i.e., the computational complexity for solving such a puzzle is .

On the other hand, we propose a multimodal jigsaw puzzle, where tiles can be from different modalities, as in algorithm 1. This proposed multimodal puzzle simultaneously learns the in-depth representation of how the organs compose, along with the spatial relationship across modalities. As a result, the complexity of solving multimodal puzzles is increased to . Consequently, this quickly becomes prohibitively expensive due to two growth factors in the solution space: i) factorial growth in the number of permutations , ii) exponential growth in the number of modalities . To reduce the computational burden, we use two tricks. First, we employ the Sinkhorn operator, which allows for an efficient solving of the factorial factor, largely following [17]. Second, we employ a feed-forward network that learns a cross-modal representation, which allows for canceling out the exponential factor , while simultaneously learning a semantically rich representation for downstream tasks.

1 Algorithm Create Puzzles
       Input: - modality lists (), each with slices
2                                  - number of patches in a puzzle ()
3                                  - list of possible permutations ()
4                                  - # of puzzles to generate per slice ()
       Output: list of multimodal
5       for  to  do
6             for  to  do
7                   choose random modality fill patch in position from slice with modality
8             end for
9            for  to  do
10                   shuffle using a random permutation from append to
11             end for
13       end for
14      return
Algorithm 1 Multimodal jigsaw puzzle creation

3.2 Puzzle-Solving with Sinkhorn Networks

To efficiently solve the self-supervised jigsaw puzzle task, we train a network that can learn a permutation. A permutation matrix of size corresponds to some permutation of the numbers to . Every row and column, therefore, contains precisely a single with s everywhere else, and every permutation corresponds to a unique permutation matrix. This permutation matrix is non-differentiable. However, as shown in [17], the non-differentiable parameterization of a permutation can be approximated in terms of a differentiable relaxation, the so-called Sinkhorn operator. The Sinkhorn operator iteratively normalizes rows and columns of any real-valued matrix to obtain a “soft” permutation matrix, which is doubly stochastic. Formally, for an arbitrary input , which is an dimensional square matrix, the Sinkhorn operator is defined as:


where and are the row and column normalization operators, respectively. The element-wise division is denoted by , and is an

dimensional vector of ones.

Assuming an input set of patches , where represents a puzzle that consists of square patches, and is the patch length. We pass each element in through a network , which processes every patch independently and produces a single output feature vector with length . By concatenating together these feature vectors obtained for all region sets, we obtain an matrix, which is then passed to the Sinkhorn operator to obtain the soft permutation matrix . Formally, the network learns the mapping , where is the soft permutation matrix, which is applied to the scrambled input to reconstruct the image .

The network is then trained by minimizing the mean square error (MSE) between the sorted ground-truth and the reconstructed version of the scrambled input, as in the puzzle-solving loss formula below:


where corresponds to the network parameters, and is the total number of training puzzles. After obtaining the network parameters , the yielded representations capture different tissue structures across the given modalities as a consequence of the multimodal puzzle-solving. Therefore, they can be employed in downstream tasks by simply fine-tuning them on target domains.

3.3 Cross-Modal Generation

Multi-modal imaging data exist in a multitude of medical imaging datasets, and in pairs of aligned scans. However, in many real-world scenarios, obtaining such multimodal data in large quantities can be challenging; most multimodal medical imaging datasets are small in data size. And because in many scenarios we have small amounts of multi-modal data, but large amounts of single modal data, we add an explicit cross-modal generation step. Therefore, this allows us to leverage the richness of multi-modality using our puzzle-solving task. To achieve this, we utilize the image-to-image translation framework called

Pix2Pix, which was proposed by Isola et al. [9].

Given an input set of samples from two imaging modalities and , Pix2Pix learns the mapping, , using a conditional generative adversarial network (cGAN). In other words, the generator network is trained to transform images from modality to images from modality , using the following objective:


where is the adversarial loss between the generator that tries to minimize this objective against an adversarial that tries to maximize it. The second term encourages the generator’s outputs to stay as close as possible to ground truth outputs in terms of norm. The importance of this loss is controlled with the hyper-parameter .

In our scenario, after generating data samples of the small (in number of samples) modality using samples from the larger modality , we construct our multimodal puzzles using a mix of real and generated multimodal data. As we show in our experiments, this yields better representations compared to using a single modality only when creating puzzles. Our full framework is illustrated in figure 2.

Figure 2: Schematic illustration showing the steps of the proposed framework. Assuming we have four modalities available: (a) we generate samples from the modalities that exist in smaller amounts from a reference modality which exists in larger quantities. (b) Synthetic and realistic images are then used to construct multimodal jigsaw puzzles, using patches that stem from all of these modalities randomly.

4 Experimental Results

In the following sections, we investigate the performance of our proposed pretraining method on multimodal medical imaging datasets detailed in Section 4.1. We transfer (and fine-tune) the learned representations by our model to different downstream tasks, and measure their impact in Section 4.2. Then, we study the effect of integrating generated data in constructing our multimodal puzzles in Section 4.3. Next, we assess how our self-supervised task affects the downstream tasks’ data efficiency, i.e., when operating in a low-data regime, in Section 4.4. Finally, we analyze the effect of the puzzle complexity on downstream tasks’ performance in an ablation study in Section 4.5.

4.1 Datasets

In our experiments, we consider two multimodal medical imaging datasets. The first is the Multimodal Brain Tumor Image Segmentation Benchmark (BraTS) dataset [18, 2]. This dataset is widely used to benchmark different semantic segmentation algorithms in the medical imaging domain. It contains multimodal MRI scans for 285 training cases and for 66 validation cases. All BraTS scans include four MRI modalities per case: a) native (T1), b) post-contrast T1-weighted (T1Gd), c) T2-weighted (T2), and d) T2 Fluid Attenuated Inversion Recovery (T2-FLAIR) volumes. The BraTS challenge involves two different tasks: i) brain tumor segmentation, and ii) number of survival days prediction.

The second benchmark we consider is the Prostate segmentation task from the Medical Segmentation Decathlon [22]. The prostate dataset consists of 48 multimodal MRI cases, from which 32 cases are used for training, and 16 are used for testing. Manual segmentation of the whole prostate was produced from T2-weighted scans, and the apparent diffusion coefficient (ADC) maps. The target challenge is for segmenting two adjoint prostate regions (the central gland and the peripheral zone).

In both of the above benchmarks, we fine-tune our pre-trained models using the training sets, and report evaluation metrics on corresponding validation sets, to allow for fair comparison to state-of-the-art results from the literature.

4.2 Transfer Learning Results

We evaluate the quality of the learned representations from our auxiliary task of multimodal puzzle-solving by transferring them into other downstream tasks. Then, we assess their impact on downstream performance. We do not use any synthetic data in this section.

4.2.1 Brain Tumor Segmentation

The first downstream task is brain tumor segmentation, using the BraTS benchmark. The goal of this task is to segment 3 different regions of brain tumor: a) the whole tumor (WT), b) the tumor core (TC), and c) the enhanced tumor (ET). Each of these regions has different characteristics, and each may appear more clearly on specific MRI modalities than others, justifying the need for multiple modalities.


In order to better assess the quality of our representations, we establish the following set of baselines:

From Scratch: The first sensible baseline for all self-supervised methods is to compare with the model when trained on the downstream task from scratch. This baseline provides an insight into the benefits of self-supervised pretraining, opposed to learning the target task directly.

Single-modal: We study the impact of our pretraining method on this task when processing only a single modality as input. This experiment aims at simulating the realistic situation when human experts examine brain scans, as some modalities highlight certain aspects of the tumor more than others. For instance, Flair is typically used to examine the whole tumor area, while T2 is used for tumor core, and the T1ce highlights the enhanced tumor region. We pick the best modality for each task when comparing to these results.

Isensee et al. [8]: This work ranked among the tops in the BraTS 2018 challenge. They used other datasets in addition to the challenge training data, and they performed multiple types of augmentation techniques. Their model is a 3D U-Net [21]-based architecture. We only fine-tune our learned representations from the self-supervised task, thus requiring much less data and augmentation methods.

Chang et al. [4]: Trained multiple versions of the 2D U-Net models, and used them as an ensemble to predict segmentation masks. This requires significantly more computing time and resources than training a single model that performs the task with higher performance in many cases.

Li [14]: Implemented a three-stage cascaded segmentation network that combines whole-tumor, tumor-core and then enhanced-tumor masks. For the whole-tumor stage, they utilize a modified multi-view 2D U-Net architecture, which processes three slices at a time from input 3D scans: axial, sagittal, and coronal views. We achieve better results with requiring less computations with a smaller network.

JiGen [3]: We compare our method to the multi-tasking approach called JiGen, proposed by Carlucci et al. [3]. JiGen solves jigsaw puzzles as a secondary task for domain generalization, in a multi-task setup. We implemented their model and considered the multiple modalities as if they were other domains. This baseline aims to analyze the benefits of performing modality confusion on the data-level (our approach), as opposed to the feature space (their approach).

Evaluation Metrics:

The reported metrics are the average dice scores for the Whole Tumor (WT), the Tumor Core (TC), and the Enhanced Tumor (ET). We follow the same standard post-processing techniques used by Isensee et al. [8] when evaluating on BraTs validation set.


The results of our multimodal method compared to the above baselines are shown in table 1. Our proposed method outperforms both the ”from scratch” and ”single-modal” baselines, confirming the benefits of pretraining using our approach. In addition, our method achieves comparable results to the other methods from the literature. We outperform these baselines in most cases, such as the methods of Chang et al. [4], and Li [14], in terms of all reported dice scores. We also report the result of Isensee et al. [8], which ranks among the best results on the BraTS 2018 benchmark. Even though their method uses 3D U-Net architecture co-trained with additional datasets, we outperform their results in TC and ET dice scores. We also achieve comparable results to their method on the WT dice score, even though their method uses the full 3D spatial context. Compared to the work of [3] (JiGen), we also find that our results outperform this baseline, confirming our approach of performing the modality confusion in the data-level is superior to modality confusion in the feature-level.

Model ET WT TC
From scratch 67.77 80.76 77.07
Li [14] 75.10 87.20 76.00
Chang et al. [4] 76.60 89.00 82.41
Isensee et al. [8] 79.59 90.80 84.32
JiGen [3] 77.54 87.57 81.23
Ours (Single-modal) 78.98 86.85 83.63
Ours (Multi-modal) 79.64 89.31 84.53
Table 1: Average dice scores on the BraTS segmentation task

4.2.2 Prostate Segmentation

The second downstream task we address is prostate segmentation, using the Prostate benchmark [22] from the medical segmentation decathlon datasets. The target of this task is to segment 2 regions of the prostate: central gland, and peripheral zone. This task utilizes 2 available MRI modalities.


In order to assess the quality of our representations on this task, we establish the following baselines:

From Scratch: Similar to the first downstream task, we compare our model with the same architecture when training on the prostate segmentation task from scratch.

Single-modal: We also study the impact of our pretraining method when using only a single modality to create the puzzles. Assuming this modality exists in large quantities.

JiGen [3]: Similar to the first downstream task, we compare our method to the multi-tasking approach JiGen.

Evaluation Metrics:

We report the values of 2 evaluation metrics in this task, the average dice score (Dice) and the normalized surface distance (NSD). These metrics are used on the official challenge. The metrics are computed for the 2 prostate regions (Central and Peripheral).


The results of our multimodal method compared to the above baselines are shown in table 2. Our proposed method outperforms both the ”from scratch” and ”single-modal” baselines in this task, too, supporting the advantages of pretraining the segmentation model using our multimodal approach. Also, our method outperforms the multitasking method JiGen [3], when trained on this task too. We notice a more significant gap in performance between our approach and JiGen in this task, compared to the first downstream task of brain tumor segmentation. We posit that this can be attributed to the more significant difference between the imaging modalities used in this prostate segmentation task, as opposed to those in the brain tumor segmentation task. The figure 3 shows this difference more clearly. It can be noted that the imaging modalities of the prostate dataset, i.e., T2 and ADC, are more different in appearance than those of the brain tumor dataset, i.e., T1, T2, T1ce, and Flair. This difference in appearance among the modalities can be explained by understanding the physics from which these MRI modalities are created. All of the brain MRI sequences in the BraTS dataset are variants of T1- and T2-weighted scans, they only differ in configurations of the MRI scanner. These different configurations cause the contrast and brightness of some brain areas to vary among these MRI sequences. The ADC map, on the other hand, is a measure of the magnitude of diffusion (of water molecules) within the organ tissue. This requires a specific type of MRI imaging called Diffusion Weighted Imaging (DWI). In general, highly cellular tissues or those with cellular swelling exhibit lower diffusion coefficients, e.g., a tumor, a stroke, or in our case, the prostate.

Model Dice NSD
From scratch 69.45 86.82 95.17 98.55
JiGen [3] 67.12 79.67 92.91 96.54
Ours (Single-modal) 69.87 87.65 93.32 97.93
Ours (Multi-modal) 73.93 88.55 94.64 98.82
Table 2: Results on the Prostate segmentation task

4.2.3 Survival Days Prediction (Regression)

The BraTS challenge involves a second downstream task, which is the prediction of survival days. The number of training samples is 60 cases, and the validation set contains 28 cases. Similar to what we did for the other downstream tasks, we transfer the learned weights of our multimodal puzzle solver model. The downstream task performed here is regression, hence the output of our trained model here is a single scalar that represents the expected days of survival. We reuse the convolutional features, and we add a fully connected layer with five features in it, and then a single output layer on top. We also include the age as a feature for each subject right before the output layer. The size of the fully connected layer, was determined based on the obtained performance, i.e., by hyperparameter tuning.

In table 3, we compare our results to the baselines of Suter et al. [20]. In their work, they compared deep learning-based methods performances with multiple other classical machine learning methods on the task of survival prediction. The first experiment they report is (CNN + age), which uses a 3D CNN. The second is a random forest regressor

, the third is a multi-layer perceptron (MLP) network that uses a set of hand-crafted features called

FeatNet, and finally, a linear regression model with a set of 16 engineered features. We outperform their results in all cases when fine-tuning our puzzle solver model on this task. The reported evaluation metric is the Mean Squared Error.

Model MSE
From scratch 112.841
CNN + age [20] 137.912
Random Forest Regression [20] 152.130
FeatNet + all features [20] 103.878
Lin. Reg. + top 16 features [20] 99.370
Ours (Multi-modal) 97.291
Table 3: BraTS survival prediction (regression)

4.3 Cross-Modal Generation Results

As suggested earlier, obtaining large multimodal medical imaging datasets can be challenging. Therefore, we investigate in this set of experiments, the effect of extending our approach with an explicit cross-modal generation step. This extension allows for leveraging our multimodal puzzle-solving, even in the case of having a few multimodal samples only. It is actually more common than not that some imaging modalities exist in larger quantities than others, e.g., T2-weighted MRI scans are more commonly used than the ADC diffusion-weighted scans in prostate datasets. Hence, in this set of experiments, we train the Pix2Pix [9] model on subsets of the two benchmarks. After that, we train our proposed puzzle solver on a mixture of synthetic and realistic multimodal data. Finally, the model weights obtained from solving these multimodal puzzles are fine-tuned on downstream tasks, as usual.

We perform cross-modal generation on both of our chosen benchmarks, i.e., BraTS and Prostate. This generation process is performed in a semi-supervised fashion, assuming small multimodal subsets of data and large single-modal data. Hence, we study the effect of the multimodal subset’s size on the quality of generated data as well as on the performance on downstream tasks. We evaluate the generation process at data subset sizes of 1%, 10%, 20%, 50%, and 100%, of the total number of patients in each benchmark.

Cross-modal Generation on Prostate:

In this dataset, we assume large quantities of the T2 modality and small quantities of the ADC modality, i.e, this is usually the case in practice. Then, we train the Pix2Pix [9] on subsets of aligned T2 and ADC pairs. We then use the model to generate synthetic ADC scans that correspond to T2 scans.

Cross-modal Generation on BraTS:

The generation process on the BraTS dataset is performed similarly. However, as we have four different MRI modalities in this dataset, this requires training multiple generators to convert from one modality to the others. We choose a reference modality from which we translate to the others. The choice of this modality is also motivated by the quantities in which this modality exists. In practice, the most commonly used MRI modalities for brain data are T1- and T2-weighted scans. However, brain tumors, mostly, appear brighter in T2-weighted scans. As a result, we use the T2 modality as a reference modality, and we translate it to the others. We train three Pix2Pix models, as shown in figure 2.


The aim of performing this cross-modal generation step is to allow for adopting our multimodal puzzle-solving method, even in cases where most of the data is from a single modality. However, this step is only justified if it provides a performance boost over the single-modal puzzle solving baseline, i.e., training our model on puzzles that originate from one modality. We measure the performance on the two downstream tasks, by fine-tuning these models and then evaluating them on segmentation. Table 4 shows the results of our pretrained models on the two segmentation benchmarks. We report the segmentation results in dice scores for both the BraTS and Prostate datasets. The presented results in table 4 clearly show an improvement on both benchmarks, when training our puzzle solver on synthetic multimodal data. Even when we use only 1% of the total dataset sizes when training the Pix2Pix [9] model, the generator appears to capture the important characteristics of the generated modality. It is noteworthy that the last row in table 4 represents an upper-bound for our semi-supervised generation setup. The qualitative results in figure 3, confirm the quality of generated images.

Figure 3: Qualitative results of our trained Pix2Pix generative model at different rates of multimodal data.
Model BraTS Prostate
Single-modal 72.12 82.72 79.61 61.42 79.65
Ours (1%) 74.76 85.21 82.86 62.52 81.19
Ours (2%) 74.48 86.02 82.18 64.38 81.81
Ours (20%) 75.22 86.98 82.77 70.11 84.02
Ours (50%) 77.09 87.11 83.08 72.94 85.89
Ours (100%) 79.64 89.31 84.53 73.93 88.55
Table 4: Results on segmentation. The percentages in our models are the sizes of multimodal data used to train Pix2Pix [9].

4.4 Low-Shot Learning Results

In this set of experiments, we assess how our self-supervised task benefits the performance on both downstream segmentation tasks, at different labeling rates, by fine-tuning our pre-trained model with corresponding sample sizes. We randomly select subsets of patients at 1%, 10%, 50%, and 100% of the total segmentation training set size. Then, we fine-tune our model on these subsets for a fixed number of epochs (50 epochs each). Finally, for each subset, we compare the performance of our fine-tuned model to the baseline trained from scratch. As shown in figure 

4, our method outperforms the baseline with a significant margin when using few training samples. In a low-data regime of as few samples as 1% of the overall dataset size, this margin to the baseline appears larger. This case, in particular, suggests the potential for generic unsupervised features applicable to relevant medical imaging tasks. It is noteworthy that we report these low-shot results on non-synthetic multimodal data.

Figure 4: Results on low-shot data regime.

4.5 Ablation Study on Puzzle Complexity

In this set of experiments, we analyze the impact of the complexity of the jigsaw puzzles in the pretraining stage, on the performance of downstream tasks. This aims to evaluate whether the added complexity in our self-supervised tasks can result in more informative data representations; as the model works harder to solve the more complex tasks. Our results confirm this intuition, as shown in figure 5. We also use non-synthetic data in this set of experiments. It is also noteworthy that all of our reported results in previous experiment sections use 5-by-5 configurations.

Figure 5: Puzzle complexity vs downstream performance. The trendlines suggest that a more complex jigsaw puzzle produces better downstream task performance.

5 Conclusion & Future Work

We demonstrated that self-supervised puzzle-solving in a multimodal context allows for learning powerful semantic representations that facilitate downstream tasks in the medical imaging context. In this regard, we showed competitive results to the state-of-the-art results in two multimodal medical imaging benchmarks. What is more, our method achieves this by utilizing a rather inexpensive training procedure. Our approach leverages unlabelled multimodal medical scans, and further reduces the cost of manual annotation required for downstream tasks. The results in our experiments support this idea, especially those of operating on low-data regimes. We also evaluated a cross-modal translation method, as an extension to our approach that was motivated by a real-world scenario where most of the data is from a single modality. We show that this step allows for adopting our multimodal puzzle-solving in these cases. To this end, our evaluation results show performance gains even when using as few as 1% of data samples to train this generative model.

Our self-supervised approach provides performance gains on the evaluated downstream tasks. However, to further reduce the performance gap between 2D and 3D models, we plan to extend the work towards 3D multimodal puzzles, making full use of the spatial context. In addition, our current approach assumes pairing (or alignment) across the used modalities. While this is the case in many multimodal medical imaging datasets, it is not generally the case in real-world situations. Hence, we aim to improve our method in this direction, by relaxing the multimodal alignment constraint. Finally, in this work, we proposed a multimodal jigsaw puzzle proxy task, which proved to boost the model’s performance on downstream tasks. Ideally, we aim to generalize this idea to other types of proxy tasks. This way, we can expand our quest for the most suitable proxy task for a given downstream task with multimodal data.


  • [1] R. P. Adams and R. S. Zemel (2011) Ranking via sinkhorn propagation. arXiv preprint arXiv:1106.1925. Cited by: §1.
  • [2] S. Bakas, H. Akbari, A. Sotiras, M. Bilello, M. Rozycki, J. S. Kirby, J. B. Freymann, K. Farahani, and C. Davatzikos (2017-09-05) Advancing the cancer genome atlas glioma mri collections with expert segmentation labels and radiomic features. 4, pp. 170117 EP –. Cited by: §4.1.
  • [3] F. M. Carlucci, A. D’Innocente, S. Bucci, B. Caputo, and T. Tommasi (2019) Domain generalization by solving jigsaw puzzles. CoRR abs/1903.06864. External Links: Link, 1903.06864 Cited by: §2, §4.2.1, §4.2.1, §4.2.2, §4.2.2, Table 1, Table 2.
  • [4] Y. Chang (2018)

    Automatic segmentation of brain tumor from 3d mr images using a 2d convolutional neural networks

    In Pre-Conference Proceedings of the 7th MICCAI BraTS Challenge, Cited by: §4.2.1, §4.2.1, Table 1.
  • [5] C. Doersch, A. Gupta, and A. A. Efros (2015) Unsupervised visual representation learning by context prediction. In

    Proceedings of the IEEE International Conference on Computer Vision

    pp. 1422–1430. Cited by: §1.
  • [6] R. Eisenberg and A. Margulis (2011) A patient’s guide to medical imaging. New York: Oxford University Press. Cited by: §3.
  • [7] K. Grünberg, O. Jimenez-del-Toro, A. Jakab, G. Langs, T. Salas Fernandez, M. Winterstein, M. Weber, and M. Krenn (2017) Annotating medical image data. In Cloud-Based Benchmarking of Medical Image Analysis, pp. 45–67. Cited by: §1.
  • [8] F. Isensee, P. Kickingereder, W. Wick, M. Bendszus, and K. H. Maier-Hein (2018) No new-net. In International MICCAI Brainlesion Workshop, pp. 234–244. Cited by: §4.2.1, §4.2.1, §4.2.1, Table 1.
  • [9] P. Isola, J. Zhu, T. Zhou, and A. A. Efros (2016)

    Image-to-image translation with conditional adversarial networks

    CoRR abs/1611.07004. External Links: Link, 1611.07004 Cited by: 3rd item, Appendix B, §3.3, §4.3, §4.3, §4.3, Table 4.
  • [10] D. P. Kingma and J. Ba (2014) Adam: a method for stochastic optimization. Note: cite arxiv:1412.6980Comment: Published as a conference paper at the 3rd International Conference for Learning Representations, San Diego, 2015 External Links: Link Cited by: Appendix A.
  • [11] A. Krizhevsky, I. Sutskever, and G. E. Hinton (2012) Imagenet classification with deep convolutional neural networks. In Advances in neural information processing systems, pp. 1097–1105. Cited by: §1.
  • [12] C. Li and M. Wand (2016) Precomputed real-time texture synthesis with markovian generative adversarial networks. CoRR abs/1604.04382. External Links: Link, 1604.04382 Cited by: 3rd item.
  • [13] H. Li and Y. Fan (2018-04) Non-rigid image registration using self-supervised fully convolutional networks without training data. In 2018 IEEE 15th International Symposium on Biomedical Imaging (ISBI 2018), Vol. , pp. 1075–1078. External Links: ISSN 1945-8452 Cited by: §2.
  • [14] X. Li (2018) Fused u-net for brain tumor segmentation based on multimodal mr images. In Pre-Conference Proceedings of the 7th MICCAI BraTS Challenge, Cited by: §4.2.1, §4.2.1, Table 1.
  • [15] X. Liu, A. Sinha, M. Unberath, M. Ishii, G. D. Hager, R. H. Taylor, and A. Reiter (2018) Self-supervised learning for dense depth estimation in monocular endoscopy. abs/1806.09521. External Links: Link, 1806.09521 Cited by: §2.
  • [16] D. Long, J. Wang, M. Xuan, Q. Gu, X. Xu, D. Kong, and M. Zhang (2012-11) Automatic classification of early parkinson’s disease with multi-modal mr imaging.

    PLOS ONEIEEE Transactions on Medical ImagingScientific DataScientific DataCoRRCoRRCoRRCoRR2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition

    7, pp. 1–9.
    External Links: Link Cited by: §1.
  • [17] G. Mena, D. Belanger, S. Linderman, and J. Snoek (2018) Learning latent permutations with gumbel-sinkhorn networks. arXiv preprint arXiv:1802.08665. Cited by: §1, §3.1, §3.2.
  • [18] B. H. Menze, A. Jakab, S. Bauer, J. Kalpathy-Cramer, K. Farahani, and et al. (2015) The multimodal brain tumor image segmentation benchmark (brats). 34 (10), pp. 1993–2024. Cited by: §4.1.
  • [19] M. Noroozi and P. Favaro (2016) Unsupervised learning of visual representations by solving jigsaw puzzles. CoRR abs/1603.09246. External Links: 1603.09246 Cited by: §1, §1.
  • [20] M. Reyes (2018) End-to-end deep learning versus classical regression for brain tumor patient survival prediction. In Pre-Conference Proceedings of the 7th MICCAI BraTS Challenge, Cited by: §4.2.3, Table 3.
  • [21] O. Ronneberger, P. Fischer, and T. Brox (2015) U-net: convolutional networks for biomedical image segmentation. In MICCAI, Cited by: 2nd item, 3rd item, §1, §4.2.1.
  • [22] A. L. Simpson, M. Antonelli, S. Bakas, M. Bilello, K. Farahani, B. van Ginneken, A. Kopp-Schneider, B. A. Landman, G. J. S. Litjens, B. H. Menze, O. Ronneberger, R. M. Summers, P. Bilic, P. F. Christ, R. K. G. Do, M. Gollub, J. Golia-Pernicka, S. Heckers, W. R. Jarnagin, M. McHugo, S. Napel, E. Vorontsov, L. Maier-Hein, and M. J. Cardoso (2019) A large annotated medical image dataset for the development and evaluation of segmentation algorithms. abs/1902.09063. External Links: Link, 1902.09063 Cited by: §4.1, §4.2.2.
  • [23] R. Sinkhorn (1964) A relationship between arbitrary positive matrices and doubly stochastic matrices. The annals of mathematical statistics 35 (2), pp. 876–879. Cited by: §1.
  • [24] Y. Sun, E. Tzeng, T. Darrell, and A. A. Efros (2019) Unsupervised domain adaptation through self-supervision. External Links: 1909.11825 Cited by: §2.
  • [25] K. Yan, X. Wang, L. Lu, L. Zhang, A. P. Harrison, M. Bagheri, and R. M. Summers (2018) Deep lesion graphs in the wild: relationship learning and organization of significant radiology image findings in a diverse large-scale lesion database. pp. 9261–9270. Cited by: §2.
  • [26] P. Zhang, F. Wang, and Y. Zheng (2017-04) Self supervised deep representation learning for fine-grained body part recognition. In 2017 IEEE 14th International Symposium on Biomedical Imaging (ISBI 2017), Vol. , pp. 578–582. Cited by: §2.

Appendix A Model Training for all tasks

Input preprocessing.

For all input scans, we perform the following pre-processing steps:

  • We resize each scan to a resolution of for data samples from BraTS, and for Prostate data samples.

  • Then, each scan intensity values are normalized by scaling them to the range .

  • Finally, we create 2-dimensional slices by navigating the scans over the axial axis (-axis).

Training details.

For all tasks, we use Adam [10] optimizer to train our models. The initial learning rate we use is in puzzle solving tasks, in cross-modal generation tasks, and

for segmentation and regression tasks. The network weights are initialized from a Gaussian distribution of

in puzzle solving and segmentation tasks, and from the distribution in the cross-modal generation task. An regularizer with a regularization constant is imposed on the network weights in puzzle solving and downstream tasks. In terms of training epochs, we train all the puzzle solving tasks for 500 epochs, the cross-modal generators for 200 epochs, and all fine-tuning on downstream tasks for 50 epochs.

Network architectures.

All of our network architectures are convolutional, and they vary a little per task:

  • For jigsaw puzzle solving tasks: we use 5 convolutional layers, followed by one fully-connected layer and one Sinkhorn layer.

  • For downstream segmentation tasks: we use a U-Net [21] based architecture, with 5 layers in the encoder, and 5 layers in the decoder. When fine-tuning, the weights of the encoder layers are copied from a pretrained model. The decoder layers, on the other hand, are randomly initialized. In terms of training losses in these tasks, we utilize a combination of two losses: i) weighted cross-entropy, ii) dice loss. We give the same importance to both losses in the total loss formula.

  • For cross-modal generation tasks: as mentioned earlier, we largely follow the architecture of the Pix2Pix [9] model. For the generator, we use a U-Net [21] based network, with an encoder-decoder network that consists of:
    CD512-CD512-CD512-C512-C256-C128-C64 Where Ck

    denotes a Convolution-BatchNorm-ReLU layer with k filters, and

    CDk denotes a Convolution-BatchNorm-Dropout-ReLU layer with a dropout rate of 50%. With regards to the network discriminator: we utilize a PatchGAN [12] discriminator, which process input patches, with an architecture of: C64-C128-C256-C512

Processing multi-modal inputs.

In downstream segmentation tasks, the reported methods from literature use all available modalities when performing the segmentation, e.g. in table 1 in our paper. They typically stack these modalities in the form of image color channels, similar to RGB channels. However, our proposed puzzle-solving method expects a single channel input at test time, i.e. one slice with multi-modal patches. This difference only affects the input layer of the pretrained network, as fine-tuning on an incompatible number of input channels causes this process of fine-tuning to fail. We resolve this issue by duplicating (copying) the weights of only the pretrained input layer. This minor modification only adds a few additional parameters in the input layer of the fine-tuned model, but allows us to leverage its weights. The other alternative for this solution is to discard the weights of this input layer, and initialize the rest of the model layers from pretrained models normally. However, our solution for this issue takes advantage of any useful information encoded in these weights, allowing the model to fuse data from all the channels. The exact numbers of channels in each downstream task is as follows:

  • BraTS Brain Tumor Segmentation: in each input slice, the MRI 4 modalities are stacked as channels.

  • BraTS Number of Survival Days Prediction: for each input slice we also stack the 4 MRI modalities, on top of the predicted tumor segmentation mask; summing up to 5 channels for each input slice. The predicted masks are produced by our best segmentation model.

  • Prostate segmentation: we stack the 2 available MRI modalities in each input slice.

Training the multimodal puzzle solver

It is noteworthy that after we sample patches from input slices, we add a random jitter of 5 pixels in each side before using them in constructing puzzles. This mechanism ensures the model does not use any shortcuts in solving the puzzles, thus enforcing it to work harder and learn better representations.

Algorithm 2 provides the detailed steps of the training process of our proposed multimodal puzzle solver. After obtaining the network parameters, the yielded representations capture different tissue structures across the given modalities as a consequence of the multimodal puzzle solving. Therefore, they can be employed in downstream tasks by simply fine-tuning them on target domains.

1 Algorithm Train Puzzle Solver
       Input: list of multimodal
       Output: trained model
2       initialize model weights foreach  from  do // each puzzle contains patches
3             foreach patch in  do
4                   // -dimensional feature vector
5             end foreach
             concat. vectors  // form a matrix with size
              // permutation matrix
              // reconstructed version
7       end foreach
Algorithm 2 One epoch of training multimodal puzzle solver

Appendix B Results for Cross-Modal Generation

In section 4.3 of our paper, we summarized the main qualitative and quantitative results on both of our chosen datasets. In this section, on the other hand, we present the complete set of results for our cross-modal generation part, for which we utilized a Pix2Pix [9] model. First, we present the full qualitative results of our model on the BraTS dataset. In the summarized version of these results, we presented the synthetic Flair MRI images from corresponding T2 images. In figure 6, the cross-modal generation results of all 4 BraTS modalities (i.e. T1 from T2, T1CE from T2, and Flair from T2) are depicted. As expected, the quality of generated modalities improves when using a larger multimodal dataset to train the Pix2Pix generator. However, even with as few training examples as of the total dataset size, the quality of generated samples is acceptable, and the quantitative performance gains obtained from this part are significant. This was illustrated in table 4 in our paper.

Figure 6: Full qualitative results on the brats dataset
Figure 7: Qualitative results with heatmaps/differences between synthetic and ground truth images. It shows that cross-modal generation on BraTS is easier than Prostate dataset. It also highlights that our models produce less errors with more training multimodal data, as expected. However, even with less data, the errors do not appear to be significantly large, especially in regions of interest, e.g. brain tumor or prostate regions