Log In Sign Up

Attentive Symmetric Autoencoder for Brain MRI Segmentation

Self-supervised learning methods based on image patch reconstruction have witnessed great success in training auto-encoders, whose pre-trained weights can be transferred to fine-tune other downstream tasks of image understanding. However, existing methods seldom study the various importance of reconstructed patches and the symmetry of anatomical structures, when they are applied to 3D medical images. In this paper we propose a novel Attentive Symmetric Auto-encoder (ASA) based on Vision Transformer (ViT) for 3D brain MRI segmentation tasks. We conjecture that forcing the auto-encoder to recover informative image regions can harvest more discriminative representations, than to recover smooth image patches. Then we adopt a gradient based metric to estimate the importance of each image patch. In the pre-training stage, the proposed auto-encoder pays more attention to reconstruct the informative patches according to the gradient metrics. Moreover, we resort to the prior of brain structures and develop a Symmetric Position Encoding (SPE) method to better exploit the correlations between long-range but spatially symmetric regions to obtain effective features. Experimental results show that our proposed attentive symmetric auto-encoder outperforms the state-of-the-art self-supervised learning methods and medical image segmentation models on three brain MRI segmentation benchmarks.


Exploring The Role of Mean Teachers in Self-supervised Masked Auto-Encoders

Masked image modeling (MIM) has become a popular strategy for self-super...

Research on Patch Attentive Neural Process

Attentive Neural Process (ANP) improves the fitting ability of Neural Pr...

An Auto-Encoder Strategy for Adaptive Image Segmentation

Deep neural networks are powerful tools for biomedical image segmentatio...

Segmentation of Shoulder Muscle MRI Using a New Region and Edge based Deep Auto-Encoder

Automatic segmentation of shoulder muscle MRI is challenging due to the ...

MIST: Multiple Instance Spatial Transformer Network

We propose a deep network that can be trained to tackle image reconstruc...

A Deep Unsupervised Learning Approach Toward MTBI Identification Using Diffusion MRI

Mild traumatic brain injury (mTBI) is a growing public health problem wi...

Weakly-Supervised Spatial Context Networks

We explore the power of spatial context as a self-supervisory signal for...

1 Introduction

Accurate segmentation of brain lesion, tumour or tissue for Magnetic Resonance Imaging (MRI) data is essential for building a computer-aided diagnosis (CAD) system, and helps medical experts improve diagnosis and treatment planning. It is necessary to develop an automatic segmentation tool for brain MRI.

Deep convolutional neural networks (DCNNs) have achieved success in brain MRI segmentation 

[cciccek20163d, ronneberger2015u, zhou2018unet++], but their local receptive fields fail to capture long-range spatial dependencies. Recently, transformer-based models [dosovitskiy2020image, li2022view] have drawn extensive attention and shown the state-of-the-art results on 3D image segmentation [wang2021transbts, hatamizadeh2022unetr, zhou2021nnformer]. These methods collect dense correlations between long-range voxels for representation learning, but they require numerous voxel-level annotations that is scarce in brain medical image. Self-supervised learning (SSL) [taleb2021multimodal, taleb20203d, tao2020revisiting] uses unlabeled data to pre-train a model that can be fine-tuned to improve the results on downstream tasks. Recently, reconstruction-based SSL methods [he2021masked, wei2022masked], which pre-train transformers for patch-level recovering with natural images. If these methods are applied to 3D medical images, they may fail to model the prior of a brain because they treat all recovered patches equally. Some recent work [Tang_2022_CVPR] pre-trains transformers for medical images but it neglects the symmetry of brain structures and the different importance of brain regions.

Motivated by the above observations, we consider a novel transformer-based SSL framework for brain MRI segmentation. Despite individual variations, the structure of brain tissues is relatively stable while lesions have their particular textures and appearance. During the SSL, reconstructing a smooth brain region is not challenging and may cause over-fitting. On the contrary, synthesizing an informative image patch is more difficult, which requires mining the intrinsic representations of anatomical structures. In this work, we propose an attentive reconstruction loss weighting different image regions with their informativeness that is measured by a handcrafted gradient-based score. Moreover, symmetry is an essential prior of brain structure. As transformers encode the coordinates of image patches for computing correlations between different positions, we introduce the symmetry to design a new position encoding method which returns the same code for two distant but symmetrical positions. Transformers with the encoding can enhance the visual features by emphasizing the correlations between contralateral brain regions. Finally, we integrate the proposed loss and encoding with a masked autoencoder to build our proposed SSL framework. Our contributions are summarized as: (1) a novel attentive reconstruction loss function, (2) a new symmetric position encoding method, and (3) an SSL framework attentive symmetric autoencoder for brain MRI segmentation. (4) Experimental results show that our method outperforms the state-of-the-art SSL methods and medical image segmentation models on three public benchmarks.

2 Methodology

2.1 Attentive Symmetric Autoencoder

We propose a novel Attentive Symmetric Autoencoder (ASA) that can be trained to obtain generalizable model weights for adapting brain MRI segmentation tasks. As shown in Fig 1, the proposed ASA consists of a pair of encoder and decoder with symmetric position encoding (SPE) and an attentive reconstruction loss. During the self-supervised training of ASA, the input 3D image is divided into regular non-overlapping image patches (of size ).

% of these image patches are randomly masked and only the unmasked patches are visible. After a linear projection, each visible patch is embedded into a feature vector, which is added with its Symmetric Position Encoding (SPE) to produce the encoder input. The encoder outputs the same number of vectors as its input. Mask Tokens are the same learnable vector added with different SPEs. Each mask token corresponds to a masked image patch. The encoder output is concatenated with the mask tokens to form the decoder input. The decoder reconstructs all the image patches and only the masked ones are used to compute the proposed loss.

Figure 1: The architecture of Attentive Symmetric Autoencoder. SPE means symmetric position encoding.

2.1.1 Attentive Reconstruction Loss.

Considering that learning to recover flatten regions is less helpful for encouraging the model to harvest discriminative representations. We develop an attentive reconstruction loss function that emphasizes the informative regions of brain MRI. To estimate the information of an image patch, we adopt a gradient based metric for 3D images. Inspired by 3D VHOG [hog2015], we calculate the gradient vector for each voxel by applying the filter mask of [-1, 0, 1]. In spherical coordinates, we use two scalars and to represent the orientation of a voxel. and can be calculated as:


For each image patch we build a 2D histogram and the number of bins is . To compute the values , we traverse each voxel in the image patch. Let denote the orientation of the current voxel. We first determine the bin indexes of the voxel as . And then we accumulate (the gradient magnitude of the current voxel) to the corresponding bin of the 2D histogram . After processing all voxels in an image patch, norm is performed on , the histogram of the patch. We calculate the mean of as for each image patch. is normalized among all masked image regions to characterize the relative importance as .

Our proposed loss function adopts mean squared error (MSE) to measure the pixel-level difference between the recovered image areas and the original ones, and pays more attention to the informative brain regions using the gradient-based weight . The overall loss can be formulated as Eq. (2):


where are the reconstructed and the original images. is the number of masked image patches in an image. is the voxel number in an image patch. denotes the -th voxel of the -th patches in the image .

Figure 2: Symmetric Position Encoding.

2.1.2 Symmetric Position Encoding.

We observe the left-right symmetry of brain structures, and propose a Symmetric Position Encoding (SPE) method. The proposed method narrow the encoding difference of two symmetric image positions, and can encourage the model to harvest better features from these two correlated regions. For the patches in the same horizontal plane (Fig. 2(a)), the vanilla position encoding [vaswani2017attention] of the top left is largely different from that of the top right (Fig. 2(b)), even though these regions have similar contents. However, using our proposed SPE, the leftmost and the rightmost positions (in the same row) can share the same encoding. Let , denote the patch number of an image and the coordinate of an image patch. The symmetric position encoding is computed as Eq. (3):


where is the dimension number of the SPE vector and is set to the channel number of image patch embeddings. returns the -th/-th element of the SPE vector for a patch at . As Fig. 1 shows, the SPE method is used for twice, one for patch embeddings, the other for mask tokens.

2.2 Network Architecture

The proposed ASA model is to provide pre-trained model weights for the downstream task, brain MRI segmentation. Here we describe the architecture of the ASA model and the image segmentation model. The encoder and the decoder of the ASA are based on Vision Transformer (ViT) [dosovitskiy2020image]. The standard ViT [dosovitskiy2020image] uses vanilla self-attention (SA), which leads to high computational cost, especially when processing 3D images. For efficiency, we develop Linear Window-based Multi-head Self-attention (LW-MSA) and Shifted Linear Window-based Multi-head Self-attention(SLW-MSA). Inspired by SwinT [liu2021swin], we flatten 3D patches into a sequence of patch embeddings, and split the sequence into windows of size . LW-MSA computes self-attention within each 1D window. SLW-MSA shifts the sequence by before computing a LW-MSA module, and shifts the sequence by reversely after the LW-MSA module. LW-MSA and SLW-MSA are computed on a patch level since we convert each image patch to a feature vector via a patch embedding layer at the very beginning. LW-MSA and SLW-MSA are stacked alternately to extract cross-windows features and to build a shifted-window ViT (SW-ViT) for our ASA model. For brain MRI segmentation, we build a U-net with the ASA encoder as the backbone, as shown in Fig. 3.

Figure 3: The architecture of network in downstream tasks.

3 Experiments and Results

3.0.1 Implementation Details

To pre-train the ASA model, we use center-cropping augmentation, Xavier uniform initializer [glorot2010understanding] for SW-ViT blocks and set the hyper-parameters following [he2021masked] (see Table 1(a)). We follow MAE [he2021masked] and set to 75. The patch size is 8. To fine-tune the image segmentation model (Fig.3), we adopt the online data augmentation [isensee2021nnu] (random rotation, scaling, flipping and Gamma transformation). Only the encoder of the ASA is used for initialization. Other settings are in Table 1

(b). The experiments are run with PyTorch. For the pre-training we use four 32GB GPUs (NVIDIA V100). It takes 1 day with the early-stop strategy. The fune-tuning takes 1-2 days with 1 GPU.

 config  value
 optimizer  AdamW[loshchilov2017decoupled]
 optimizer momentum  0.9, 0.95
 weight decay  0.05
 learning rate schedule  cosine decay[loshchilov2016sgdr]

 warmup epochs

 base learning rate  1.5e-4
 batch size  96
(a) Pre-training setting.
 config  value
 optimizer  SGD
 optimizer momentum  0.99
 weight decay  3e-05
 initial learning rate  0.01
 batch size  2
 num_epoch  1000
 loss Dice and CE loss
(b) Fine-tuning setting.
Table 1: The hyper-parameters setting for pre-training and fine-tuning.

Datasets. For pre-training our ASA model, we adopt T1 MRI from 2 public datasets, including 9952 cases from Alzheimer’s Disease Neuroimaging Initiative(ADNI) dataset-2-2-2 [jack2008alzheimer] and 2041 cases from Open Access Series of Imaging Studies(OASIS) dataset-1-1-1 [lamontagne2019oasis]. We convert the data into Brain Imaging Data Structure (BIDS), affinely align the T1 images to the MNI space via Clinica platform [el2021clinica], strip the brain skull from these images with ROBEX [iglesias2011robust] and crop a region at their center.

For downstream task, we adopt 3 brain MRI segmentation benchmarks: Brain Tumor Segmentation (BraTS) 2021 dataset000 [menze2014multimodal] has 1251 subjects. Each subject has 4 aligned MRI modalities: T1, T1Gd, T2 and T2-FLAIR. The annotations consist of GD-enhancing tumor (ET), peritumoral edematous (ED) and necrotic tumor core (NCR), which are combined into 3 nested sub-regions: Whole Tumor (WT), Tumor Core (TC), Enhancing Tumor (ET). Following [zhou2021nnformer], we set the ratio of training/validation/test as 7:1:2.

Internet Brain Segmentation Repository (IBSR) dataset111 [rohlfing2004evaluation] has 18 T1-weighted MRI volumes of 4 healthy females and 14 healthy males. The ground truth (GT) has 3 categories: Cerebrospinal Fluid (CSF), Gray Matter (GM), White Matter (WM). We adopt 12 cases for training and 6 cases for testing.

White Matter Hyperintensities (WMH) dataset222 [kuijf2019standardized] involves 60 T1 images with pixel-level labels of White Matter Hyperintensities(WMH). We process data as [li2018fully] and use 36 cases for training and the rest for testing.

Task BraTS 2021
Metric Dice(%) HD95(mm)
nnFormer [zhou2021nnformer] 91.46 87.42 82.22 10.15 9.59 16.78
TransBTS [wang2021transbts] 92.06 88.20 79.46 4.98 4.86 16.32
UNETR [hatamizadeh2022unetr] 92.12 88.32 79.61 4.91 4.67 16.32
3D-RPL [taleb20203d] 93.92 90.13 85.92 3.74 3.98 13.71
3D-Jig [taleb20203d] 93.87 90.14 86.01 3.85 3.94 11.79
Ours 94.03 90.29 86.76 3.61 3.78 10.25
Table 2: Comparison on BraTS 2021 dataset. The first group are several competing methods. The best performance is in bold.

Evaluation Metric. We calculated Dice coefficient scores (Dice) and 95% Hausdorff Distance (HD95) to evaluate the segmentation results in our experiments.

3.0.2 Comparison with the State-of-the-art.

We compare our method with existing 3D transformer-based models (nnFormer [zhou2021nnformer], TransBTS [wang2021transbts], UNETR [hatamizadeh2022unetr]) and 3D self-supervised methods (Relative 3D patch location(3D-RPL) [taleb20203d], 3D Jigsaw puzzle Solving (3D-Jig) [taleb20203d]) on 3 brain MRI segmentation tasks.

Figure 4: Visualization of segmentation results on BraTS 2021 dataset.

As Table 2 shows, on Brats 2021 dataset our method achieves the Dice scores of 94.03%, 90.29%, 86.76% and the HD95 of 3.61mm, 3.78mm and 10.25mm on WT, TC, ET. Compared to transformer-based methods, our method achieves significantly better performance with both metrics. Specifically, our approach outperforms TransBTS [wang2021transbts] and nnFormer [zhou2021nnformer] by more than 7% and 4% Dice on ET respectively. Besides, our method shows more competitive results than other SSL methods using the same image segmentation network. For ET category, our method obtains 3.46mm and 1.54mm lower in HD95 than 3D-RPL and 3D-Jig. The visual comparisons are shown in Fig. 4. Our method does predict the ET region (blue) more accurately. As Table 3 shows, on IBSR dataset our method displays the highest Dice on CSF & GM, and obtains the lowest HD95 on CSF. On WMH dataset, the proposed method performs the best on both metrics. These results show that the model weights pre-trained by our method can be transferred to a wide range of datasets and help achieve the state-of-the-art performance.

3.0.3 Ablation Analysis.

Metric Dice(%) HD95(mm) Dice(%) HD95(mm)
nnFormer [zhou2021nnformer] 87.31 93.81 92.12 1.52 1.52 1.21 78.04 2.81
TransBTS [wang2021transbts] 81.42 93.91 92.17 7.84 1.54 1.40 78.81 2.91
UNETR [hatamizadeh2022unetr] 86.75 93.49 91.86 1.64 1.74 1.48 77.99 3.53
3D-RPL [taleb20203d] 86.63 93.85 92.50 1.83 1.54 1.29 78.63 3.06
3D-Jig [taleb20203d] 86.93 93.57 92.11 2.00 1.74 1.44 77.86 3.36
Ours 87.63 93.91 92.44 1.46 1.54 1.33 78.99 2.73
Table 3: Comparison on IBSR dataset and WMH dataset.
Metric Dice(%) HD95(mm)
Baseline 93.75 89.76 84.98 3.93 4.09 13.93
w/ SSL 94.02 90.28 86.25 4.01 4.06 13.44
w/ A-SSL 93.95 90.24 86.38 3.84 3.79 11.69
w/ SPE 93.90 90.15 85.86 3.69 3.82 11.59
w/ SPE&SSL 93.85 90.04 86.83 3.64 3.84 11.59
w/ ASA (Ours) 94.03 90.29 86.76 3.61 3.78 10.25
Table 4: Ablation study on BraTS 2021. SSL denotes the 3D Masked Autoencoder (MAE) SSL method. A-SSL is the MAE method with our AR-Loss.

We verify the strength of the attentive reconstuction loss (AR-Loss), the SPE, and our overall ASA framework on BraTS 2021, as shown in Table 4. ‘Baseline’ denotes the SW-ViT based segmentation network (see Fig. 3) trained from scratch. ‘w/ SSL’ denotes training the segmentation network with the model weights pre-trained by a 3D Masked Autoencoder (MAE) SSL method [he2021masked]. ‘A-SSL’ denotes the MAE method with the proposed AR-Loss. As shown as the first half of Table 4, the A-SSL method produces more accurate segmentation results than the competitor SSL at HD95 metric. Especially, on ET the HD95 of using A-SSL is nearly 2mm lower than that of using the SSL method. Note that HD95 measures the distance between the point sets of two boundaries. The above results show that the proposed AR-Loss can encourage the encoder to learn better representations for boundary information. ‘w/ SPE’ denotes applying the SPE to the train-from-scratch Baseline. ‘w/ SPE’ obtains 0.9% higher in Dice and 2.3mm lower in HD95 than ‘Baseline’. ‘w/ ASA’ denotes using our overall method with the loss and SPE. By comparing ASA with A-SSL, the SPE can further slightly improve A-SSL by 1.4mm HD95 on ET. These results suggest that our proposed encoding can help the ViT-based encoder understand symmetric structures and harvest discriminative features.

4 Conclusion

In this paper, we propose a novel self-supervised learning architecture for 3D medical images. The proposed framework contains two key components, the symmetric position encoding and the attentive reconstruction loss. The encoding can benefit feature learning for symmetric structures and the attentive loss emphasizes informative image regions for reconstruction-based SSL. Both techniques can improve the generalization of trained models. Extensive experiments are conducted on three public brain MRI datasets. The results suggest that our method can achieve competitive performance with the state-of-the-art SSL methods and medical image segmentation models.