CapsNet for Medical Image Segmentation

03/16/2022
by   Minh Tran, et al.
47

Convolutional Neural Networks (CNNs) have been successful in solving tasks in computer vision including medical image segmentation due to their ability to automatically extract features from unstructured data. However, CNNs are sensitive to rotation and affine transformation and their success relies on huge-scale labeled datasets capturing various input variations. This network paradigm has posed challenges at scale because acquiring annotated data for medical segmentation is expensive, and strict privacy regulations. Furthermore, visual representation learning with CNNs has its own flaws, e.g., it is arguable that the pooling layer in traditional CNNs tends to discard positional information and CNNs tend to fail on input images that differ in orientations and sizes. Capsule network (CapsNet) is a recent new architecture that has achieved better robustness in representation learning by replacing pooling layers with dynamic routing and convolutional strides, which has shown potential results on popular tasks such as classification, recognition, segmentation, and natural language processing. Different from CNNs, which result in scalar outputs, CapsNet returns vector outputs, which aim to preserve the part-whole relationships. In this work, we first introduce the limitations of CNNs and fundamentals of CapsNet. We then provide recent developments of CapsNet for the task of medical image segmentation. We finally discuss various effective network architectures to implement a CapsNet for both 2D images and 3D volumetric medical image segmentation.

READ FULL TEXT VIEW PDF

page 3

page 7

page 8

page 18

page 20

03/16/2022

3D-UCaps: 3D Capsules Unet for Volumetric Image Segmentation

Medical image segmentation has been so far achieving promising results w...
10/13/2020

Exploring Efficient Volumetric Medical Image Segmentation Using 2.5D Method: An Empirical Study

With the unprecedented developments in deep learning, many methods are p...
03/16/2021

Invertible Residual Network with Regularization for Effective Medical Image Segmentation

Deep Convolutional Neural Networks (CNNs) i.e. Residual Networks (ResNet...
10/03/2018

PADDIT: Probabilistic Augmentation of Data using Diffeomorphic Image Transformation

For proper generalization performance of convolutional neural networks (...
05/07/2020

How Can CNNs Use Image Position for Segmentation?

Convolution is an equivariant operation, and image position does not aff...
07/02/2020

PGD-UNet: A Position-Guided Deformable Network for Simultaneous Segmentation of Organs and Tumors

Precise segmentation of organs and tumors plays a crucial role in clinic...
07/29/2022

Beyond CNNs: Exploiting Further Inherent Symmetries in Medical Image Segmentation

Automatic tumor or lesion segmentation is a crucial step in medical imag...

1 Convolutional Neural Networks: Limitations

Despite outperforming in various computer vision tasks, CNNs ignore the geometrical relationships of objects. As a result, CNNs are sensitive to image rotation and affine transformation, which are not present in the training data. Recent works have shown that small translations or rescalings of the input image can drastically change the network’s performance. To address such limitations in CNNs, their generalization relies on a large-scale training data, which captures various input variations such as rotations and viewpoint changes. In this section, we first report some quantified analyses of this phenomenon. We then analyze how the CNNs architectures do not produce invariance.

The lack of invariance of modern CNNs to small image deformation was reported in [engstrom2018rotation, azulay2018deep, zhang2019making, azulay2020deep]. Take [azulay2020deep]

as an instance. Azulay and Weiss selected three different CNN architectures VGG16, ResNet50, InceptionRes- NetV2 from Keras package and three other different CNN architectures VGG16, ResNet50, DenseNet121 from Pytorch package. They tested on 1000 images with four different protocols including crop, Translation - Embedding - Black, Translation - Embedding - Inpainting, Scale - Embedding - Black to systematically quantify the effect of invariance to CNNs. In the first protocol, a random square is randomly chosen within the original image and resize the square to be 224x224. In the second protocol, the image is downsampled so that its minimal dimension is of size 100 while maintaining aspect ratio, and embed it in a random location within the 224x224 image, while filling in the rest of the image with black pixels. They then shift the embedding location by a single pixel, again creating two images that are identical up to a shift by a single pixel. In the third protocol, they repeat the embedding experiment but rather than filling in the rest of the image with black pixels we use a simple inpainting algorithm (each black pixel is replaced by a weighted average of the non black pixels in its neighborhood). The fourth protocol is similar to the second protocol, but rather than shifting the embedding location, they keep the embedding location fixed and change the size of the embedded image by a single pixel. They used P(Top-1 change) and mean absolute change (MAC) to measure the network sensitivity. The first metric P(Top-1 change) is invariant to any monotonic transformation of the output of the final layer of the network, while the second one tells us the possibility that changes in the top-1 prediction are due to very small differences between the most likely class and the second most likely class. The quantitative analysis on three Keras networks is shown in Fig.

1 which indicates that CNNs are not fully translation invariant.

Figure 1:

Quantitative analysis on three Keras networks with four different protocols on 1000 randomly chosen images from the ImageNet validation set. Image is from

[azulay2020deep].

Furthermore, in a CNN, a pooling layer works as a messenger between two layers and transfers the activation information from one layer to the next layer. By doing that, a pooling layer can indicate the presence of a part, but is unable to capture the spatial relation among the parts. Clearly, a pooling operation (e.g. max-pooling) does not make the model invariant to viewpoint changes. Moreover, each filter of convolutional layers works like a feature detector in a small region of the input features. When going deeper into a network, the detected low-level features are aggregated and become high-level features that can be used to distinguish between different objects. The higher-level features in a CNN are built as a weighted sum of lower-level features. Thus, geometric relationships among features are not taken into account.

Over the years, different techniques have been developed to tackle the aforementioned limitations in CNNs. Most common solutions including data augmentation techniques, which increase the data size to include transformations.

2 Capsule Network: Fundamental

Inspired by how a visual image was constructed from geometrical objects and matrices that represent relative positions and orientation, Hinton [hinton2011transforming]

proved that preserving hierarchical pose relationships among object parts is important to correctly classify and recognize an object. This is also known as inverse graphics, similar to how the brain recognizes the object. To address the limitations in CNNs, Hinton

[hinton2011transforming]

proposed to replace neuron’s scalar output which only represents the activation of replicated feature detectors with vector output (a.k.a. capsule). Each capsule will learn a visual entity (e.g. a part of an object). The output of each capsule is presented by a vector that contains both the probability that this entity is present and a set of ”instantiation parameters” that captured different properties of the entity

CapsNet [sabour2017dynamic]

was proposed to address the above intrinsic limitations of CNNs. CapsNet strengthens feature learning by retaining more information at the aggregation layer for pose reasoning and learning the part-whole relationship. Different from CNNs that contain a backbone network to extract features, several fully connected layers, and N-way Soft-Max produces the classification logits, a CapsNet contains more complex five components as follows:

  • Non-shared transformation module: the primary capsules are transformed to execute votes by non-shared transformation matrices.

  • Dynamic routing layer: to group input capsules to produce output capsules with high agreements in each output capsule.

  • Squashing function: to squash the capsule vectors’ lengths to the range of [0, 1).

  • Marginal classification loss: to work together with the squashed capsule representation.

  • Reconstruction loss: To recover the original image from the capsule representations.

The network architecture comparison between CNNs and CapsNet is shown in Fig. 2 whereas the operators, input, and output comparison is given in Fig. 3. In CNNs, each filter takes in charge of feature detector in a small region of the input features and as we go deeper. The detected low-level features are aggregated and become high-level features that can be used to distinguish between different objects. However, by doing so, each feature map only contains information about the presence of the feature, and the network relies on fix learned weight matrix to link features between layers. It leads to the problem that the model cannot generalize well to unseen changes in the input [alcorn2019strike]. In CapsNet, each layer aims to learn a part of an object together with its properties and represent them in a vector. The entity of previous layer represents simple objects whereas the next layers represents complex objects through the voting process.

Figure 2: Network architecture comparison between CNNs and CapsNet.
Figure 3: Operators comparison between CNNs and CapsNet.

Let denote the visual feature map which is extracted by a few convolutional layers. It then is reshaped as primary capsules , where is the dimensions of the primary capsules. In such design, there is number of capsules, each of capsule is in . Let is a transformer matrix, each primary capsule is transformed to make a vote as , where is output classes and is the dimensions of output capsules. In CapsNet, a dynamic routing process at the iteration takes all votes into consideration to compute weight for each vote as follows:

(1)

,

is log prior probability,

is coupling coefficient that models the degree with which is able to predict . It is initialized as . is a squashing function that maps the length of a vector to [0, 1), i.e.,

(2)

The classification loss is defined as margin loss as follows:

(3)

As suggested by [sabour2017dynamic], are set 0.9, 0.1 0.5. if the object of the class is present in the input.

The reconstruction loss is computed as a regularization term in the loss function. The pseudo-code of the dynamic routing algorithm is presented in Algorithm

1

 

1:Capsule at layer (l).
2:Capsule at layer (l+1).
3:Initial: for all capsule i in layer (l) and capsule j in layer (l + 1),
4:for each iteration do
5:     for all capsule i at layer (l) do
6:     end for
7:     for all capsule j at layer (l+1) do
8:         
9:         
10:     end for
11:     for all capsule i in layer (l) and capsule j in layer (l + 1) do
12:         
13:     end for
14:end for
Algorithm 1 The pseudo-code of the dynamic routing algorithm.

Now, let consider the backpropagation through routing iterations. Assuming that there are

iterations and capsules at output , gradients through the routing procedure are:

(4)

The second term in Eq.4 is actually the main computational burden of the expensive routing.

We further investigate the robustness CapsNet when comparing it with CNNs as follows:

  • Translation invariant: While CNNs are able to identify if the object exists in a certain region, they are unable to identify the position of one object relative to another. Thus, CNNs can not model spatial relationships between objects/features. As shown in Fig. 5, CNNs can tell it is a dog image or face image but it can not tell the spatial relationship between the dog and the picture or position relation between facial components.

    Figure 4: Translation invariant comparison between CNNs and CapsNet.
  • Require less data to generalize: CNNs require to gether massive amounts of data that represents each object in various positions and poses. Then we train the CNNs on this huge dataset with a hope that the network is able to see enough examples of the object to generalize. Thus, to better generalize over variations of the same object, CNNs are trained on multiple copies of every sample, each being slightly different. Data augmentation is one of the most common techniques to make the CNNs model more robust. With CapsNet, it encodes invariant part-whole spatial relationships into learned weights. Thus, CapsNet is able to encode various positions and poses information of parts and the invariant part-whole relationships to generalize to unseen variations of the objects.

  • Interpretability: There have been a large number of interpretation methods proposed to understand individual classifications of CNNs, model interpretability is still a significant challenge for CNNs. By taking part-whole relation into consideration, the higher capsule in CapsNet is interpretable and explainable. Thus, CapsNet is inherently more interpretable networks than traditional neural networks as capsules tend to encode specific semantic concepts. Especially, the disentangled representations captured by CapsNet often correspond to human understandable visual properties of input objects, e.g., rotations and translations. Let us take the example of reconstructing a MNIST digit from

    [sabour2017dynamic], different dimensions of the activity vector of a capsule controlled different features, including scale and thickness, localized part, stroke thickness, and width and translation as shown in Fig.5.

    Figure 5: Experiments on MNIST. Different dimensions of the capsules are responsible for encoding different characteristics of the digits. Image is from [sabour2017dynamic].

3 Capsule Network: Related Work

There have been different mechanisms proposed to improve the performance of CapsNet. In general, we categorize them into two groups. The first one aims to propose various effective dynamic routing mechanisms to improve dynamic routing [sabour2017dynamic]. Dynamic routing identifies the weights of predictions made by the lower-level capsules, called coupling coefficients by an iterative routing-by-agreement mechanism. EM Routing [hinton2018matrix]

updates coupling coefficients iteratively using Expectation-Maximization. By utilizing attention modules augmented by differentiable binary routers,

[chen2018generalized] proposes a straight-through attentive routing to reduce the high computational complexity of dynamic routing iterations in [sabour2017dynamic]. To increase the computational efficiency, Mobiny, et al., [mobiny2018fast] proposes a consistent dynamic routing mechanism that results in speedup of CapsNet. Recently, [ribeiro2020capsule] proposes a new capsule routing algorithm derived from Variational Bayes for fitting a mixture of transforming gaussians to show that it is possible transform capsule network into a Capsule-VAE. To reduce the parameters of CapsNet, [hinton2018matrix, rajasegaran2019deepcaps]

propose to use a matrix or a tensor to represent an entity instead of a vectors. The second category focuses on network architecture such as combinging both Convolutional layers and Capsule layers

[phaye2018multi]

, unsupervised capsule autoencoder

[kosiorek2019stacked], Aff-CapsNets [gu2020improving], Memory-augmented CapsNet[mobiny2021memory] . While [gu2020improving] removes the dynamic routing by sharing the transformation matrix, [gu2020interpretable] replaces the dynamic routing with a multi-head attention-based graph pooling approach to achieve better interpretability. Recently, Mobiny, et al., [mobiny2020decaps] proposed DECAPS which utilize Inverted Dynamic Routing (IDR) mechanism to group lower-level capsules before sending them to higher-level capsules as well as employ a Peekaboo training procedure to encourage the network to focus on fine-grained information through a second-level attention scheme. DECAPS has outperformed experienced, well-trained thoracic radiologists [mobiny2020radiologist].

4 CapsNets in Medical Image Segmentation

This section introduces various recent CapsNets that have been proposed in medical image segmentation.

4.1 2D-SegCaps

Figure 6: Network architecture of 2D-SegCaps [lalonde2018capsules, lalonde2021capsules] for biomedical image segmentation. The network is an UNet-based architecture with Capsule blocks are at both encoder and decoder paths.

CapsNet has been mainly applied to image classification, image recognition; its performance is still limited compared to the state-of-the-art by CNNs-based approaches. 2D-SegCaps [lalonde2018capsules, lalonde2021capsules] was the first CapsNet proposed for semantic image segmentation and illustrated as in Fig. 6. As stated by [lalonde2018capsules, lalonde2021capsules]

, performing semantic image segmentation with a CapsNet is extremely challenging because of high computational complexity during the routing process which takes place between every parent and every possible child. The 2D-SegCaps is an UNet-based architecture with Capsule blocks are at both encoder and decoder paths. It contains four components corresponding to (i) visual feature extraction, which produces 16 feature maps of the same spatial dimensions, (ii) convolutional capsule at the encoder path, (iii) deconvolutional capsule at the decoder path, and (iv) reconstruction regularization at decoder path. Details of four components are as follows:

  • Feature Extraction: 2D-SegCaps network takes a large 2D image (e.g. a slice of a MRI scan) as its input. The image is passed through a 2D Conv layer which produces 16 feature maps of the same spatial dimensions, . This becomes input of the following convolutional capsule layers.

  • Convolutional Capsule Encoder: . The process of convolutional capsules and routing to any given layer in the network are given as follows:

    • At layer : There exists a set of capsule types

      (5)

      For every , there exists an grid of -dimensional child capsules

      (6)

      , where is the spatial dimensions of the output of layer .

    • At layer : There exists a set of capsule types

      (7)

      For every , there exists an grid of -dimensional parent capsules

      (8)

      , where is the spatial dimensions of the output of layer .

    • For every parent capsule type , every parent capsule receives a set of “prediction vectors”, each value is for each capsule type in , i.e., . This set of prediction vectors is defined as the matrix multiplication between a learned transformation matrix for the given parent capsule type, , and the sub-grid of child capsules outputs, . The sub-grid of child capsules outputs , where are the dimensions of the user-defined kernel, for all capsule types . Each parents capsules has dimension . The is computed as:

      (9)

    To reduce total number of learned parameters, 2D-SegCaps shares transformation matrices across members of the grid, i.e., does not depend on the spatial location. This transformation matrix is shared across all spatial locations within a given capsule type. Such mechanism is similar to how convolutional kernels scan an input feature map and this is the main different between 2D-SegCaps and CapsNet. The parent capsule for parent capsule type is then computed as follows:

    (10)

    where is coupling coefficient defined in Eq.1, i.e., The output capsule is then computed using a non-linear squashing function as defined in Eq.1 as follows:

    (11)

    Lastly, the agreement is measured as the scalar product

    (12)

    Unlike dynamic routing in CapsNet [sabour2017dynamic], 2D-SegCaps locally constrains the creation of the prediction vectors. Furthermore, 2D-SegCaps only routes the child capsules within the user-defined kernel to the parent, rather than routing every single child capsule to every single parent.

  • Deconvolutional Capsule Decoder: Deconvolutional capsules are as similar as the convolutional capsules; however, the prediction vectors are now formed using the transpose of the operation previously described. In the deconvolutional capsules, the set of prediction vectors are defined as the matrix multiplication between a learned transformation matrix, for a parent capsule type and the sub-grid of child capsules outputs, for each capsule type in . For each member of the grid, we can then form our prediction vectors again by the following equation.

    (13)

    Similar to convolutional capsule encoder, is input to the dynamic routing algorithm to form our parent capsules and .

  • Reconstruction Regularization: This component aims to model the distribution of the positive input class and treat all other pixels as background, the segmentation capsules which do not belong to the positive class is masked out. The reconstruction is performed via a three Conv. layers. Then, it is computed by a mean-squared error (MSE) loss between only the positive input pixels and this reconstruction. The supervised loss for the reconstruction regularization is computed as follows:

    (14)

    , where is the input image, is the reconstruction target, is the ground-truth segmentation mask, is the output of the reconstruction network. is weighting coefficient for the reconstruction loss and set to 1 – 0.001.

2D-SegCaps [lalonde2018capsules, lalonde2021capsules] is trained with a supervised loss function. There are three loss functions included in the algorithm as follows:

  • Margin Loss: The margin loss is adopted from  [sabour2017dynamic] and it is defined between the predicted label and the ground truth label as follows: L_margin =& y^* ×(max(0, 0.9 - y))^2 +
    & 0.5 ×(1 - y^*) ×(max(0, y - 0.1))^2. Particularly, we compute the margin loss () on the capsule encoder output with downsampled ground truth segmentation.

  • Weighted Cross Entropy Loss: We compute the weighted cross-entropy loss () on the convolutional decoder.

  • Reconstruction Regularization: We also regularize the training with a network branch that aims at reconstructing the original input with masked mean-squared errors () as in Eq.14.

The total loss is the weighted sum of the three losses as follows:

(15)

2D-SegCaps has obtained promising performance on LUNA16 dataset [setio2017validation]; however, Survarachakan, et al. [survarachakan2020capsule] has shown that 2D-SegCaps performance is significantly decreased on the MSD dataset [simpson2019large] compared with Unet-based architectures. Furthermore, Survarachakan, et al. [survarachakan2020capsule] extend 2D-SegCaps to Multi-SegCaps to support multiple class segmentation. Unlike 2D-SegCaps, the output capsule layer in Multi-SegCaps is modified to output output capsules, where is the number of classes in the dataset, including background, and the predicted class is the one represented by the capsule with the longest euclidean length. Thus, Multi-SegCaps attempts to reconstruct the pixels belonging to all classes, except for the background class instead of a single target class as in 2D-SegCaps.

Like other CapsNet, 2D-SegCaps and Multi-SegCaps have the limitation of high computational complexity, which is caused by dynamic routing. To address such concern, Survarachakan, et al. [survarachakan2020capsule] makes used on EM-routing [hinton2018matrix] and proposes EM-routing SegCaps. The EM-routing SegCaps architecture uses matrix capsules with EM-routing and is shown in Fig.7. The main difference between EM-routing SegCaps and 2D-SegCaps is that convolutional capsule layers accept the poses and activations from capsules in the previous layer and output new poses and activations for the capsules in the next via Expectation-Maximization routing algorithm (EM-routing). In EM-routing SegCaps, all child capsules cast an initial vote of the output for every capsule in the next layer, using its own pose matrices before performance EM-routing. To cast this vote, a transformation matrix going into the parent capsule is trained and shared by all child capsules. In EM-routing SegCaps, predictions are first computed and then forwarded to the EM-routing algorithm, along with the activations from the previous layer. The EM-routing algorithm is run for three iterations before it returns the final pose and activations for all capsules in the current layer. By replacing dynamic routing in 2D-SegCaps by EM-routing, the performance of EM-routing SegCaps is not improved much compared with 2D-SegCaps.

Figure 7: EM-routing SegCaps architecture

4.2 3D-SegCaps

2D-SegCaps has shown promising performance on 2D medical image segmentation, however, it ignores the temporal relationship when performing on 3D images, e.g. MRIs, CT scans. Nguyen, et al.[nguyen20213d] extends 2D-SegCaps to 3D-SegCaps by incorporating temporal information into capsules. Similar to 2D-SegCaps, 3D-SegCaps contains four components corresponding to visual feature extraction, convolutional capsule layers, deconvolutional capsule layers, and reconstruction regularization. Details of four components are as follows:

  • Feature Extraction: 3D-SegCaps network takes a volumetric data as its input. The volumetric data is passed through a 3D Conv layer which produces 16 feature maps of the same spatial dimensions, . This becomes input of the following convolutional capsule layers.

  • Convolutional Capsule Encoder: The process of convolutional capsules and routing to any given layer in the network are given as follows:

    • At layer : There exists a set of capsule types

      (16)

      For every , there exists an grid of -dimensional child capsules

      (17)

      where is the spatial dimensions of the output of layer .

    • At layer : There exists a set of capsule types

      (18)

      For every , there exists an grid of -dimensional parent capsules

      (19)

      where is the spatial dimensions of the output of layer .

    • For every parent capsule type , every parent capsule receives a set of “prediction vectors”, each value is for each capsule type in , i.e., . This set of prediction vectors is defined as the matrix multiplication between a learned transformation matrix for the given parent capsule type, , and the sub-grid of child capsules outputs, . The sub-grid of child capsules outputs , where are the dimensions of the user-defined kernel, for all capsule types . Each parents capsules has dimension . The is computed as:

      (20)

    To reduce a total number of learned parameters, 3D-SegCaps shares transformation matrices across members of the grid i.e., does not depend on the spatial location. This transformation matrix is shared across all spatial locations within a given capsule type. Such a mechanism is similar to how convolutional kernels scan an input feature map and this is the main difference between 3D-SegCaps and CapsNet. The parent capsule for parent capsule type is then computed as follows:

    (21)

    where is coupling coefficient defined in Eq.1, i.e., Similar to 2D-SegCaps, the output capsule is then computed using a non-linear squashing function as defined in Eq.1 as follows:

    (22)

    Finally, the agreement is measured as the scalar product

    (23)
  • Deconvolutional Capsule Decoder: Deconvolutional capsules are similar to the one in 2D-SegCaps, in which the set of prediction vectors are defined again as the matrix multiplication between a learned transformation matrix, for a parent capsule type and the sub-grid of child capsules outputs, for each capsule type in . For each member of the grid, we can then form our prediction vectors again by the following equation.

    (24)

    where is input to the dynamic routing algorithm to form our parent capsules and .

  • Reconstruction Regularization: This component is implemented in as the same manner as it is in 2D-SegCaps.

4.3 3D-UCaps

Figure 8: 3D-UCaps architecture with four components: visual feature extraction; convolutional capsule encoder, deconvolutional decoder, and reconstruction regularization. The number on the blocks indicates the number of channels in convolution layer and the dimension of capsules in capsule layers.

By taking temporal information into consideration, segmentation performance by 3D-SegCaps has been improved compared to 2D-SegCaps. However, the achievement by 3D-SegCaps is still lower than the SOTA performance by 3D-UNets. The observation by [nguyen20213d] has shown that capsule design is capable of extracting richer representation comparing to traditional neural network design. Thus, convolutional capsule layer is utilized to encode visual representation in 3D-UCaps. Furthermore, under an encoder-decoder network architecture, the decoder path aims to produce a high detailed segmentation task, which has been high accurately performed by deconvolutional layers. Thus, deconvolutional layers are used at the decoder path in 3D-UCaps, which is the main difference between 3D-UCaps and 3D-SegCaps. This replacement does not only improve segmentation performance but also reduce computational cost caused by the dynamic routing. The entire network of 3D-UCaps is shown in Fig.9. In CapsNet and CNNs. It follows Unet-like architecture [cciccek20163d] and contains four main components as follows:

  • Visual Feature Extractor: A set of dilated convolutional layers is used to convert the input to high-dimensional features that can be further processed by capsules. It contains three convolution layers with the number of channels increased from 16 to 32 then 64, kernel size and dilate rate set to 1, 3, and 3, respectively. The output of this part is a feature map of size .

  • Convolutional Capsule Encoder: This component is designed as similar mechanism as the one designed in 3D-SegCaps. The implement details of this component is as follows: The visual feature from the previous component can be cast (reshaped) into a grid of capsules, each represented as a single 64-dimensional vector. In the convolutional capsule encoder, it is suggested to be designed with more capsule types in low-level layers and less capsule types in high-level layers. This is due to the fact that low-level layers represent simple object while high-level layers represent complex object and the clustering nature of routing algorithm [hinton2018matrix]. The number of capsule types in the encoder path of our network are set to , respectively. This is in contrast to the design in 2D-SegCaps and 3D-SegCaps where the numbers of capsules are increasing along the encoder path. The number of capsule types in the last convolutional capsule layer is equal to the number of categories in the segmentation, which can be further supervised by a margin loss [sabour2017dynamic]. The output from a convolution capsule layer has the shape , where is the number of capsule types and is the dimension of each capsule.

  • Deconvolutional Decoder: The decoder of 3D Unet [cciccek20163d] is used in expanding path. This contains deconvolution, skip connection, convolution and BatchNorm layers [ioffe2015batch] to generate the segmentation from features learned by capsule layers.The features is reshaped to before passing them to the next convolution layer or concatenating with skip connections.

  • Reconstruction Regularization : This component is implemented in as the same manner as it is in 3D-SegCaps.

4.4 SS-3DCapsNet

Despite the recent success of CapsNet-based approaches in medical image segmentation, there remains a wide range of challenges: (1) Most methods are based on supervised learning, which is prone to many data problems like small-scale data, low-quality annotation, small objects, ambiguous boundaries, to name a few. These problems are not straightforward to overcome: labeling medical data is laborious and expensive, requiring an expert’s domain knowledge. (2) Capsule networks for medical segmentation do not outperform CNNs yet, even though the performance gap gets significantly closer 

[nguyen20213d].

To address the aforementioned limitations, Tran et al., [tran2022ss] improve 3D-UCaps and propose SS-3DCapsNet, a self-supervised capsule network. Self-supervised learning (SSL) is a technique for learning feature representation in a network without requiring a labeled dataset. A common workflow to apply SSL is to train the network in an unsupervised manner by learning with a pretext task in the pre-training stage, and then fine-tuning the pre-trained network on a target downstream task. In the case of MIS, the suitable pretext tasks can be considered in four categories: context-based, generation-based, free semantic label-based, and cross-modal-based. In SS-3DCapsNet [tran2022ss], the pretext task is based on image reconstruction.

Figure 9: SS-3DCapsNet architecture with three components: visual representation; convolutional capsule encoder, and deconvolutional decoder. Number on the blocks indicates number of channels in convolution layer and dimension of capsules in capsule layers.

The pretext task and downstream task in SS-3DCapsNet are detailed as follows:

  • Pretext Task: In computer vision, it is common to use pseudo-labels defined by different image transformations, e.g, rotation, random crop, adding noise, blurring, scaling, flipping, jigsaw puzzle, etc. to supervise the pretext task. While such transformations work well for classification as a downstream task, they can not be applied directly into image segmentation. In SS-3DCapsNet [tran2022ss], image reconstruction from various transformations, i.e., noisy, blurring, zero-channels (R,G,B), swapping as shown in Fig. 10 is utilized to perform pretext task.

    Figure 10: Examples of six transformations for self-supervised learning. (a): original image. (b) from left to right, top to bottom: zeros-green-channel, zeros-red-channel, zeros-blue-channel, swapping (4 swapped patches are shown in yellow boxes), blurring, noisy.
    Figure 11: Our pretext task with reconstruction loss.

    Let is the visual representation network. The transformation is defined as , where is an identity transformation and is set as 6 corresponding to six transformations (Fig. 10). Let denote as the original input volumetric data. The pretext task is performed by applying two random transformations into . The transformed data is then and , respectively. The visual feature of transformed data after applying the network is and , where and . The network is trained with a reconstruction loss defined by:

    (25)

    The pretext task procedure is illustrated in Fig. 11.

  • Downstream Task: After pre-training, SS-3DCapsNet network is trained with annotated data on the medical segmentation task. The total loss function to train this downstream task is a sum of three losses, i.e., margin loss, weighted cross-entropy loss and reconstruction regularization loss as defined in Eq. 15

4.5 Comparison

In this section, the comparison will be conducted regarding both network architecture and performance accuracy.

The network architecture comparison between various CapsNet-based approaches for medical image segmentation is shown in Table 1.

Input Initialization Encoder Decoder
2D-SegCaps [lalonde2018capsules, lalonde2021capsules] 2D still image Random Capsule Capsule
3D-SegCaps [nguyen20213d] 3D volumetric Random Capsule Capsule
3D-UCaps [nguyen20213d] 3D volumetric Random Capsule Deconvolution
SS-3DCapsNet[tran2022ss] 3D volumetric SSL Capsule Deconvolution
Table 1: Network architecture comparison between various CapsNet-based image segmentation

To compare the performance of various CapNets-based approaches, small-size datasets such as iSeg [wang2019benchmark], Cardiac, and Hippocampus [simpson2019large] are selected to conduct experimental results. Samples from three datasets are visualized in Fig.12.

Figure 12: Visualization of samples from iSeg (first row), Cardiac (second row), and Hippocampus (third row).
  • iSeg dataset:[wang2019benchmark] is an infant brain dataset consisting of 10 subjects with ground-truth labels for training and 13 subjects without ground-truth labels for testing. Subjects were selected from Baby Connectome Project [BCP] and have average age is in [5.5 - 6.5] months at the time of scanning. Each subject includes T1-weighted and T2-weighted images with size of and image resolution of . The difficulty of this dataset lies in the low contrast between tissues in the infant brain MRI that can reduce the accuracy of the automatic segmentation algorithms.

  • Cardiac: [simpson2019large] is a mono-modal MRI dataset containing 20 training images and 10 testing images covering the entire heart acquired during a single cardiac phase. This dataset was provided by King’s College London and obtained with voxel resolution .

  • Hipposcampus: [simpson2019large] is a larger-scale mono-modal MRI dataset taken from the Psychiatric Genotype/Phenotype Project data repository at Vanderbilt University Medical Center (Nashville, TN, USA). It consists of 260 training and 130 testing samples acquired with a 3D T1-weighted MPRAGE sequence (TI/TR/TE, 860/8.0/3.7 ms; 170 sagittal slices; voxel size, ). The task of this dataset is segmenting two neighbouring small structures (posterior and anterior hippocampus) with high precision.

Method Depth Dice Score
WM GM CSF Average
2D-SegCaps [lalonde2018capsules] 16 82.80 84.19 90.19 85.73
3D-SegCaps [nguyen20213d] 16 86.49 88.53 93.62 89.55
3D-UCaps [nguyen20213d] 17 90.21 91.12 94.93 92.08
Our SS-3DCapsNet [tran2022ss] 17 90.78 91.48 94.92 92.39
Table 2: Performance comparison on iSeg-2017.
SegCaps (2D) [lalonde2018capsules] 66.96
Multi-SegCaps (2D) [survarachakan2020capsule] 66.96
3D-UCaps [nguyen20213d] 89.69
SS-3DCapsNet [tran2022ss] 89.77
Table 3: Comparison on Cardiac with 4-fold cross validation.
Method Anterior Posterior
Recall Precision Dice Recall Precision Dice
Multi-SegCaps (2D) [survarachakan2020capsule] 80.76 65.65 72.42 84.46 60.49 70.49
EM-SegCaps (2D) [survarachakan2020capsule] 17.51 20.01 18.67 19.00 34.55 24.52
3D-UCaps [nguyen20213d] 81.70 80.19 80.99 80.2 79.25 79.48
SS-3DCapsNet [tran2022ss] 81.84 81.49 81.59 80.71 80.21 79.97
Table 4: Comparison on Hippocampus with 4-fold.
Figure 13: Performance comparison on iSeg of 3D-UCaps and 3D-UNet with rotation equivariance on x, y, z, and all axis.
Figure 14: Performance comparison on iSeg with various various network on a particular axis rotation equivariance.
Figure 15: Performance comparison between 3D-UCaps and 3D-UNet on iSeg with various artifact.

For iSeg, the experimental results is followed by [bui2019skip] in which of 9 subjects are used to train and subject #9 is used to test. On Cardiac, and Hippocampus [simpson2019large], the experiments are conducted by 4-fold cross-validation.

The comparison is conducted by Pytorch. Patch size are selected as follows: for iSeg and Hippocampus, for Cardiac. All the networks were trained without any data augmentation. Adam optimizer with an initial learning rate of 0.0001 is chosen. The learning rate is decayed by 0.05 if the Dice score on the validation set does not increase for 50,000 iterations. Early stopping is set at 250,000 iterations as in [lalonde2018capsules].

The performance comparison on various CapsNet-based medical image segmentation approaches is shown in Table.2, Table.3, Table.4 corresponding to iSeg, Cardiac and Hippocampus datasets.

5 Discussion

Although CNNs have achieved outstanding performance on various tasks including medical image segmentation, they suffer from the loss of part-whole relationships and geometric information. CapsNet was proposed to address such limitations. 3D-UCaps [nguyen20213d] conducted an analysis with two experiments on small-size datasets iSeg with rotation equivariance and invariance properties to various artifact as follows:

  • Rotation Equivariance: In the first experiment, the testing subject is rotated from 0 to 90 degrees (15, 30, 45, 60, 75, 90) on x-axis, y-axis, z-axis, all-axis. The performance comparison on rotation equivariance between 3D-SegCaps and 3D-UNet is shown in Fig.13. Furthermore, the performance comparison between various networks, i.e., 3D-UCaps, 3D-SegCaps, 2D-SegCapsand, and 3D-UNet on a particular axis , i.e., z-axis is shown in Fig.14.

  • Various Artifact: In the second experiment, MonAI [MonAI] and TorchIO [perez2021torchio] are utilized to create artifacts. The performance comparison on i-Seg between 3D-UCaps and 3D-UNet is shown in Fig. 15

    CapsNets with their capability of modelling the part-whole relationships have obtained remarkable results various tasks including medical image segmentation. The aforementioned discussion has proved that CapsNets significantly outperform CNNs for small-size datasets, which is a common case in medical image segmentation applications due to the lack of annotated data. The experimental results also show that CapsNets obtain higher robustness to affine transformations than CNNs, however, their performances are still limited on unseen transformed inputs and their computational complexity is still high. Exploring hybrid architecture between CapsNet-based and traditional neural network is therefore a promising approach to medical image analysis while keeping model complexity and computation cost plausible.

Acknowledgment

This material is based upon work supported by the National Science Foundation under Award No. OIA-1946391.