Positional Contrastive Learning for Volumetric Medical Image Segmentation

by   Dewen Zeng, et al.

The success of deep learning heavily depends on the availability of large labeled training sets. However, it is hard to get large labeled datasets in medical image domain because of the strict privacy concern and costly labeling efforts. Contrastive learning, an unsupervised learning technique, has been proved powerful in learning image-level representations from unlabeled data. The learned encoder can then be transferred or fine-tuned to improve the performance of downstream tasks with limited labels. A critical step in contrastive learning is the generation of contrastive data pairs, which is relatively simple for natural image classification but quite challenging for medical image segmentation due to the existence of the same tissue or organ across the dataset. As a result, when applied to medical image segmentation, most state-of-the-art contrastive learning frameworks inevitably introduce a lot of false-negative pairs and result in degraded segmentation quality. To address this issue, we propose a novel positional contrastive learning (PCL) framework to generate contrastive data pairs by leveraging the position information in volumetric medical images. Experimental results on CT and MRI datasets demonstrate that the proposed PCL method can substantially improve the segmentation performance compared to existing methods in both semi-supervised setting and transfer learning setting.



page 4


Contrastive learning of global and local features for medical image segmentation with limited annotations

A key requirement for the success of supervised deep learning is a large...

Self-Ensembling Contrastive Learning for Semi-Supervised Medical Image Segmentation

Deep learning has demonstrated significant improvements in medical image...

Weakly-Supervised Feature Learning via Text and Image Matching

When training deep neural networks for medical image classification, obt...

Contrastive Learning Meets Transfer Learning: A Case Study In Medical Image Analysis

Annotated medical images are typically rarer than labeled natural images...

Contrastive Registration for Unsupervised Medical Image Segmentation

Medical image segmentation is a relevant task as it serves as the first ...

SimCVD: Simple Contrastive Voxel-Wise Representation Distillation for Semi-Supervised Medical Image Segmentation

Automated segmentation in medical image analysis is a challenging task t...

Bootstrapping Semi-supervised Medical Image Segmentation with Anatomical-aware Contrastive Distillation

Contrastive learning has shown great promise over annotation scarcity pr...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Deep neural networks (DNNs) play an important role in today’s medical image segmentation

[ronneberger2015u, xu2019whole, wang2020ica, wang2019msu, isensee2018nnu]

. To achieve state-of-the-art accuracy, most of the existing methods rely on supervised learning when large labeled datasets can be used for training. However, due to the extensive annotation effort and the requirement of expertise in the medical domain, acquiring such large labeled datasets is usually prohibitive. In the meantime, a large amount of unlabeled image data from modalities such as Computed Tomography (CT) and Magnetic Resonance Imaging (MRI) is generated every day all around the world. Therefore, it is desirable that the DNNs can leverage the numerous unlabeled data to achieve higher performance with limited annotations. Contrastive learning

[chen2020simple, chen2020big, he2020momentum, chen2020improved, misra2020self], as a self-supervised learning (SSL) method, has shown great success in learning image-level features from large-scale unlabeled data without using any human-annotated labels. The main idea of contrastive learning is to contrast the similarities of sample pairs in the representation space through contrastive loss, pulling the representations of similar pairs (a.k.a. positive pairs) together and pushing the representations of dissimilar pairs (a.k.a. negative pairs) apart. In SSL setting, an encoder is trained using contrastive loss with unlabeled data. After that, the trained encoder can be used as the initialization for training a supervised downstream task such as object detection and image segmentation. Many works have shown that the encoder learned by contrastive learning performs better than the encoder trained with supervised learning [he2020momentum, chen2020simple].

Most existing contrastive learning frameworks are for image classification where the instances in two different images have dissimilar features. When directly applying them to medical image segmentation where different images can have similar structures or organs, a large number of false negative pairs will be induced, leading to degraded performance. Recently, [chaitanya2020contrastive] attempted to address this issue through a global contrastive learning approach for 3D medical image segmentation. It divides each volume into several partitions and considers the slices of corresponding partitions in different volumes as positive pairs and those of different partitions as negative pairs. However, the last a few slices of a partition can be very similar to the first a few slices of the next partition as they are naturally adjacent, which may still result in many false negative pairs.

To alleviate the problem, we propose a novel positional contrastive learning (PCL) framework, which generates contrastive data pairs based on the position of a slice in volumetric medical images. Slices that are close are considered positive pairs while those that are far apart are considered negative. Such a strategy can better leverage the domain-specific cue of medical images as adjacent slices typically contain similar anatomical structures, thus reducing false negatives. We evaluate the proposed PCL framework on two CT datasets and two MRI datasets. The experimental results show that our method can achieve better performance compared with state-of-the-art baselines in both semi-supervised and transfer learning settings.

2 Related Work

Recent years have seen powerful self-supervised visual feature learning approaches with DNNs. By exploiting the information in large unlabeled datasets, a network can learn hierarchical features that can help the training of other downstream tasks, especially when the training labels of these tasks are limited. Early SSL methods are mostly based on the design of pretext tasks, in which pseudo labels are automatically generated for network training. As these methods rely on ad-hoc heuristics, the learned representation lack generality

[chen2020simple]. Contrastive learning has recently become a prevailing SSL method because of its superior performance. In contrastive learning, a contrastive loss [hadsell2006dimensionality] is used to enforce representations of positive pairs to be similar and those of negative pairs to be dissimilar [he2020momentum, chen2020simple, misra2020self, tian2019contrastive, jiao2020self, li2020contrastive]. MoCo [he2020momentum] and SimCLR [chen2020simple] are two different contrastive learning frameworks that yield state-of-the-art results. MoCo maintains a dictionary as a queue to store negative samples for training, while SimCLR explores the use of in-batch samples for negative sampling. Most of these works are based on image classification tasks, assuming that the instances in two different images have dissimilar features. This is not the case, however, for medical images, because the same target organ or structure usually exists in all the images across the dataset. For example, in ACDC MICCAI 2017 dataset [bernard2018deep], the target structures such as the left ventricle and the right ventricle appear in almost every slice of the volumetric image for all patients. As such, if we follow the method used in image classification tasks and treat the augmented images from different slices as negative, many of them will actually be false negatives.

The state-of-the-art contrastive learning method for medical image segmentation [chaitanya2020contrastive] attempted to address this issue through the partition of 3D medical images. However, it will still induce false negatives as discussed in Section 1. In contrast, the PCL method we propose uses the relative position of the slices in the volumes to decide whether they are positive pairs, thus the false negative issue can be alleviated. In addition, the method in [chaitanya2020contrastive] is only evaluated in semi-supervised setting where contrastive learning and downstream tasks are done on the same dataset. We extend the evaluation to transfer learning to test whether the features learned by PCL on one dataset are transferable to another, and show that PCL can do better than [chaitanya2020contrastive] in both settings.

3 Method

Figure 1: Overview of the proposed PCL framework. In the pre-training stage, 2D slices (denoted as ) in the plane are extracted from volumetric medical images and fitted into a U-Net encoder for representation learning. The learned encoder is then used as initialization in the fine-tuning stage. We use to denote the relative position of a slice along the axis in a volume. Data pairs with small difference (e.g., ) are considered as positive pairs and those with large difference are considered as negative pairs. Similar slices are marked/labeled with the same color.

3.1 Framework Overview

In this work, for fair comparison we follow [chaitanya2020contrastive] and use 2D U-Net [ronneberger2015u] to perform segmentation on 2D slices of 3D images, which has shown a remarkable success in many 3D image segmentation tasks [isensee2018nnu, nemoto2020efficacy, wang2019msu, ushinsky20213d, isensee2017automatic]. The proposed method can also be readily generalized to patch-based 3D U-Net and 3D-2D hybrid U-Net approaches. Our PCL framework is shown in Fig.1. In the pre-training stage, the input of the framework is a set of 2D slices in the plane sampled randomly from unlabeled volumetric medical images. These slices are then propagated to a U-Net encoder

(also known as the feature extractor) followed by a shallow multilayer perceptron (MLP) projection head

. Let denote an input 2D slice. Then is the representation learned by the encoder and

is the embedding vector. A contrastive loss is employed on all the embeddings learned from the data in a mini-batch to perform contrastive learning. After contrastive learning finishes,

is thrown away and is used as the initialization in the standard U-Net architecture to train the network on the limited labeled dataset by supervised learning in the fine-tuning stage.

3.2 Leveraging Structural Information in Medical Image

In medical images, similar anatomical structures often exist in all volumes of different patients across the dataset. In addition, we note the following two observations for volumetric medical images: 1) they have high spatial resolution along axis so that adjacent 2D slices (e.g., and in Fig. 1) inside a volume usually have similar content; 2) if the volumes of different patients are perfectly aligned, the corresponding 2D slices in different volumes (e.g., and in Fig. 1) often contain similar anatomical information. In this paper, we utilize these two distinctive features in volumetric medical images to generate data pairs for contrastive learning.

To be specific, each 2D slice extracted from a volume is associated with a variable. The , which is between 0 and 1, represents the relative or normalized position of the slice along the axis in the volume. Suppose is the index of the 2D slice along the axis and is the total number of slices in the axis (see Fig. 1). Then . This allows the proper alignment between different volumes. Once each 2D slice in a mini-batch is assigned with its , we can use the difference to decide whether each data pair is similar or not. If the difference of two slices is less than a threshold (e.g., 0.1 in Fig. 1), they are likely to contain similar anatomical content and can be considered as positive pair. Otherwise, they are negative pair. The threshold is a hyper-parameter that is different for different medical datasets. Note that this approach allows the positive and negative pairs to be formed on the fly instead of predefined such as in [chaitanya2020contrastive]. It is possible that and are positive pairs but

is a negative pair. We believe this can enforce the feature representation to be uniformly distributed on the representation hypersphere which may boost the contrastive learning performance


As in [he2020momentum, chen2020simple], a pair of random transformations is applied for each sample in the mini-batch to help the encoder learn the spatial invariant feature of the target. The augmentations will not change the value of the original sample, so our contrastive data pair generation strategy discussed above still works.

3.3 Contrastive Loss Function

Our contrastive learning loss function is based on

[khosla2020supervised]. For a set of randomly sampled slices, , the corresponding mini-batch consists of samples after data augmentation, , in which and are two random augmentations of . represents the learned embedding of . Then the loss function can be defined as:


where is the set of indices of positive samples to .

is the cosine similarity function that computes the similarity between two vectors in the representation space.

is a temperature scaling parameter. Compared with the standard contrastive loss [chen2020simple] that only has one positive pair on the numerator for any sample , in Eq. 2 all positive pairs in a mini-batch (e.g., the augmented one and any of the remaining samples whose is close to ) contribute to the numerator, allowing better utilization of the proposed strategy.

4 Experiments and Results

Datasets: We evaluate the performance of the proposed PCL on four publicly available medical image datasets. (1) The CHD dataset is a CT dataset that consists of 68 3D cardiac images captured by a Simens biograph 64 machine [xu2019whole]. The dataset covers 14 types of congenital heart disease and the segmentation labels include seven substructures: left ventricle (LV), right ventricle (RV), left atrium (LA), right atrium (RA), myocardium (Myo), aorta (Ao) and pulmonary artery (PA). (2) The MMWHS dataset was hosted in STACOM and MICCAI 2017 [zhuang2016multi, zhuang2013challenges]. It consists of 20 cardiac CT and 20 MRI images and the annotations include the same seven substructures as the CHD dataset. (3) The ACDC dataset was hosted in MICCAI 2017 challenge [bernard2018deep]. The dataset has 100 patients with 3D cardiac MRI images. Each patient has around 15 volumes covering a full cardiac cycle, only volumes for the end-diastolic and end-systolic phase are labeled by an expert. The segmentation labels include three substructures: LV, RV, and Myo. (4) The HVSMR dataset was hosted in MICCAI 2016 challenge [pace2015interactive]. It has 10 3D cardiac MRI images captured in an axial view on a 1.5T scanner. Manual annotations of blood pool and Myo are provided.

Preprocessing: Following the work of [chaitanya2020contrastive], we first normalize the intensity of each 3D volume to [], where is the -th intensity percentile in . Then all 2D slices and the corresponding annotations are resampled to a fixed spatial resolution

and padded to a fixed image size

with 0. We do not apply cropping because it may remove important structure information in the original slice. The and for each dataset are defined as follows (1) CHD dataset: and , (2) MMWHS dataset: and , (3) ACDC dataset: and , (4) HVSMR dataset: and .

4.1 Semi-supervised Learning

In this section, we test whether the proposed PCL can improve the performance in semi-supervised learning where contrastive learning and down-stream supervised learning (with limited annotation) are done on the same dataset.

Setup: We employ our PCL to pre-train a U-Net encoder on the whole CHD and ACDC, respectively, without using any human label. Note that for ACDC, each patient has more than 10 volumes covering a full cardiac cycle, only two of which have annotations. Since we do not need labels anyway, we use all the volumes from 100 patients for pre-training. Then the pre-trained model is used as the initialization to fine-tune a U-Net segmentation network with a small number of labeled samples on the same dataset. 5-fold cross-validation is used to evaluate the segmentation performance. Specifically, for each cross-validation fold on CHD, We randomly sample patients from the 51 patients for fine-tuning, as if we only have the labels for these patients, and evaluate the results on the remaining 17 patients. We experiment with different values of (e.g., 2, 6 and 10) to assess the influence of training set size in the fine-tuning stage on the contrastive learning performance. The same training strategy is also used for ACDC. We choose the threshold

to be 0.1 and 0.35 for CHD and ACDC because they have the best performance according to our experiment. The influence of thresholds on accuracy will be discussed in the supplementary. Data augmentations, including translation, rotation, and scale, are used in both the pre-training and fine-tuning stages. The pre-training is done on two NVIDIA GeForce GTX 1080 GPUs with 200 epochs. SGD is used as the optimizer and the learning rate is set to 0.1. We use cosine learning rate scheduler, batch size is set to 32. Temperature

is set to 0.1 as in [he2020momentum, chen2020simple]. In the fine-tuning stage, we train the U-Net with cross-entropy loss for 100 epochs. The batch size is set to 5 and the learning rate is . Adam optimizer and cosine scheduler are used.

Baselines: We compare the performance of PCL with a random approach that does not use any pre-training as well as the following state-of-the-art baselines, all of which use the same unlabeled dataset in the pre-training and labeled dataset in the fine-tuning as PCL: (1) Rotation [gidaris2018unsupervised]: a pretext-based method that uses image rotation prediction to pre-train the encoder; (2) PIRL [misra2020self]: adopted from a contrastive learning scheme for natural image classification, which uses contrastive loss to learn pretext-invariant representations. (3) SimCLR [chen2020simple]: adopted from another contrastive learning scheme for natural image classification, which constructs positive pairs for each sample only using two random augmentations; (4) GCL [chaitanya2020contrastive]: a contrastive learning scheme for 3D medical image segmentation which divides each volume into four partitions so that slices belonging to the same partition in different volumes are considered as positive pairs.

CHD (68 patients in total)
Method =2 =6 =10 =15 =20 =30 =51
Random 0.184(.06) 0.508(.06) 0.584(.05) 0.627(.05) 0.658(.04) 0.693(.04) 0.754(.02)
Rotation [gidaris2018unsupervised] 0.171(.06) 0.488(.07) 0.575(.04) 0.625(.04) 0.651(.04) 0.691(.04) 0.749(.03)
PIRL [misra2020self] 0.196(.07) 0.504(.08) 0.617(.05) 0.658(.03) 0.674(.04) 0.714(.04) 0.761(.03)
SimCLR [chen2020simple] 0.192(.06) 0.515(.06) 0.599(.06) 0.631(.05) 0.666(.05) 0.699(.05) 0.756(.03)
GCL [chaitanya2020contrastive] 0.255(.10) 0.564(.04) 0.646(.03) 0.669(.04) 0.697(.04) 0.725(.04) 0.766(.03)
PCL 0.356(.08) 0.600(.06) 0.661(.05) 0.686(.05) 0.716(.04) 0.735(.05) 0.774(.03)
ACDC (100 patients in total)
Method =2 =6 =10 =15 =20 =30 =80
Random 0.588(.07) 0.782(.03) 0.840(.03) 0.876(.01) 0.894(.01) 0.909(.01) 0.928(.00)
Rotation [gidaris2018unsupervised] 0.572(.08) 0.809(.03) 0.868(.02) 0.886(.01) 0.898(.01) 0.910(.01) 0.925(.00)
PIRL [misra2020self] 0.492(.03) 0.823(.04) 0.865(.01) 0.880(.02) 0.896(.02) 0.912(.01) 0.927(.00)
SimCLR [chen2020simple] 0.352(.06) 0.725(.08) 0.824(.04) 0.869(.02) 0.894(.01) 0.913(.01) 0.927(.00)
GCL [chaitanya2020contrastive] 0.636(.05) 0.803(.04) 0.872(.01) 0.891(.01) 0.902(.01) 0.913(.01) 0.927(.01)
PCL 0.671(.06) 0.850(.01) 0.885(.01) 0.904(.01) 0.909(.01) 0.919(.00) 0.929(.00)
Table 1: Comparison of the proposed PCL method with baseline methods on CHD and ACDC.

is the number of patients used in the fine-tuning process. Results are reported in the form of mean(standard deviation) on 5-fold cross-validation. PCL provides better results than the baselines for all values of


Results and Analysis: The results of the comparative study on both CHD and ACDC are shown in Table 1. We report the averaging Dice of 5-fold cross-validation results. From the table, we have the following observations. (1) Comparing PCL and GCL with other baselines, we can see that the performance improves significantly () in many settings for both CHD and ACDC, suggesting that by leveraging domain-specific structural information in volumetric medical images, the encoder can learn better task-related representation for segmentation. (2) The performance improvement of PCL and GCL are especially high when a very small number of training samples are used (e.g., 2 and 4). The gains become lesser when the number of training samples increases. This is because with more training samples, the information difference between the training set for fine-tuning and the training set for contrastive learning becomes small and the fine-tuning performance saturates. (3) SimCLR performs worse than Random on ACDC. This suggests that only using data augmentations to generate contrastive data pairs may lead to a large false negative rate for datasets like ACDC where the volumes have small dimensions (around 10). (4) PCL performs better than GCL in all settings. The improvement in Dice can be up to . This shows that using the relative difference instead of a hard partition strategy can better utilize the structural information in medical images and reduce false negatives to improve contrastive learning performance.

4.2 Transfer Learning

CHD transferring to MMWHS (20 patients in total)
Method =2 =4 =6 =8 =10 =16
Random 0.232(.14) 0.661(.10) 0.732(.07) 0.769(.06) 0.808(.05) 0.834(.05)
Rotation [gidaris2018unsupervised] 0.247(.16) 0.659(.13) 0.751(.07) 0.768(.07) 0.803(.06) 0.850(.04)
PIRL [misra2020self] 0.251(.10) 0.670(.11) 0.755(.07) 0.774(.06) 0.821(.05) 0.851(.04)
SimCLR [chen2020simple] 0.269(.17) 0.683(.10) 0.751(.07) 0.783(.06) 0.818(.05) 0.850(.04)
GCL [chaitanya2020contrastive] 0.262(.11) 0.703(.07) 0.768(.05) 0.805(.04) 0.820(.04) 0.851(.03)
PCL 0.339(.15) 0.748(.08) 0.792(.05) 0.820(.05) 0.840(.04) 0.869(.03)
ACDC transferring to HVSMR (10 patients in total)
Method =2 =4 =6 =8
Random 0.742(.06) 0.813(.05) 0.842(.03) 0.842(.04)
Rotation [gidaris2018unsupervised] 0.737(.07) 0.816(.06) 0.845(.03) 0.844(.03)
PIRL [misra2020self] 0.740(.05) 0.826(.04) 0.849(.03) 0.846(.03)
SimCLR [chen2020simple] 0.700(.07) 0.779(.05) 0.808(.04) 0.815(.04)
GCL [chaitanya2020contrastive] 0.770(.05) 0.818(.05) 0.842(.03) 0.843(.03)
PCL 0.781(.05) 0.832(.05) 0.857(.03) 0.857(.03)
Table 2: Transfer learning comparison of the proposed PCL method with the baselines. Except for Random, all the methods are pre-trained on CHD and ACDC without labels and fine-tuned on MMWHS and HVSMR respectively.

To assess whether the learned representations by PCL are transferrable, we use the encoder pre-trained on CHD and ACDC without labels as the initialization of a U-Net to fine-tune on MMWHS and HVSMR datasets respectively. The experiment setup and baselines are the same as in Section 4.1. Table 2 shows the comparison results. It can be seen that the proposed PCL framework outperforms all baselines on both datasets. The overall improvement on HVSMR is relatively smaller than MMWHS. This is because MMWHS is very similar to CHD which makes the features learned on CHD more helpful on MMWHS. On the other hand, ACDC and HVSMR are different in terms of acquisition view and image resolution, which limits the transfer learning performance. Visualization of the segmentation results on all datasets is shown in the supplementary.

5 Conclusion

In this paper, we propose a novel PCL framework for representation learning in volumetric medical images. The framework can effectively eliminate false negative pairs in existing contrastive learning methods for medical image segmentation. Experimental results on four 3D medical image datasets show that PCL significantly improves the segmentation performance in both semi-supervised setting and transfer learning setting.