1 Introduction
Automated segmentation of left atrium (LA) in magnetic resonance (MR) images is of great importance in promoting the treatment of atrial fibrillation. With a large amount of labeled data, deep learning has greatly advanced the segmentation of LA
[15]. In the medical imaging domain, however, it is expensive and tedious to delineate reliable annotations from 3D medical images in a slicebyslice manner by experienced experts. Since unlabeled data is generally abundant, we focus on studying semisupervised approach on LA segmentation by leveraging both limited labeled data and abundant unlabeled data.Considerable effort has been devoted to utilizing unlabeled data to improve the segmentation performance in medical image community [1, 2, 3, 7, 19]. For example, Bai et al. [1] introduced a selftrainingbased method for cardiac MR image segmentation, where the network parameters and the segmentation for unlabeled data were alternatively updated. Besides, adversarial learning has been used in semisupervised learning [6, 12, 18]. Zhang et al. [18] designed a deep adversarial network to use the unannotated images by encouraging the segmentation of unannotated images to be similar to those of the annotated ones. Another approach [12] utilized an adversarial network to select the trustworthy regions of unlabeled data to train the segmentation network. With the promising results achieved by selfensembling methods [9, 14] on semisupervised natural image classification, Li et al. [10] extended the model [9] with transformation consistent for semisupervised skin lesion segmentation. Other approaches [5, 13] utilized the weightaveraged consistency targets for semisupervised MR segmentation. Although promising progress has been achieved, these methods do not consider the reliability of the targets, which may lead to meaningless guidance.
In this paper, we present a novel uncertaintyaware semisupervised learning framework for left atrium segmentation from 3D MR images by additionally leveraging the unlabeled data. Our method encourages the segmentation predictions to be consistent under different perturbations for the same input, following the same spirit of mean teacher [14]
. Specifically, we build a teacher model and a student model, where the student model learns from the teacher model by minimizing the segmentation loss on the labeled data and the consistency loss with respect to the targets from the teacher model on all input data. Without ground truth provided in the unlabeled input, the predicted target from the teacher model may be unreliable and noisy. In this regard, we design the uncertaintyaware mean teacher (UAMT) framework, where the student model gradually learns from the meaningful and reliable targets by exploiting the uncertainty information of the teacher model. Concretely, besides generating the target outputs, the teacher model also estimates the
uncertainty of each target prediction with Monte Carlo sampling. With the guidance of the estimated uncertainty, we filter out the unreliable predictions and preserve only the reliable ones (low uncertainty) when calculating the consistency loss. Hence, the student model is optimized with more reliable supervision and in return, encourages the teacher model to generate higherquality targets. Our method was extensively evaluated on the dataset of MICCAI 2018 Atrial Segmentation Challenge. The results demonstrate that our semisupervised method achieves large improvements for the LA segmentation by utilizing the unlabeled data, and also outperforms other stateoftheart semisupervised segmentation methods.2 Method
Fig. 1 illustrates our uncertaintyaware selfensembling mean teacher framework (UAMT) for semisupervised LA segmentation. The teacher model generates targets for the student model to learn from and also estimates the uncertainty of the target. The uncertaintyguided consistency loss improves the student model and the robustness of the framework.
2.1 Semisupervised Segmentation
We study the task of semisupervised segmentation for 3D data, where the training set consists of labeled data and unlabeled data. We denote the labeled set as and the unlabeled set as , where is the input volume and is the groundtruth annotations. The goal of our semisupervised segmentation framework is to minimize the following combined objective function:
(1) 
where denotes the supervised loss (e.g., crossentropy loss) to evaluate the quality of the network output on labeled inputs, and represents the unsupervised consistency loss for measuring the consistency between the prediction of the teacher model and the student model for the same input under different perturbations. Here, denotes the segmentation neural network; and represents the weights and different perturbation operations (e.g., adding noise to input and network dropout) of the teacher and student models, respectively. is an rampup weighting coefficient that controls the tradeoff between the supervised and unsupervised loss.
Recent study [9, 14] show that ensembling predictions of the network at different training process can improve the quality of the predictions, and using them as the teacher predictions can improve the results. Therefore, we update the teacher’s weights as an exponential moving average (EMA) of the student’s weights to ensemble the information in different training step [14]; see Fig. 1. Specifically, we update the teacher’s weights at training step as: where is the EMA decay that controls the updating rate.
2.2 UncertaintyAware Mean Teacher Framework
Without the annotations in the unlabeled inputs, the predicted targets from the teacher model may be unreliable and noisy. Therefore, we design an uncertaintyaware scheme to enable the student model to gradually learn from the more reliable targets. Given a batch of training images, the teacher model not only generates the target predictions but also estimates the uncertainty for each target. Then the student model is optimized by the consistency loss, which focuses on only the confident targets under the guidance of the estimated uncertainty.
2.2.1 Uncertainty Estimation.
Motivated by the uncertainty estimation in Bayesian networks, we estimate the uncertainty with the Monte Carlo Dropout
[8]. In detail, we performstochastic forward passes on the teacher model under random dropout and input Gaussian noise for each input volume. Therefore, for each voxel in the input, we obtain a set of softmax probability vector:
. We choose the predictive entropy as the metric to approximate the uncertainty, since it has a fixed range [8]. Formally, the predictive entropy can be summarized as:(2) 
where is the probability of the th class in the th time prediction. Note that the uncertainty is estimated in voxel level and the uncertainty of the whole volume is .
2.2.2 UncertaintyAware Consistency Loss.
With the guidance of the estimated uncertainty , we filter out the relatively unreliable (high uncertainty) predictions and select only the certain predictions as targets for the student model to learn from. In particular, for our semisupervised segmentation task, we design the uncertaintyaware consistency loss as the voxellevel mean squared error (MSE) loss of the teacher and student models only for the most certainty predictions:
(3) 
where is the indicator function; and are the predictions of teacher model and student model at the th voxel, respectively; is the estimated uncertainty at the th voxel; and is a threshold to select the most certain targets. With our uncertaintyaware consistency loss in the training procedure, both the student and teacher can learn more reliable knowledge, which can then reduce the overall uncertainty of the model.
2.3 Technique Details
We employ VNet [11]
as our network backbone. We remove the short residual connection in each convolution block, and use a joint crossentropy loss and dice loss
[16]. To adapt the VNet as a Bayesian network to estimate the uncertainty, two dropout layers with dropout rate 0.5 are added after the LStage 5 layer and RStage 1 layer of the VNet. We turn on the dropout in the network training and uncertainty estimation, while we turn off the dropout in the testing phase, as we do not need to estimate uncertainty. We empirically set the EMA decay as referring to the previous work [14]. Following [9, 14], we use a timedependent Gaussian warming up function to control the balance between the supervised loss and unsupervised consistency loss, where denotes the current training step and is the maximum training step. Such design can ensure that at the beginning, the objective loss is dominated by the supervised loss term and avoid the network get stuck in a degenerate solution where no meaningful target prediction of unlabeled data is obtained [9]. For the uncertainty estimation, we set to balance the uncertainty estimation quality and training efficiency. We also use the same Gaussian rampup paradigm to ramp up the uncertainty threshold from to in Eq. (3), where is the maximum uncertainty value (i.e., in our experiments). As the training continues, our method would filter out less and less data and enable the student to gradually learn from the relatively certain to uncertain cases.3 Experiments and Results
3.0.1 Dataset and Preprocessing.
We evaluated our method on the Atrial Segmentation Challenge dataset^{2}^{2}2http://atriaseg2018.cardiacatlas.org/. It provides 100 3D gadoliniumenhanced MR imaging scans (GEMRIs) and LA segmentation mask for training and validation. These scans have an isotropic resolution of
. We split the 100 scans into 80 scans for training and 20 scans for evaluation. All the scans were cropped centering at the heart region for better comparison of the segmentation performance of different methods, and normalized as zero mean and unit variance.
3.0.2 Implementation.
The framework was implemented in PyTorch, using a TITAN Xp GPU. We used the SGD optimizer to update the network parameters (weight decay=
, momentum=0.9). The initial learning rate was set as 0.01 and divided by 10 every 2500 iterations. We totally trained 6000 iterations as the network has converged. The batch size was 4, consisting of 2 annotated images and 2 unannotated images. We randomly cropped subvolumes as the network input and the final segmentation results were obtained using a sliding window strategy. We used the standard data augmentation techniques onthefly to avoid overfitting following [17], including randomly flipping, and rotating with 90, 180 and 270 degrees along the axial plane.Method  # scans used  Metrics  
Labeled  Unlabeled  Dice[%]  Jaccard[%]  ASD[voxel]  95HD[voxel]  
Vanilla VNet  16  0  84.13  73.26  4.75  17.93 
Bayesian VNet  16  0  86.03  76.06  3.51  14.26 
Vanilla VNet  80  0  90.25  82.40  1.91  8.29 
Bayesian VNet  80  0  91.14  83.82  1.52  5.75 
Selftraining [1]  16  64  86.92  77.28  2.21  9.19 
DAN [18]  16  64  87.52  78.29  2.42  9.01 
ASDNet [12]  16  64  87.90  78.85  2.08  9.24 
TCSE [10]  16  64  88.15  79.20  2.44  9.57 
UAMTUN (ours)  16  64  88.83  80.13  3.12  10.04 
UAMT (ours)  16  64  88.88  80.21  2.26  7.32 
3.0.3 Evaluation of Our Semisupervised Segmentation.
We use four metrics to quantitatively evaluate our method, including Dice, Jaccard, the average surface distance (ASD), and the 95% Hausdorff Distance (95HD). Out of the 80 training scans, we use 20% (i.e., 16) scans as labeled data and the remaining 64 scans as unlabeled data. Table 1 presents the segmentation performance of VNet trained with only the labeled data (the first two rows) and our semisupervised method (UAMT) on the testing dataset. Compared with the Vanilla VNet, adding dropout (Bayesian VNet) improves the segmentation performance, and achieves an average Dice of 86.03% and Jaccard of 76.06% with only the labeled training data. By utilizing the unlabeled data, our semisupervised framework further improves the segmentation by 4.15% Jaccrad and 2.85% Dice.
To analyze the importance of consistency loss for labeled data and unlabeled data, we conducted another experiment (UAMTUN) with the consistency loss only on the unlabeled data. The performance of this method is very close to UAMT, validating that the performance of our method improves mainly due to the unlabeled data. We trained the fully supervised VNet with all 80 labeled scans, which can be regarded as the upperline performance. As we can see, our semisupervised method is approaching the fully supervised ones. To validate our network backbone design, we reference the stateoftheart challenging method [4], which used multitask UNet for LA segmentation. They reported a 90.10% Dice on 20 testing scans with 80 training scans. Compared with this method, we can regard our VNet as a standard baseline model.
3.0.4 Comparison with Other Semisupervised Methods.
We implemented several stateoftheart semisupervised segmentation methods for comparison, including selftraining based method [1], deep adversarial network (DAN) [18], adversarial learning based semisupervised method (ASDNet) [12], and Model based method (TCSE) [10]. Note that we used the same network backbone (Bayesian VNet) in these methods for fair comparison. As shown in Table 1, compared with the selftraining method, the DAN and ASDNet improve by 0.60% and 0.98% Dice, respectively, showing the effect of adversarial learning in semisupervised learning. The ASDNet is better than DAN, since it selects the trustworthy region of unlabeled data for training the segmentation network. The selfensemblingbased methods TCSE achieve slightly better performance than ASDNet, demonstrating that perturbationbased consistency loss is helpful for the semisupervised segmentation problem. Notably, our method (UAMT) achieves the best performance over the stateoftheart semisupervised methods, except that the ASD performance is comparable with ASDNet, corroborating that our uncertaintyaware mean teacher framework has the full capability to draw out the rich information from the unlabeled data.
Method  # scans used  Metrics  
Labeled  Unlabeled  Dice[%]  Jaccard[%]  ASD[voxel]  95HD[voxel]  
MT  16  64  88.23  79.29  2.73  10.64 
MTDice [5]  16  64  88.32  79.37  2.76  10.50 
Our UAMT  16  64  88.88  80.21  2.26  7.32 
Bayesian VNet  8  0  79.99  68.12  5.48  21.11 
Our UAMT  8  72  84.25  73.48  3.36  13.84 
Bayesian VNet  24  0  88.52  79.70  2.60  10.45 
Our UAMT  24  56  90.16  82.18  2.73  8.90 
3.0.5 Analysis of Our Method.
To validate the effectiveness of our uncertaintyaware scheme, we evaluate the performance of the original mean teacher method (MT) and an adapted mean teacher method (MTDice) with dicelosslike consistency loss [5]. As shown in Table 2, our uncertaintyaware method outperforms both the MT model and MTDice model. We also investigate the impact of using different numbers of labeled scans in our semisupervised method. As shown in Table 2, our semisupervised method consistently improves the supervisedonly VNet (Bayesian VNet) by utilizing the unlabeled data on both 10% (i.e., 8) and 30% (i.e., 24) labeled scans, demonstrating our method effectively utilizes the unlabeled data for the performance gains. In Fig. 2, we show some segmentation examples of supervised method and our semisupervised method, and the estimated uncertainty. Compared with the supervised method, our results have higher overlap ratio with the ground truth (the second row) and produce less false positives (the first row). As shown in Fig. 2(d), the network estimates high uncertainty near the boundary and ambiguous regions of great vessels.
4 Conclusion
We present a novel uncertaintyaware semisupervised learning method for left atrium segmentation from 3D MR images.
Our method encourages the segmentation to be consistent for the same input under different perturbations to use the unlabeled data.
More importantly, we explore the model uncertainty to improve the quality of the target.
The comparison with other semisupervised methods confirm the effectiveness of our method.
The future works include investigating the effect of different uncertainty estimation manners and applying our framework to other semisupervised medical image segmentation problems.
Acknowledgments.
The work was partially supported by HK RGC TRS project T42409/18R and
in part by the CUHK T Stone Robotics Institute, The Chinese University of Hong Kong.
References
 [1] Bai, W., Oktay, O., Sinclair, M.e.a.: Semisupervised learning for networkbased cardiac mr image segmentation. In: MICCAI. pp. 253–260 (2017)
 [2] Baur, C., Albarqouni, S., Navab, N.: Semisupervised deep learning for fully convolutional networks. In: MICCAI. pp. 311–319 (2017)
 [3] Chartsias, A., Joyce, T., Papanastasiou, G., Semple, S., Williams, M., Newby, D., Dharmakumar, R., Tsaftaris, S.A.: Factorised spatial representation learning: application in semisupervised myocardial segmentation. MICCAI pp. 490–498 (2018)
 [4] Chen, C., Bai, W., Rueckert, D.: Multitask learning for left atrial segmentation on gemri. arXiv preprint arXiv:1810.13205 (2018)
 [5] Cui, W., Liu, Y., Li, Y., Guo, M., Li, Y., Li, X., Wang, T., Zeng, X., Ye, C.: Semisupervised brain lesion segmentation with an adapted mean teacher model. In: IPMI. pp. 554–565 (2019)
 [6] Dong, N., Kampffmeyer, M., Liang, X., Wang, Z., Dai, W., Xing, E.: Unsupervised domain adaptation for automatic estimation of cardiothoracic ratio. In: MICCAI. pp. 544–552 (2018)
 [7] Ganaye, P.A., Sdika, M., BenoitCattin, H.: Semisupervised learning for segmentation under semantic constraint. In: MICCAI. pp. 595–602 (2018)

[8]
Kendall, A., Gal, Y.: What uncertainties do we need in bayesian deep learning for computer vision? In: NIPS. pp. 5574–5584 (2017)
 [9] Laine, S., Aila, T.: Temporal ensembling for semisupervised learning. arXiv preprint (2016)
 [10] Li, X., Yu, L., Chen, H., Fu, C.W., Heng, P.A.: Semisupervised skin lesion segmentation via transformation consistent selfensembling model. BMVC (2018)
 [11] Milletari, F., Navab, N., Ahmadi, S.A.: Vnet: Fully convolutional neural networks for volumetric medical image segmentation. In: 3DV. pp. 565–571 (2016)
 [12] Nie, D., Gao, Y., Wang, L., Shen, D.: Asdnet: Attention based semisupervised deep networks for medical image segmentation. In: MICCAI. pp. 370–378 (2018)
 [13] Perone, C.S., CohenAdad, J.: Deep semisupervised segmentation with weightaveraged consistency targets. In: DLMIA workshop (2018)
 [14] Tarvainen, A., Valpola, H.: Mean teachers are better role models: Weightaveraged consistency targets improve semisupervised deep learning results. In: NIPS (2017)
 [15] Xiong, Z., Fedorov, V.V., Fu, X., Cheng, E., Macleod, R., Zhao, J.: Fully automatic left atrium segmentation from late gadolinium enhanced magnetic resonance imaging using a dual fully convolutional neural network. TMI 38(2), 515–524 (2019)
 [16] Yang, X., Bian, C., Yu, L., Ni, D., Heng, P.A.: Hybrid loss guided convolutional networks for whole heart parsing. In: International Workshop on STACOM (2017)
 [17] Yu, L., Cheng, J.Z., Dou, Q., Yang, X., Chen, H., Qin, J., Heng, P.A.: Automatic 3d cardiovascular mr segmentation with denselyconnected volumetric convnets. In: MICCAI. pp. 287–295. Springer (2017)
 [18] Zhang, Y., Yang, L., Chen, J., Fredericksen, M., Hughes, D.P., Chen, D.Z.: Deep adversarial networks for biomedical image segmentation utilizing unannotated images. In: MICCAI. pp. 408–416 (2017)
 [19] Zhou, Y., Wang, Y., Tang, P., Bai, S., Shen, W., Fishman, E.K., Yuille, A.L.: Semisupervised multiorgan segmentation via multiplanar cotraining. arXiv preprint arXiv:1804.02586 (2018)
Comments
There are no comments yet.