nnFormer: Interleaved Transformer for Volumetric Segmentation

09/07/2021 ∙ by Hong-Yu Zhou, et al. ∙ Xiamen University Association for Computing Machinery The University of Hong Kong 17

Transformers, the default model of choices in natural language processing, have drawn scant attention from the medical imaging community. Given the ability to exploit long-term dependencies, transformers are promising to help atypical convolutional neural networks (convnets) to overcome its inherent shortcomings of spatial inductive bias. However, most of recently proposed transformer-based segmentation approaches simply treated transformers as assisted modules to help encode global context into convolutional representations without investigating how to optimally combine self-attention (i.e., the core of transformers) with convolution. To address this issue, in this paper, we introduce nnFormer (i.e., Not-aNother transFormer), a powerful segmentation model with an interleaved architecture based on empirical combination of self-attention and convolution. In practice, nnFormer learns volumetric representations from 3D local volumes. Compared to the naive voxel-level self-attention implementation, such volume-based operations help to reduce the computational complexity by approximate 98 ACDC datasets, respectively. In comparison to prior-art network configurations, nnFormer achieves tremendous improvements over previous transformer-based methods on two commonly used datasets Synapse and ACDC. For instance, nnFormer outperforms Swin-UNet by over 7 percents on Synapse. Even when compared to nnUNet, currently the best performing fully-convolutional medical segmentation network, nnFormer still provides slightly better performance on Synapse and ACDC.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 13

page 14

Code Repositories

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Transformers [34], which are the de-facto choices for natural language processing (NLP) problems, have recently been widely exploited in vision-based applications [8, 21]. The core idea behind is to apply the self-attention mechanism to capture long-range dependencies. Compared to convnets (i.e., convolutional neural networks [15]), transformers relax the inductive bias of locality, making them more capable of dealing with non-local interactions [44]. It has also been investigated that the prediction errors of transformers are more consistent with those of humans than convnets [32].

Given the fact that transformers are naturally more advantageous than convnets, there are a number of approaches trying to apply transformers to the field of medical image analysis. [6] first time proposed TransUNet to explore the potential of transformers in the context of medical image segmentation. The overall architecture of TransUNet is similar to that of U-Net [28], where convnets act as feature extractors and transformers help encode the global context. In fact, a major feature of TransUNet and most of its followers [43, 33, 4, 5] is to treat convnets as main bodies, on top of which transformers are further applied to capture long-term dependencies. However, such characteristic may cause a problem: advantages of transformers are not fully exploited. In other words, we believe one- or two-layer transformers are not enough to entangle long-term dependencies with convolutional representations that often contain precise spatial information and provide hierarchical concepts.

To address the above issue, some researchers [13, 3, 19] started to use transformers as the main stem of segmentation models. [13]

first time introduced a convolution-free segmentation model by forwarding flattened image representations to transformers, whose outputs are then reorganized into 3D tensors to align with segmentation masks. Recently, Swin Transformer

[21] showed that by referring to the feature pyramids used in convnets, transformers can learn hierarchical object concepts at different scales by applying appropriate down-sampling to feature maps. Inspired by this idea, Swin-UNet [3] utilized hierarchical transformer blocks to construct the encoder and decoder within a U-Net like architecture, based on which DS-TransUNet [19] added one more encoder to accept different-sized inputs. Both Swin-UNet and DS-TransUNet have achieved consistent improvements over TransUNet. Nonetheless, they did not explore how to appropriately combine convolution and self-attention for building an optimal medical segmentation network.

The main contribution of nnFormer (i.e., not-another transFormer) is its hybrid stem where convolution and self-attention are interleaved to give full play to their strengths. Figure 1 presents the effects of different components used in the encoder of nnFormer. Firstly, we put a light-weight convolutional embedding layer ahead of transformer blocks. In comparison to directly flattening raw pixels and applying 1D pre-processing in [13], the convolutional embedding layer encodes precise (i.e., pixel-level) spatial information and provide low-level yet high-resolution 3D features. After the embedding block, transformer and convolutional down-sampling blocks are interleaved to fully entangle long-term dependencies with high-level and hierarchical object concepts at various scales, which helps improve the generalization ability and robustness of learned representations.

Figure 1: Overview of the interleaved stem used in the encoder of nnFormer.

The other contribution of nnFormer lies in proposing a computational-efficient way to capture inter-slice dependencies. To be specific, nnFormer introduces volume-based multi-head self-attention (V-MSA) to learn representations on 3D local volumes, which are then aggregated to produce whole-volumetric predictions. Compared to the naive multi-head self-attention (MSA) [34], V-MSA is able to reduce the computational complexity by about 98% and 99.5% in transformer blocks on Synapse and ACDC datasets, respectively.

In the experiment section, we compare nnFormer with a wide range of baseline segmentation approaches. The proposed nnFormer surpasses Swin-UNet by over 7 percents in the task of multi-organ segmentation on Synapse. When performing automated cardiac diagnosis on ACDC dataset, nnFormer outperforms Swin-UNet by nearly 2 percents in average. Considering the average dice score on ACDC is over 90 percents, we believe 2-percent improvements on ACDC are as impressive as the 7-percent improvements on Synapse.

2 Related work

In this section, we mainly review methodologies that resort to transformers to improve segmentation results of medical images. Since most of them employ hybrid architectures of convolution and self-attention [34], we divide them into two categories based on whether the majority of the stem is convolutional or transformer-based.

Convolution-based stem. TransUNet [6] first time applied transformer to improve the segmentation results of medical images. TransUNet treats the convnet as a feature extractor to generate a feature map for the input slice. Patch embedding is then applied to patches of feature maps in the bottleneck instead of raw images in ViT [8]. Concurrently, similar to TransUNet, Li et al.[17] proposed to use a squeezed attention block to regularize the self-attention modules of transformers and an expansion block to learn diversified representations for fundus images, which are all implemented in the bottleneck within convnets. TransFuse [43] introduced a BiFusion module to fuse features from the shallow convnet-based encoder and transformer-based segmentation network to make final predictions on 2D images. Compared to TransUNet, TransFuse mainly applied the self-attention mechanism to the input embedding layer to improve segmentation models on 2D images. Yun et al.[41] employed transformers to incorporate spectral information, which are entangled with spectral information encoded by convolutional features to address the problem of hyperspectral pathology. Xu et al.[40] extensively studied the trade-off between transformers and convnets and proposed a more efficient encoder named LeViT-UNet. Li et al.[18] presented a new up-sampling approach and incorporated it into the decoder of UNet to model long-term dependencies and global information for better reconstruction results. TransClaw U-Net [4] utilized transformers in UNet with more convolutional feature pyramids. TransAttUNet [5] explored the feasibility of applying transformer self attention with convolutional global spatial attention. Xie et al.[39] adopted transformers to capture long-term dependencies of multi-scale convolutional features from different layers of convnets. TransBTS [35]

first utilized 3D convnets to extract the volumetric spatial features and down-sample the input 3D images to produce hierarchical representations. The outputs of the encoder in TransBTS are then reshaped into a vector (i.e. token) and fed into transformers for global feature modeling, after which an ordinary convolutional decoder is appended to up-sample feature maps for the goal of reconstruction. Different from these approaches that directly employ convnets as feature extractors, our nnFormer functionally relies on convolutional and transformer-based blocks, which are interleaved to take advantages of each other.


Transformer-based stem. Valanarasu et al.[33]

proposed a gated axial-attention model (i.e., MedT) which extends the existing convnet architectures by introducing an additional control mechanism in the self-attention. Karimi

et al.[13] removed the convolutional operations and built a 3D segmentation model based on transformers. The main idea is to first split the local volume block into 3D patches, which are then flattened and embedded to 1D sequences and passed to a ViT-like backbone to extract representations. Swin-UNet [3] built a U-shape transformer-based segmentation model on top of transformer blocks in [21], where observable improvements were achieved. DS-TransUNet [19] further extended Swin-UNet by adding one more encoder to handle multi-scale inputs and introduced a fusion module to effectively establish global dependencies between features of different scales through the self-attention mechanism. Compared to these transformer-based stems, nnFormer inherits the superiority of convolution in encoding precise spatial information and producing hierarchical representations that help model object concepts at various scales.

3 Method

3.1 Overview

Figure 2: Architecture of nnFormer. In (a), we show the overall architecture of nnFormer. In (b), we present more details of the embedding, down-sampling, up-sampling and the last expanding blocks. denotes the number of segmentation classes. Note that the displayed architecture is applied to images from Synapse dataset. In practice, the architecture may slightly vary depending on the input patch size.

The overall architecture of nnFormer is presented in Figure 2, which maintains a similar U shape as that of U-Net [28] and mainly consists of two branches, i.e., the encoder and decoder. Concretely, the encoder involves one embedding block, seven transformer blocks and three down-sampling blocks. Symmetrically, the decoder branch includes seven transformer blocks, three up-sampling blocks and one patch expanding block for making final predictions. Inspired by U-Net [28]

, we add long residual connections between corresponding feature pyramids of the encoder and decoder in a symmetrical manner, which helps to recover fine-grained details in the prediction.

3.2 Encoder

The input of nnFormer is a 3D patch (usually randomly cropped from the original image), where , and denote the height, width and depth of each input patch, respectively.

Embedding block. The embedding block is responsible for transforming each input scan into a high-dimensional tensor , where represents the number of the patch tokens and C represents the sequence length. In practice, we set to 192 and 96 on Synapse and ACDC datasets, respectively. Different from ViT [8] and Swin Transformer [21] that use large convolutional kernels in the embedding block to extract features, we found that applying successive convolutional layers with small convolutional kernels bring more benefits in the initial stage, which could be explained from two perspectives, i.e., i) why applying successive convolutional layers and ii) why using small-sized kernels. For i), we use convolutional layers in the embedding block because they encode pixel-level spatial information, more precisely than patch-wise positional encoding used in transformers. For ii), compared to large-sized kernels, small kernel sizes help reduce computational complexity while providing equal-sized receptive field. As shown in Figure 2b, the embedding block consists of four convolutional layers whose kernel size is 3. After each convolutional layer (except the last one), one GELU [11] and one layer normalization (i.e., LayerNorm) [1]

layers are appended. In practice, depending on the size of input patch, strides of convolution in the embedding block may accordingly vary.


Transformer block. After the embedding block, we pass the high-dimensional tensor to interleaved transformer blocks. The main point behind is to fully entangle the captured long-term dependencies with hierarchical object concepts at various scales provided by following down-sampling convolution and high-resolution spatial information encoded by the initial embedding block. Compared to Swin Transformer [21], we employ a hierarchical way to conduct self-attention but compute self-attention within 3D local volumes (i.e., V-MSA, volume-based multi-head self-attention) instead of 2D local windows.

Supposing being the input of transformer blocks, it would be first reshaped to , where is the number of 3D local volumes and denotes the number of patch tokens in each volume. stand for the size of local volume. In nnFormer, to adapt to various shape of MRI/CT scans, we design

to make it cover all patch tokens of the output of the last transformer block in the encoder. The intuition behind is that it may not be desirable to brute-forcely pad the data in order to satisfy fixed

. Thus, the size of the input cropped patch needs to adaptively adjusted in order to accord with the size of local volumes. In practice, we set on Synapse and ACDC to and , respectively.

We follow [21] to conduct two successive transformer blocks, where the main difference lies in that our computation is built on top of 3D volumes instead of 2D windows. The computational procedure can be summarized as follows:

(1)

Here,

stands for the layer index. MLP is an abbreviation for multi-layer perceptron. V-MSA and SV-MSA denote the volume-based multi-head self-attention and its shifted version. The computational complexity of V-MSA on a volume of

patches is:

(2)

Compared to the complexity of the naive multi-head self-attention (MSA) [34] used in ViT [8], i.e.,

(3)

V-MSA reduces the computational complexity by approximate 98% and 99.5% on Synapse and ACDC datasets, respectively.

SV-MSA displaces the 3D local volume used in V-MSA by to introduce more interactions between different local volumes.

The query-key-value (QKV) attention [34] in each 3D local volume can be computed as follows:

(4)

where denote the query, key and value matrices of dimension . is the relative position encoding. In practice, we first initialize a smaller-sized position matrix and take corresponding values from to build a larger position matrix .

It is worth noting that the last transformer block of the encoder (the downmost block in Figure 2) only employs V-MSA as we found introducing SV-MSA would deteriorate the overall segmentation results.

Down-sampling block. We found that by replacing the neighboring concatenation operation in [21] with direct strided convolution, nnFormer provides more improvements to volumetric segmentation. The intuition behind is that convolutional down-sampling produces hierarchical representations that help model object concepts at multiple scales. As displayed in Figure 2b, in most cases, the down-sampling block involves a strided convolution operation where the stride is set to 2 in all dimensions. However, in practice, the stride with respect to specific dimension (refer to Table 1b) is set to 1 as the number of slices is limited in this dimension and over-down-sampling (i.e., a large stride) might be harmful.

3.3 Decoder

Architectures of transformer blocks of the decoder are highly symmetrical to those of the encoder. In contrast to the down-sampling blocks, we employ strided deconvolution to up-sample low-resolution feature maps to high-resolution ones, which in turn are merged with representations from the encoder via long-range residual connections to capture both semantic and fine-grained information. Similar to up-sampling blocks, the last patch expanding block also takes the deconvolutional operation to produce final predictions.

4 Experiments

To fairly compare nnFormer with previous Transformer-based architectures, we conduct experiments on Synapse [14] and Automatic Cardiac Diagnosis Challenge (ACDC) [2] datasets. For each experiment, we repeat it for three times and report their average results.

Synapse for multi-organ CT segmentation. This dataset includes 30 cases of abdominal CT scans. Following the split used in [6], 18 cases are extracted to build the training set while the rest 12 cases are used for testing. We report the model performance evaluated with the average Dice Similarity Coefficient (DSC) on 8 abdominal organs, which are aorta, gallbladder, spleen, left kidney, right kidney, liver, pancreas and stomach.

4.0.1 ACDC for automated cardiac diagnosis.

ACDC involves 100 patients, with the cavity of the right ventricle, the myocardium of the left ventricle and the cavity of the left ventricle to be segmented. Each case’s labels involve left ventricle (LV), right ventricle (RV) and myocardium (MYO). The dataset is split into 70 training samples, 10 validation samples and 20 testing samples.

4.1 Implementation details

We run all experiments based on Python 3.6, PyTorch 1.8.1 and Ubuntu 18.04. All training procedures have been performed on a single NVIDIA 2080 GPU with 11GB memory. The initial learning rate is set to 0.01 and we employ a “poly” decay strategy as described in Equation

5

. The default optimizer is SGD where we set the momentum to 0.99. The weight decay is set to 3e-5. We utilize both cross entropy loss and dice loss by simply summing them up. The number of training epochs (i.e., max_epoch in Equation

5) is 1000 and one epoch contains 250 iterations. The numbers of heads of multi-head self-attention used in different encoder stages are [6, 12, 24, 48] and [3, 6, 12, 24] on Synapse and ACDC, respectively.

(5)

nnFormer nnUNet
Target spacing
Median image shape
Crop size
Batch size 2 2
Str. of Down-samp.
,
,
(a)
nnFormer nnUNet
Target spacing
Median image shape
Crop size
Batch size 4 4
Str. of Down-samp.
,
,
(b)
Table 1: Network configurations of our nnFormer and nnUNet on Synapse and ACDC. We only report the stride of down-sampling (i.e. str. of down-samp.) as the stride of up-sampling can be easily inferred according to symmetrical down-sampling operations. Note that the network configuration of nnUNet is automatically determined based on pre-defined hand-crafted rules (for self-adaptation).

Pre-processing and augmentation strategies. All images will be first resampled to the same target spacing. Augmentations such as rotation, scaling, gaussian noise, gaussian blur, brightness and contrast adjust, simulation of low resolution, gamma augmentation and mirroring are applied in the given order during the training process.

Deep supervision. We also add deep supervision during the training stage. Specifically, the output of each stage in the decoder is passed to the final expanding block, where cross entropy loss and dice loss would be applied. In practice, given the prediction of one typical stage, we down-sample the ground truth segmentation mask to match the prediction’s resolution. Thus, the final training objective function is the sum of all losses at three resolutions:

(6)

Here, denote the magnitude factors for losses in different resolutions. In practice, halve with each decrease in resolution, leading to and . Finally, all weight factors are normalized to 1.

Pre-trained model weights.

Pre-training can be vastly important to provide generalized and transferable representations for downstream tasks. Given the fact that most operations in nnFormer operate on 1D sequences, we explore the possibility of transfering pre-trained weights on natural images to the medical imaging field. More concretely, we aim to reap the benefit of pre-trained weights of MLP layers and QKV attention on ImageNet pre-training. To this goal, we align channel numbers of transformer blocks to those of pre-trained models so that we load the weights of MLP layers and QKV attention. Besides, considering architectures of the encoder and decoder are highly symmetrical, we propose

symmetrical initialization to reuse the pre-trained weights of the encoder in the decoder. Specifically, transformer blocks with the same input and output resolution are initialized using the same set of model weights (i.e., symmetrical transformer blocks of the encoder and decoder in Figure 2).

Network configurations. In Table 1, we display network configurations of experiments on Synapse and ACDC. Compared to nnUNet, in nnFormer, better segmentation results can be achieved with smaller-sized input patches.

Methods Average Aotra Gallbladder Kidnery(L) Kidnery(R) Liver Pancreas Spleen Stomach
VNet [23] 68.81 75.34 51.87 77.10 80.75 87.84 40.04 80.56 56.98
DARR [9] 69.77 74.74 53.77 72.31 73.24 94.08 54.18 89.90 45.96
R50 U-Net [28] 74.68 87.74 63.66 80.60 78.19 93.74 56.90 85.87 74.16
U-Net [28] 76.85 89.07 69.72 77.77 68.60 93.43 53.98 86.67 75.58
R50 Att-UNet [29] 75.57 55.92 63.91 79.20 72.71 93.56 49.37 87.19 74.95
Att-UNet [24] 77.77 89.55 68.88 77.98 71.11 93.57 58.04 87.30 75.75
VIT None [8] 61.50 44.38 39.59 67.46 62.94 89.21 43.14 75.45 68.78
VIT CUP [8] 67.86 70.19 45.10 74.70 67.40 91.32 42.00 81.75 70.44
R50 VIT CUP [8] 71.29 73.73 55.13 75.80 72.20 91.51 45.99 81.99 73.95
R50-Deeplabv3+ [7] 75.73 86.18 60.42 81.18 75.27 92.86 51.06 88.69 70.19
DualNorm-UNet [38] 80.37 86.52 55.51 88.64 86.29 95.64 55.91 94.62 79.8
CGNET [37] 75.08 83.48 65.32 77.91 72.04 91.92 57.37 85.47 67.15
ContextNet [26] 71.17 79.92 51.17 77.58 72.04 91.74 43.78 86.65 66.51
DABNet [16] 74.91 85.01 56.89 77.84 72.45 93.05 54.39 88.23 71.45
EDANet [22] 75.43 84.35 62.31 76.16 71.65 93.20 53.19 85.47 77.12
ENet [25] 77.63 85.13 64.91 81.10 77.26 93.37 57.83 87.03 74.41
FPENet [20] 68.67 78.98 56.35 74.54 64.36 90.86 40.60 78.30 65.35
FSSNet [42] 74.59 82.87 64.06 78.03 69.63 92.52 53.10 85.65 70.86
SQNet [31] 73.76 83.55 61.17 76.87 69.40 91.53 56.55 85.82 65.24
FastSCNN [27] 70.53 77.79 55.96 73.61 67.38 91.68 44.54 84.51 68.76
TransUNet [6] 77.48 87.23 63.16 81.87 77.02 94.08 55.86 85.08 75.62
SwinUNet [3] 79.13 85.47 66.53 83.28 79.61 94.29 56.58 90.66 76.6
TransClaw U-Net [4] 78.09 85.87 61.38 84.83 79.36 94.28 57.65 87.74 73.55
LeVit-UNet-384s [40] 78.53 87.33 62.23 84.61 80.25 93.11 59.07 88.86 72.76
WAD [18] 80.30 87.73 69.93 83.95 79.78 93.95 61.02 88.86 77.16
nnUNet (3D) [12] 86.99 93.01 71.77 85.57 88.18 97.23 83.01 91.86 85.25
nnFormer 87.40 92.04 71.09 87.64 87.34 96.53 82.49 92.91 89.17
Table 2: Experiments on Synapse (dice score in %). Best results are bolded.

4.2 Experiments on Synapse

As shown in Table 2

, we make experiments on Synapse and to compare our nnFormer against a variety of both transformer- and convnet-based baselines. The major evaluation metric is dice score.

Apart from nnUNet, the best performing convnet-based method is DualNorm-UNet [38] that achieves an average dice score of 80.37. In comparison, WAD reports the best transformer-based results whose average is 80.30, slightly lower than DualNorm-UNet. Our nnFormer is able to outperform WAD and DualNorm-UNet by over 8 percents and 7 percents in average, respectively, which are quite impressive improvements on Synapse. Besides, we found that the performance of nnUNet are severely underestimated. When being carefully tuned, nnUNet reaches an average dice score of 86.99, which is much better than DualNorm-UNet and WAD but still worse than proposed nnFormer.

Methods Average RV Myo LV
R50-U-Net [28] 87.55 87.10 80.63 94.92
R50-Attn UNet [29] 86.75 87.58 79.20 93.47
VIT-CUP [8] 81.45 81.46 70.71 92.18
R50-VIT-CUP [8] 87.57 86.07 81.88 94.75
CBAM[36] 87.30 87.70 82.10 92.20
ResUNet[10] 86.90 86.20 82.50 92.20
Dual-Attn[10] 87.00 86.40 82.30 92.40
UTNET[10] 88.30 88.20 83.50 93.10
TransUNet [6] 89.71 88.86 84.54 95.73
SwinUNet [3] 90.00 88.55 85.62 95.83
LeViT-UNet-384s [40] 90.32 89.55 87.64 93.76
nnUNet (3D) [12] 91.59 90.25 89.10 95.41
nnFormer 91.78 90.22 89.53 95.59
Table 3: Experiments on ACDC (dice score in %). Best results are bolded.

4.3 Experiments on ACDC

Table 3 presents the experimental results on ACDC, where the overall performance of transformer-based baselines are better than those of convnet-based ones. The underlying reason is that images from ACDC have many fewer slices on -axis (i.e., the spacing on -axis is quite large in Figure 1b), which is the exactly the case where transformer has more advantages as they are designed to deal with 2D inputs with less interaction on -axis. From Table 3, we can see that the best transformer model is LeViT-UNet-384s, which average dice is slightly higher than SwinUNet but much higher than convnet-based Dual-Attn. In contrast, nnFormer surpasses LeViT-UNet-384s by nearly 1.5 percents on average, again displaying its advantages over transformer-based baselines.

4.4 Ablation study

Average Aotra Gallbladder Kidnery(L) Kidnery(R) Liver Pancreas Spleen Stomach
Patch-wise convolution 84.63 88.84 65.33 86.97 85.98 95.58 77.30 91.83 85.15
Ours 87.40 92.04 71.09 87.64 87.34 96.53 82.49 92.91 89.17
Table 4: Investigation of the embedding block on Synapse. Patch-wise convolution consists of only one convolutional layer with large kernel size and stride.
Average Aotra Gallbladder Kidnery(L) Kidnery(R) Liver Pancreas Spleen Stomach
Neighboring concatenation 84.30 88.00 67.60 87.52 87.38 95.31 80.63 85.29 82.69
Ours 87.40 92.04 71.09 87.64 87.34 96.53 82.49 92.91 89.17
Table 5: Investigation of the convolutional down-sampling blocks.
Average Aotra Gallbladder Kidnery(L) Kidnery(R) Liver Pancreas Spleen Stomach
More transformer blocks 85.98 89.02 71.74 86.76 87.06 96.37 82.30 89.04 85.51
Ours 87.40 92.04 71.09 87.64 87.34 96.53 82.49 92.91 89.17
Table 6: Investigation of adding more transformer blocks.
Average Aotra Gallbladder Kidnery(L) Kidnery(R) Liver Pancreas Spleen Stomach
No pre-training 84.34 90.15 69.00 86.34 87.48 95.93 80.97 85.23 79.67
Ours 87.40 92.04 71.09 87.64 87.34 96.53 82.49 92.91 89.17
Table 7: Benefits of using pre-trained weights on natural images.

In this section, we present the significance of the embedding block and convolutional down-sampling blocks. Besides, we also investigate the influences of adding more transformer blocks to the encoder and employing natural images based pre-training.

Significance of the embedding block. To investigate the influence of our embedding block, we replace it with patch-wise convolution, as shown in Table 4. Patch-wise convolution [8, 21] only contains one convolutional layer that has large kernel size and stride. For instance, on Synapse, both the kernel size and convolutional stride are set to [4,4,2]. We can see that the successive convolutional layers of the embedding block surpasses the atypical patch-wise convolution by nearly 3 percents in average. Actually, similar phenomena were first time observed in [30], where small-sized kernels are found to be more effective than large ones.

Figure 3: Segmentation results of some hard samples on Synapse.
Figure 4: Segmentation results of some hard samples on ACDC.

Influences of convolutional down-sampling blocks. In Table 5, we report the results of using neighboring concatenation [21] in nnFormer to replace convolutional down-sampling blocks. The main operation of neighboring concatenation is to concatenate neighboring patches of features maps in the channel dimension, after which a fully-connected layer is added to reduce the number of channels. From Table 5, we can see that convolutional down-sampling blocks provide over 3 percents improvements over neighboring concatenation in average, demonstrating that applying interleaved convolutional down-sampling blocks are more helpful to build hierarchical object concepts at various scales.

Investigation of adding more transformer blocks. We add one more SV-MSA block to the last encoder stage and report its results in Table 6. Generally speaking, more transformer blocks help nnFormer to encode more long-term dependencies into representations. Somewhat interestingly, from Table 6, we can find that capturing more long-term dependencies may not be an optimal choice for nnFormer. For instance, although introducing more transformer blocks achieve a higher segmentation dice score on Gallbladder (the most difficult organ on Synapse), it deteriorates the segmentation performance on other organs. We will explore the reason behind in the future.

Influences of using pre-trained models on natural images. In Table 7, we show that it is crucial to make use of pre-trained weights on natural images, where removing pre-training deteriorates the overall segmentation performance by over 3 percents. The underlying reason is that Synapse does not have enough labeled 3D scans to fully realize the potential of nnFormer.

4.5 Visualization

In Figures 3 and 4, we compare the segmentation results of nnUNet and nnFormer in some hard samples. On Synapse, it seems that nnFormer has quite apparent advantages on stomach where nnUNet often fails to generate an integrated delineation mask. Meanwhile, compared nnUNet, nnFormer has the ability to reduce false positive predictions of spleen, which are also consistent with the performance reported in Table 2.

Similar phenomena can also be observed in Figure 4, where nnFormer can greatly reduce false positive predictions of right ventricle (RV) and myocardium (MYO), especially myocardium. These segmentation results help to verify the fact that nnFormer can produce more robust and discriminative representations than nnUNet.

5 Conclusion

In this paper, we present a new medical image segmentation network named nnFormer. nnFormer is constructed on top of an interleaved stem of convolution and self-attention, where convolution helps encode precise spatial information into high-resolution low-level features and build hierarchical object concepts at multiple scales. On the other hand, self-attention in transformer blocks entangles long-term dependencies with convolutional representations to capture global context. Based on such hybrid architecture, nnFormer achieves tremendous progress over previous transformer-based segmentation methodologies. Even when compared to nnUNet, currently the best performing segmentation network, nnFormer still provides consistent yet observable improvements. In the future, we hope nnFormer could draw more attention from the medical imaging community to make efforts on developing more efficient segmentation models.

References

  • [1] J. L. Ba, J. R. Kiros, and G. E. Hinton (2016) Layer normalization. arXiv preprint arXiv:1607.06450. Cited by: §3.2.
  • [2] O. Bernard, A. Lalande, C. Zotti, F. Cervenansky, X. Yang, P. Heng, I. Cetin, K. Lekadir, O. Camara, M. A. G. Ballester, et al. (2018) Deep learning techniques for automatic mri cardiac multi-structures segmentation and diagnosis: is the problem solved?. IEEE transactions on medical imaging 37 (11), pp. 2514–2525. Cited by: §4.
  • [3] H. Cao, Y. Wang, J. Chen, D. Jiang, X. Zhang, Q. Tian, and M. Wang (2021) Swin-Unet: unet-like pure transformer for medical image segmentation. arXiv preprint arXiv:2105.05537. Cited by: §1, §2, Table 2, Table 3.
  • [4] Y. Chang, H. Menghan, Z. Guangtao, and Z. Xiao-Ping (2021) TransClaw U-Net: claw u-net with transformers for medical image segmentation. arXiv preprint arXiv:2107.05188. Cited by: §1, §2, Table 2.
  • [5] B. Chen, Y. Liu, Z. Zhang, G. Lu, and D. Zhang (2021) TransAttUnet: multi-level attention-guided u-net with transformer for medical image segmentation. arXiv preprint arXiv:2107.05274. Cited by: §1, §2.
  • [6] J. Chen, Y. Lu, Q. Yu, X. Luo, E. Adeli, Y. Wang, et al. (2021) TransUNet: transformers make strong encoders for medical image segmentation. arXiv preprint arXiv:2102.04306. Cited by: §1, §2, Table 2, Table 3, §4.
  • [7] L. C. Chen, G. Papandreou, I. Kokkinos, K. Murphy, and A. L. Yuille (2018) DeepLab: semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected crfs. IEEE Transactions on Pattern Analysis and Machine Intelligence 40 (4), pp. 834–848. Cited by: Table 2.
  • [8] A. Dosovitskiy, L. Beyer, A. Kolesnikov, D. Weissenborn, X. Zhai, T. Unterthiner, et al. (2020) An image is worth 16x16 words: transformers for image recognition at scale. arXiv preprint arXiv:2010.11929. Cited by: §1, §2, §3.2, §3.2, §4.4, Table 2, Table 3.
  • [9] S. Fu, Y. Lu, Y. Wang, Y. Zhou, W. Shen, E. Fishman, and A. Yuille (2020) Domain adaptive relational reasoning for 3d multi-organ segmentation. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 656–666. Cited by: Table 2.
  • [10] Y. Gao, M. Zhou, and D. Metaxas (2021) UTNet: a hybrid transformer architecture for medical image segmentation. arXiv preprint arXiv:2107.00781. Cited by: Table 3.
  • [11] D. Hendrycks and K. Gimpel (2016) Gaussian error linear units (gelus). arXiv preprint arXiv:1606.08415. Cited by: §3.2.
  • [12] F. Isensee, P. F. Jäger, S. A. Kohl, J. Petersen, and K. H. Maier-Hein (2019) Automated design of deep learning methods for biomedical image segmentation. arXiv preprint arXiv:1904.08128. Cited by: Table 2, Table 3.
  • [13] D. Karimi, S. Vasylechko, and A. Gholipour (2021) Convolution-free medical image segmentation using transformers. arXiv preprint arXiv:2102.13645. Cited by: §1, §1, §2.
  • [14] B. Landman, Z. Xu, J. E. Igelsias, M. Styner, T. Langerak, and A. Klein (2015) MICCAI multi-atlas labeling beyond the cranial vault–workshop and challenge. In Proc. MICCAI: Multi-Atlas Labeling Beyond Cranial Vault-Workshop Challenge, Cited by: §4.
  • [15] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner (1998) Gradient-based learning applied to document recognition. Proceedings of the IEEE 86 (11), pp. 2278–2324. Cited by: §1.
  • [16] G. Li, I. Yun, J. Kim, and J. Kim (2019) Dabnet: depth-wise asymmetric bottleneck for real-time semantic segmentation. arXiv preprint arXiv:1907.11357. Cited by: Table 2.
  • [17] S. Li, X. Sui, X. Luo, X. Xu, Y. Liu, and R. S. M. Goh (2021) Medical image segmentation using squeeze-and-expansion transformers. arXiv preprint arXiv:2105.09511. Cited by: §2.
  • [18] Y. Li, W. Cai, Y. Gao, and X. Hu (2021) More than encoder: introducing transformer decoder to upsample. arXiv preprint arXiv:2106.10637. Cited by: §2, Table 2.
  • [19] A. Lin, B. Chen, J. Xu, Z. Zhang, and G. Lu (2021) DS-TransUNet: dual swin transformer u-net for medical image segmentation. arXiv preprint arXiv:2106.06716. Cited by: §1, §2.
  • [20] M. Liu and H. Yin (2019) Feature pyramid encoding network for real-time semantic segmentation. arXiv preprint arXiv:1909.08599. Cited by: Table 2.
  • [21] Z. Liu, Y. Lin, Y. Cao, H. Hu, Y. Wei, Z. Zhang, S. Lin, and B. Guo (2021) Swin transformer: hierarchical vision transformer using shifted windows. arXiv preprint arXiv:2103.14030. Cited by: §1, §1, §2, §3.2, §3.2, §3.2, §3.2, §4.4, §4.4.
  • [22] S. Y. Lo, H. M. Hang, S. W. Chan, and J. J. Lin (2019) Efficient dense modules of asymmetric convolution for real-time semantic segmentation. In MMAsia ’19: ACM Multimedia Asia, Cited by: Table 2.
  • [23] F. Milletari, N. Navab, and S. Ahmadi (2016) V-net: fully convolutional neural networks for volumetric medical image segmentation. In 2016 fourth international conference on 3D vision (3DV), pp. 565–571. Cited by: Table 2.
  • [24] O. Oktay, J. Schlemper, L. L. Folgoc, M. Lee, M. Heinrich, K. Misawa, K. Mori, S. McDonagh, N. Y. Hammerla, B. Kainz, et al. (2018) Attention u-net: learning where to look for the pancreas. arXiv preprint arXiv:1804.03999. Cited by: Table 2.
  • [25] A. Paszke, A. Chaurasia, S. Kim, and E. Culurciello (2016) Enet: a deep neural network architecture for real-time semantic segmentation. arXiv preprint arXiv:1606.02147. Cited by: Table 2.
  • [26] R. P. Poudel, U. Bonde, S. Liwicki, and C. Zach (2018) Contextnet: exploring context and detail for semantic segmentation in real-time. arXiv preprint arXiv:1805.04554. Cited by: Table 2.
  • [27] R. P. Poudel, S. Liwicki, and R. Cipolla (2019) Fast-scnn: fast semantic segmentation network. arXiv preprint arXiv:1902.04502. Cited by: Table 2.
  • [28] O. Ronneberger, P. Fischer, and T. Brox (2015) U-Net: convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention, pp. 234–241. Cited by: §1, §3.1, Table 2, Table 3.
  • [29] J. Schlemper, O. Oktay, M. Schaap, M. Heinrich, B. Kainz, B. Glocker, and D. Rueckert (2019) Attention gated networks: learning to leverage salient regions in medical images. Medical Image Analysis 53, pp. 197–207. External Links: ISSN 1361-8415, Document, Link Cited by: Table 2, Table 3.
  • [30] K. Simonyan and A. Zisserman (2014) Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556. Cited by: §4.4.
  • [31] M. Treml, J. Arjona-Medina, T. Unterthiner, R. Durgesh, and S. Hochreiter (2016) Speeding up semantic segmentation for autonomous driving. In NIPS 2016 Workshop - MLITS, Cited by: Table 2.
  • [32] S. Tuli, I. Dasgupta, E. Grant, and T. L. Griffiths (2021) Are convolutional neural networks or transformers more like human vision?. arXiv preprint arXiv:2105.07197. Cited by: §1.
  • [33] J. M. J. Valanarasu, P. Oza, I. Hacihaliloglu, and V. M. Patel (2021) Medical Transformer: gated axial-attention for medical image segmentation. arXiv preprint arXiv:2102.10662. Cited by: §1, §2.
  • [34] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, et al. (2017) Attention is all you need. In Adv. in Neur. Info. Process. Sys., pp. 5998–6008. Cited by: §1, §1, §2, §3.2, §3.2.
  • [35] W. Wang, C. Chen, M. Ding, J. Li, H. Yu, and S. Zha (2021) TransBTS: multimodal brain tumor segmentation using transformer. arXiv preprint arXiv:2103.04430. Cited by: §2.
  • [36] S. Woo, J. Park, J. Lee, and I. S. Kweon (2018) Cbam: convolutional block attention module. In

    Proceedings of the European conference on computer vision (ECCV)

    ,
    pp. 3–19. Cited by: Table 3.
  • [37] T. Wu, S. Tang, R. Zhang, J. Cao, and ^. Zhang (2021) CGNet: a light-weight context guided network for semantic segmentation. IEEE Transactions on Image Processing 30, pp. 1169–1179. Cited by: Table 2.
  • [38] J. Xiao, L. Yu, L. Xing, A. Yuille, and Y. Zhou (2021) DualNorm-unet: incorporating global and local statistics for robust medical image segmentation. arXiv preprint arXiv:2103.15858. Cited by: §4.2, Table 2.
  • [39] Y. Xie, J. Zhang, C. Shen, and Y. Xia (2021) CoTr: efficiently bridging cnn and transformer for 3d medical image segmentation. arXiv preprint arXiv:2103.03024. Cited by: §2.
  • [40] G. Xu, X. Wu, X. Zhang, and X. He (2021) LeViT-UNet: make faster encoders with transformer for medical image segmentation. arXiv preprint arXiv:2107.08623. Cited by: §2, Table 2, Table 3.
  • [41] B. Yun, Y. Wang, J. Chen, H. Wang, W. Shen, and Q. Li (2021) SpecTr: spectral transformer for hyperspectral pathology image segmentation. arXiv preprint arXiv:2103.03604. Cited by: §2.
  • [42] X. Zhang, Z. Chen, Q. Wu, L. Cai, D. Lu, and X. Li (2018) Fast semantic segmentation for scene perception. IEEE Transactions on Industrial Informatics. Cited by: Table 2.
  • [43] Y. Zhang, H. Liu, and Q. Hu (2021) TransFuse: fusing transformers and cnns for medical image segmentation. arXiv preprint arXiv:2102.08005. Cited by: §1, §2.
  • [44] H. Zhou, C. Lu, S. Yang, and Y. Yu (2021) ConvNets vs. Transformers: whose visual representations are more transferable?. arXiv preprint arXiv:2108.05305. Cited by: §1.