1 Introduction
Semantic object segmentation is a fundamental task in medical image analysis and has been widely used in automatic delineation of regions of interest in 3D medical images, such as cells, tissues or organs. Recently, tremendous progress has been made in medical semantic segmentation [15]
thanks to modern deep convolutional networks, which achieve state-of-the-art performances in many real-world tasks. However, training deep neural networks often requires a large amount of annotated data, which is particularly expensive in medical segmentation problems. In order to reduce labeling cost, a promising approach is to adopt a semi-supervised learning
[1, 2] framework that typically utilizes a small labeled dataset and many unlabeled images for effective model training.Recent efforts in semi-supervised segmentation have been focused on incorporating unlabeled data into convolutional network training, which can be largely categorized into two groups. The first group of those methods mainly consider the generic setting of semi-supervised segmentation [19, 7, 11, 20, 8, 16, 18, 3, 9]. Most of them adopt adversarial learning or consistency loss as regularization in order to leverage unlabeled data for model learning. The adversarial learning methods [19, 7, 11, 20] enforces the distributions of segmentation of unlabeled and labeled images to be close while the consistency loss approaches [8, 16, 18, 3, 9] utilize a teacher-student network design and require their outputs being consistent under random perturbation or transformation of input images. To cope with difficult regions, Nie et al. [11] utilize adversarial learning to select regions of unlabeled images with high confidence to train the segmentation network. Yu et al. [18] introduce an uncertainty map based on the mean-teacher framework [16] to guide student network learning. Despite their promising results, those methods lack explicit modeling of the geometric prior of semantic objects, often leading to poor object coverage and/or boundary prediction.
The second group of semi-supervised methods attempt to address the above drawback by incorporating a strong anatomical prior on the object of interest in their model learning [20, 6]. For instance, Zheng et al. [20] introduce the Deep Atlas Prior (DAP) model that encodes a probabilistic shape prior in its loss design. He et al. [6] propose an auto-encoder to learn priori anatomical features on unlabeled dataset. However, such prior typically assumes properly aligned input images, which is difficult to achieve in practice for objects with large variation in pose or shape.
In this work, we propose a novel shape-aware semi-supervised segmentation strategy to address the aforementioned limitations. Our main idea is to incorporate a more flexible geometric representation in the network so that we are able to enforce a global shape constraint on the segmentation output, and meanwhile to handle objects with varying poses or shapes. Such a “shape-aware” representation enables us to capture the global shape of each object class more effectively. Moreover, by exploiting consistency of the geometric representations between labeled and unlabeled images, we aim to design a simple and yet effective semi-supervised learning strategy for deep segmentation networks.
To achieve this, we develop a multi-task deep network that jointly predicts semantic segmentation and signed distance map (SDM) [13, 4, 12, 17] with a shared backbone network module. The SDM assigns each pixel a value indicating its signed distance to the nearest boundary of target object, which provides a shape-aware representation that encodes richer features of object shape and surface. To utilize the unlabeled data, we then introduce an adversarial loss between the predicted SDMs of labeled and unlabeled data for semi-supervised learning. This allows the model to learn shape-aware features more effectively by enforcing similar distance map distributions on the entire dataset. In addition, the SDM naturally imposes more weights on the interior region of each semantic class, which can be viewed as a proxy of confidence measure. In essence, we introduce an implicit shape prior and its regularization based on an adversarial loss for semi-supervised volumetric segmentation.
We evaluate our approach on the Atrial Segmentation Challenge dataset with extensive comparisons to prior arts. The results demonstrate that our segmentation network outperforms the state-of-the-art methods and generates object segmentation with high-quality global shapes.
Our main contributions are three-folds: (1) We propose a novel shape-aware semi-supervised segmentation approach by enforcing geometric constraints on labeled and unlabeled data. (2) We develop a multi-task loss on segmentation and SDM predictions, and impose global consistency in object shapes through adversarial learning. (3) Our method achieves strong performance on the Atrial Segmentation Challenge dataset with only a small number of labeled data.
2 Method
2.1 Overview
We aim to build a deep neural network for medical image segmentation in a semi-supervised setting in order to reduce annotation cost. Due to lack of annotated images, our key challenge is to regularize the network learning effectively from a set of unlabeled ones. In this paper, we tackle this problem by utilizing the regularity in geometric shapes of the target object class, which provides an effective constraint for both segment prediction and network learning.
Specifically, we propose to incorporate a shape-aware representation of object segments into the deep network prediction. In particular, we develop a multi-task segmentation network that takes a 3D image as input and jointly predicts a segmentation map and a SDM of object segmentation. Based on this SDM representation, we then design a semi-supervised learning loss for training the segmentation network. Our loss mainly consists of two components, one for the network predictions on the labeled set while the other enforcing consistency between the SDM predictions on the labeled and unlabeled set. To achieve effective consistency constraint, we adopt an adversarial loss that encourages the segmentation network to produce segment predictions with similar distributions on both datasets. Figure 1 illustrates the overall pipeline of our semi-supervised segmentation network. Below we will introduce the detailed model design in Section 2.2, followed by the learning loss and network training in Section 2.3.
2.2 Segmentation Network
In order to encode geometric shape of a target semantic class, we propose a multi-task segmentation network that jointly predicts a 3D object mask and its SDM for the input 3D volume. Our network has a V-Net [10] structure that consists of an encoder module and a decoder module with two output branches, one for the segmentation map and the other for the SDM. For notation clarity, we mainly focus on the single-class setting below111It is straightforward to generalize our formulation to the multi-class setting by treating each semantic class separately for SDMs..
Specifically, we employ a V-Net backbone as in [18], and then add a light-weighted SDM head in parallel with the original segmentation head. Our SDM head is composed by a 3D convolution block followed by the activation. Given an input image , the segmentation head generates a confidence score map and the SDM head predicts a SDM as follows:
(1) |
where are the parameters of our segmentation network, and each element of indicates the signed distance of a corresponding voxel to its closest surface point after normalization [17].
2.3 Shape-aware Semi-supervised Learning
We now introduce our semi-supervised learning strategy for the segmentation network. While prior methods typically rely on the segmentation output , we instead utilize the shape-aware representation to regularize the network training. To this end, we develop a multi-task loss consisting of a supervised loss on the labeled set and an adversarial loss on the entire set to enforce consistency of the model predictions.
Formally, we assume a standard semi-supervised learning setting, in which the training set contains labeled data and unlabeled data, where . We denote the labeled set as and unlabeled set as , where are the input volumes, are the segmentation annotations and are the groundtruth SDMs derived from . Below we first describe the supervised loss on followed by the adversarial loss that utilizes the unlabeled set .
2.3.1 Supervised Loss
On the labeled set, we employ a dice loss and a mean square loss for the segmentation and SDM output of the multi-task segmentation network, respectively:
(2) | ||||
(3) |
where denotes the segmentation loss and is the SDM loss, and is a weighting coefficient balancing two loss terms.
2.3.2 Adversarial Loss
To regularize the model learning with the unlabeled data, we introduce an adversarial loss that enforces the consistency of SDM predictions on the labeled and unlabeled set. To this end, we propose a discriminator network to tell apart the predicted SDMs from the labeled set, which should be high-quality due to the supervision, and the ones from the unlabeled set. Minimizing the adversarial loss induced by this discriminator enables us to learn effective shape-aware features that generalizes well to the unlabeled dataset.
Specifically, we adopt a similar discriminator network as [14]
, which consists of 5 convolution layers followed by an MLP. The network takes a SDM and input volume as input, fuses them through convolution layers, and predicts its class probability of being labeled data. Given the discriminator
, we denote its parameter as and define the adversarial loss as follows,(4) |
where and are the predicted SDMs.
2.3.3 Overall Training Pipeline
Our overall training objective combines the supervised and the adversarial loss defined above and the learning task can be written as,
(5) |
where is a weight coefficient that balances two loss terms. We adopt a standard alternating procedure to train the entire network, which includes the following two subproblems.
Given a fixed discriminator , we minimize the overall loss w.r.t the segmentation network parameter . To speed up model learning, we simplify the loss in two steps: Firstly, we ignore the first loss term in Eqn (4) due to high-quality SDM predictions on the labeled set, i.e., , and additionally, we adopt a similar surrogate loss for the generator as in [5]. Hence the learning problem for the segmentation network can be written as,
(6) |
On the other hand, given a fixed segmentation network, we simply minimize the binary cross entropy loss induced by Eqn (5) to train the discriminator, i.e., , or . To stablize the overall training, we use an annealing strategy based on a time-dependent Gaussian warm-up function to slowly increase the loss weight (See Sec. 3 for details).
3 Experiments and Results
Method | # scans used | Metrics | ||||
Labeled | Unlabeled | Dice[%] | Jaccard[%] | ASD[voxel] | 95HD[voxel] | |
V-Net | 80 | 0 | 91.14 | 83.82 | 1.52 | 5.75 |
V-Net | 16 | 0 | 86.03 | 76.06 | 3.51 | 14.26 |
DAP [20] | 16 | 64 | 87.89 | 78.72 | 2.74 | 9.29 |
ASDNet [11] | 16 | 64 | 87.90 | 78.85 | 2.08 | 9.24 |
TCSE [9] | 16 | 64 | 88.15 | 79.20 | 2.44 | 9.57 |
UA-MT [18] | 16 | 64 | 88.88 | 80.21 | 2.26 | 7.32 |
UA-MT(+NMS) | 16 | 64 | 89.11 | 80.62 | 2.21 | 7.30 |
SASSNet(ours) | 16 | 64 | 89.27 | 80.82 | 3.13 | 8.83 |
SASSNet(+NMS) | 16 | 64 | 89.54 | 81.24 | 2.20 | 8.24 |
V-Net | 8 | 0 | 79.99 | 68.12 | 5.48 | 21.11 |
DAP [20] | 8 | 72 | 81.89 | 71.23 | 3.80 | 15.81 |
UA-MT [18] | 8 | 72 | 84.25 | 73.48 | 3.36 | 13.84 |
UA-MT(+NMS) | 8 | 72 | 84.57 | 73.96 | 2.90 | 12.51 |
SASSNet(ours) | 8 | 72 | 86.81 | 76.92 | 3.94 | 12.54 |
SASSNet(+NMS) | 8 | 72 | 87.32 | 77.72 | 2.55 | 9.62 |
We validate our method on the Left Atrium (LA) dataset from Atrial Segmentation Challenge222http://atriaseg2018.cardiacatlas.org/ with detailed comparisons to prior arts. The dataset contains 100 3D gadolinium-enhanced MR imaging scans (GE-MRIs) and LA segmentation masks, with an isotropic resolution of . Following [18], we split them into 80 scans for training and 20 scans for validation, and apply the same pre-processing methods.
3.0.1 Implementation Details and Metrics.
The segmentation network is trained by a SGD optimizer for 6000 iterations, with an initial learning rate (lr) 0.01 decayed by 0.1 every 2500 iterations. The discriminator uses
kernels with stride 2 in its convolutional layers and an Adam optimizer with a constant lr 0.0001. We use a batch size of 4 images and a single GPU with 12Gb RAM for the model training. In all our experiments, we set
as 0.3 and as a time-dependent Gaussian warming-up function where indicates number of iterations.During testing, we take the segmentation map output
for evaluation. In addition, an non-maximum suppression (NMS) is applied as the post process in order to remove isolated extraneous regions. We use the standard evaluation metrics, including Dice coefficient (Dice), Jaccard Index (Jaccard), 95% Hausdorff Distance (95HD) and Average Symmetric Surface Distance (ASD).
![]() |
![]() |
3.0.2 Quantitative Evaluation and Comparison.
We evaluate our method in two different settings with comparisons to several recent semi-supervised segmentation approaches, including DAP [20], ASDNet [11], TCSE [9] and UA-MT [18]. Table 1 presents a summary of the quantitative results, in which we first show the upper-bound performance achieved by a fully-supervised network, followed by two individual settings.
The first setting follows the work [18], which takes 20% of training data as labeled data (16 labeled), and the others as unlabeled data for semi-supervised training. We can see that this setting is relative easy as the model trained with 20% of data already achieves good performance (86.03% in Dice). Among the semi-supervised methods, the DAP performs worst, indicating the limitation of an atlas-based prior, while UA-MT achieves the top performance in the previous methods. Our method outperforms all the other semi-supervised networks in both Dice (89.54%) and Jaccard (81.24%), and achieves competitive results on other metrics. In particular, our SASSNet surpasses UA-MT in Dice without resorting to a complex multiple network architecture.
To validate the robustness of our method, we also consider a more challenging setting in which we only have 8 labeled images for training. The second half of Table 1 show the comparison results, where SASSNet outperforms UA-MT with a large margin (Dice: +2.56% without NMS and +3.07% with NMS). Without NMS, our SASSNet tends to generate more foreground regions, which leads to slightly worse performance on ASD and 95HD. However, it also produce better segmentation preserving the original object shape. By contrast, UA-MT often misses inner regions of target objects and generates irregular shapes. Figure 2 provides several qualitative results for visual comparison.
Method | # scans used | Metrics | Cost | ||||
Labeled | Unlabeled | Dice[%] | Jaccard[%] | ASD[voxel] | 95HD[voxel] | Params[M] | |
V-Net | 8 | 0 | 79.99 | 68.12 | 5.48 | 21.11 | 187.7 |
V-Net +SDM | 8 | 0 | 81.12 | 69.75 | 6.93 | 25.58 | 187.9 |
V-Net +SDM +GAN | 8 | 72 | 86.81 | 76.92 | 3.94 | 12.54 | 249.7 |
UA-MT [18] | 8 | 72 | 84.25 | 73.48 | 3.36 | 13.84 | 375.5 |
V-Net +SDM +MT | 8 | 72 | 84.97 | 74.14 | 6.12 | 22.20 | 375.8 |
3.0.3 Ablative Study.
We conduct several detailed experimental studies to examine the effectiveness of our proposed SDM head and the adversarial loss (GAN). Table 2 shows the quantitative results of different model settings. The first row is a V-Net trained with only the labeled data, which is our base model. We first add a SDM head, denoted as V-Net+SDM, and as shown in the second row, such joint learning improves segmentation results by 1.1% in Dice. We then add the unlabeled data and our adversarial loss, denoted as V-Net+SDM+GAN, which significantly improves the performance (5.7% in Dice).
We also compare our semi-supervised learning strategy with two methods in the mean-teacher (MT) framework (last two rows). One is the original UA-MT and the other is our segmentation network with the MT consistency loss. Our SASSNet outperforms both methods with higher Dice and Jaccard scores, which indicates the advantage of our representation and loss design. Moreover, our network has a much simpler architecture than those two networks.
4 Conclusion
In this paper, we proposed a shape-aware semi-supervised segmentation approach for 3D medical scans. In contrast to previous methods, our method exploits the regularity in geometric shapes of the target object class for effective segment prediction and network learning. We developed a multi-task segmentation network that jointly predicts semantic segmentation and SDM of object surfaces, and a semi-supervised learning loss enforcing consistency between the predicted SDMs of labeled and unlabeled data. We validated our approach on the Atrial Segmentation Challenge dataset, which demonstrates that our segmentation network outperforms the state-of-the-art methods and generates object segmentation with high-quality global shapes.
References
- [1] (2017) Semi-supervised learning for network-based cardiac mr image segmentation. In MICCAI, pp. 253–260. Cited by: §1.
- [2] (2017) Semi-supervised deep learning for fully convolutional networks. In MICCAI, pp. 311–319. Cited by: §1.
- [3] (2019) Semi-supervised medical image segmentation via learning consistency under transformations. In MICCAI, pp. 810–818. Cited by: §1.
- [4] (2019) A distance map regularized cnn for cardiac cine mr image segmentation. Medical physics 46 (12), pp. 5637–5651. Cited by: §1.
- [5] (2014) Generative adversarial nets. In NIPS, pp. 2672–2680. Cited by: §2.3.3.
- [6] (2019) DPA-densebiasnet: semi-supervised 3d fine renal artery segmentation with dense biased network and deep priori anatomy. In MICCAI, pp. 139–147. Cited by: §1.
- [7] (2018) Adversarial learning for semi-supervised semantic segmentation. In BMVC, Cited by: §1.
- [8] (2016) Temporal ensembling for semi-supervised learning. arXiv preprint arXiv:1610.02242. Cited by: §1.
- [9] (2018) Semi-supervised skin lesion segmentation via transformation consistent self-ensembling model. arXiv preprint arXiv:1808.03887. Cited by: §1, §3.0.2, Table 1.
-
[10]
(2016)
V-net: fully convolutional neural networks for volumetric medical image segmentation
. In 3DV, pp. 565–571. Cited by: §2.2. - [11] (2018) Asdnet: attention based semi-supervised deep networks for medical image segmentation. In MICCAI, pp. 370–378. Cited by: §1, §3.0.2, Table 1.
- [12] (2019) Deepsdf: learning continuous signed distance functions for shape representation. In CVPR, pp. 165–174. Cited by: §1.
- [13] (2015) Motion segmentation of truncated signed distance function based volumetric surfaces. In WACV, pp. 1046–1053. Cited by: §1.
-
[14]
(2015)
Unsupervised representation learning with deep convolutional generative adversarial networks
. arXiv preprint arXiv:1511.06434. Cited by: §2.3.2. - [15] (2019) Deep semantic segmentation of natural and medical images: a review. arXiv preprint arXiv:1910.07655. Cited by: §1.
- [16] (2017) Mean teachers are better role models: weight-averaged consistency targets improve semi-supervised deep learning results. In NIPS, pp. 1195–1204. Cited by: §1.
- [17] (2019) Shape-aware organ segmentation by predicting signed distance maps. arXiv preprint arXiv:1912.03849. Cited by: §1, §2.2.
- [18] (2019) Uncertainty-aware self-ensembling model for semi-supervised 3d left atrium segmentation. In MICCAI, pp. 605–613. Cited by: §1, §2.2, Figure 2, §3.0.2, §3.0.2, Table 1, Table 2, §3.
- [19] (2017) Deep adversarial networks for biomedical image segmentation utilizing unannotated images. In MICCAI, pp. 408–416. Cited by: §1.
- [20] (2019) Semi-supervised segmentation of liver using adversarial learning with deep atlas prior. In MICCAI, pp. 148–156. Cited by: §1, §1, §3.0.2, Table 1.