Momentum Contrastive Voxel-wise Representation Learning for Semi-supervised Volumetric Medical Image Segmentation

05/14/2021 ∙ by Chenyu You, et al. ∙ 0

Automated segmentation in medical image analysis is a challenging task that requires a large amount of manually labeled data. However, manually annotating medical data is often laborious, and most existing learning-based approaches fail to accurately delineate object boundaries without effective geometric constraints. Contrastive learning, a sub-area of self-supervised learning, has recently been noted as a promising direction in multiple application fields. In this work, we present a novel Contrastive Voxel-wise Representation Learning (CVRL) method with geometric constraints to learn global-local visual representations for volumetric medical image segmentation with limited annotations. Our framework can effectively learn global and local features by capturing 3D spatial context and rich anatomical information. Specifically, we introduce a voxel-to-volume contrastive algorithm to learn global information from 3D images, and propose to perform local voxel-to-voxel contrast to explicitly make use of local cues in the embedding space. Moreover, we integrate an elastic interaction-based active contour model as a geometric regularization term to enable fast and reliable object delineations in an end-to-end learning manner. Results on the Atrial Segmentation Challenge dataset demonstrate superiority of our proposed scheme, especially in a setting with a very limited number of annotated data.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 7

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Learning from just a few labeled examples while leveraging a large amount of unlabeled data is the long-standing pursuit in the machine learning community, which is especially crucial for the medical imaging domain. Generating reliable manual annotations of 3D imaging data at scale is expensive, time-consuming, and may require domain-specific expertise. Due to privacy concerns, another challenge in medical imaging is relatively small training datasets. To this end, contrastive learning applied to self-supervised learning (SSL) is a promising direction since they have recently shown great promise in learning useful representations with limited human supervisions to achieve remarkable results in vision tasks 

[4, 12, 1, 3]. Our work draws on existing related literature in self-supervised learning, semi-supervised learning, and geometric priors. SSL [7, 21, 33] aims to learn effective visual representations from unlabeled data in an unsupervised setting. In recent years, a growing popularity of work [8, 27, 19, 4, 25, 3]

in contrastive learning has brought significant progress in SSL. The central idea is to learn powerful representations that optimize similarity constraints to discriminate similar (positive) and dissimilar (negative) pairs within a dataset. The major stream of subsequent work focuses on the choice of dissimilar pairs which decide the quality of learned representations. The loss function used to contrast is chosen from several options, such as InfoNCE 

[22], Triplet [26], and so on. Recently, Chaitanya et al. [3] extended contrastive learning framework to learn global and local cues, which require large amounts of negative pairs. Second, the recent analysis [4] has suggested that large batch size may achieve stronger performance gains, but such training is expensive in terms of both time and computational resource. Third, most works mainly focus on exploring 2D/2.5D context instead of 3D space.

In the medical imaging domain, substantial efforts [31, 16, 20, 14, 29, 2] have been devoted to incorporating unlabeled data to improve network performance due to the limited 3D data and annotations. The most commonly training techniques are adversarial learning and consistency loss as regularization terms to encourage unsupervised mapping. Nie et al. [20] adopted adversarial training to select confident regions of unlabeled images for segmentation performance improvements. Yu et al. [29] designed a deep uncertainty-aware framework based on the mean-teacher scheme [24] to guide the student network to capture better features. Despite having success in superior performance gains, these methods ignore the intrinsic geometric structure of the image, such as border, which may fail in recognizing object contours.

In order to address the aforementioned challenges, several efforts [32, 11, 15, 28] have attempted to leverage shape priors to detect more accurate boundaries. Although those methods have achieved promising results, they may not guarantee a target topology in the test stage. Active Contour Models (ACM) [13] is a set of algorithms that, given an initial contour, dynamically fit an image boundary with geometric constraints. Thus, it may be appropriate in such cases, especially in medical imaging where the training dataset is too small to train a deep segmentation network. Recent works [6, 30, 17, 10, 5] used the energy function as the part of supervised loss in an end-to-end manner. However, these methods lack learning efficiency and suffer from a time-consuming mechanism that only involves comparing two binary images at a time.

In this paper, we present CVRL, an end-to-end semi-supervised framework that integrates contrastive learning with geometric constraints to learn global contexts and local features for more effective representation learning in 3D imaging. The motivation comes from the fact that feature presentations derived on the 2D context fail to utilize useful spacial contexts in the 3D domain, which may result in a huge performance degradation. Two key aspects distinguish from the recent advance [3]. First, in contrast to follow the standard contrastive learning setting [4] to perform slice-level global contrast, CVRL perform voxel-to-volume global contrast in 3D embedding space, which could capture rich anatomic information. The second important aspect concerns time and computational costs, which is challenging to apply in most real-wold applications.

Our contributions are three-fold: (1) we propose two contrasting strategies - global voxel-to-volume and local voxel-to-voxel contrast - to better utilize unlabeled data, which capture higher-order consistency in the well-structured global embedding space while making full use of voxel-wise similarities of the local context; (2) To address occasional inaccuracies along object boundaries, we then incorporate an active contour with elastica (ACE) loss as the regularization term for penalization on curvature region and accurate edge placements; (3) The results demonstrate that our segmentation network outperforms the state-of-the-art methods on the Atrial Segmentation Challenge dataset, and generates object segmentation with high-quality global shapes.

2 Method

Figure 1: Overview of CVRL architecture. We learn rich dense voxel-wise representations by exploiting global context between voxels and volumetric regions and local correlations among voxels for semi-supervised segmentation.

2.1 Overview

We aim at constructing an end-to-end voxel-wise contrastive algorithm with geometric constraints to learn useful representations in discriminative manner for semi-supervised volumetric medical imaging segmentation. An overview of the architecture is illustrated in Figure 1

. In the limited annotation setting, we train semi-supervised CVRL alongside two components - supervised and unsupervised learning objectives. Specifically, we propose a voxel-wise contrastive algorithm to learn global-level and local-level representations from 3D unlabeled data by regularizing the embedding space and exploring the geometric and spatial context of training voxels. The ACM model further enables the segmentation network to capture more rich details of image boundaries.

In our problem setting, we consider a set of training data (3D images) including labeled data and unlabeled data, where . For simplicity of exposition, we denote limited label data as , and abundant unlabeled data as , where  are volume inputs, and are ground-truth labels. Specifically, we adopt V-Net [29] as the network backbone , which consists of an encoder network  and a decoder network . To maximize mutual information between latent representations, we design a projection head , that comprises one encoder network which is similar to 

, and followed by 3-layer multilayer perceptron (MLP).

2.2 Unsupervised Contrastive Learning

A key component of CVRL is the ability to capture rich voxel-wise representations of high dimensional data by contrastive learning. CVRL trains the contrastive objective as an auxiliary loss during the volume batch updates. To specify our voxel-wise contrastive algorithm, we define two discrimination terms: (i) global contrastive objective (ii) local contrastive objective. Specification of these aspects largely influence quality of learned representations.

2.2.1 Global Contrastive Loss

CVRL performs global voxel-to-volume contrast by using a novel 3D projection head . Different from the standard setting [4, 3] of using MLP, which are designed for 2D image-level data, can encode rich spatial information, which can build a stronger ability to model spatial correlations in high-dimensional 3D space. To learn embeddings that measure these similarity relations across volumes, here we use InfoNCE loss [22]:

(1)

where is the embedding of transformed version of the same i.e.  where are simple random 3D transformations [33, 23]. Similarly, , and are referred to as positive, and negative respectively in the parlance of contrastive learning. Here we denote ‘’ as the inner (dot) product [27], and denotes a temperature parameter for . Recent work [27] shows that a large set of negatives (i.e.) can improve the quality of learned representations in unsupervised contrastive training. In our experiments, we include all other feature maps except only the feature map itself from the same volume batches as negatives. In the unlabeled set, our global contrastive loss is defined as:

(2)

where denotes embedding collections of the positive volume samples.

2.2.2 Local Contrastive Loss

On one hand, global contrastive objectives uncover distinctive volume-level representations that benefits the training of downstream tasks, e.g. object classification, when limited labeled data is available. On the other hand, dense predictive tasks, e.g. semantic segmentation, may require more discriminative local representations. As complementary to global contrastive objectives, a promising local strategy may be vita for the downstream medical imaging tasks. With this insight, we propose to perform voxel-to-voxel local contrast to explicitly explore local relationships between voxel samples.

(3)

where and denote positives and negatives for feature voxel at 3D location , respectively. The positive is feature voxel under a different ; whereas negatives contain all mis-matched feature voxels: . denotes a temperature parameter for . The local contrastive loss is written as follows:

(4)

2.3 Semi-supervised CVRL Implementation

CVRL is a general semi-supervised framework for combing contrastive learning with geometric constraints. In our experiments, we train CVRL alongside two strategy - supervised objective and unsupervised objective.

2.3.1 Consistency Loss

Recent work [14, 24] show that using an exponential moving average (EMA) over network parameters is empirically shown to improve training stability and models’ final performance. With this insight, we introduce an EMA model with parameters  as the moving-average of the parameters  from the original network. Specifically, the architecture of EMA model follows the original model. At training step , the update rule follows , where is momentum parameter. On the unlabeled set, we perform different perturbation operations on the unlabeled input volume sample e.g. adding noise . To encourage training stability and performance improvements, we define consistency loss as:

(5)

where is the mean squared error loss.

2.3.2 Overall Training Objective

Our overall learning objective is to minimize a combination of supervised and unsupervised losses. On the labeled data, we incorporate ACE loss [5] as geometric constraints in training. On the unlabeled dataset, unsupervised training objective consist of global contrastive loss, local contrastive loss, and consistency loss. The overall loss function is:

(6)

where

are hyperparameters that balance each term.

denotes the segmentation loss, and is the ACE loss. More details about ACE loss can be found in [5].

3 Experiments

Method # scans used Metrics
Labeled Unlabeled Dice[%] Jaccard[%] ASD[voxel] 95HD[voxel]
V-Net [18] 80 0 91.14 83.82 1.52 5.75
V-Net 16 0 86.03 76.06 3.51 14.26
DAP [32] 16 64 87.89 78.72 2.74 9.29
UA-MT [29] 16 64 88.88 80.21 2.26 7.32
LG-ER-MT [9] 16 64 89.56 81.22 2.06 7.29
LG-ER-MT(+NMS) 16 64 89.54 81.17 1.89 7.62
SASSNet [15] 16 64 89.27 80.82 3.13 8.83
SASSNet(+NMS) 16 64 89.54 81.24 2.20 8.24
CVRL(ours) 16 64 89.74 81.42 1.97 7.32
CVRL(+NMS) 16 64 89.87 81.65 1.72 6.96
V-Net [18] 8 0 79.99 68.12 5.48 21.11
DAP [32] 8 72 81.89 71.23 3.80 15.81
UA-MT [29] 8 72 84.25 73.48 3.36 13.84
LG-ER-MT [9] 8 72 85.43 74.95 3.75 15.01
LG-ER-MT(+NMS) 8 72 85.95 75.70 2.30 10.12
SASSNet [15] 8 72 86.81 76.92 3.94 12.54
SASSNet(+NMS) 8 72 87.32 77.72 2.55 9.62
CVRL(ours) 8 72 87.41 77.80 3.31 13.40
CVRL(+NMS) 8 72 87.72 78.29 2.23 9.34
Table 1: Quantitative segmentation results on the LA dataset. The backbone network of all evaluated methods are V-Net.

3.0.1 Dataset and Pre-processing

We conduct our experiments on the Left Atrium (LA) dataset from Atrial Segmentation Challenge111http://atriaseg2018.cardiacatlas.org/. The dataset comprises of 100 3D gadolinium-enhanced MR imaging scans (GE-MRIs) with expert annotations, with an isotropic resolution of . Following the experimental setting in [29]

, we use 80 scans for training, and 20 scans for evaluation. We employ the same pre-processing methods by cropping all the scans at the heart region and normalized as zero and unit variance.

Figure 2: Visual comparisons with other methods. As observed, our CVRL achieves superior performance with more accurate borders and shapes.

3.0.2 Implementation Details

In our framework, we use V-Net as the network backbone for two networks. All the training sub-volumes are augmented by random cropping to . For data augmentation, we use standard data augmentation techniques [29, 23]. We empirically set the hyper-parameters , , , , ,  as , , , , , , respectively. We use SGD optimizer with a momentum and weight decay to optimize the network parameters. The initial learning rate is set as and divided by every iterations. For EMA updates, we follow the experimental setting in [29], where the EMA decay rate is set to . We use the time-dependent Gaussian warming-up function  to ramp up parameters, where and

denote the current and the maximum training step, respectively. For fairness, all evaluated methods are implemented in PyTorch, and trained for

iterations on an NVIDIA 1080Ti GPU with batch size .

In the testing stage, we adopt four metrics to evaluate the segmentation performance, including Dice coefficient (Dice), Jaccard Index (Jaccard), 95% Hausdorff Distance (95HD), and Average Symmetric Surface Distance (ASD). To eliminate isolated extraneous regions, we use the non-maximum suppression (NMS) method as the post-processing step.

3.0.3 Comparison with Other Semi-supervised Methods

We evaluate our CVRL with several state-of-the-art semi-supervised segmentation methods on different amounts of labeled data, including V-Net [18], DAP [32], UA-MT [29], LG-ER-MT [9], and SASSNet [15]. Table 1 compares our segmentation results with other methods.

We first conduct experiments under 20% annotation ratios (16 labeled and 64 unlabeled). Under this setting, most above approaches achieve superior segmentation performance. CVRL gives slightly better performance thanks to its global and local voxel-wise feature extraction. In particular, our proposed method can be further improved with NMS, outperforming other end-to-end semi-supervised methods in Dice (89.87%), Jaccard (81.65%), ASD (1.72), and 95HD (6.96).

To further evaluate the effectiveness of CVRL, we compare it with other methods in 10% annotation ratio (8 labeled and 72 unlabeled), as reported in Table 1. We observe consistent performance improvements over state-of-the-arts, in terms of Dice (87.41%), and Jaccard (77.80%). Meanwhile, the results in Table 1 suggest that with NMS, CVRL achieves consistent performance gains. This evidence that i). taking voxel samples with contrastive learning yields better voxel embeddings; ii) both global voxel-to-volume and local voxel-to-voxel relations are informative cues; iii) utilizing an active contour model is capable of helping identify more accurate boundaries. Leveraging all these aspects, it can observe consistent performance gains. As shown in Fig. 2, our method is capable of generating more accurate segmentation, considering the fact the improvement in such setting is difficult. This demonstrates i) the necessity of comprehensively considering both global voxel-to-volume contrast and local voxel-to-voxel contrast; and ii) efficacy of elastica (curvature and length) and region constraints.

Method # scans used Metrics
Labeled Unlabeled Dice[%] Jaccard[%] ASD[voxel] 95HD[voxel]
Baseline 8 72 83.09 71.75 5.53 19.65
Baseline+ 8 72 85.56 75.05 4.43 16.81
Baseline+ 8 72 86.41 76.35 4.40 16.93
Baseline+ 8 72 85.77 75.39 5.17 18.98
Baseline++ 8 72 87.23 77.53 3.50 13.61
Baseline+++ 8 72 87.41 77.80 3.31 13.40
Table 2: Ablation study for the key component modules of CVRL on the LA dataset with 10% annotation ratio (8 labeled and 72 unlabeled).

3.0.4 Ablation Study

We perform ablation experiments to validate the effectiveness of major components in our proposed method, including global contrastive strategy (global projection head), local contrastive strategy, and ACE loss. The quantitative results is reported in Table 2. We compare CVRL with its five variants under 10% annotation ratio (8 labeled and 72 unlabeled). Specially, the Baseline model refers to UA-MT-UN [29] without uncertainty measures. We gradually incorporate , , , denoted as Baseline+, Baseline+, Baseline+, Baseline++, Baseline+++ (CVRL), respectively. As shown in the table, the Baseline network achieve 83.09%, 71.75%, 5.53, 19.65 in terms of Dice, Jaccard, ASD, and 95HD. With the progressive introduction of , , , our proposed algorithm enjoys consistently improvement gains over the Baseline network, boosting Dice and Jaccard by 4.32%, 6.05%, respectively. Also, the metrics ASD and 95HD are reduced by 2.22 and 6.25, respectively. This further validates the effectiveness of each key component.

4 Conclusion

In this work, we propose CVRL, a semi-supervised contrastive representation learning framework by leveraging global and local cues to learn voxel-wise representations for volumetric medical image segmentation. Specifically, we propose to use global and local contrastive learning to exploit complex relations among training voxels. To further constrain the segmentation process, we use ACE loss to capture more geometrical information. Experimental results demonstrate that our model yields state-of-the-art performance with generating more accurate boundaries with very limited annotations.

References

  • [1] Bai, W., Chen, C., Tarroni, G., Duan, J., Guitton, F., Petersen, S.E., Guo, Y., Matthews, P.M., Rueckert, D.: Self-supervised learning for cardiac mr image segmentation by anatomical position prediction. In: MICCAI. pp. 541–549. Springer (2019)
  • [2] Bortsova, G., Dubost, F., Hogeweg, L., Katramados, I., de Bruijne, M.: Semi-supervised medical image segmentation via learning consistency under transformations. In: MICCAI. pp. 810–818. Springer (2019)
  • [3] Chaitanya, K., Erdil, E., Karani, N., Konukoglu, E.: Contrastive learning of global and local features for medical image segmentation with limited annotations. In: NeurIPS (2020)
  • [4] Chen, T., Kornblith, S., Norouzi, M., Hinton, G.: A simple framework for contrastive learning of visual representations. In: ICML. pp. 1597–1607. PMLR (2020)
  • [5] Chen, X., Luo, X., Zhao, Y., Zhang, S., Wang, G., Zheng, Y.: Learning euler’s elastica model for medical image segmentation. arXiv preprint arXiv:2011.00526 (2020)
  • [6] Chen, X., Williams, B.M., Vallabhaneni, S.R., Czanner, G., Williams, R., Zheng, Y.: Learning active contour models for medical image segmentation. In: CVPR. pp. 11632–11640 (2019)
  • [7] Doersch, C., Gupta, A., Efros, A.A.: Unsupervised visual representation learning by context prediction. In: ICCV. pp. 1422–1430 (2015)
  • [8] Hadsell, R., Chopra, S., LeCun, Y.: Dimensionality reduction by learning an invariant mapping. In: CVPR. vol. 2, pp. 1735–1742. IEEE (2006)
  • [9] Hang, W., Feng, W., Liang, S., Yu, L., Wang, Q., Choi, K.S., Qin, J.: Local and global structure-aware entropy regularized mean teacher model for 3d left atrium segmentation. In: MICCAI. pp. 562–571. Springer (2020)
  • [10] Hatamizadeh, A., Sengupta, D., Terzopoulos, D.: End-to-end trainable deep active contour models for automated image segmentation: Delineating buildings in aerial imagery. In: ECCV. pp. 730–746. Springer (2020)
  • [11] He, Y., Yang, G., Chen, Y., Kong, Y., Wu, J., Tang, L., Zhu, X., Dillenseger, J.L., Shao, P., Zhang, S., et al.: Dpa-densebiasnet: Semi-supervised 3d fine renal artery segmentation with dense biased network and deep priori anatomy. In: MICCAI. pp. 139–147. Springer (2019)
  • [12]

    Hjelm, R.D., Fedorov, A., Lavoie-Marchildon, S., Grewal, K., Bachman, P., Trischler, A., Bengio, Y.: Learning deep representations by mutual information estimation and maximization. In: ICLR (2019)

  • [13]

    Kass, M., Witkin, A., Terzopoulos, D.: Snakes: Active contour models. International journal of computer vision

    1(4), 321–331 (1988)
  • [14] Laine, S., Aila, T.: Temporal ensembling for semi-supervised learning. arXiv preprint arXiv:1610.02242 (2016)
  • [15] Li, S., Zhang, C., He, X.: Shape-aware semi-supervised 3d semantic segmentation for medical images. In: MICCAI. pp. 552–561. Springer (2020)
  • [16] Li, X., Yu, L., Chen, H., Fu, C.W., Heng, P.A.: Semi-supervised skin lesion segmentation via transformation consistent self-ensembling model. arXiv preprint arXiv:1808.03887 (2018)
  • [17] Ma, J., He, J., Yang, X.: Learning geodesic active contours for embedding object global information in segmentation cnns. IEEE Transactions on Medical Imaging (2020)
  • [18]

    Milletari, F., Navab, N., Ahmadi, S.A.: V-net: Fully convolutional neural networks for volumetric medical image segmentation. In: 3DV. pp. 565–571. IEEE (2016)

  • [19] Misra, I., Maaten, L.v.d.: Self-supervised learning of pretext-invariant representations. In: CVPR. pp. 6707–6717 (2020)
  • [20] Nie, D., Gao, Y., Wang, L., Shen, D.: Asdnet: Attention based semi-supervised deep networks for medical image segmentation. In: MICCAI. pp. 370–378. Springer (2018)
  • [21] Noroozi, M., Favaro, P.: Unsupervised learning of visual representations by solving jigsaw puzzles. In: ECCV. pp. 69–84. Springer (2016)
  • [22] Oord, A.v.d., Li, Y., Vinyals, O.: Representation learning with contrastive predictive coding. arXiv preprint arXiv:1807.03748 (2018)
  • [23] Taleb, A., Loetzsch, W., Danz, N., Severin, J., Gaertner, T., Bergner, B., Lippert, C.: 3d self-supervised methods for medical imaging. In: NeurIPS. pp. 18158–18172 (2020)
  • [24]

    Tarvainen, A., Valpola, H.: Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results. In: NeurIPS. pp. 1195–1204 (2017)

  • [25] Tian, Y., Krishnan, D., Isola, P.: Contrastive multiview coding. arXiv preprint arXiv:1906.05849 (2019)
  • [26] Wang, X., Gupta, A.: Unsupervised learning of visual representations using videos. In: ICCV. pp. 2794–2802 (2015)
  • [27] Wu, Z., Xiong, Y., Stella, X.Y., Lin, D.: Unsupervised feature learning via non-parametric instance discrimination. In: CVPR (2018)
  • [28] Xue, Y., Tang, H., Qiao, Z., Gong, G., Yin, Y., Qian, Z., Huang, C., Fan, W., Huang, X.: Shape-aware organ segmentation by predicting signed distance maps. In: AAAI. vol. 34, pp. 12565–12572 (2020)
  • [29] Yu, L., Wang, S., Li, X., Fu, C.W., Heng, P.A.: Uncertainty-aware self-ensembling model for semi-supervised 3d left atrium segmentation. In: MICCAI. pp. 605–613. Springer (2019)
  • [30] Zhang, M., Dong, B., Li, Q.: Deep active contour network for medical image segmentation. In: MICCAI. pp. 321–331. Springer (2020)
  • [31] Zhang, Y., Yang, L., Chen, J., Fredericksen, M., Hughes, D.P., Chen, D.Z.: Deep adversarial networks for biomedical image segmentation utilizing unannotated images. In: MICCAI. pp. 408–416. Springer (2017)
  • [32] Zheng, H., Lin, L., Hu, H., Zhang, Q., Chen, Q., Iwamoto, Y., Han, X., Chen, Y.W., Tong, R., Wu, J.: Semi-supervised segmentation of liver using adversarial learning with deep atlas prior. In: MICCAI. pp. 148–156. Springer (2019)
  • [33] Zhuang, X., Li, Y., Hu, Y., Ma, K., Yang, Y., Zheng, Y.: Self-supervised feature learning for 3d medical images by playing a rubik’s cube. In: MICCAI. pp. 420–428. Springer (2019)