Transfer Learning for Domain Adaptation in MRI: Application in Brain Lesion Segmentation

by   Mohsen Ghafoorian, et al.
Harvard University

Magnetic Resonance Imaging (MRI) is widely used in routine clinical diagnosis and treatment. However, variations in MRI acquisition protocols result in different appearances of normal and diseased tissue in the images. Convolutional neural networks (CNNs), which have shown to be successful in many medical image analysis tasks, are typically sensitive to the variations in imaging protocols. Therefore, in many cases, networks trained on data acquired with one MRI protocol, do not perform satisfactorily on data acquired with different protocols. This limits the use of models trained with large annotated legacy datasets on a new dataset with a different domain which is often a recurring situation in clinical settings. In this study, we aim to answer the following central questions regarding domain adaptation in medical image analysis: Given a fitted legacy model, 1) How much data from the new domain is required for a decent adaptation of the original network?; and, 2) What portion of the pre-trained model parameters should be retrained given a certain number of the new domain training samples? To address these questions, we conducted extensive experiments in white matter hyperintensity segmentation task. We trained a CNN on legacy MR images of brain and evaluated the performance of the domain-adapted network on the same task with images from a different domain. We then compared the performance of the model to the surrogate scenarios where either the same trained network is used or a new network is trained from scratch on the new dataset.The domain-adapted network tuned only by two training examples achieved a Dice score of 0.63 substantially outperforming a similar network trained on the same set of examples from scratch.



page 7


Automated Multi-sequence Cardiac MRI Segmentation Using Supervised Domain Adaptation

Left ventricle segmentation and morphological assessment are essential f...

Transfer Learning for Ultrasound Tongue Contour Extraction with Different Domains

Medical ultrasound technology is widely used in routine clinical applica...

A Lifelong Learning Approach to Brain MR Segmentation Across Scanners and Protocols

Convolutional neural networks (CNNs) have shown promising results on sev...

One-shot domain adaptation in multiple sclerosis lesion segmentation using convolutional neural networks

In recent years, several convolutional neural network (CNN) methods have...

The reliability of a deep learning model in clinical out-of-distribution MRI data: a multicohort study

Deep learning (DL) methods have in recent years yielded impressive resul...

Domain Adaptation for Deviating Acquisition Protocols in CNN-based Lesion Classification on Diffusion-Weighted MR Images

End-to-end deep learning improves breast cancer classification on diffus...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Deep neural networks have been extensively used in medical image analysis and have outperformed the conventional methods for specific tasks such as segmentation, classification and detection [litjens2017]. For instance on brain MR analysis, convolutional neural networks (CNN) have been shown to achieve outstanding performance for various tasks including white matter hyperintensities (WMH) segmentation [ghafoorian2016location], tumor segmentation [kamnitsas2017efficient], microbleed detection [dou2016automatic], and lacune detection [Ghafoorian_2017]. Although many studies report excellent results on specific domains and image acquisition protocols, the generalizability of these models on test data with different distributions are often not investigated and evaluated. Therefore, to ensure the usability of the trained models in real world practice, which involves imaging data from various scanners and protocols, domain adaptation remains a valuable field of study. This becomes even more important when dealing with Magnetic Resonance Imaging (MRI), which demonstrates high variations in soft tissue appearances and contrasts among different protocols and settings.

Mathematically, a domain can be expressed by a feature space

and a marginal probability distribution

, where [pan2010survey]

. A supervised learning task on a specific domain

, consists of a pair of a label space and an objective predictive function (denoted by ). The objective function can be learned from the training data, which consists of pairs , where and . After the training process, the learned model denoted by is used to predict the label for a new instance . Given a source domain with a learning task and a target domain with learning task , transfer learning is defined as the process of improving the learning of the target predictive function in using the information in and , where , or [pan2010survey]. We denote as the predictive model initially trained on the source domain , and domain-adapted to the target domain .

In the medical image analysis literature, transfer classifiers such as adaptive SVM and transfer AdaBoost, are shown to outperform the common supervised learning approaches in segmenting brain MRI, trained only on a small set of target domain images


. In another study a machine learning based sample weighting strategy was shown to be capable of handling multi-center chronic obstructive pulmonary disease images


. Recently, also several studies have investigated transfer learning methodologies on deep neural networks applied to medical image analysis tasks. A number of studies used networks pre-trained on natural images to extract features and followed by another classifier, such as a Support Vector Machine (SVM) or a random forest 

[esteva2017dermatologist]. Other studies [tajbakhsh2016convolutional, Shin2016] performed layer fine-tuning on the pre-trained networks for adapting the learned features to the target domain.

Considering the hierarchical feature learning fashion in CNN, we expect the first few layers to learn features for general simple visual building blocks, such as edges, corners and simple blob-like structures, while the deeper layers learn more complicated abstract task-dependent features. In general, the ability to learn domain-dependent high-level representations is an advantage enabling CNNs to achieve great recognition capabilities. However, it is not obvious how these qualities are preserved during the transfer learning process for domain adaptation. For example, it would be practically important to determine how much data on the target domain is required for domain adaptation with sufficient accuracy for a given task, or how many layers from a model fitted on the source domain can be effectively transferred to the target domain. Or more interestingly, given a number of available samples on the target domain, what layer types and how many of those can we afford to fine-tune. Moreover, there is a common scenario in which a large set of annotated legacy data is available, often collected in a time-consuming and costly process. Upgrades in the scanners, acquisition protocols, etc., as we will show, might make the direct application of models trained on the legacy data unsuccessful. To what extent these legacy data can contribute to a better analysis of new datasets, or vice versa, is another question worth investigating.

In this study, we aim towards answering the questions discussed above. We use transfer learning methodology for domain adaptation of models trained on legacy MRI data on brain WMH segmentation.

2 Materials and Method

2.1 Dataset

Radboud University Nijmegen Diffusion tensor and Magnetic resonance imaging Cohort (RUN DMC)


is a longitudinal study of patients diagnosed with small vessel disease. The baseline scans acquired in 2006 consisted of fluid-attenuated inversion recovery (FLAIR) images with voxel size of 1.0

1.25.0 mm and an inter-slice gap of 1.0 mm, scanned with a 1.5 T Siemens scanner. However, the follow-up scans in 2011 were acquired differently with a voxel size of mm, including a slice gap of 0.5 mm. The follow-up scans demonstrate a higher contrast as the partial volume effect is less of an issue due to thinner slices. For each subject, we also used 3D T1 magnetization-prepared rapid gradient-echo (MPRAGE) with voxel size of mm which is the same among the two datasets. Reference WMH annotations on both datasets were provided semi-automatically, by manually editing segmentations provided by a WMH segmentation method [ghafoorian2016automated] wherever needed.

The T1 images were linearly registered to FLAIR scans, followed by brain extraction and bias-filed correction operations. We then normalized the image intensities to be within the range of [0, 1].

In this study, we used 280 patient acquisitions with WMH annotations from the baseline as the source domain, and 159 scans from all the patients that were rescanned in the follow-up as the target domain. Table 1 shows the data split into the training, validation and test sets. It should be noted that the same patient-level partitioning which was used on the baseline, was respected on the follow-up dataset to prevent potential label leakages.

Source Domain Target Domain
Set Train Validation Test Train Validation Test
Size 200 30 50 100 26 33
Table 1: Number of patients for the domain adaptation experiments.

2.2 Sampling

We sampled 3232 patches to capture local neighborhoods around WMH and normal voxels from both FLAIR and T1 images. We assigned each patch with the label of the corresponding central voxel. To be more precise, we randomly selected 25% of all voxels within the WMH masks, and randomly selected the same number of negative samples from the normal appearing voxels inside the brain mask. We augmented the dataset by flipping the patches along the axis. This procedure resulted in training and validation datasets of size 1.2m and 150k on the baseline, and 1.75m and 200k on the followup.

2.3 Network Architecture and Training

We stacked the FLAIR and T1 patches as the input channels and used a 15-layer architecture consisting of 12 convolutional layers of 3

3 filters and 3 dense layers of 256, 128 and 2 neurons, and a final softmax layer. We avoided using pooling layers as they would result in a shift-invariance property that is not desirable in segmentation tasks, where the spatial information of the features are important to be preserved. The network architecture is illustrated in Figure 


Figure 1: Arcitecture of the convolutional neural network used in our experiments. The shallowest layers are frozen and the rest layers are fine-tuned. is the depth of the network which was 15 in our experiments.

To tune the weights in the network, we used the Adam update rule [kingma2014adam]

with a mini-batch size of 128 and a binary cross-entropy loss function. We used the Rectified Linear Unit (ReLU) activation function as the non-linearity and the He method

[he2015delving] that randomly initializes the weights drawn from a distribution, where

is the number of inputs to a neuron. Activations of all layers were batch-normalized to speed up the convergence

[ioffe2015batch]. A decaying learning rate was used with a starting value of for the optimization process. To avoid over-fitting, we regularized our networks with a drop-out rate of 0.3 as well as the weight decay with

=0.0001. We trained our networks for a maximum of 100 epochs with an early stopping policy. For each experiment, we picked the model with the highest area under the curve on the validation set.

We trained our networks with a patch-based approach. At segmentation time, however, we converted the dense layers to their equivalent convolutional counterparts to form a fully convolutional network (FCN). FCNs are much more efficient as they avoid the repetitive computations on neighboring patches by feeding the whole image into the network. We prefer the conceptual distinction between dense and convolutional layers at the training time, to keep the generality of experiments for classification problems as well (e.g., testing the benefits of fine-tuning the convolutional layers in addition to the dense layers). Patch-based training allows class-specific data augmentation to handle domains with hugely imbalanced class ratios (e.g., WMH segmentation domain).

2.4 Domain Adaptation

To build the model , we transferred the learned weights from , then we froze shallowest layers and fine-tuned the remaining deeper layers with the training data from , where is the depth of the trained CNN. This is illustrated in Figure 1. We used the same optimization update-rule, loss function, and regularization techniques as described in Section 2.3.

2.5 Experiments

On the WMH segmentation domain, we investigated and compared three different scenarios: 1) Training a model on the source domain and directly applying it on the target domain; 2) Training networks on the target domain data from scratch; and 3) Transferring model learned on the source domain onto the target domain with fine-tuning. In order to identify the target domain dataset sizes where transfer learning is most useful, the second and third scenarios were explored with different training set sizes of 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 25, 50 and 100 cases. We extensively expanded the third scenario investigating the best freezing/tuning cut-off for each of the mentioned target domain training set sizes. We used the same network architecture and training procedure among the different experiments. The reported metric for the segmentation quality assessment is the Dice score.

Figure 2: (a) The comparison of Dice scores on the target domain with and without transfer learning. A logarithmic scale is used on the axis. (b) Given a deep CNN with =15 layers, transfer learning was performed by freezing the initial layers and fine-tuning the last layers. The Dice scores on the test set are illustrated with the color-coded heatmap. On the map, the number of fine-tuned layers are shown horizontally, whereas the target domain training set size is shown vertically.

3 Results

The model trained on the set of images from the source domain (), achieved a Dice score of 0.76. The same model, without fine-tuning, failed on the target domain with a Dice score of 0.005. Figure 2(a) demonstrates and compares the Dice scores obtained with three domain-adapted models to a network trained from scratch on different target training set sizes. Figure 2(b) illustrates the target domain test set Dice scores as a function of target domain training set size and the number of abstract layers that were fine-tuned. Figure 3 presents and compares qualitative results of WMH segmentation of several different models of a single sample slice.

Figure 3: Examples of the brain WMH MRI segmentations. (a) Axial T1-weighted image. (b) FLAIR image. (c-f) FLAIR images with WMH segmented labels: (c) reference (green) WMH. (d) WMH (red) from a domain adapted model () fine-tuned on five target training samples. (e) WMH (yellow) from model trained from scratch () on 100 target training samples. (f) WMH (orange) from model trained from scratch () on 5 target training samples.

4 Discussion and Conclusions

We observed that while demonstrated a decent performance on , it totally failed on . Although the same set of learned representations is expected to be useful for both as the two tasks are similar, the failure comes to no surprise as the distribution of the responses to these features are different. Observing the comparisons presented by Figure 2(a), it turns out that given only a small set of training examples on , the domain adapted model substantially outperforms the model trained from scratch with the same size of training data. For instance, given only two training images, achieved a Dice score of 0.63 on a test set of 33 target domain test images, while resulted in a dice of 0.15. As Figure 2(b) suggests, with only a few training cases available, best results can be achieved by fine-tuning only the last dense layers, otherwise enormous number of parameters compared to the training sample size would result in over-fitting. As soon as more training data becomes available, it makes more sense to fine-tune the shallower representations (e.g., the last convolutional layers). It is also interesting to note that tuning the first few convolutional layers is rarely useful considering their domain-independent characteristics.