Image segmentation is a longstanding challenge in medical image analysis. Since the introduction of U-Net , fully convolutional neural networks (CNNs) have become the predominant approach to addressing this task [25, 28, 12, 10, 24, 23]. Despite their prevalence, CNNs still suffer from the limited receptive field and fail to capture the long-range dependency, due to the inductive bias of locality and weight sharing . Many efforts have been devoted to enlarge a CNN’s receptive field thus improve its ability to context modeling. Yu et al.  proposed the atrous convolution with an adjustable dilated rate, which shows superior performance in semantic segmentation . More straightforwardly, Peng et al.  designed large kernels to capture rich global context information. Zhao et al.  employed the pyramid pooling at multiple feature scales to aggregate multi-scale global information. Wang et al.  presented the non-local operations which is usually embedded at the end of encoder to capture the long-range dependency. Although improving the context modeling to some extent, these models still have an inevitably limited receptive field, stranded by the CNN architecture.
Transformer, a sequence-to-sequence prediction framework, has a proven track record in machine translation and nature language processing[19, 8], due to its strong ability to long-range modeling. The self-attention mechanism in Transformer can dynamically adjust the receptive field according to the input content, and hence is superior to convolutional operations in modeling the long-range dependency.
Recently, Transformer has been considered as an alternative architecture, and has achieved competitive performance on many computer vision tasks, like image recognition[9, 17], semantic/instance segmentation [27, 21], object detection [2, 29], low-level vision [14, 3], and image generation . A typical example is the vision Transformer (ViT) , which outperforms a ResNet-based CNN on recognition tasks but at a cost of using 300M data for training. Since a huge training dataset is not always available, recent studies attempt to combine a CNN and a Transformer into a hybrid model. Carion et al.  employed a CNN to extract image features and a Transformer to further process the extracted features. Chen et al.  designed TransUNet, in which a CNN and a Transformer are combined in a cascade manner to make a strong encoder for 2D medical image segmentation. Although the design of TransUNet is interesting and the performance is good, it is challenging to optimize this model due to the existence of self-attention . First, it requires extremely long training time to focus the attention, which was initially cast to each pixel uniformly, on salient locations, especially in a 3D scenario. Second, due to its high computational complexity, a vanilla Transformer  can hardly process multi-scale and high-resolution feature maps, which play a critical role in image segmentation.
In this paper, we propose a hybrid framework that efficiently bridges Co-nvolutional neural network and Transformer (CoTr) for 3D medical image segmentation. CoTr has an encoder-decoder structure. In the encoder, a concise CNN structure is adopted to extract feature maps and a Transformer is used to capture the long-range dependency (see Fig. 1). Inspired by [7, 29], we introduce the deformable self-attention mechanism to the Transformer. This attention mechanism casts attentions only to a small set of key sampling points, and thus dramatically reduces the computational and spatial complexity of Transformer. As a result, it is possible for the Transformer to process the multi-scale feature maps produced by the CNN and keep abundant high resolution information for segmentation. The main contributions of this paper are three-fold: (1) we are the first to explore Transformer for 3D medical image segmentation, particularly in a computationally and spatially efficient way; (2) we introduce the deformable self-attention mechanism to reduce the complexity of vanilla Transformer, and thus enable our CoTr to model the long-range dependency using multi-scale features; (3) our CoTr outperforms the competing CNN-based, Transformer-based, and hybrid methods on the 3D multi-organ segmentation task.
The Multi-Atlas Labeling Beyond the Cranial Vault (BCV) dataset 111https://www.synapse.org/#!Synapse:syn3193805/wiki/217789 was used for this study. It contains 30 labeled CT scans for automated segmentation of 11 abdominal organs, including the spleen (Sp), kidney (Ki), gallbladder (Gb), esophagus (Es), liver (Li), stomach (St), aorta (Ao), inferior vena cava (IVC), portal vein and splenic vein (PSV), pancreas (Pa), and adrenal gland (AG).
CoTr aims to learn more effective representations for medical image segmentation via bridging CNN and Transformer. As shown in Fig. 2
, it consists of a CNN-encoder for feature extraction, a deformable Transformer-encoder (DeTrans-encoder) for long-range dependency modeling, and a decoder for segmentation. We now delve into the details of each module.
contains a Conv-IN-ReLU block and three stages of 3D residual blocks. The Conv-IN-ReLU block contains a 3D convolutional layer followed by an instance normalization (IN)
and Rectified Linear Unit (ReLU) activation. The numbers of 3D residual blocks in three stages are three, three, and two, respectively.
Given an input image with a height of , a width of , and a depth (, number of slices) of , the feature maps produced by can be formally expressed as
where indicates the number of feature levels, denotes the parameters of the CNN-encoder, and denotes the number of channels.
Due to the intrinsic locality of convolution operations, the CNN-encoder cannot capture the long-range dependency of pixels effectively. To this end, we propose the DeTrans-encoder that introduces the multi-scale deformable self-attention (MS-DMSA) mechanism for efficient long-range contextual modeling. The DeTrans-encoder is a composition of an input-to-sequence layer and stacked deformable Transformer (DeTrans) layers.
Input-to-sequence Transformation. Considering that Transformer processes the information in a sequence-to-sequence manner, we first flatten the feature maps produced by the CNN-encoder into a 1D sequence. Unfortunately, the operation of flattening the features leads to losing the spatial information that is critical for image segmentation. To address this issue, we supplement the 3D positional encoding sequence to the flattened . For this study, we use sine and cosine functions with different frequencies  to compute the positional coordinates of each dimension , shown as follows
where indicates each of three dimensions, . For each feature level , we concatenate , , and as the 3D positional encoding and combine it with the flattened via element-wise summation to form the input sequence of DeTrans-encoder.
MS-DMSA Layer. In the architecture of Transformer, the self-attention layer would look over all possible locations in the feature map. It has the drawback of slow convergence and high computational complexity, and hence can hardly process multi-scale features. To remedy this, we design the MS-DMSA layer that focuses only on a small set of key sampling locations around a reference location, instead of all locations.
Let be the feature representation of query and be the normalized 3D coordinate of the reference point. Given the multi-scale feature maps that are extracted in the last stages of CNN-encoder, the feature representation of the -th attention head can be calculated as
where is the number of sampled key points, is the attention weight, is the sampling offset of the -th sampling point in the -th feature level, and re-scales to the -th level feature. Following , both and are obtained via linear projection over the query feature . Then, the MS-DMSA layer can be formulated as
where is the number of attention heads, and is a linear projection layer that weights and aggregates the feature representation of all attention heads.
DeTrans Layer. The DeTrans layer is composed of a MS-DMSA layer and a feed forward network, each being followed by the layer normalization  (see Fig. 2). The skip connection strategy  is employed in each sub-layer to avoid gradient vanishing. The DeTrans-encoder is constructed by repeatedly stacking DeTrans layers.
The output sequence of DeTrans-encoder is reshaped into feature maps according to the size at each scale. The decoder, a pure CNN architecture, progressively upsamples the feature maps to the input resolution (,
) using the transpose convolution, and then refines the upsampled feature maps using a 3D residual block. Besides, the skip connections between encoder and decoder are also added to keep more low-level details for better segmentation. We also use the deep supervision strategy by adding auxiliary losses to the decoder outputs with different scales. The loss function of our model is the sum of the Dice loss and cross-entropy loss[12, 24, 28]. More details on the network architecture gare in Appendix.
3.4 Implementation details
Following , we first truncated the HU values of each scan using the range of to filter irrelevant regions, and then normalized truncated voxel values by subtracting 82.92 and dividing by 136.97. We randomly split the BCV dataset into two parts: 21 scans for training and 9 scans for test, and randomly selected 6 training scans to form a validation set, which just was used to select the hyper-parameters of CoTr. The final results on the test set are obtained by the model trained on all training scans.
In the training stage, we randomly cropped sub-volumes of size from CT scans as the input. To alleviate the over-fitting of limited training data, we employed the online data argumentation , including the random rotation, scaling, flipping, adding white Gaussian noise, Gaussian blurring, adjusting rightness and contrast, simulation of low resolution, and Gamma transformation, to diversify the training set. Due to the benefits of instance normalization 
, we adopted the micro-batch training strategy with a small batch size of 2. To weigh the balance between training time cost and performance reward, CoTr was trained for 1000 epochs and each epoch contains 250 iterations. We adopted the stochastic gradient descent algorithm with a momentum of 0.99 and an initial learning rate of 0.01 as the optimizer. We set the hidden size in MS-DMSA and feed forward network to 384 and 1536, respectively, and empirically set the hyper-parameters, , and . Besides, we formed two variants of CoTr with small CNN-encoders, denoted as CoTr and CoTr. In CoTr, there is only one 3D residual block in each stage of CNN-encoder. In CoTr, the number of 3D residual blocks in each stage of CNN-encoder is two.
In the test stage, we employed the sliding window strategy, where the window size equals to the training patch size. Besides, Gaussian importance weighting  and test time augmentation by flipping along all axes were also utilized to improve the robustness of segmentation. To quantitatively evaluate the segmentation results, we calculated the Dice coefficient scores (Dice) metric that measures the overlapping between a prediction and its ground truth.
|SETR (ViT-B/16-rand) ||100.5||95.2||92.3||55.6||71.3||96.2||80.2||89.7||83.9||68.9||68.7||60.5||78.4|
|SETR (ViT-B/16-pre) ||100.5||94.8||91.7||55.2||70.9||96.2||76.9||89.3||82.4||69.6||70.7||58.7||77.8|
|CoTr w/o CNN-encoder||21.9||95.2||92.8||59.2||72.2||96.3||81.2||89.9||85.1||71.9||73.3||61.0||79.8|
|CoTr w/o DeTrans||32.6||96.0||92.6||63.8||77.9||97.0||83.6||90.8||87.8||76.7||81.2||72.6||83.6|
Comparing to models with only Transformer encoder. We first evaluated our CoTr against two variants of the state-of-the-art SEgmentation Transformer (SETR) , which were formed by using randomly initialized and pre-trained ViT-B/16  as the encoder. We also compared to a variant of CoTr that removes the CNN-encoder (CoTr w/o CNN-encoder). To ensure an unprejudiced comparison, all models use the same decoder. The segmentation performance of these models is shown in Table 1, from which three conclusions can be drawn. First, although the Transformer architecture is not limited by the type of input images, the ViT-B/16 pre-trained on 2D natural images does not work well on 3D medical images. The suboptimal performance may be attributed to the domain shift between 2D natural images and 3D medical images. Second, ‘CoTr w/o CNN-encoder’ has about 22M parameters and outperforms the SETR with about 100M parameters. We believe that a lightweight Transformer may be more friendly for medical image segmentation tasks, where there is usually a small training dataset. Third, our CoTr with comparable parameters significantly outperforms ‘CoTr w/o CNN-encoder’, improving the average Dice over 11 organs by 4%. It suggests that the hybrid CNN-Transformer encoder has distinct advantages over the pure Transformer encoder in medical image segmentation.
Comparing to models with only CNN encoder. Then, we compared CoTr against a variant of CoTr that removes the DeTrans-encoder (CoTr w/o DeTrans) and three CNN-based context modeling methods, , the Atrous Spatial Pyramid Pooling (ASPP)  module, pyramid parsing (PP)  module, and Non-local  module. For a fair comparison, we used the same CNN-encoder and decoder but replaced our DeTrans-encoder with ASPP, PP, and Non-local modules, respectively. The results in Table 1 shows that our CoTr elevates consistently the segmentation performance over ‘CoTr w/o DeTrans’ on all organs and improves the average Dice by 1.4%. It corroborates that our CoTr using a hybrid CNN-Transformer encoder has a stronger ability than using a pure CNN encoder to learn effective representations for medical image segmentation. Moreover, comparing to these context modeling methods, our Transformer architecture contributes to more accurate segmentation.
Comparing to models with hybrid CNN-Transformer encoder. We also compared CoTr to other hybrid CNN-Transformer architectures like TransUNet . To process 3D images directly, we extended the original 2D TransUNet to a 3D version by using 3D CNN-encoder and decoder as done in CoTr. We also set the number of heads and layers of Transformer in 3D TransUNet to be the same as our CoTr. It shows in Table 1 that CoTr steadily beats TransUNet in the segmentation of all organs, particularly for the gallbladder and pancreas segmentation. Even with a smaller CNN-encoder, CoTr still achieves better performance than TransUnet in the segmentation of seven organs. The superior performance owes to the deformable mechanism in CoTr that makes it possible to process high-resolution and multi-scale feature maps due to the reduced computational and spatial complexities.
The proposed CoTr was trained using a workstation with a NVIDIA GTX 2080Ti GPU and the Pytorch software packages. It took about 2 days for training, and less than 30ms to segment a volume of size.
5 Discussion on Hyper-parameter Settings
In the DeTrans-encoder, there are three hyper-parameters, , , , and , which represent the number of sampled key points, heads, and stacked DeTrans layers, respectively. To investigate the impact of their settings on the segmentation, we set to 1, 2, and 4, set to 2, 4, and 6, and set to 2, 4, and 6. In Fig. 3 (a-c), we plotted the average Dice over all organs obtained on the validation set versus the values of , , and . It shows that increasing the number of , , or can improve the segmentation performance. To demonstrate the performance gain resulted from the multi-scale strategy, we also attempted to train CoTr with single-scale feature maps from the last stage. The results in Fig. 3 (d) show that using multi-scale feature maps instead of single-scale feature maps can effectively improve the average Dice by 1.2%.
In this paper, we propose a hybrid model of CNN Transformer, namely CoTr, for 3D medical image segmentation. In this model, we design the deformable Transformer (DeTrans) that employs the deformable self-attention mechanism to reduce the computational and spatial complexities of modelling the long-range dependency on multi-scale and high-resolution feature maps. Comparative experiments were conducted on the BCV dataset. The superior performance of our CoTr over both CNN-based and vanilla Transformer-based models suggests that, via combining the advantages of CNN and Transformer, the proposed CoTr achieves the balance in keeping the details of low-level features and modeling the long-range dependency. As a stronger baseline, our CoTr can be extended to deal with other structures (e.g., brain structure or tumor segmentation) in the future.
-  (2016) Layer normalization. arXiv preprint arXiv:1607.06450. Cited by: §3.2.
-  (2020) End-to-end object detection with transformers. In European Conference on Computer Vision, pp. 213–229. Cited by: §1.
-  (2020) Pre-trained image processing transformer. arXiv preprint arXiv:2012.00364. Cited by: §1.
-  (2021) TransUNet: transformers make strong encoders for medical image segmentation. arXiv preprint arXiv:2102.04306. Cited by: §1, Table 1, §4.
-  (2018) Encoder-decoder with atrous separable convolution for semantic image segmentation. In Proceedings of the European conference on computer vision (ECCV), pp. 801–818. Cited by: §1, Table 1, §4.
-  (2016) Inductive bias of deep convolutional networks through pooling geometry. arXiv preprint arXiv:1605.06743. Cited by: §1.
-  (2017) Deformable convolutional networks. In Proceedings of the IEEE international conference on computer vision, pp. 764–773. Cited by: §1.
-  (2018) Bert: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805. Cited by: §1.
-  (2020) An image is worth 16x16 words: transformers for image recognition at scale. arXiv preprint arXiv:2010.11929. Cited by: §1, §4.
-  (2020) Multi-organ segmentation over partially labeled datasets with multi-scale feature abstraction. IEEE Transactions on Medical Imaging 39 (11), pp. 3619–3629. Cited by: §1.
Deep residual learning for image recognition.
Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778. Cited by: §3.2.
Automated design of deep learning methods for biomedical image segmentation. arXiv preprint arXiv:1904.08128. Cited by: §1, §3.3, §3.4, §3.4, §3.4.
-  (2021) TransGAN: two transformers can make one strong gan. arXiv preprint arXiv:2102.07074. Cited by: §1.
International Conference on Machine Learning, pp. 4055–4064. Cited by: §1.
-  (2017) Large kernel matters–improve semantic segmentation by global convolutional network. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 4353–4361. Cited by: §1.
-  (2015) U-net: convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention, pp. 234–241. Cited by: §1, §7.2.
-  (2020) Training data-efficient image transformers & distillation through attention. arXiv preprint arXiv:2012.12877. Cited by: §1.
-  (2016) Instance normalization: the missing ingredient for fast stylization. arXiv preprint arXiv:1607.08022. Cited by: §3.1, §3.4, Figure 4.
-  (2017) Attention is all you need. In Proceedings of the 31st International Conference on Neural Information Processing Systems, pp. 6000–6010. Cited by: §1, §1, §3.2.
-  (2018) Non-local neural networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 7794–7803. Cited by: §1, Table 1, §4.
-  (2020) End-to-end video instance segmentation with transformers. arXiv preprint arXiv:2011.14503. Cited by: §1.
-  (2016) Multi-scale context aggregation by dilated convolutions. In International Conference on Learning Representations (ICLR), Cited by: §1.
-  (2020) Inter-slice context residual learning for 3d medical image segmentation. IEEE Transactions on Medical Imaging. Cited by: §1.
-  (2019) Light-weight hybrid convolutional network for liver tumor segmentation.. In IJCAI, pp. 4271–4277. Cited by: §1, §3.3, §7.2.
-  (2020) Block level skip connections across cascaded v-net for multi-organ segmentation. IEEE Transactions on Medical Imaging 39 (9), pp. 2782–2793. Cited by: §1.
-  (2017) Pyramid scene parsing network. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2881–2890. Cited by: §1, Table 1, §4.
-  (2020) Rethinking semantic segmentation from a sequence-to-sequence perspective with transformers. arXiv preprint arXiv:2012.15840. Cited by: §1, Table 1, §4.
-  (2019) Unet++: redesigning skip connections to exploit multiscale features in image segmentation. IEEE Transactions on Medical Imaging 39 (6), pp. 1856–1867. Cited by: §1, §3.3, §7.2.
-  (2020) Deformable detr: deformable transformers for end-to-end object detection. arXiv preprint arXiv:2010.04159. Cited by: §1, §1, §3.2.
7.1 Detailed network architecture
Fig. 4 shows the architecture of CNN-encoder, decoder and feed forward network in Detrans-encoder. It consists of a Conv-In-Relu and three stages of 3D residual blocks. The numbers of 3D residual blocks are three, three, and two in three stages, respectively. The decoder contains four upsampling modules. Each of first three modules has a TransConv layer followed by a residual block, and a pixel-wise summation with the corresponding feature maps from the encoder and the TransConv layer. The last module comprises of an Upsampling layer followed by a 1 × 1 Conv layer that maps the 64-channel feature maps to the desired number of classes. The feed forward network in Detrans-encoder has two linear projection layers. The first layer is followed by a layer normalization layer and a Dropout layer. The second layer is followed by a Dropout layer.
7.2 Loss function
We jointly use the Dice loss and cross-entropy loss for optimization, which is popular in many medical image segmentation applications and has achieved prominent success [16, 28, 24]. The loss function is formulated as
where the first item is the soft Dice loss, the second item is the cross-entropy loss, the prediction and ground truth are denoted by and , respectively, is the expectation operation, is a smoothing factor, and
is the number of categories. To speed up convergence and alleviate the vanishing gradient problem, we also use the deep supervision strategy that adds auxiliary losses to the decoder outputs with different resolutions. The total loss function is the sum of the losses at all resolutions.
The segmentation results produced by (1) SETR with pre-trained ViT-B/16, (2) replacing DeTrans-encoder with ASPP module, (3) 3D TransUNet, and (4) our CoTr, were visually compared in Fig. 5. We can see that: 1) comparing to the pure Transformer encoder method (SETR) and pure CNN encoder method (ASPP), our CoTr with the hybrid CNN-Transformer encoder is able to produce the segmentation results that are more similar to the ground truth, and 2) our CoTr are more likely to produce less false positives compared to TransUNet, which confirms the superiority of our 3D deformable Transformer over vanilla Transformer.