Unsupervised Multi-modal Style Transfer for Cardiac MR Segmentation

08/20/2019
by   Chen Chen, et al.
Imperial College London
0

In this work, we present a fully automatic method to segment cardiac structures from late-gadolinium enhanced (LGE) images without using labelled LGE data for training, but instead by transferring the anatomical knowledge and features learned on annotated balanced steady-state free precession (bSSFP) images, which are easier to acquire. Our framework mainly consists of two neural networks: a multi-modal image translation network for style transfer and a cascaded segmentation network for image segmentation. The multi-modal image translation network generates realistic and diverse synthetic LGE images conditioned on a single annotated bSSFP image, forming a synthetic LGE training set. This set is then utilized to fine-tune the segmentation network pre-trained on labelled bSSFP images, achieving the goal of unsupervised LGE image segmentation. In particular, the proposed cascaded segmentation network is able to produce accurate segmentation by taking both shape prior and image appearance into account, achieving an average Dice score of 0.92 for the left ventricle, 0.83 for the myocardium, and 0.88 for the right ventricle on the test set.

READ FULL TEXT VIEW PDF

page 9

page 11

11/28/2018

Play as You Like: Timbre-enhanced Multi-modal Music Style Transfer

Style transfer of polyphonic music recordings is a challenging task when...
08/14/2019

Segmentation of Multimodal Myocardial Images Using Shape-Transfer GAN

Myocardium segmentation of late gadolinium enhancement (LGE) Cardiac MR ...
08/27/2020

Random Style Transfer based Domain Generalization Networks Integrating Shape and Spatial Information

Deep learning (DL)-based models have demonstrated good performance in me...
09/24/2020

Style-invariant Cardiac Image Segmentation with Test-time Augmentation

Deep models often suffer from severe performance drop due to the appeara...
07/19/2020

Unified cross-modality feature disentangler for unsupervised multi-domain MRI abdomen organs segmentation

Our contribution is a unified cross-modality feature disentagling approa...
05/22/2017

Anatomically Constrained Neural Networks (ACNN): Application to Cardiac Image Enhancement and Segmentation

Incorporation of prior knowledge about organ shape and location is key t...

1 Introduction

Cardiac segmentation from late-gadolinium enhanced (LGE) cardiac magnetic resonance (CMR) which highlights myocardial infarcted tissue is of great clinical importance, enabling quantitative measurements useful for treatment planning and patient management. To this end, the segmentation of the myocardium is an important first step for myocardial infarction analysis.

Since manual segmentation is tedious and likely to suffer from inter-observer variability, it is of great interest to develop an accurate automated segmentation method. However, this is a challenging task due to the fact that 1) the infarcted myocardium presents an enhanced and heterogeneous intensity distribution different from the normal myocardium region and 2) the border between infarcted myocardium and blood pool appears blurry and ambiguous [12]. While the borders of the myocardium can be difficult to delineate on LGE images, they are clear and easy to identify on the balanced steady-state free precession (bSSFP) CMR images, which have high signal-to-noise ratio and whose contrast is less sensitive to pathology (see red arrows in Fig. 1 (a)). Conventional methods [6, 9] use the segmentation result from the bSSFP CMR of the same patient as prior knowledge to assist the segmentation on LGE CMR images. These methods generally require accurate registration between the bSSFP and LGE images, which can be challenging as the imaging field-of-view (FOV), image contrast and resolution between the two acquisitions can vary significantly [12, 13]. Fig. 1 (b) visualizes the discrepancy between the intensity distributions of the two imaging modalities in the cardiac structures (specifically, left ventricle (LV), myocardium (MYO), and right ventricle (RV)).

Figure 1: The differences of image appearance (a) and intensity distributions (b) in the cardiac region (the union of LV, MYO, RV) between LGE images and bSSFP images

Most recently, a deep neural network-based method has been proposed to segment the three cardiac structures directly from LGE images [11], reporting superior performance. However, this supervised segmentation method requires a large amount of labelled LGE data. Because of the heterogeneous intensity distribution of the myocardium in LGE images and the scarcity of experienced image analysts, it is difficult to perform accurate manual segmentations on LGE images and collect a large training set, compared to that on bSSFP images.

In this paper, we present a fully automatic framework that addresses the above mentioned issues by training a segmentation model without using manual annotations on LGE images. This is achieved by transferring the anatomical knowledge and features learned on annotated bSSFP images, which are easier to acquire. Our framework mainly consists of two neural networks:

  • A multi-modal image translation network: this network is used for translating annotated bSSFP images into LGE images through style transfer. Of note, the network is trained in an unsupervised fashion where the training bSSFP images and LGE images are unpaired. In addition, unlike common one-to-one translation networks, this network allows the generation of multiple synthetic LGE images conditioned on a single bSSFP image;

  • A cascaded segmentation network for LGE images consisting of two U-net [7] models (Cascaded U-net): this network is first trained using the labelled bSSFP images and then fine-tuned using the synthetic LGE data generated by the image translation network.

The main contributions of our work are the following: 1) we employ a translation network that can generate realistic and diverse synthetic LGE images given a single bSSFP image. This network enables generative model-based data augmentation for unsupervised domain adaptation, which not only closes the domain gap between the two modalities, but also improves the generalization properties of the following segmentation network by increasing data variety; 2) we demonstrate that the proposed two-stage cascaded network, which takes both anatomical shape information and image appearance information into account, produces accurate segmentation on LGE images, greatly outperforming baseline methods; 3) the proposed framework can be easily extended to other unsupervised cross-modality domain adaptation applications where labels of one modality are not available.

2 Methodology

The proposed method aims at learning an LGE image segmentation model using labelled bSSFP and unlabelled LGE only. Specifically, the proposed method is a two-stage framework. In the first stage, an unsupervised image translation network is trained to translate each bSSFP image into multiple instances of LGE-like images, noted as . In the second stage, these LGE-stylized bSSFP images are used together with their original labels to adapt an image segmentation network pre-trained on labelled bSSFP images to segment LGE images.

2.1 Image Translation

We employ the state-of-the-art multi-modal unsupervised image-to-image translation network (MUNIT) 

[4] as our multi-modal image translator. Let and denote unpaired images from the two different imaging modalities (domains): LGE and bSSFP, given an image drawn from one domain as input, the network is able to change the appearance (i.e. image style) of the image to that of the other domain while preserving the underlying anatomical structure. This is achieved by learning disentangled image representations.

Figure 2: Overview of the multi-modal image translation network. The network employs the structure of MUNIT [4], which consists of two encoder-decoder pairs for the two domains: bSSFP and LGE, respectively.

As shown in Fig. 2, each image is disentangled into (a) a domain-invariant content code : and (b) a domain-specific style code : using the content encoder and the style encoder relative to its domain where the content code captures the anatomical structure and the style code carries the information for rendering the structure which is determined by the imaging modality. The image-to-image translation from one domain to the other is achieved by swapping latent codes in two domains. For example, translating a bSSFP image to be stylized as LGE, is achieved by feeding the content code for the bSSFP image and the style code into the LGE decoder : .

Of note, during training, each style encoder is trained to embed images into a latent space that matches the standard Gaussian distribution

, minimizing the Kullback-Leibler (KL) divergence between the two. This allows to generate an arbitrary number of synthetic LGE images given a single bSSFP image during inference, by repeatedly sampling the style code from the prior distribution . Of note, although this prior distribution is unimodal, the distribution of translated images in the output space is multi-modal thanks to the nonlinearity of the decoder[4]. We apply this translation network to translate annotated bSSFP images, resulting in a synthetic labelled LGE dataset, which will then be used to finetune a segmentation network. For more details about training the translation network, readers are referred to the original work by Huang et al. [4].

2.2 Image Segmentation

Figure 3: Overview of the two-stage cascaded segmentation network.

The architecture of each U-net is the same as the one of the vanilla U-net, except for two main differences: (1) batch normalization is applied after each convolutional layer; (2) a dropout layer (dropout rate=0.1) is applied after each concatenation operation in the network’s expanding path to encourage model generalizability. Of note, in this diagram, we simplify the training procedure by omitting the pre-training procedure using labelled bSSFP images.

Let

be an observed LGE image, the aim of the segmentation task is to estimate label maps

having observed by modeling the posterior . Inspired by curriculum learning [1]

and transfer learning, we first train a segmentation network using annotated bSSFP images (source domain; easy examples) and then fine-tune it to segment LGE images (target domain; hard examples). Since labelled LGE images

are not available for finetuning, we use a synthetic dataset generated by the aforementioned multi-modal image translator. Ideally, the posterior modelled by the network matches when image space and label space are shared. For simplicity, we use and to denote an image and its corresponding label map from the synthetic dataset in the following paragraphs.

The segmentation network is a two-stage cascaded network which consists of two U-nets [7], see Fig. 3. Specifically, given an image as input, the first U-net (U-net 1) aims at predicting four-class pixel-wise probabilistic maps for the three cardiac structures (i.e. LV, MYO, RV) and the background class (BG). Inspired by the auto-context architecture [10], we combine these learned probabilistic maps from the first network with the raw image to form a 5-channel input to train the second U-net (U-net 2) for fine-grained segmentation: . By combining the appearance information from the image with the shape prior information from the initial segmentation as input, the cascaded network has the potential to produce more precise and robust segmentations even in the presence of unclear boundaries for the different cardiac structures.

To train the network, we use a composite segmentation loss function which consists of two loss terms: The first term is a weighted cross entropy loss: where denotes the weight for class and

is the corresponding predicted probability map. We set the weight for myocardium

to be higher than the weights for the other three classes to address class imbalance problem since there is a lower percentage of pixels that corresponds to the myocardium class in CMR images. The second term is an edge-based loss which penalizes the disagreement on the contours of the cardiac structures. Specifically, we apply two 2D Sobel filters [8] (k=1,2) to the soft prediction maps as well as the one-hot heatmaps of the ground truth to extract edge information along horizontal and vertical directions. The edge loss is then computed by calculating the distance between the predicted edge maps and the ground truth edge maps: , where is the edge map extracted by applying the sobel filter to the predicted probabilistic map for foreground class .

By using the edge loss together with the weighted cross entropy for optimization, the network is encouraged to focus more on the contours of the three structures and the myocardium, which are usually more difficult to delineate. In our experiments, we set = 0.5 to balance the contribution of the two losses.

2.3 Post-processing

At inference time, each slice from a previously unseen LGE stack is fed to the cascaded network to get the probabilistic maps for the four classes. Dense conditional random field (CRF) [5] is then applied to refine the 2D predicted segmentation mask slice by slice. After that, 3D morphological dilation and erosion operations are applied to the whole segmentation stack to further improve the global smoothness. In particular, we perform the operations in a hierarchical order: first we apply them to the binary map covering all the three structures, then to the MYO and the LV labels, separately.

3 Experiments and Results

3.1 Data

The framework was trained and evaluated on the Multi-sequence Cardiac MR Segmentation Challenge (MS-CMRSeg 2019) dataset111https://zmiclab.github.io/mscmrseg19/. We used a subset of 40 bSSFP and 40 LGE images to train the image translation network. Then, we created a synthetic dataset by applying the learned translation network to 20 labelled bSSFP images. Specifically, for each bSSFP image, we randomly sampled the style code from five times, resulting in a set of 100 synthetic LGE images in total. This synthetic dataset and the original 20 bSSFP images with corresponding labels formed the training set for the segmentation network. Exemplar results of these synthetic LGE images are provided in the supplemental material. For validation, we used a subset of 5 annotated LGE images provided by the challenge organizers.

3.2 Implementation Details

Image preprocessing. To deal with the different image size and heterogeneous pixel spacing between different imaging modalities, all images were resampled to a voxeling spacing of and then cropped to pixels, with the heart roughly at the center of each image. This spatial normalization would reduce the computational cost and task complexity in the following training procedure of image translation and segmentation, making the networks focus on the relevant regions. To identify the heart, we trained a localization network based on U-net using the 20 annotated bSSFP images in the training set to produce rough segmentations for the three structures. The localization network employs instance normalization layers which perform style normalization [3], encouraging the network invariance to image style changes (e.g. image contrast). As a result, the network is able to produce coarse masks localizing the heart on all bSSFP images and most LGE images even though it was trained on bSSFP images only. In case that this network might fail to locate the heart on certain LGE slices, we summed the segmentation masks across slices in each volume and then cropped them according to the center of the aggregated mask. After cropping, each image was intensity normalized.

Network training. (1) For the image translation network, we used the official implementation222https://github.com/NVlabs/MUNIT of [4]. Network configuration and hyper-parameters were kept the same as in [4] except the input and output images are 2D, single-channel. It was trained for 20k iterations with a batch size of 1. (2) For the segmentation network, we first trained the first U-net with the labelled bSSFP images and then fine-tuned it with synthetic LGE images. This procedure was replicated to train the second U-net with the parameters of the first U-net being fixed. Both networks were optimized using the composite loss

where adam was used for stochastic gradient descent. The learning rate was initially set to 0.001 and was then decreased to

for fine-tuning. The weights for BG, LV, MYO, and RV in were empirically set to . During training, we applied data augmentation on the fly. Specifically, elastic deformations, random scaling and random rotations as well as gamma augmentation [2]

were used. The algorithm was implemented using python and PyTorch and was trained for 1000 epochs in total on an NVIDIA

Tesla P40 GPU.

3.3 Results

To evaluate the accuracy of segmentation results, the Dice metric and the average surface distance (ASD) between the automatic segmentation and the corresponding manual segmentation for each volume were calculated.

We compare the proposed method with two baseline methods: (1) a registration-based method and (2) a single U-net. Specifically, for the registration-based method, each LGE segmentation result was obtained by directly registering the corresponding bSSFP labels to the LGE image using MIRTK toolkit 333https://mirtk.github.io/ for ease of comparison. The transformation matrix was learned by applying mutual information-based registration (Rigid+Affine+FFD) between the two images. For U-net, we trained it with two settings: a) U-net: trained on labelled bSSFP images only; b) U-net with fine-tuning (FT): trained on labelled bSSFP images and then fine-tuned using the synthetic LGE data, which is the same training procedure of the proposed method. Quantitative and qualitative results are shown in Table 1 and Fig. 4.

While the registration-based method (MIRTK) outperforms the U-net (see row 1 and row 2 in Table 1), it still fails to produce accurate segmentation on the myocardium (see the red number in row 1), indicating the limitation of this registration-based method. However, by contrast, neural network-based methods (row 3-5) fine-tuned using the synthetic LGE dataset significantly improves the segmentation accuracy, increasing the Dice score for MYO by . This improvement demonstrates the learned translation network is capable of generating realistic LGE images while preserving the domain-invariant structural information that is informative to optimize the segmentation network. In particular, compared to U-net (FT), the proposed Cascaded U-net (FT) achieves more accurate segmentation performance with improvement in terms of both Dice and ASD (see blue numbers). The model even produces robust segmentation results on the challenging apical and basal slices (please see the last column in Fig. 4). This demonstrates the benefit of integrating the high-level shape knowledge and low-level image appearance to guide the segmentation procedure. In addition, the proposed post-processing further refines the segmentation results through smoothing, reducing the average ASD from 1.37 to 1.26 (see the last row in Table 1).

Dice ASD
Method LV MYO RV AVG* LV MYO RV AVG*
MIRTK 0.819 0.665 0.831 0.772 2.56 1.65 2.11 2.11
U-net 0.624 0.441 0.577 0.547 10.03 6.07 N/A N/A
U-net (FT) 0.874 0.781 0.896 0.850 1.78 1.50 1.28 1.52
Cascaded U-net (FT) 0.895 0.812 0.898 0.868 1.41 1.46 1.23 1.37
Cascaded U-net (FT) + PP 0.897 0.816 0.895 0.869 1.17 1.42 1.18 1.26
  • * For ease of comparison, we calculate the average (AVG) Dice score and the average ASD score over the three structures for each method.

Table 1: Dice scores and ASD (mm) of the proposed segmentation method (Cascaded U-net) and baseline methods on the validation set. Blue numbers indicate the best scores among the results obtained by those methods before post-processing (PP) whereas red numbers are those mean Dice scores under 0.700. FT: fine-tuning using the synthetic LGE dataset. N/A means that the ASD value cannot be calculated due to missing predictions for that cardiac structure.
Figure 4: Segmentation results for the proposed Cascaded U-net and the baseline approaches. Our proposed method (the right-most column) produces more anatomically plausible segmentation results on the images, greatly outperforming the baseline methods, especially in the challenging cases: the apical (the top row) and the basal slices (the bottom row).

Finally, we applied ensemble learning to improve our model’s performance in the test phase. Specifically, we trained the proposed segmentation network for multiple times, each time regenerating a new synthetic LGE dataset for fine-tuning. We trained four models in total. Our final submission result for each test image was obtained by averaging the probabilistic maps from these models and then assigning to each pixel the class with the highest score. In the testing stage of the competition, the method achieves very promising segmentation performance on a relative large test set (40 subjects), with an average Dice score of 0.92 for LV, 0.83 for MYO, and 0.88 for RV; an ASD of 1.66 for LV, 1.76 for MYO, and 2.16 for RV.

4 Conclusion

In this paper, we showed that synthesizing multi-modal LGE images from labelled bSSFP images to finetune a pre-trained segmentation network shows impressive segmentation performance on LGE images even though the network has not seen real

labelled LGE images before. We also demonstrated that the proposed segmentation network (Cascaded U-net) outperformed the baseline methods by a significant margin, suggesting the benefit of integrating the high-level shape knowledge and low-level image appearance to guide the segmentation procedure. More importantly, our cascaded segmentation network is independent of the particular architecture of underlying convolutional neural networks. In other words, the basic neural network (U-net) in our work can be replaced with any of the state-of-the-art segmentation network to potentially improve the prediction accuracy and robustness. Moreover, the proposed solution based on unsupervised multi-modal style transfer is not only limited to the cardiac image segmentation but can be extended to other multi-modal image analysis tasks where the manual annotations of one modality are not available. Future work will focus on the application of the method to the problems such as domain adaptation for multi-modality brain segmentation.

References

  • [1] Y. Bengio, J. Louradour, R. Collobert, and J. Weston (2009) Curriculum learning. In

    Proceedings of the 26th Annual International Conference on Machine Learning

    ,
    ICML ’09, New York, NY, USA, pp. 41–48. External Links: ISBN 9781605585161, Document Cited by: §2.2.
  • [2] C. Chen, W. Bai, and D. Rueckert (2019) Multi-task learning for left atrial segmentation on ge-mri. In Statistical Atlases and Computational Models of the Heart. Atrial Segmentation and LV Quantification Challenges, External Links: ISBN 978-3-030-12029-0 Cited by: §3.2.
  • [3] X. Huang et al. (2017) Arbitrary style transfer in real-time with adaptive instance normalization. In ICCV, Cited by: §3.2.
  • [4] X. Huang et al. (2018) Multimodal unsupervised image-to-image translation. In ECCV, Cited by: Figure 2, §2.1, §2.1, §3.2.
  • [5] P. Krähenbühl et al. (2011) Efficient inference in fully connected CRFs with gaussian edge potentials. In NeuralIPS, Cited by: §2.3.
  • [6] Y. Lu et al. (2013) Automatic myocardium segmentation of LGE MRI by deformable models with prior shape data. JCMR 15 (1), pp. P14. Cited by: §1.
  • [7] et al. (2015) U-net: Convolutional networks for biomedical image segmentation. In MICCAI, Cited by: 2nd item, §2.2.
  • [8] I. Sobel and G. Feldman (1973-01) A 3x3 isotropic gradient operator for image processing. Pattern Classification and Scene Analysis, pp. 271–272. Cited by: §2.2.
  • [9] Q. Tao et al. (2015) Automated left ventricle segmentation in late gadolinium-enhanced MRI for objective myocardial scar assessment. JMRI 42 (2), pp. 390–399. Cited by: §1.
  • [10] Z. Tu and X. Bai (2010) Auto-context and its application to high-level vision tasks and 3D brain image segmentation. PAMI. External Links: Document, ISSN 01628828 Cited by: §2.2.
  • [11] Q. Yue et al. (2019) Cardiac Segmentation from LGE MRI Using Deep Neural Network Incorporating Shape and Spatial Priors. MICCAI. Cited by: §1.
  • [12] X. Zhuang (2016) Multivariate mixture model for cardiac segmentation from multi-sequence MRI. In MICCAI, pp. 581–588. Cited by: §1.
  • [13] X. Zhuang (2018) Multivariate mixture model for myocardium segmentation combining multi-source images. PAMI. Cited by: §1.

Supplemental Material

Figure 5: Exemplar synthetic LGE images generated from bSSFP images using the multi-modal image translation network. Given one bSSFP image (column 1), the translation network translates the image into multi-modal LGE-like images (column 2 to 4). These translated images differ in image brightness and contrast as well as the intensity distribution in the cardiac region, while preserving the same cardiac anatomy. These synthetic images, in together with the annotations on the original bSSFP images (the last column) contribute to the synthetic dataset which is used to fine-tune the proposed segmentation network.