DeepAI
Log In Sign Up

Shape-aware Semi-supervised 3D Semantic Segmentation for Medical Images

Semi-supervised learning has attracted much attention in medical image segmentation due to challenges in acquiring pixel-wise image annotations, which is a crucial step for building high-performance deep learning methods. Most existing semi-supervised segmentation approaches either tend to neglect geometric constraint in object segments, leading to incomplete object coverage, or impose strong shape prior that requires extra alignment. In this work, we propose a novel shapeaware semi-supervised segmentation strategy to leverage abundant unlabeled data and to enforce a geometric shape constraint on the segmentation output. To achieve this, we develop a multi-task deep network that jointly predicts semantic segmentation and signed distance map(SDM) of object surfaces. During training, we introduce an adversarial loss between the predicted SDMs of labeled and unlabeled data so that our network is able to capture shape-aware features more effectively. Experiments on the Atrial Segmentation Challenge dataset show that our method outperforms current state-of-the-art approaches with improved shape estimation, which validates its efficacy. Code is available at https://github.com/kleinzcy/SASSnet.

READ FULL TEXT VIEW PDF
09/09/2020

Semi-supervised Medical Image Segmentation through Dual-task Consistency

Deep learning-based semi-supervised learning (SSL) algorithms have led t...
08/12/2020

Guided Collaborative Training for Pixel-wise Semi-Supervised Learning

We investigate the generalization of semi-supervised learning (SSL) to d...
02/22/2021

Adversarial Shape Learning for Building Extraction in VHR Remote Sensing Images

Building extraction in VHR RSIs remains to be a challenging task due to ...
11/23/2021

Uncertainty-Aware Deep Co-training for Semi-supervised Medical Image Segmentation

Semi-supervised learning has made significant strides in the medical dom...
12/04/2018

Multiview Cross-supervision for Semantic Segmentation

This paper presents a semi-supervised learning framework for a customize...
12/10/2021

PyTorch Connectomics: A Scalable and Flexible Segmentation Framework for EM Connectomics

We present PyTorch Connectomics (PyTC), an open-source deep-learning fra...
01/12/2023

Self-Supervised Correction Learning for Semi-Supervised Biomedical Image Segmentation

Biomedical image segmentation plays a significant role in computer-aided...

1 Introduction

Semantic object segmentation is a fundamental task in medical image analysis and has been widely used in automatic delineation of regions of interest in 3D medical images, such as cells, tissues or organs. Recently, tremendous progress has been made in medical semantic segmentation [15]

thanks to modern deep convolutional networks, which achieve state-of-the-art performances in many real-world tasks. However, training deep neural networks often requires a large amount of annotated data, which is particularly expensive in medical segmentation problems. In order to reduce labeling cost, a promising approach is to adopt a semi-supervised learning 

[1, 2] framework that typically utilizes a small labeled dataset and many unlabeled images for effective model training.

Recent efforts in semi-supervised segmentation have been focused on incorporating unlabeled data into convolutional network training, which can be largely categorized into two groups. The first group of those methods mainly consider the generic setting of semi-supervised segmentation [19, 7, 11, 20, 8, 16, 18, 3, 9]. Most of them adopt adversarial learning or consistency loss as regularization in order to leverage unlabeled data for model learning. The adversarial learning methods [19, 7, 11, 20] enforces the distributions of segmentation of unlabeled and labeled images to be close while the consistency loss approaches [8, 16, 18, 3, 9] utilize a teacher-student network design and require their outputs being consistent under random perturbation or transformation of input images. To cope with difficult regions, Nie et al. [11] utilize adversarial learning to select regions of unlabeled images with high confidence to train the segmentation network. Yu et al. [18] introduce an uncertainty map based on the mean-teacher framework [16] to guide student network learning. Despite their promising results, those methods lack explicit modeling of the geometric prior of semantic objects, often leading to poor object coverage and/or boundary prediction.

The second group of semi-supervised methods attempt to address the above drawback by incorporating a strong anatomical prior on the object of interest in their model learning [20, 6]. For instance, Zheng et al. [20] introduce the Deep Atlas Prior (DAP) model that encodes a probabilistic shape prior in its loss design. He et al. [6] propose an auto-encoder to learn priori anatomical features on unlabeled dataset. However, such prior typically assumes properly aligned input images, which is difficult to achieve in practice for objects with large variation in pose or shape.

In this work, we propose a novel shape-aware semi-supervised segmentation strategy to address the aforementioned limitations. Our main idea is to incorporate a more flexible geometric representation in the network so that we are able to enforce a global shape constraint on the segmentation output, and meanwhile to handle objects with varying poses or shapes. Such a “shape-aware” representation enables us to capture the global shape of each object class more effectively. Moreover, by exploiting consistency of the geometric representations between labeled and unlabeled images, we aim to design a simple and yet effective semi-supervised learning strategy for deep segmentation networks.

To achieve this, we develop a multi-task deep network that jointly predicts semantic segmentation and signed distance map (SDM) [13, 4, 12, 17] with a shared backbone network module. The SDM assigns each pixel a value indicating its signed distance to the nearest boundary of target object, which provides a shape-aware representation that encodes richer features of object shape and surface. To utilize the unlabeled data, we then introduce an adversarial loss between the predicted SDMs of labeled and unlabeled data for semi-supervised learning. This allows the model to learn shape-aware features more effectively by enforcing similar distance map distributions on the entire dataset. In addition, the SDM naturally imposes more weights on the interior region of each semantic class, which can be viewed as a proxy of confidence measure. In essence, we introduce an implicit shape prior and its regularization based on an adversarial loss for semi-supervised volumetric segmentation.

We evaluate our approach on the Atrial Segmentation Challenge dataset with extensive comparisons to prior arts. The results demonstrate that our segmentation network outperforms the state-of-the-art methods and generates object segmentation with high-quality global shapes.

Our main contributions are three-folds: (1) We propose a novel shape-aware semi-supervised segmentation approach by enforcing geometric constraints on labeled and unlabeled data. (2) We develop a multi-task loss on segmentation and SDM predictions, and impose global consistency in object shapes through adversarial learning. (3) Our method achieves strong performance on the Atrial Segmentation Challenge dataset with only a small number of labeled data.

2 Method

Figure 1: Overview of our method. Our network takes as input a 3D volume, and predicts a 3D SDM and a segmentation map. Our learning loss consists of a multi-task supervised term and an adversarial loss on the SDM predictions.

2.1 Overview

We aim to build a deep neural network for medical image segmentation in a semi-supervised setting in order to reduce annotation cost. Due to lack of annotated images, our key challenge is to regularize the network learning effectively from a set of unlabeled ones. In this paper, we tackle this problem by utilizing the regularity in geometric shapes of the target object class, which provides an effective constraint for both segment prediction and network learning.

Specifically, we propose to incorporate a shape-aware representation of object segments into the deep network prediction. In particular, we develop a multi-task segmentation network that takes a 3D image as input and jointly predicts a segmentation map and a SDM of object segmentation. Based on this SDM representation, we then design a semi-supervised learning loss for training the segmentation network. Our loss mainly consists of two components, one for the network predictions on the labeled set while the other enforcing consistency between the SDM predictions on the labeled and unlabeled set. To achieve effective consistency constraint, we adopt an adversarial loss that encourages the segmentation network to produce segment predictions with similar distributions on both datasets. Figure 1 illustrates the overall pipeline of our semi-supervised segmentation network. Below we will introduce the detailed model design in Section 2.2, followed by the learning loss and network training in Section 2.3.

2.2 Segmentation Network

In order to encode geometric shape of a target semantic class, we propose a multi-task segmentation network that jointly predicts a 3D object mask and its SDM for the input 3D volume. Our network has a V-Net [10] structure that consists of an encoder module and a decoder module with two output branches, one for the segmentation map and the other for the SDM. For notation clarity, we mainly focus on the single-class setting below111It is straightforward to generalize our formulation to the multi-class setting by treating each semantic class separately for SDMs..

Specifically, we employ a V-Net backbone as in [18], and then add a light-weighted SDM head in parallel with the original segmentation head. Our SDM head is composed by a 3D convolution block followed by the activation. Given an input image , the segmentation head generates a confidence score map and the SDM head predicts a SDM as follows:

(1)

where are the parameters of our segmentation network, and each element of indicates the signed distance of a corresponding voxel to its closest surface point after normalization [17].

2.3 Shape-aware Semi-supervised Learning

We now introduce our semi-supervised learning strategy for the segmentation network. While prior methods typically rely on the segmentation output , we instead utilize the shape-aware representation to regularize the network training. To this end, we develop a multi-task loss consisting of a supervised loss on the labeled set and an adversarial loss on the entire set to enforce consistency of the model predictions.

Formally, we assume a standard semi-supervised learning setting, in which the training set contains labeled data and unlabeled data, where . We denote the labeled set as and unlabeled set as , where are the input volumes, are the segmentation annotations and are the groundtruth SDMs derived from . Below we first describe the supervised loss on followed by the adversarial loss that utilizes the unlabeled set .

2.3.1 Supervised Loss

On the labeled set, we employ a dice loss and a mean square loss for the segmentation and SDM output of the multi-task segmentation network, respectively:

(2)
(3)

where denotes the segmentation loss and is the SDM loss, and is a weighting coefficient balancing two loss terms.

2.3.2 Adversarial Loss

To regularize the model learning with the unlabeled data, we introduce an adversarial loss that enforces the consistency of SDM predictions on the labeled and unlabeled set. To this end, we propose a discriminator network to tell apart the predicted SDMs from the labeled set, which should be high-quality due to the supervision, and the ones from the unlabeled set. Minimizing the adversarial loss induced by this discriminator enables us to learn effective shape-aware features that generalizes well to the unlabeled dataset.

Specifically, we adopt a similar discriminator network as [14]

, which consists of 5 convolution layers followed by an MLP. The network takes a SDM and input volume as input, fuses them through convolution layers, and predicts its class probability of being labeled data. Given the discriminator

, we denote its parameter as and define the adversarial loss as follows,

(4)

where and are the predicted SDMs.

2.3.3 Overall Training Pipeline

Our overall training objective combines the supervised and the adversarial loss defined above and the learning task can be written as,

(5)

where is a weight coefficient that balances two loss terms. We adopt a standard alternating procedure to train the entire network, which includes the following two subproblems.

Given a fixed discriminator , we minimize the overall loss w.r.t the segmentation network parameter . To speed up model learning, we simplify the loss in two steps: Firstly, we ignore the first loss term in Eqn (4) due to high-quality SDM predictions on the labeled set, i.e., , and additionally, we adopt a similar surrogate loss for the generator as in [5]. Hence the learning problem for the segmentation network can be written as,

(6)

On the other hand, given a fixed segmentation network, we simply minimize the binary cross entropy loss induced by Eqn (5) to train the discriminator, i.e., , or . To stablize the overall training, we use an annealing strategy based on a time-dependent Gaussian warm-up function to slowly increase the loss weight (See Sec. 3 for details).

3 Experiments and Results

Method # scans used Metrics
Labeled Unlabeled Dice[%] Jaccard[%] ASD[voxel] 95HD[voxel]
V-Net 80 0 91.14 83.82 1.52 5.75
V-Net 16 0 86.03 76.06 3.51 14.26
DAP [20] 16 64 87.89 78.72 2.74 9.29
ASDNet [11] 16 64 87.90 78.85 2.08 9.24
TCSE [9] 16 64 88.15 79.20 2.44 9.57
UA-MT [18] 16 64 88.88 80.21 2.26 7.32
UA-MT(+NMS) 16 64 89.11 80.62 2.21 7.30
SASSNet(ours) 16 64 89.27 80.82 3.13 8.83
SASSNet(+NMS) 16 64 89.54 81.24 2.20 8.24
V-Net 8 0 79.99 68.12 5.48 21.11
DAP [20] 8 72 81.89 71.23 3.80 15.81
UA-MT [18] 8 72 84.25 73.48 3.36 13.84
UA-MT(+NMS) 8 72 84.57 73.96 2.90 12.51
SASSNet(ours) 8 72 86.81 76.92 3.94 12.54
SASSNet(+NMS) 8 72 87.32 77.72 2.55 9.62
Table 1: Quantitative comparisons of semi-supervised segmentation models on the LA dataset. All models use the V-Net as backbone network. Results on two different data partition settings show that our SASSNet outperforms the state-of-the-art results consistently.

We validate our method on the Left Atrium (LA) dataset from Atrial Segmentation Challenge222http://atriaseg2018.cardiacatlas.org/ with detailed comparisons to prior arts. The dataset contains 100 3D gadolinium-enhanced MR imaging scans (GE-MRIs) and LA segmentation masks, with an isotropic resolution of . Following [18], we split them into 80 scans for training and 20 scans for validation, and apply the same pre-processing methods.

3.0.1 Implementation Details and Metrics.

The segmentation network is trained by a SGD optimizer for 6000 iterations, with an initial learning rate (lr) 0.01 decayed by 0.1 every 2500 iterations. The discriminator uses

kernels with stride 2 in its convolutional layers and an Adam optimizer with a constant lr 0.0001. We use a batch size of 4 images and a single GPU with 12Gb RAM for the model training. In all our experiments, we set

as 0.3 and as a time-dependent Gaussian warming-up function where indicates number of iterations.

During testing, we take the segmentation map output

for evaluation. In addition, an non-maximum suppression (NMS) is applied as the post process in order to remove isolated extraneous regions. We use the standard evaluation metrics, including Dice coefficient (Dice), Jaccard Index (Jaccard), 95% Hausdorff Distance (95HD) and Average Symmetric Surface Distance (ASD).

(a) 2D comparison
(b) 3D comparison
Figure 2: 2D and 3D Visualization of the segmentations by UA-MT [18] and our method, where GT denotes groundtruth segmetnation.

3.0.2 Quantitative Evaluation and Comparison.

We evaluate our method in two different settings with comparisons to several recent semi-supervised segmentation approaches, including DAP [20], ASDNet [11], TCSE [9] and UA-MT [18]. Table 1 presents a summary of the quantitative results, in which we first show the upper-bound performance achieved by a fully-supervised network, followed by two individual settings.

The first setting follows the work [18], which takes 20% of training data as labeled data (16 labeled), and the others as unlabeled data for semi-supervised training. We can see that this setting is relative easy as the model trained with 20% of data already achieves good performance (86.03% in Dice). Among the semi-supervised methods, the DAP performs worst, indicating the limitation of an atlas-based prior, while UA-MT achieves the top performance in the previous methods. Our method outperforms all the other semi-supervised networks in both Dice (89.54%) and Jaccard (81.24%), and achieves competitive results on other metrics. In particular, our SASSNet surpasses UA-MT in Dice without resorting to a complex multiple network architecture.

To validate the robustness of our method, we also consider a more challenging setting in which we only have 8 labeled images for training. The second half of Table 1 show the comparison results, where SASSNet outperforms UA-MT with a large margin (Dice: +2.56% without NMS and +3.07% with NMS). Without NMS, our SASSNet tends to generate more foreground regions, which leads to slightly worse performance on ASD and 95HD. However, it also produce better segmentation preserving the original object shape. By contrast, UA-MT often misses inner regions of target objects and generates irregular shapes. Figure 2 provides several qualitative results for visual comparison.

Method # scans used Metrics Cost
Labeled Unlabeled Dice[%] Jaccard[%] ASD[voxel] 95HD[voxel] Params[M]
V-Net 8 0 79.99 68.12 5.48 21.11 187.7
V-Net +SDM 8 0 81.12 69.75 6.93 25.58 187.9
V-Net +SDM +GAN 8 72 86.81 76.92 3.94 12.54 249.7
UA-MT [18] 8 72 84.25 73.48 3.36 13.84 375.5
V-Net +SDM +MT 8 72 84.97 74.14 6.12 22.20 375.8
Table 2: Effectiveness of our proposed modules on the LA dataset. All the models use the same V-Net as the backbone, and we conduct an ablative study to show the contribution of each component module.

3.0.3 Ablative Study.

We conduct several detailed experimental studies to examine the effectiveness of our proposed SDM head and the adversarial loss (GAN). Table 2 shows the quantitative results of different model settings. The first row is a V-Net trained with only the labeled data, which is our base model. We first add a SDM head, denoted as V-Net+SDM, and as shown in the second row, such joint learning improves segmentation results by 1.1% in Dice. We then add the unlabeled data and our adversarial loss, denoted as V-Net+SDM+GAN, which significantly improves the performance (5.7% in Dice).

We also compare our semi-supervised learning strategy with two methods in the mean-teacher (MT) framework (last two rows). One is the original UA-MT and the other is our segmentation network with the MT consistency loss. Our SASSNet outperforms both methods with higher Dice and Jaccard scores, which indicates the advantage of our representation and loss design. Moreover, our network has a much simpler architecture than those two networks.

4 Conclusion

In this paper, we proposed a shape-aware semi-supervised segmentation approach for 3D medical scans. In contrast to previous methods, our method exploits the regularity in geometric shapes of the target object class for effective segment prediction and network learning. We developed a multi-task segmentation network that jointly predicts semantic segmentation and SDM of object surfaces, and a semi-supervised learning loss enforcing consistency between the predicted SDMs of labeled and unlabeled data. We validated our approach on the Atrial Segmentation Challenge dataset, which demonstrates that our segmentation network outperforms the state-of-the-art methods and generates object segmentation with high-quality global shapes.

References

  • [1] W. Bai, O. Oktay, M. Sinclair, H. Suzuki, M. Rajchl, G. Tarroni, B. Glocker, A. King, P. M. Matthews, and D. Rueckert (2017) Semi-supervised learning for network-based cardiac mr image segmentation. In MICCAI, pp. 253–260. Cited by: §1.
  • [2] C. Baur, S. Albarqouni, and N. Navab (2017) Semi-supervised deep learning for fully convolutional networks. In MICCAI, pp. 311–319. Cited by: §1.
  • [3] G. Bortsova, F. Dubost, L. Hogeweg, I. Katramados, and M. de Bruijne (2019) Semi-supervised medical image segmentation via learning consistency under transformations. In MICCAI, pp. 810–818. Cited by: §1.
  • [4] S. Dangi, C. A. Linte, and Z. Yaniv (2019) A distance map regularized cnn for cardiac cine mr image segmentation. Medical physics 46 (12), pp. 5637–5651. Cited by: §1.
  • [5] I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, and Y. Bengio (2014) Generative adversarial nets. In NIPS, pp. 2672–2680. Cited by: §2.3.3.
  • [6] Y. He, G. Yang, Y. Chen, Y. Kong, J. Wu, L. Tang, X. Zhu, J. Dillenseger, P. Shao, S. Zhang, et al. (2019) DPA-densebiasnet: semi-supervised 3d fine renal artery segmentation with dense biased network and deep priori anatomy. In MICCAI, pp. 139–147. Cited by: §1.
  • [7] W. C. Hung, Y. H. Tsai, Y. T. Liou, Y. Y. Lin, and M. H. Yang (2018) Adversarial learning for semi-supervised semantic segmentation. In BMVC, Cited by: §1.
  • [8] S. Laine and T. Aila (2016) Temporal ensembling for semi-supervised learning. arXiv preprint arXiv:1610.02242. Cited by: §1.
  • [9] X. Li, L. Yu, H. Chen, C. Fu, and P. Heng (2018) Semi-supervised skin lesion segmentation via transformation consistent self-ensembling model. arXiv preprint arXiv:1808.03887. Cited by: §1, §3.0.2, Table 1.
  • [10] F. Milletari, N. Navab, and S. Ahmadi (2016)

    V-net: fully convolutional neural networks for volumetric medical image segmentation

    .
    In 3DV, pp. 565–571. Cited by: §2.2.
  • [11] D. Nie, Y. Gao, L. Wang, and D. Shen (2018) Asdnet: attention based semi-supervised deep networks for medical image segmentation. In MICCAI, pp. 370–378. Cited by: §1, §3.0.2, Table 1.
  • [12] J. J. Park, P. Florence, J. Straub, R. Newcombe, and S. Lovegrove (2019) Deepsdf: learning continuous signed distance functions for shape representation. In CVPR, pp. 165–174. Cited by: §1.
  • [13] S. Perera, N. Barnes, X. He, S. Izadi, P. Kohli, and B. Glocker (2015) Motion segmentation of truncated signed distance function based volumetric surfaces. In WACV, pp. 1046–1053. Cited by: §1.
  • [14] A. Radford, L. Metz, and S. Chintala (2015)

    Unsupervised representation learning with deep convolutional generative adversarial networks

    .
    arXiv preprint arXiv:1511.06434. Cited by: §2.3.2.
  • [15] S. A. Taghanaki, K. Abhishek, J. P. Cohen, J. Cohen-Adad, and G. Hamarneh (2019) Deep semantic segmentation of natural and medical images: a review. arXiv preprint arXiv:1910.07655. Cited by: §1.
  • [16] A. Tarvainen and H. Valpola (2017) Mean teachers are better role models: weight-averaged consistency targets improve semi-supervised deep learning results. In NIPS, pp. 1195–1204. Cited by: §1.
  • [17] Y. Xue, H. Tang, Z. Qiao, G. Gong, Y. Yin, Z. Qian, C. Huang, W. Fan, and X. Huang (2019) Shape-aware organ segmentation by predicting signed distance maps. arXiv preprint arXiv:1912.03849. Cited by: §1, §2.2.
  • [18] L. Yu, S. Wang, X. Li, C. Fu, and P. Heng (2019) Uncertainty-aware self-ensembling model for semi-supervised 3d left atrium segmentation. In MICCAI, pp. 605–613. Cited by: §1, §2.2, Figure 2, §3.0.2, §3.0.2, Table 1, Table 2, §3.
  • [19] Y. Zhang, L. Yang, J. Chen, M. Fredericksen, D. P. Hughes, and D. Z. Chen (2017) Deep adversarial networks for biomedical image segmentation utilizing unannotated images. In MICCAI, pp. 408–416. Cited by: §1.
  • [20] H. Zheng, L. Lin, H. Hu, Q. Zhang, Q. Chen, Y. Iwamoto, X. Han, Y. Chen, R. Tong, and J. Wu (2019) Semi-supervised segmentation of liver using adversarial learning with deep atlas prior. In MICCAI, pp. 148–156. Cited by: §1, §1, §3.0.2, Table 1.