DeepAI
Log In Sign Up

SS-3DCapsNet: Self-supervised 3D Capsule Networks for Medical Segmentation on Less Labeled Data

Capsule network is a recent new deep network architecture that has been applied successfully for medical image segmentation tasks. This work extends capsule networks for volumetric medical image segmentation with self-supervised learning. To improve on the problem of weight initialization compared to previous capsule networks, we leverage self-supervised learning for capsule networks pre-training, where our pretext-task is optimized by self-reconstruction. Our capsule network, SS-3DCapsNet, has a UNet-based architecture with a 3D Capsule encoder and 3D CNNs decoder. Our experiments on multiple datasets including iSeg-2017, Hippocampus, and Cardiac demonstrate that our 3D capsule network with self-supervised pre-training considerably outperforms previous capsule networks and 3D-UNets.

READ FULL TEXT VIEW PDF
07/10/2021

Hierarchical Self-Supervised Learning for Medical Image Segmentation Based on Multi-Domain Data Aggregation

A large labeled dataset is a key to the success of supervised deep learn...
03/16/2022

3D-UCaps: 3D Capsules Unet for Volumetric Image Segmentation

Medical image segmentation has been so far achieving promising results w...
01/21/2022

ERS: a novel comprehensive endoscopy image dataset for machine learning, compliant with the MST 3.0 specification

The article presents a new multi-label comprehensive image dataset from ...
11/27/2020

Unsupervised part representation by Flow Capsules

Capsule networks are designed to parse an image into a hierarchy of obje...
04/09/2020

Capsules for Biomedical Image Segmentation

Our work expands the use of capsule networks to the task of object segme...
11/21/2019

Self-Expiring Data Capsule using Trusted Execution Environment

Data privacy is unarguably of extreme importance. Nonetheless, there exi...
04/17/2018

Sparse Unsupervised Capsules Generalize Better

We show that unsupervised training of latent capsule layers using only t...

1 Introduction

Since the introduction of UNet [ronneberger2015u, cciccek20163d]

, UNet-based neural networks have achieved impressive performance in various modalities of medical image segmentation (MIS), e.g. brain tumor 

[brain_le_2018, le2021multi, ho2021point], infant brain [hoang2021dam, le2021offset], liver tumor [bilic2019liver], optic disc [ramani2020improved], retina [le2021narrow], lung [nguyen20213ducaps], and cell [moshkov2020test], etc. Recently, capsule networks [sabour2017dynamic] have also been applied successfully for MIS [lalonde2018capsules, lalonde2021capsules, nguyen20213ducaps]. Despite such, there remains a wide range of challenges: (1) Most methods are based on supervised learning, which is prone to many data problems like small-scale data, low-quality annotation, small objects, ambiguous boundaries, to name a few. These problems are not straightforward to overcome: labeling medical data is laborious and expensive, requiring an expert’s domain knowledge. (2) Capsule networks for medical segmentation does not outperform CNNs yet, even though the performance gap gets significantly closer [nguyen20213ducaps].

To address such limitations and inspired by the recent success of capsule networks, in this work, we develop SS-3DCapsNet, a self-supervised capsule network for volumetric MIS. Our SS-3DCapsNet is built upon a state-of-the-art 3D capsule network that leverages both 3D Capsule blocks and CNN blocks for encoder and decoder architecture, respectively, which accounts for temporal relations in volumetric slices in learning contextual visual representation. We introduce self-supervised learning (SSL) to our 3D capsule network, which results in a UNet-like architecture that contains three pathways, i.e., visual representation, encoder, and decoder. The first path consists of dilated convolutional layers, which were pre-trained by SSL techniques. The encoder path is built upon 3D Capsule blocks, whereas the decoder path is built upon 3D CNNs blocks. Compared to 2D-SegCaps [lalonde2021capsules], which is highly dependent on some random phenomena such as sampling order or weight initialization, our SS-3DCapsNet learns visual representation better as well as having a more robust weight initialization thanks to self-supervised learning. Compared to 3D-UCaps [nguyen20213ducaps], we show that self-supervised learning results in additional gain in segmentation accuracy while keeping the same network complexity at test time.

Our contributions are: (1) An effective self-supervised 3D capsules network for volumetric image segmentation. Our network architecture inherits the merits from 3D Capsule block, 3D CNN blocks, and self-supervised learning for better visual representation learning; and (2) A suite of experiments with ablation studies that empirically demonstrates the effectiveness of self-supervised 3D capsules network for MIS

2 Related works

Medical Segmentation. Among various DL architectures [ZHOU2021102193, siddique2021u], an encoder-decoder like UNet [ronneberger2015u] and its extension have achieved impressive performance among semantic segmentation approaches. Since the seminal work of UNet [cciccek20163d] for MIS, there have been numerous subsequent works in this task. As shown in a recent survey [lei2020medical], MIS can be divided into two main DL groups: supervised learning and weakly supervised learning techniques.

The first group includes CNN-based supervised learning methods such as FCN [long2015fully], UNet [3DUNet], CC-3D-FCN[nie20183], RLS [le2018deep], ACRes[le2021multi], DenseVoxNet [jegou2017one], Flow-based[bui2020flow], VoxResNet [chen2018voxresnet], 3D DR-UNet [vesal2018dilated], Recurrent Level Set [le2018deep], Atrous-Net [le2021multi], Offset Curves Loss [le2021offset, le2021narrow], Point-Unet [nguyen20213ducaps]

as notable methods. The second group includes weakly supervised learning methods such as transfer learning

[kalinin2020medical], domain adaptation [chen2019synergistic], interactive segmentation [wang2018deepigeos]

. To address the issue of data limitation for training, Generative Adversarial Network (GAN)

[mirza2014conditional] has been incorporated into CNNs [chang2020synthetic, le2021enhance, le2021pairflow, bui2020flow]. Training with imperfect datasets with scarce annotations and weak annotations has also been considered recently [tajbakhsh2020embracing].

Capsule Networks. Capsule networks [hinton2011transforming]

(CapsNet) is a new network architecture concept that strengthens feature learning by retaining more information at the aggregation layer for pose reasoning and learning the part-whole relationship, which makes it a potential solution for semantic segmentation and object detection tasks. In CapsNet, a capsule aims to represent an entity: capsule norm indicates the probability that entity is present and capsule direction indicates the configuration that entity is in. CapsNet is recently made practical

[sabour2017dynamic] in a CNN that incorporates two layers of capsules with dynamic routing.

While most CapsNet has been proposed for image classification, SegCaps [lalonde2018capsules, lalonde2021capsules, nguyen20213ducaps] expanded CapsNet for object segmentation. This method functions by treating an MRI image as a collection of slices, each of which is then encoded and decoded by capsules to output the segmentation. However, SegCaps is mainly designed for 2D still images, and it performs poorly when being applied to volumetric data because of missing temporal information. 3D-UCaps [nguyen20213ducaps] is a hybrid network architecture that utilizes both capsules and deconvolutions for feature learning and segmentation output, respectively, which shows that such combination can outperforms SegCaps design significantly in the segmentation task while retaining the merits of capsules. Our method further improves upon 3D-UCaps by integrating an efficient pre-training stage.

Self-supervised Learning. Self-supervised learning (SSL) is a technique for learning feature representation in a network without requiring a labeled dataset. A common workflow to apply SSL is to train the network in an unsupervised manner by learning with a pretext task in the pre-training stage, and then finetuning the pre-trained network on a target downstream task. In the case of MIS, the suitable pretext tasks can be considered in four categories: context-based, generation-based, free semantic label-based, and cross-modal-based. The first techniques utilize context features of images or videos such as context similarity [caron2018deep], spatial structure [ahsan2019video], temporal structure [wei2018learning]. The second techniques have been used in image generation [zhang2016colorful] and video generation [srivastava2015unsupervised]. The third techniques aim to automatically generate semantic labels and applied into segmentation [pathak2017learning], contour detection [pathak2017learning]. The fourth techniques are applied to multiple modalities data such as visual-audio [korbar2018cooperative], RGB-Flow [sayed2018cross]. In this work, our pretext task is based on image reconstruction.

3 SS-3DCapsNet: Self-supervised 3D Capsule Networks

We draw on the ideas of SegCaps [lalonde2018capsules] and 3D-UCaps [nguyen20213ducaps] to build our 3D capsule network for the medical segmentation task. Particularly, our network has three stages: (i) Visual representation, (ii) Capsule encoder, and (iii) Convolutional decoder as follows.

Figure 1: Our proposed SS-3DCapsNet architecture with three components: visual representation; capsule encoder, and convolution decoder. Number on the blocks indicates number of channels in convolution layer and dimension of capsules in capsule layers.

(i) Visual Representation: This stage is for converting the input to a feature volume that can be consumed by the capsule encoder. Followed the concurrent work, we use three dilated convolution layers with 16, 32, 64 channels, respectively. The kernel size is set to , with dilate rates of 1, 3, and 3, respectively. The size of the visual features is .

(ii) Capsule Encoder: In this stage, we reshape the feature volume into

capsules, where each capsule is represented by a 64-dimensional vector. Here we consider both spatial and temporal data by using our 3D convolutional capsules to learn a richer representation. The output from a convolution capsule has the shape

, where is the number of capsule types and is the dimension of each capsule. We follow the concurrent work and set to for each layer in the capsule encoder, respectively. Note that as the number of capsule types in the last convolutional capsule layer is equal to the number of class labels, we can further supervise this particular layer with a margin loss [sabour2017dynamic].

(iii) Convolutional Decoder: This is the final stage in our network. Here we use the decoder of 3D UNet [cciccek20163d] which includes deconvolution, skip connection, convolution and BatchNorm layers [ioffe2015batch]

to generate the segmentation from features learned by capsule layers. Particularly, we reshape the capsules back to tensors of size

and pass them to the decoder. The overall architecture can be seen in Fig. 1.

3.1 Pretext Task

Our pretext task is self-supervised based on medical image reconstruction. In computer vision, it is common to use pseudo-labels defined by different image transformations, e.g, rotation, random crop, adding noise, blurring, scaling, flipping, jigsaw puzzle, etc. to supervise the pretext task. While such transformations work well for classification as a downstream task, since our downstream task is segmentation, we propose to use a pretext task that can consider reconstructing the original image. As medical images are captured in low contrast and the object-of-interest in medical images usually follows some specific patterns, we select contrast transformations to perform the pretext task with the reconstruction loss.

Figure 2: Examples of six transformations for self-supervised learning. (a): original image. (b) from left to right, top to bottom: zeros-green-channel, zeros-red-channel, zeros-blue-channel, swapping (4 swapped patches are shown in yellow boxes), blurring, noisy.
Figure 3: Our pretext task with reconstruction loss.

The details of our pre-training are as follows. Our pretext task is based on reconstruction from various transformations i.e. noisy, blurring, zero-channels (R,G,B), swapping as shown in Fig. 2. Let is the visual representation network. The transformation is defined as , where is an identity transformation and is set as 6 corresponding to six transformations (Fig. 2). Let denote as the original input volumetric data. Our pretext task is performed by applying two random transformations into . The transformed data is then and , respectively. The visual feature of transformed data after applying the network is and , where and . The network is trained with a reconstruction loss defined by:

(1)

The pretext task procedure is illustrated in Fig. 3.

3.2 Downstream Task

After pre-training, we train our SS-3DCapsNet network with annotated data on the medical segmentation task. The total loss function to train this downstream task is a sum of three losses:

(2)

The margin loss is adopted from  [sabour2017dynamic] and it is defined between the predicted label and the ground truth label as follows:

(3)

Particularly, we compute the margin loss () on the capsule encoder output with downsampled ground truth segmentation. We compute the weighted cross-entropy loss () on the convolutional decoder. We also regularize the training with a network branch that aims at reconstructing the original input with masked mean-squared errors ([sabour2017dynamic, lalonde2018capsules].

4 Experimental Results

4.1. Implementation Details We conduct our experiments and comparisons on iSeg [wang2019benchmark], Hippocampus, and Cardiac [simpson2019large] datasets. For iSeg, we follow 3D-SkipDenseSeg [bui2019skip] to have the training set of 9 subjects and testing set of subject #9. On Hippocampus, and Cardiac [simpson2019large], the experiments are conducted by 4-fold cross-validation.

Method Depth Dice Score
WM GM CSF Average
Qamar et al. [qamar2020variant] 82 90.50 92.05 95.80 92.77
CNN 3D-SkipDenseSeg [bui2019skip] 47 91.02 91.64 94.88 92.51
VoxResNet [chen2018voxresnet] 25 89.87 90.64 94.28 91.60
3D-UNet [cciccek20163d] 18 89.83 90.55 94.39 91.59
CC-3D-FCN [nie20183] 34 89.19 90.74 92.40 90.79
DenseVoxNet [jegou2017one] 32 85.46 88.51 91.26 89.24
Capsule 2D SegCaps [lalonde2018capsules] 16 82.80 84.19 90.19 85.73
3D-SegCaps [nguyen20213ducaps] 16 86.49 88.53 93.62 89.55
3D-UCaps [nguyen20213ducaps] 17 90.21 91.12 94.93 92.08
Our SS-3DCapsNet 17 90.78 91.48 94.92 92.39
Table 1: Comparison on iSeg-2017. group: 3D CNN-based networks and group: Capsule-based networks.

We implemented our method in Pytorch. We used patch size of

for iSeg and Hippocampus whereas patch size of on LUNA16 and Cardiac. Our SS-3DCapsNet was trained without any data augmentation. We used Adam optimizer with an initial learning rate of 0.0001. The learning rate is decayed by 0.05 if the Dice score on the validation set does not increase for 50,000 iterations. Early stopping is set at 250,000 iterations as in [lalonde2018capsules].

3D CNN-based networks Capsule-based networks
3D UNet[cciccek20163d] 84.30 SegCaps (2D) [lalonde2018capsules] 66.96
3D Vnet[VNet] 84.20 Multi-SegCaps (2D) [survarachakan2020capsule] 66.96
3D DR-UNet [vesal2018dilated] 87.40 3D-UCaps [nguyen20213ducaps] 89.69
Our SS-3DCapsNet 89.77
Table 2: Comparison on Cardiac with 4-fold cross validation.
Method Anterior Posterior
Recall Precision Dice Recall Precision Dice
Multi-SegCaps (2D) [lalonde2018capsules] 80.76 65.65 72.42 84.46 60.49 70.49
EM-SegCaps (2D) [survarachakan2020capsule] 17.51 20.01 18.67 19.00 34.55 24.52
3D-UCaps [nguyen20213ducaps] 81.70 80.19 80.99 80.2 79.25 79.48
Our SS-3DCapsNet 81.84 81.49 81.59 80.71 80.21 79.97
Table 3: Comparison on Hippocampus with 4-fold.
Method Dice Score
WM GM CSF Average
change number of capsule (set to 4) 89.02 89.78 89.95 89.58
w/o visual representation 89.15 89.66 90.82 89.88
w/o margin loss 87.62 88.85 92.06 89.51
w/o reconstruction loss 88.50 88.96 90.18 89.22
w/o pretext task 90.21 91.12 94.93 92.08
SS-3DCapsNet 90.78 91.48 94.92 92.39
Table 4: Performance of SS-3DCapsNet on iSeg with different network configurations.
iSeg Hippocamus Cardiac
Pre Rec DSC Pre Rec DSC Pre Rec DSC
w/o. SSL 92.28 91.29 92.08 79.72 80.95 80.24 84.60 95.06 89.69
w. SSL 92.54 92.37 92.39 80.85 81.27 80.78 86.24 94.21 89.77
Table 5: Performance of SS-3DCapsNet on Precision (Pre), Recall (Rec) and Dice score (DSC) with and without pretext task on various datasets.

4.2. Performance and Comparison We compare our SS-3DCapsNet with both SOTA 3D CNNs-based and Capsule-based segmentation methods. 3D-Ucaps [nguyen20213ducaps] has two versions of with and without utilizing MONAI [MonAI]. To conduct a fair comparison, we report the version without MONAI.

The comparison between our proposed SS-3DCapsNet with SOTA segmentation approaches on iSeg dataset [wang2019benchmark] is given in Table 1. As can be seen, 3D capsule networks (3D-SegCaps, SS-3DCapsNet) outperform 2D-SegCaps by a wide margin. This performance gap can be explained by the combination of pre-training, Capsule encoder, and Convolutional decoder in SS-3DCapsNet. Our SS-3DCapsNet also outperforms 3D-SegCaps, which contains only a Capsule-based encoder and decoder. Our SS-3DCapsNet also performs comparably to SOTA 3D CNNs, but our network is significantly shallower (17 layers vs. 82 layers in [qamar2020variant]). Our network also has fewer parameters and a better Dice score when compared to SOTA 3D CNNs with similar number of layers, e.g. 3D-UNets [cciccek20163d] (18 layers). In addition to iSeg, we also evaluate our SS-3DCapsNet on Hippocampus and Cardiac, where the results are shown in Table 2 and Table 3.

4.3. Ablation Study We analyze the performance of our method as follows.

i. Network Configuration: We trained SS-3DCapsNet under various settings as shown in Table 4. By following the concurrent work on 3D capsule networks, we use a baseline where the number of capsules of the first layer is reduced to 4 (similar to SegCaps). As can be seen, each component including visual representation, margin loss, reconstruction loss, pre-training contributes to the final performance. Removing any of such components would result in performance drops.

ii. SSL Contribution: We perform experiments on various datasets and turn on/off the self-supervision step in the experiments. The results in Table 5 clearly shows that pre-training plays an important role in our method, which improves the Dice score considerably in iSeg, and slightly in other datasets.

CONCLUSION

In this work, we proposed a capsule network for MIS powered with self-supervised pre-training. Our SS-3DCapsNet can both utilize self-supervised learning and 3D capsules for learning features while retaining the advantage of traditional convolutions in decoding the segmentation results. Even though we use capsules with dynamic routing only in the encoder of a simple Unet-like architecture, we can achieve the competitive result with the SOTA models on iSeg-2017 challenge while outperforming SegCaps [lalonde2018capsules] on different complex datasets with less labeled annotated data. Future work includes exploring different self-supervised learning methods such as SimCLR [chen2020simple] for better feature learning and representation.

Acknowledgment This material is based upon work supported by the National Science Foundation under Award No. OIA-1946391.

References