DeepAI
Log In Sign Up

Voxel-wise Adversarial Semi-supervised Learning for Medical Image Segmentation

Semi-supervised learning for medical image segmentation is an important area of research for alleviating the huge cost associated with the construction of reliable large-scale annotations in the medical domain. Recent semi-supervised approaches have demonstrated promising results by employing consistency regularization, pseudo-labeling techniques, and adversarial learning. These methods primarily attempt to learn the distribution of labeled and unlabeled data by enforcing consistency in the predictions or embedding context. However, previous approaches have focused only on local discrepancy minimization or context relations across single classes. In this paper, we introduce a novel adversarial learning-based semi-supervised segmentation method that effectively embeds both local and global features from multiple hidden layers and learns context relations between multiple classes. Our voxel-wise adversarial learning method utilizes a voxel-wise feature discriminator, which considers multilayer voxel-wise features (involving both local and global features) as an input by embedding class-specific voxel-wise feature distribution. Furthermore, we improve our previous representation learning method by overcoming information loss and learning stability problems, which enables rich representations of labeled data. Our method outperforms current best-performing state-of-the-art semi-supervised learning approaches on the image segmentation of the left atrium (single class) and multiorgan datasets (multiclass). Moreover, our visual interpretation of the feature space demonstrates that our proposed method enables a well-distributed and separated feature space from both labeled and unlabeled data, which improves the overall prediction results.

READ FULL TEXT VIEW PDF

page 1

page 4

page 5

page 7

page 8

06/03/2022

Mutual- and Self- Prototype Alignment for Semi-supervised Medical Image Segmentation

Semi-supervised learning methods have been explored in medical image seg...
05/14/2021

Momentum Contrastive Voxel-wise Representation Learning for Semi-supervised Volumetric Medical Image Segmentation

Automated segmentation in medical image analysis is a challenging task t...
05/17/2021

Voxel-level Siamese Representation Learning for Abdominal Multi-Organ Segmentation

Recent works in medical image segmentation have actively explored variou...
08/31/2022

Addressing Class Imbalance in Semi-supervised Image Segmentation: A Study on Cardiac MRI

Due to the imbalanced and limited data, semi-supervised medical image se...
03/08/2021

Boosting Semi-supervised Image Segmentation with Global and Local Mutual Information Regularization

The scarcity of labeled data often impedes the application of deep learn...
05/31/2021

Learning Inductive Attention Guidance for Partially Supervised Pancreatic Ductal Adenocarcinoma Prediction

Pancreatic ductal adenocarcinoma (PDAC) is the third most common cause o...
03/28/2022

Translation Consistent Semi-supervised Segmentation for 3D Medical Images

3D medical image segmentation methods have been successful, but their de...

I Introduction

Medical image segmentation is an essential task in several clinical approaches, such as computer-aided diagnosis, radiation therapy, and virtual surgeries [788580, robotics_surgery, cad]

. Automated segmentation of organs (e.g., left atrium (LA), heart, or liver) is of significant importance in optimizing clinical workflow, such as the planning of surgeries and treatments. Convolutional neural networks (CNNs), which have demonstrated significant abilities in learning visual features in computer vision tasks, have been successfully adapted to medical segmentation problems by leveraging a large amount of annotated medical data (i.e., computed tomography (CT) scans

[kudo2008diagnostic]). However, the generation of reliable large-scale annotations of three-dimensional (3D) medical images requires domain-specific expertise, which is expensive and time-consuming.

Fig. 1: Our proposed method Existing semi-supervised segmentation models learn to map voxels from the data space to the feature space, ignoring global features or class-wise voxel relations. We enforced models to directly learn features representations of labeled and unlabeled data using our proposed method; (a) by defining voxel-wise feature relations of labeled data in the feature space (i.e., voxel-wise representation learning) and (b) by discriminating between the voxel-level features from labeled and unlabeled data (i.e., voxel-wise adversarial learning)

Significant efforts, such as pretraining, self-supervised learning, and active learning, have been dedicated towards learning from a large number of unlabeled datasets. Semi-supervised learning is one of the approaches used to reduce the annotation cost, where the method simultaneously utilizes a large number of unlableled datasets with a limited number of labeled datasets. The semi-supervised approach assumes that labeled and unlabeled data from the same label share the same or similar underlying distribution (i.e., manifold assumption)

[4787647, vanEngelen2019ASO]. We can infer that labeled and unlabeled data usually share similar distributions (e.g., intensities or structures) in medical imaging; consequently, rich semantic information can be embedded using unlabeled data via semi-supervised learning. In practice, several studies on semi-supervised medical image segmentation has proposed effective methods for leveraging unlabeled data. Consistency regularization [tarvainen2018mean], pseudo-labeling [Leepseudo] and adversarial learning [li2019semi] methods are some of the most commonly used learning methods in semi-supervised learning. The teacher-student model architecture [tarvainen2018mean] has been broadly applied, and it was demonstrated to be effective for consistency regularization- and pseudo-labeling-based methods. Furthermore, improved model performance can be expected through the synergy of representation learning methods from self-supervised and supervised learning by encoding representations from labeled data.

Consistency regularization-based methods [tarvainen2018mean] learn network outputs that are invariant to perturbations or augmentations by adding noise to the unlabeled samples. Different types of methods have been presented to enforce consistency between outputs from different passes, such as uncertainty-aware schemes for data-level consistency [ua_mt] or task-level consistency using a task-transform layer [dtc]. Pseudo-labeling-based methods [MC-Net] generate high-confidence training targets as pseudo-labels for training unlabeled samples. Similar to consistency-based methods, the generated pseudo-labels are used to encourage mutual consistency [MC-Net]

to enhance the generalized feature performance. However, these methods learn features by minimizing the loss function in the last layer (i.e., decision space), which can be limited to the local region so that the model learns only the local features of data. On the other hand, adversarial learning-based methods

[li2019semi, sassnet] model data distribution of unlabeled samples in an unsupervised setting by utilizing a discriminator. To capture the global shape constraint, a shape-aware adversarial learning method [sassnet] has been proposed for unlabeled data. Although this method is effective for learning shape-aware global features, reproducing features through a separate network is ineffective for learning. Furthermore, both consistency- and adversarial learning- based methods only consider single-class cases and can be limited when they are extended to a multiclass dataset.

Our goal was to improve the representation power for medical image segmentation tasks by leveraging a large amount of unlabeled data. Specifically, we intended to present an effective method that could successfully learn both local and global features from labeled and unlabeled datasets. However, there were several limitations associated with increasing representation power in previous studies. First, these studies focused only on local discrepancy minimization. Most consistency-based methods [ua_mt, MC-Net] calculate output discrepancy in the last layer such that only local features are embedded throughout the training scheme. However, both local and global features should be considered to obtain a better representation space. Second, feature relations across different classes of organs are ignored. Previous studies have only discussed the effectiveness of their methods for single-class data by embedding voxel-to-voxel local relations without distinguishing among different classes. The feature relation between different classes is also important for multiclass data.

In this paper, we propose a novel adversarial learning-based method to incorporate unlabeled data to improve the network performance. We propose a context-aware semi-supervised segmentation method for efficiently learning the distributions of labeled and unlabeled datasets. To resolve the aforementioned problems of recent studies, we considered voxel-wise features from multiple hidden layers, which include both the local and global information of the data, as an input to our voxel-wise feature discriminator to embed distributions of unlabeled datasets. As illustrated in Fig. 1b, the job of this discriminator is to determine if a voxel-wise feature belongs to labeled data or unlabeled data (real for labeled data and fake for unlabeled data). This voxel-wise feature discriminator assumes the form of a multitask discriminator that can learn distributions from different classes simultaneously, thereby allowing us to embed class-specific context-aware features in the embedding space. Furthermore, we propose an improved voxel-wise representation learning method (Fig. 1a) for labeled data. To effectively embed unlabeled data, we are required to implement well-distributed features from labeled data prior to adversarial learning. In our previous study [lee2022voxel], we presented an explicit representation learning method for a supervised segmentation task by defining voxel-level feature relations. We adjusted this previous method to embed feature representations from labeled data without information loss using a multiresolution context resizing technique. Moreover, we used the Bootstrap Your Own Latent (BYOL) approach [byol], instead of SimSiam [chen2020exploring], for learning stability.

To summarize, our contributions are as follows:

  • We propose a voxel-wise adversarial learning method that learns both the local and global contexts of labeled and unlabeled data (avoiding the local discrepancy problem) by considering voxel-wise features as an input. Furthermore, our voxel-wise feature discriminator embeds feature relations across different classes by learning a class-specific voxel-wise feature distribution.

  • We improve the previous voxel-wise representation learning method by overcoming information loss and learning stability problems. This enables our adversarial learning method to effectively learn well-distributed voxel-wise feature representations.

  • Our method achieves superior results on the Atrial Segmentation Challenge dataset and abdominal multiorgan (MO) dataset when compared with existing state-of-the-art semi-supervised segmentation methods (i.e., consistency regularization, pseudo-labeling and adversarial learning based methods).

Ii Related Work

Ii-a Semi-Supervised Medical Image Segmentation

For semi-supervised medical image segmentation, traditional methods, such as prior- [you2011segmentation] and clustering-based models [portela2014semi]

, use hand-crafted features to enhance model performance. With the advanced ability of CNNs, deep learning-based approaches have been widely used for medical image segmentation. Recently, semi-supervised methods based on consistency regularization

[ua_mt, dtc], pseudo labeling[MC-Net], and adversarial learning-based[dan, sassnet, hung2018adversarial] have proven the effectiveness of incorporating a large amount of unlabeled data for medical image segmentation task.

Consistency Regularization. Consistency regularization is based on the assumption that the segmentation prediction of a network is consistent under realistic perturbations. This motivation was first proposed in [bachman2014learning] and further studied in [laine2017temporal, tarvainen2018mean]. The -Model[laine2017temporal] encourages consistent training under different augmentation and dropout conditions. Owing to the noisy training target problem, temporal ensembling [laine2017temporal] adopts the exponential moving average (EMA) of previous evaluations to obtain an ensemble prediction. As a more time-effective method, the teacher-student model [tarvainen2018mean] introduces a pair of networks(i.e., teacher and student networks) and enforces consistency in their predictions. Time efficiency and accuracy can be achieved by averaging model weights, instead of label predictions.

In medical research, the uncertainty-aware mean teacher (UA-MT) model, proposed in [ua_mt], utilizes an uncertainty-aware teacher-student framework for LA segmentation. The base model framework was extended from the teacher-student architecture [tarvainen2018mean], and uncertainty map guidance was adopted to filter out unreliable predictions. More recently, a dual-task consistency (DTC) model [dtc] simultaneously used a pixel-wise segmentation map and level set representation as dual tasks. By utilizing the level set representation, the network could learn the geometric prior. However, the aforementioned methods tend to consider only the local context from the last layer, which can limit the representation of rich global contextual features in the embedding space.

Pseudo-labeling. The concept of pseudo-labeling was proposed in [Leepseudo], and its variants have presented significant results in semi-supervised learning. For instance, NoisyStudent [noisy] employed a pair of networks, one acting as a teacher and the other as a student. They first trained the teacher network and inferred pseudo-labels for unlabeled images using the teacher network. A larger student network model was then trained using a combination of labeled and pseudo-labeled data, and this process was iterated by converting the student to the teacher. Moreover, a mutual consistency network (MC-Net) [MC-Net] proposed a cycled pseudo-label scheme that used one encoder and two marginally different decoders to utilize unlabeled data. Our method also adopts pseudo-labeling based on teacher-student architecture to infer voxel-wise features from unlabeled data in a simple yet powerful manner.

Adversarial Learning.

Inspired by the concept of generative adversarial networks (GANs)

[goodfellow2014generative], several methods that use adversarial learning to exploit unlabeled data have attracted attention in semi-supervised medical image segmentation. For instance, [2018GANlesion, souly2017semi] used GANs to expand the training set to increase data diversity and avoid overfitting. Another key idea of using GANs in semi-supervised learning is to force the statistical prior-shape distribution and prediction distribution to be close so that they can effectively learn the distribution on the entire dataset (both labeled and unlabeled). A shape-aware semi-supervised segmentation network (SASSNET)[sassnet] employs GANs to learn the distribution of both labeled and unlabeled data. This method utilizes the signed distance map (SDM) of images as an input to the discriminator, which plays a vital role in embedding the geometric context of unlabeled data. Although this method [sassnet] considers global features employing SDM and a discriminator, context relations between different classes cannot be considered.

Ii-B Representation Learning

Self-supervised learning methods [tian2020contrastive, he2020momentum, chen2020simple] based on contrastive loss have proven to be effective in representation learning. In contrastive learning, positive (similar) pairs are pulled close together, whereas negative (dissimilar) pairs are pushed away. Because more negative samples can prevent collapse [tian2020contrastive], several approaches, such as large batch sizes [chen2020simple] or memory banks [he2020momentum], have been proposed. Meanwhile, non-contrastive based approaches [byol, chen2020exploring] have shown effective results that avoid collapsing without using negative samples. The BYOL[byol] method is based on teacher-student model, and one branch of momentum encoder enables the network to learn representations without negative samples. Similarly, SimSiam [chen2020exploring] uses a Siamese network and stop-gradient operation, instead of momentum encoder, to prevent collapsing.

These non-contrastive based approaches can be employed in supervised learning to learn rich representations [lee2022voxel]. Inspired by SimSiam [chen2020exploring], our previous study [lee2022voxel] presented an effective representation learning method for medical segmentation task by defining voxel-level relations in the embedding space. In this study, we improved our previous method by solving the information loss and learning instability problems of Siamese networks.

Fig. 2: Overview of the proposed architecture Two backbone networks (i.e., VNet [VNet]), i.e., teacher and student networks, take computed tomography scans as an input. The teacher network is learned passively via exponential mean average (EMA). The features () from multiple hidden layers of the student network pass through each section of our proposed network (i.e., voxel-wise feature layer and voxel-wise feature discriminator)so that feature representations from labeled and unlabeled data can be learned. The features () of the teacher network are used for optimizing the student and our proposed network. The student network is trained using four loss functions (, , , and

). The gradients are not backpropagated through the dashed lines.

Iii Proposed Method

We aim to learn feature representation (i.e., local and global features) from both the labeled and unlabeled datasets. To achieve this, we propose a context-aware semi-supervised segmentation method that can be incorporated into a segmentation network (i.e., VNet [VNet]). The overall architecture of semi-supervised segmentation is illustrated in Fig. 2. There exists a backbone network (i.e., VNet [VNet]) that takes labeled and unlabeled data (i.e., CT scans) as the inputs. We assume a set of training sets containing labeled data and unlabeled data, where . We denote the labeled set as and unlabeled set as , where represents the 3D volume, and denotes the ground-truth label. The proposed architecture for semi-supervised learning consists of two parts: voxel-wise representation learning (the blue box in Fig. 2) and voxel-wise adversarial learning (the red box in Fig. 2). Features from the hidden layers of the backbone network pass through each part to learn feature representations from and . The voxel-wise adversarial learning method takes voxel-wise features from and , after which it learns class-specific data distributions. The voxel-wise representation learning method uses voxel-wise features from and improves current embeddings by defining feature relations from the same class. In Section III-A and III-B, we describe the details of these methods. In Section III-C, we explain the overall training process of our proposed method.

Iii-a Voxel-wise Adversarial Learning

To leverage a large amount of unlabeled data, the network must be able to learn feature representations using only CT images. Previous consistency-based methods [ua_mt, dtc] have applied a consistency loss function and trained the network for consistent prediction with perturbed or transformed outputs. The consistency loss is computed between and for labeled and unlabeled data. However, this loss is computed in the last layer (i.e., decision space), which embeds only the local features of data. Moreover, it penalizes voxel-wise consistency ignoring class-specific information. It is also problem in [sassnet] that embedded shape-aware global features are only limited to a single class.

To resolve this problem, we propose a novel voxel-wise feature discriminator for embedding class-specific features of both labeled and unlabeled data. As presented in Fig. (a)a, our voxel-wise feature discriminator takes a set of multiresolution features, , as an input, where denotes an encoder of the backbone, and denotes features from the hidden layer. These features from multiple hidden layers pass through the convolution layer to adjust the channel size, and each feature is upsampled to the same spatial size. Such features from multiple hidden layers are fused into one by adding an operation and a convolution layer. Thereafter, voxel-level features (

-d vector) from this fused feature,

, pass through a voxel-level feature discriminator, which consists of two multilayer perceptron networks (MLPs) and prediction layer (i.e., linear branch). The number of prediction layers corresponds to the number of class (in case of LA dataset, there exist two classes; foreground and background). The voxel-level features from different classes pass through different prediction layers. To specify the class of each voxel-level feature, we use ground-truth label

for labeled data and pseudo-labels for unlabeled data, which can be computed using the following equation:

(1)

where t represents the threshold parameter, which lies in the range of .

(a) Voxel-wise Feature Discriminator
(b) Voxel-wise Feature Layer
Fig. 3:

Details of the proposed architecture. (a) Multiresolution features are fused and the sampled voxel-wise features pass through multilayer perceptron networks (MLPs). The voxel-level features from different classes pass through different prediction layers. This enables the model to learn class-specific voxel-wise distribution of unlabeled data. (b) Multiresolution features pass through multiresolution feature resizing and class-specific feature selection stages. Based on a previous study

[byol], we can learn voxel-wise feature relation in the representation space.

This different prediction branches enable multiple simultaneous adversarial classification tasks. We define features from labeled data as real and those from unlabeled data as fake so that the encoder of the segmentation network (generator) can generate voxel-level features of unlabeled data with a distribution similar to that of voxel-level features of labeled data. This forces the distributions of class-specific voxel-level features from both labeled and unlabeled features to be close. In this manner, the segmentation network can learn class-specific context-aware features more effectively. The encoder can embed both local and global features using a multiresolution context-fusion technique. In representing the voxel-wise feature discriminator, we can define our proposed adversarial loss function as follows:

(2)

Iii-B Voxel-wise Representation Learning

In Section III-A, we propose a new voxel-wise feature discriminator for learning the feature representations of unlabeled data via learning based on the feature distribution of labeled data. In this setting, the most important task is the modeling of the distribution of features from labeled data beforehand. Accurate modeling of the labeled data distribution is essential for effective adversarial learning. In other words, the model is unlikely to learn effectively from adversarial learning if the distribution of labeled data is incorrect. In contrast, the model is likely to learn effectively if distribution is recovered from labeled data. Thus, our model can learn rich feature representations from both labeled and unlabeled data.

In our previous work [lee2022voxel], we proposed a voxel-level Siamese representation learning method for medical image segmentation tasks. By defining voxel-wise feature relations in the representation space, the model learned feature representations that were effective in the segmentation task. We used the stop-gradient technique and Siamese network from SimSiam [chen2020exploring] to learn voxel-wise feature relations. We also proposed multiresolution feature aggregation method for embedding both local and global features. However, our previous study had two limitations: (1) learning stability and (2) information loss.

In this study, we propose an improved voxel-wise representation learning method for embedding features from labeled data. Inspired by previous studies [byol, lillicrap2015continuous], we used the learning technique from BYOL [byol], instead of SimSiam [chen2020exploring], for the first problem(i.e., learning stability). Using EMA from BYOL enabled the model to produce a more stable prediction target [lillicrap2015continuous] than the stop-gradient technique from SimSiam [chen2020exploring]. As presented in Fig. (b)b, there are teacher and student models; however, the teacher model uses the slow moving average of the student parameter, instead of learning for its own parameter (i.e., EMA). We update the weights of the teacher as , where represents the decay parameter, and indicates the weights of the student. Furthermore, for the second problem (i.e., information loss), we propose multiresolution context resizing method. The information loss occurs during the downsampling of mask data to match the class location for each voxel-wise feature. Thus, instead of downsampling the mask data, we upsampled the multiresolution features from the encoder, . Figure (b)b illustrates the upsampling and convolution stage that can reduce information loss.

As explained in Section III-A, our voxel-wise feature layer (Fig. 2 and Fig. (b)b) uses multiresolution features from the encoder of the backbone as an input. These features pass through the upsampling and convolution stages, and voxel-wise features, , are selected for each class; here, refers to the voxel-wise feature from class (class-specific feature selection). These sampled voxel-wise features pass through the projection and prediction layers. The projection layer from the teacher network outputs , and the projection and prediction layers from the student network output , where denotes the prediction layer. Based on a previous research [byol], we used the mean square error between normalized and as the feature loss function. The feature loss function for updating the student network can be defined as follows:

(3)

where refers to l-normalization (i.e., ).

Iii-C Training Details

Our backbone network is based on VNet [VNet]. We first demonstrate a basic VNet [VNet] segmentation training scheme for a labeled dataset. Two VNets [VNet] are displayed in Fig. 2: the teacher and student networks. These two networks take the 3D volume, , as an input, and they output prediction masks, and respectively. Based on [ua_mt, sassnet], we used the dice loss [dc] to maximize the overlap between the ground truth and prediction to train the student network. We used the labeled dataset (i.e., ) to compute the dice loss, which can be defined as

(4)

For updating the teacher network, we used the EMA [lillicrap2015continuous] technique.

Following [tarvainen2018mean], we also added a consistency loss between the softmax predictions of the teacher and student networks for semi-supervised learning. The consistency loss between the outputs of the teacher and student networks can be summarized as follows:

(5)

where represents the VNet architecture [VNet]. We can stabilize the label prediction by using the teacher-student framework and penalize the predictions that are inconsistent with the target (i.e., output of the teacher network) by adding consistency loss. In this manner, we can learn the generalized local features of both labeled and unlabeled datasets.

The final loss function for training the student network (i.e., VNet [VNet]) is summarized as follows:

(6)

where , and represent the coefficients used to balance the different loss terms.

Iv Experimental Results

Method # Scans used Metrics
Labeled Unlabeled Dice(%) Jaccard(%) 95HD(voxel) ASSD(voxel)
VNet 8(10%) 72 79.99 58.12 21.11 5.48
VNet 16(20%) 64 86.03 76.06 14.26 3.51
VNet 80(All) 0 91.14 83.82 5.75 1.52
DAP[dap] 8(10%) 72 81.89 71.23 15.81 3.80
UA-MT[ua_mt] 8(10%) 72 84.25 73.48 13.84 3.36
SASSNet[sassnet] 8(10%) 72 87.32 77.72 9.62 2.55
LG-ER-MT[lg-er-mt] 8(10%) 72 85.54 75.12 13.29 3.77
DUWM[duwm] 8(10%) 72 85.91 75.75 12.67 3.31
DTC[dtc] 8(10%) 72 86.57 76.55 14.47 3.74
CVRL[you2021momentum] 8(10%) 72 87.72 78.29 9.34 2.23
MC-Net[MC-Net] 8(10%) 72 87.71 78.31 9.36 2.18
Ours 8(10%) 72 88.42 79.38 8.74 2.52
DAP 16(20%) 64 87.89 78.72 9.29 2.74
UA-MT[ua_mt] 16(20%) 64 88.88 80.21 7.32 2.26
SASSNet[sassnet] 16(20%) 64 89.54 81.24 8.24 2.20
LG-ER-MT[lg-er-mt] 16(20%) 64 89.62 81.31 7.16 2.06
DUWM[duwm] 16(20%) 64 89.65 81.35 7.04 2.03
DTC[dtc] 16(20%) 64 89.42 80.98 7.32 2.10
CVRL[you2021momentum] 16(20%) 64 89.87 81.65 6.96 1.72
MC-Net[MC-Net] 16(20%) 64 90.34 82.48 6.00 1.77
Ours 16(20%) 64 90.56 82.84 5.95 1.79
TABLE I: Quantitative comparisons of the performances of semi-supervised segmentation models on the left atrium dataset. All models use VNet as the backbone network.
spleen
(a) spleen
(b) left kidney
(c) gallbladder
(d) esophagus
(e) liver
(f) stomach
(g) pancreas
(h) duodenum
Fig. 4: Box plots of the dice score coefficient of different methods for eight different organs.
Method Metrics (average) DSC
DSC(%) JC(%) HD(voxel) ASSD(voxel) spleen left kidney gallbladder esophagus liver stomach pancreas duodenum
VNet[VNet] 66.58 54.08 5.74 1.79 87.79 81.98 64.69 44.88 91.00 66.51 53.39 42.42
UA-MT[ua_mt] 69.57 56.90 4.99 1.36 89.64 77.53 67.82 56.19 92.21 70.73 54.58 47.86
SASSNet[sassnet] 69.09 56.42 4.85 1.48 87.42 87.26 60.19 54.16 90.41 69.41 57.30 46.59
DTC[dtc] 69.39 57.00 5.78 1.79 89.05 87.03 59.64 56.11 91.23 68.45 56.63 46.99
MC-Net[MC-Net] 69.76 57.34 5.61 1.90 89.15 87.82 64.66 50.50 92.28 71.22 56.97 45.49
Ours 71.28 59.01 4.32 1.24 89.75 87.07 66.64 56.01 93.03 71.58 59.08 47.03
TABLE II: Quantitative comparisons of the performances of semi-supervised segmentation models on the multiorgan dataset
Fig. 5: Qualitative comparison of different semi-supervised segmentation methods using the left atrium dataset with 20% labeled data. The first and second rows present the 2D and 3D visualization results, respectively.
Fig. 6: Qualitative comparison of different semi-supervised segmentation methods based on the 2D and 3D visualization results obtained using the multiorgan dataset with 20% labeled data.

Iv-a Dataset details

We evaluated our method using two datasets: the LA dataset from the Atrial Segmentation Challenge and an MO dataset.

Atrial Segmentation Challenge dataset We used 100 3D gadolinium-enhanced magnetic resonance imaging scans and an LA segmentation mask for training and validation. In the dataset, the scans exhibited an isotropic resolution of . Following the settings of a previous method [ua_mt, sassnet], the dataset was separated into two sets: training and testing, with 80 images for training and 20 for testing. We applied the same preprocessing method.

Abdominal multiorgan dataset To further evaluate the effectiveness of our method in multiclass segmentation, we evaluated its performance on an MO dataset. We used 90 abdominal CT images: 47 from the Beyond the Cranial Vault dataset [btcv] and 43 from the Pancreas-CT dataset. The segmentation standard consisted of the spleen, left kidney, gallbladder, esophagus, liver, stomach, pancreas, and duodenum. The slice thickness was in the range of and pixel sizes were in the range of . The dataset was separated into two sets: 70 images for training and 20 for testing. We sampled all abdominal CT images into pixels and preprocessed the image using a soft-tissue CT windowing range of

Hounsfield units. After rescaling, we normalized the input images to zero mean and unit variance(i.e., the range of the value is

).

Iv-B Implementation details

For training both LA and MO dataset, we used a VNet [VNet] architecture as the base network. We set the batch size to 4, and each batch included two labeled patches and two unlabeled patches.

For the LA dataset, we used the stochastic gradient descent optimizer (momentum = 0.9, weight decay of 0.0001) for 6000 iterations, with an initial learning rate of 0.01. The learning rate was divided by 10 for every 2500 iterations. To train the multitask feature discriminator, we followed the method described in

[kurach2019largescale]; we used an Adam optimizer (=0.5, =0.999) and a learning rate of 0.0002. The weighting parameter was 0.01 for and was 0.1 for . Following [sassnet], we used Gaussian warming-up function for consistency loss where indicates the number of iterations. Based on our previous study[lee2022voxel], the dimensions of all hidden layers from in voxel-level feature layer were set to 64. Furthermore, we used threshold

of 0.7. We implemented our framework in PyTorch

[paszke2019pytorch], using an NVIDIA TITAN RTX GPU and Tesla V100 GPU. At the inference time, only the VNet framework was used for segmentation.

Iter VNet + VNet + Ours ( VNet + ) MC-Net[MC-Net]
labeled unlabeled labeled unlabeled labeled unlabeled labeled unlabeled
0.1 k
0.5 k
1 k
TABLE III: Visualization of the feature alignment progress during the training phase using our proposed method with ablations and a mutual consistency network. We generated visualization using labeled (marked by triangles) and unlabeled (marked by circles) data, and we present them separately below for comparison.
Method # Scans used Metrics
Labeled Unlabeled Dice(%) Jaccard(%) 95HD(voxel) ASSD(voxel)
VNet 16(20%) 64 86.03 76.06 14.26 3.51
VNet+ 16(20%) 64 88.76 80.01 10.46 2.64
VNet+ 16(20%) 64 88.67 79.85 11.52 3.31
VNet++ 16(20%) 64 90.39 82.56 10.11 2.70
VNet+++ 16(20%) 64 90.56 82.84 5.95 1.79
TABLE IV: Ablation study of the effectiveness of our proposed method on the left atrium dataset

For MO dataset, we used Adam optimizer (=0.9, =0.999) and an initial learning rate of 0.001 decayed by 0.1 every 2500 iterations. The weighting parameter was 0.01 for and was 100 for . The rest of the experimental settings were the same as those employed in the LA dataset experiments.

Iv-C Results

For our evaluation metrics, we determined the dice score coefficient (DSC)

[dc], Hausdorff distance (HD95; mm) [hd, hd2], average symmetric surface distance (ASSD; mm) [assd]

, and Jaccard Index.

Left Atrial Segmentation Challenge dataset. We evaluated the performance of our proposed network in terms of its accuracy by comparing our results with those of the state-of-the-art models, i.e., domain-agnostic prior [dap], UA-MT[ua_mt], SASSNet[sassnet], local and global structure-aware entropy regularized mean teacher [lg-er-mt], double-uncertainty weighted method [duwm], DTC[dtc], contrastive voxel-wise representation learning [you2021momentum], and MC-Net[MC-Net]. Two semi-supervised settings widely used on the LA dataset were available from a previous study [sassnet] (i.e., using either 10 or 20% of the labeled data). Table I lists the quantitative results of LA segmentation. The results indicate that our proposed method achieves superior results in terms of the DSC, Jaccard index, and HD95 measurements and achieves competitive results on ASSD under the conditions of both 10% and 20% labeled data. Qualitative results are illustrated in Fig. 5. It can be observed that our method has a higher overlap ratio with respect to the ground truth in both 2D and 3D visualizations, thereby producing fewer false positives.

Abdominal multi-organ dataset To prove the effectiveness of our method on a multiclass dataset, we conducted an experiment on an MO dataset. For comparison, several state-of-the-art models (i.e., UA-MT[ua_mt], SASSNet[sassnet], DTC[dtc], and MC-Net[MC-Net]) and the base network, VNet, were used for evaluation. We considered 20% of training data among the 70 images as the labeled data (14 labeled) and the others as unlabeled data (54 unlabeled). All the models used VNet as their backbone network. Table II presents quantitative comparisons of the segmentation results. The results indicate that our method outperforms the other methods in terms of all evaluation metrics (i.e., Dice (71.28%), Jaccard index (59.01%), HD (4.32), and ASSD (1.24)). Our method achieves significant improvements in the segmentation of spleen, liver, stomach, and pancreas and demonstrates competitive results for other organs. A box plot for a more precise quantitative comparison is presented in Fig. 4. The qualitative results illustrated in Fig. 6 indicate that our method segments multiple organs better than other methods.

(a) VNet[VNet]
(b) Ours
Fig. 7: Visualization of features from the second layer using (a) VNet[VNet], (b) our method. The features are colored based on the class labels, and we visualize them using the test dataset (labels are only used for visualization).

Iv-D Ablation Study

We performed an ablation study to investigate the effectiveness of major components of the proposed loss function. We trained VNet under 20% labeled data using the MO and LA datasets, and the results are listed in Table III and IV, respectively.

From Table III, we can observe that each major component of our proposed method (i.e., and ) contributes to a more structured representation space in the training process. Specifically, guides unlabeled data to follow the distribution of labeled data, and plays a significant role in generating separated feature representation, as expected.

Table IV lists the comparison results of the ablations, wherein our losses (, , and ) were gradually incorporated. The results reveal a significant performance improvement in cases wherein the two losses, and , were used together, rather than being used separately. This demonstrates that these losses achieve synergy by learning the distribution of unlabeled features from well-distributed labeled features. Furthermore, including the loss function, , achieves further improvements by stabilizing label prediction.

V Discussion

Recent semi-supervised segmentation approaches in medical imaging have demonstrated promising results by employing various techniques, such as consistency regularization [ua_mt, dtc], pseudo-labeling [MC-Net], and adversarial learning [sassnet]. However, previous methods train the network with the outputs obtained from the final layer, which complicates learning of global features by the network. The proposed method is effective for learning both local and global contexts by embedding voxel-level features with voxel-level feature layers and voxel-level feature discriminators (Table I and Fig. 5). We achieved a more structured representation space (Fig. 7 b) by defining voxel-level feature (including global and local context) relations in the representation space. On comparison with a previous method [sassnet] which also included global contextual information with the discriminator and SDM, our method achieved superior results (Table II), particularly for multiclass datasets. By learning class-specific voxel-level features using BYOL[byol] and a multitask discriminator, we achieved a more structured representation space (Fig. 7 and Table III) and precise segmentation results for the multiclass dataset (Table II and Fig. 6). This indicates that our method is effective for learning feature relations across different classes. Moreover, as presented in Table IV, significant performance improvements can be observed for simultaneous use of the voxel-wise feature discriminator and voxel-wise representation learning; this implies that the unlabeled data distribution follows the labeled data distribution as we intended (Table III), thereby embedding rich feature representation. In future studies, we can improve the results by suggesting a more efficient method to enable unlabeled data to follow the distribution of labeled data.

Vi Conclusion

In this work, we propose a novel semi-supervised learning method for medical image segmentation tasks. Specifically, our voxel-wise representation learning method embedded feature representations (i.e., local and global features) in the representation space, and our voxel-wise feature discriminator successfully leveraged unlabeled data using the distribution of features from the labeled data. Extensive experimental results indicated that our method achieved competitive results when compared with existing state-of-the-art approaches. Furthermore, our method could provide a more informative representation that embedded class-specific features and achieved superior results in multiclass segmentation. We believe that our approach can provide a useful perspective on medical imaging tasks and can be applied to various medical datasets, regardless of the number of classes.

Acknowledgements

The authors report no conflict of interest.

References