Improving Calibration and Out-of-Distribution Detection in Medical Image Segmentation with Convolutional Neural Networks

04/12/2020 ∙ by Davood Karimi, et al. ∙ 21

Convolutional Neural Networks (CNNs) are powerful medical image segmentation models. In this study, we address some of the main unresolved issues regarding these models. Specifically, training of these models on small medical image datasets is still challenging, with many studies promoting techniques such as transfer learning. Moreover, these models are infamous for producing over-confident predictions and for failing silently when presented with out-of-distribution (OOD) data at test time. In this paper, we advocate for training on heterogeneous data, i.e., training a single model on several different datasets, spanning several different organs of interest and different imaging modalities. We show that not only a single CNN learns to automatically recognize the context and accurately segment the organ of interest in each context, but also that such a joint model often has more accurate and better-calibrated predictions than dedicated models trained separately on each dataset. We also show that training on heterogeneous data can outperform transfer learning. For detecting OOD data, we propose a method based on spectral analysis of CNN feature maps. We show that different datasets, representing different imaging modalities and/or different organs of interest, have distinct spectral signatures, which can be used to identify whether or not a test image is similar to the images used to train a model. We show that this approach is far more accurate than OOD detection based on prediction uncertainty. The methods proposed in this paper contribute significantly to improving the accuracy and reliability of CNN-based medical image segmentation models.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 6

page 7

page 9

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

I Introduction

Medical image segmentation is an essential component of many medical image analysis and image-guided intervention pipelines. Compared with manual segmentation by an expert, computerized automatic segmentation methods have the potential to improve the speed and reproducibility of the segmentations. Classical automatic medical image segmentation methods include such methods as region growing, level-sets, and atlas-based techniques. Recently, deep learning models, and in particular convolutional neural networks (CNNs), have shown to be excellent tools for this task.

Many recent studies have shown that CNN-based methods outperform the classical methods on various medical image segmentation tasks, often by significant margins. Because of the success of CNN-based models, various aspects of their design and training have been investigated in the past few years. Many of these studies have focused on such aspects as network architecture and loss function. However, it has been shown that factors such as more elaborate network architectures can often only marginally improve the performance of standard CNN-based medical image segmentation methods

[1].

There are two main unresolved issues with regard to the application of CNNs for medical image segmentation. The first issue has to do with the training procedures and training data. Specifically, the number of manually-labeled images that are available for training is typically very small compared with many non-medical applications. This is because the number of images is small to begin with, and accurate manual annotation is costly because it depends on domain expertise. In recent years, this challenge has led to a surge of interest in such techniques as transfer learning [2]

, unsupervised learning

[3], and learning from inaccurate and computer-generated annotations [4].

The second outstanding issue is a lack of understanding of the reliability and failure modes of these models. Deep learning models, in general, are known to produce over-confident predictions, even when the predictions are completely wrong [5]. In other words, there is little correlation between the confidence of a deep learning model in its predictions and how accurate the predictions actually are. Deep learning models also produce confident predictions on out-of-distribution (OOD) data, i.e., when the test data is from an entirely different distribution than the training data [6, 7]. Needless to say, there is no performance guarantee on OOD data. In fact, in general the model predictions on OOD data are not expected to be better than random assignment.

In order to improve the accuracy and reliability of CNN-based medical image segmentation models for real-world clinical usage, effective solutions are needed for the above-mentioned challenges. In particular, we need methods that can train accurate and well-calibrated medical image segmentation models from limited data. Furthermore, we need methods to inform us when these models fail. The goal of this paper is to make significant contributions in addressing these challenges.

Ii Related works

Ii-a Training procedures for CNN-based medical image segmentation models

Large labeled datasets are considered an essential requirement for training of modern deep learning models [8]. Since such datasets are difficult or impossible to come by in medical image segmentation, a range of strategies have been proposed to tackle this limitation. Here, we briefly review the most important classes of these methods.

One strategy is transfer learning [9], whereby the model is first trained on a larger corpus of data from other domains/tasks and then fine-tuned for the intended task. Transfer learning has been reported to improve the performance of CNN-based models on many medical image segmentation tasks [2, 10]. A limitation of transfer learning is that most of the large public image datasets include only 2D images, whereas most medical images are 3D.

Semi-supervised and weakly-supervised methods constitute a large and diverse body of techniques [11, 3]. In brief, these methods aim at utilizing a mix of labeled and unlabeled data or data that have not been labeled in detail. These methods have also been used in deep learning-based medical image segmentation with relative success, as in [12, 13]. One possibility is to use other, less accurate, automatic methods to generate approximate segmentations on large corpora of medical images and use those to train a more accurate CNN-based segmentation model [14, 15]. In general, the applicability and success of semi-supervised methods to a specific task is not certain. It has been recently argued, and experimentally demonstrated, that the gains that have been reported by many semi-supervised methods may need to be reassessed [16].

In general, it is much cheaper and faster to obtain rough segmentations, either manual or computer-generated, on large training datasets. However, rather than treating such approximate segmentation labels as ground truth, as done in [14], one can use more intelligent methods. A comprehensive recent review of deep learning with noisy labels with a focus on medical image analysis can be found in [4]. Several recent studies have reported successful applications of such methods for medical image segmentation [17, 18, 4].

Ii-B Model calibration and uncertainty estimation

All machine learning models are bound to make wrong predictions on a fraction of test data. Nonetheless, one would like the confidence of predictions to be proportional to the probability of being correct. Consider a test set of

and suppose for sample the model predicts the class with a probability . In the ideal scenario with perfect confidence calibration, [19].

Standard deep learning models have been shown to be poorly calibrated [5]. This should be concerning for safety-critical applications including medicine. A range of methods have been proposed for improving the calibration of deep learning models. For example, it has been shown that calibration can be improved by using a proper scoring rule as the loss function [20, 5], using weight decay and avoidingbatch normalization [5]. Training on adversarial examples [21] has also been shown to improve model calibration [20]. Some studies have used the Platt scaling for improving the model calibration [22, 5]. In [23], for instance, after the deep learning network is trained, a model with parameters and

is trained on the logit vector,

, of the trained network to obtain a more calibrated prediction . Another study proposed to train a separate model, to map the uncalibrated output of a CNN to calibrated probabilities [24]. For this purpose, they used a Bayeisan neural network, which they trained after training the main deep learning model.

Prediction uncertainty has also received some attention in medical image segmentation studies. Some studies have proposed methods to estimate the uncertainty

[25] or to use the prediction uncertainty to improve the segmentation accuracy [26]. However, little attention has been paid to methods for improving the calibration of CNN-based segmentation models. An example of the latter is the work of [27], where the authors use model ensembles to arrive at better-calibrated models. That study trained an ensemble of CCNs with random initialization of network weights and random shuffling of training data. They show that the ensemble average is better calibrated than prediction of a single method. However, their proposed method requires training and maintaining as many as 50 separate models, which is quite inefficient for many clinical applications. Incidentally, the same study found that batch normalization improved the model calibration, which is the opposite of the observations reported in [5]. This observation seems to justify a reevaluation of the methods that have been proposed and tested on large-scale natural image datasets in medical image segmentation.

Ii-C Detecting out-of-distribution data and model failure

Another important problem in deep learning is detection of OOD data at test time. Suppose that the training data come from a distribution . A central assumption of every machine learning method is that the test data come from the same distribution. When a data sample comes from an entirely different distribution than , there is no performance guarantee. Ideally, the model should include a mechanism to detect the OOD data samples and issue a warning. However, this has proven to be challenging with deep learning models because of the black-box nature of these models and the highly complex mapping between their input and output.

It has been shown that advancements in network architecture design have not improved the robustness of deep learning models to OOD data [28]. Some studies have proposed methods that increase the robustness of deep learning models to OOD data. As an example, one study showed that some simple techniques such as histogram equalization and Adversarial Logit Pairing [29] may improve robustness to perturbed and corrupted data. However, they noted that methods that work well on specific datasets may fail on other datasets. More importantly, these methods usually focus on in-distribution data that have been slightly perturbed, on which the model performance can be sub-optimal, and do not address the OOD samples, on which the model fails completely. Robustness against true OOD data has no meaning, and such data should be detected and reported/rejected.

Several studies have proposed methods for detecting OOD data in deep learning. For image classification, one study proposed training Gaussian discriminant models on the penultimate layer of the network and using the Mahalanobis Distance to detect OOD data [6]. Another work suggested using the distribution of features in different layers of a deep learning model for OOD detection [30]. The intuition behind that method is that if a test example is in-distribution, the training examples that are most similar to it, in terms of feature similarity, are consistent across layers. Such methods may be effective for natural image classification, where the sizes of the feature vectors is only a few hundreds or around a thousand at most and the number of training images can be millions. However, they cannot be used for 3D medical image segmentation, where feature maps are much larger and typically only tens of training images are available. A number of studies have proposed to detect OOD data based on measures of prediction uncertainty, which is usually quantified as a function of entropy of the predicted class probability [31, 7]. However, such methods have been shown to have a low accuracy in semantic segmentation applications [32].

Compared with image classification, semantic segmentation has received much less attention in terms of OOD detection. A recent study found that methods proposed for OOD detection in image classification do not translate well to image segmentation tasks [33]. For semantic segmentation of street view images, one study proposed a dedicated neural network to detect OOD data [32]

. Their approach aims to classify an image as in-distribution or OOD using a very large “background dataset” to represent the distribution of the variety of visual scenes outside of the training data distribution. The authors use the ILSVRC dataset as the background dataset. However, it is difficult to obtain or even define the background set, especially in medical imaging. One study used prediction uncertainty measures to identify OOD data in medical image segmentation

[27]. However, they evaluated this method on data that were hard to segment, not on true OOD data. As we show in Section IV of this paper, methods based on prediction uncertainty cannot accurately detect OOD data.

We should mention in passing that a related topic to OOD detection is the topic of adversarial examples [21]. These are examples that are intentionally crafted to fool a model into making wrong predictions. Adversarial examples may be important in some medical applications, but they are beyond the scope of this paper, which focuses on natural OOD data.

Ii-D Contributions of this work

In this paper, we address the critical problems discussed above and make the following significant contributions.

  • We propose “training on heterogeneous data” as a new approach for training CNNs for medical image segmentation. Specifically, rather than training a CNN to segment a single organ in a single imaging modality (e.g., prostate in MRI), we propose training a model that segments several different organs in several different imaging modalities.

  • We report, for the first time, the unexpected observation that training on heterogeneous data does not need any changes to the network architecture or training procedures. Surprisingly, the network can learn to automatically recognize the context (i.e., the imaging modality and organ) and accurately segment the organ of interest without any extra input or supervision.

  • We show, for the first time, that training on heterogeneous data can lead to segmentation accuracy on par with or even better than competing methods such as transfer learning. We further show that training on heterogeneous data also improves the model’s confidence calibration.

  • For detecting OOD test data, we devise a novel method based on spectral analysis of the CNN feature maps. We show that this method can detect OOD test data much more accurately than recently proposed methods that are based on prediction uncertainty.

Iii Materials and Methods

Iii-a Data

A large number of datasets were used in this study. We provide a summary of the information about these datasets in Table I. Unless otherwise stated, we used 70% of each dataset for training and validation and 30% for test. All Computed Tomography (CT) images were normalized by a simple linear mapping that mapped the Hounsfield Unit values in the range to intensity range

. All Magnetic Resonance (MR) images were normalized by dividing the image by the standard deviation of the voxel intensities.

name modality organ data size source
CP- younger fetus T2 MRI brain cortical plate 27 In-house (Boston Children’s Hospital)
CP- older fetus T2 MRI brain cortical plate 15 In-house (Boston Children’s Hospital)
CP- newborn T2 MRI brain cortical plate 400 [34]
Liver-CT CT liver 19 [35]
Liver-MRI-SPIR MRI liver 20 [36]
Liver-MRI-DUAL-in MRI liver 20 [36]
Liver-MRI-DUAL-out MRI liver 20 [36]
Heart MRI left atrium 20 https://decathlon-10.grand-challenge.org/
Prostate MRI prostate 32 https://decathlon-10.grand-challenge.org/
Pancreas CT pancreas 281 https://decathlon-10.grand-challenge.org/
Hippocampus MRI hippocampus 260 https://decathlon-10.grand-challenge.org/
Spleen CT spleen 41 https://decathlon-10.grand-challenge.org/
TABLE I: Summary of the information on the datasets used in this study. The first column shows the names that we use to refer to each dataset throughout this paper.

Iii-B Network architecture and training details

We used a network similar to the 3D U-Net [37]

, which we modified by adding residual connections with short and long skip connections. We set the number of features in the first stage of the encoder part of the network to 14, which was the largest possible on our GPU memory. The model worked on

-voxel image blocks. During training, we sampled blocks from random locations in the training images. On a test image, a sliding window approach with a 24-voxel overlap between adjacent blocks was used to process the image. We used the negative of the Dice Similarity Coefficient (DSC) between the predicted and target probability maps as the loss function and Adam [38] as the optimization method. We used an initial learning rate of

, which was reduced by 0.90 after every 2000 training iterations if the loss did not decrease. If the loss did not decrease for two consecutive evaluations, we stopped the training and claimed convergence. This typically occurred after 100-150 training epochs through all training images.

Since the focus of the study is on the training data, model calibration, and OOD detection, we used the same settings mentioned above in all experiments. Admittedly, this may reduce the model accuracy by a small percentage because one can always choose better model size, architecture, or learning rate using cross-validation to achieve slightly better results for a specific dataset. Nonetheless, using the same setting allowed us to remove the effect of these confounding factors and focus on the factors that were the focus of our study.

Iii-C Training on heterogeneous data

The common practice in training CNNs for medical image segmentation is to train a CNN to segment a single organ in a single imaging modality. As we explained above, to cope with the small size of medical image datasets, methods such as transfer learning have also become common. Here, on the other hand, we advocate for a training paradigm that we call training on heterogeneous data. Simply, we train a single model on a mix of training datasets that can come from different imaging modalities with different organs of interest to be segmented, such as the datasets shown in Table I.

We do not change the network architecture or use additional inputs to inform the model of the image modality or the organ that needs to be segmented. Furthermore, we use the same loss function and optimization procedure. In other words, nothing changes compared with training on a single dataset. The only point worth mentioning is the frequency of sampling from different training datasets when their sizes are very different. We sample from each dataset with a probability proportional to the inverse of the square root of dataset size, . This way, if for example we train on two datasets with 10 and 100 images each, the probability of sampling an image from these two datasets will be 0.24 and 0.76, respectively. In our experience, this strategy strikes a good balance in terms of the test performance of the model on different datasets when the training dataset sizes are very different.

Iii-D OOD detection based on the spectral signature of feature maps

We propose a novel method for detecting OOD data samples that are input to a CNN-based medical image segmentation model. As we mentioned above, such models produce over-confident predictions even when a test sample is entirely different from the training data. For example, a network trained on the Liver-CT dataset produces confident (but obviously completely wrong) segmentations on the brain cortical plate. As we show in Section IV, even on such seemingly simple cases, previously-proposed methods based on prediction uncertainty are unable to accurately detect model failure.

Due to the large size of 3D medical images and their computed features, a method based on analyzing the feature maps or the predicted segmentation map in their native space is unclear and likely to be ineffective. Instead, we propose computing the spectrum of the feature maps, which we define as the vector of singular values computed using a singular value decomposition (SVD). Consider a test image

and denote the feature map computed for this image at a certain stage (i.e., layer) of the network with , where denote the dimensions of the feature map and is the number of features. We reshape as and compute the SVD of as , where and are orthonormal matrices and the diagonal matrix contains the singular values of , which is referred to as its spectrum [39]. Values of the vector of singular values depend on the magnitude of the feature values, which in turn depend on the image voxel intensities. Moreover, the spectrum has a very large dynamic range. To eliminate these effects, we take the logarithm of the spectrum and then normalize it so that it has an norm of unity. We refer to the normalized logarithmic spectrum of the feature maps computed as explained above as “the spectral signature” of the feature maps corresponding to an organ of interest. We still denote this spectral signature with in the following.

In Figure 2(a), we have shown examples of how these signatures look like. This figure is for a model trained on several datasets from Table I including CP- younger fetus and Liver-MRI-SPIR datasets but not including Pancreas and Hippocampus datasets. The figure shows example spectral signatures of training images from these four datasets. Clearly, each dataset has a distinct spectral signature. Note that this model segments CP- younger fetus and Liver-MRI-SPIR accurately, but fails completely on Pancreas and Hippocampus, which have not been seen during training. Nonetheless, as we show in Section IV, methods based on uncertainty measures cannot detect these as OOD.

Fig. 2: A demonstration of our proposed OOD detection method in action. These figures were generated from a model that was trained on eight datasets including (See Table I): CP- younger fetus, CP- older fetus, Prostate, Heart, Liver-CT, Liver-MRI-SPIR, Liver-MRI-DUAL-In and Liver-MRI-DUAL-Out. TOP: Spectral signatures of feature maps for four different datasets. Two of these datasets (i.e., CP- older fetus and Liver-MRI-DUAL-In) are from the distribution of the training images, while the other two (i.e., Hippocampus and Pancreas) are OOD. We have shown the spectra for only four datasets in order to avoid clutter. BOTTOM: Histograms of OODM values (computed using Eq. (1)) for training images, in-distribution test images (i.e., test images from the same eight datasets as the training images), and OOD test images. OOD test images are from Pancreas, Hippocampus, and Spleen datasets. The value of the threshold has been marked with the vertical black line.

We suggest detecting OOD data based on the dissimilarity of the spectral signatures. For all images in the training data, , we compute their spectral signatures and save them in a matrix, . Given a test image, , we compute its spectral signature . We then compare with the spectral signature of the training data by computing Out-Of-Distribution Measure (OODM), which we define as:

(1)

In other words, is the Euclidean distance of the spectral signature of to its nearest neighbor in the training set.

We anticipate that for test images coming from the distribution of the training data, OODM should be smaller than for images coming from other distributions. We declare a test image to be OOD if . The threshold is determined using the training data. Specifically, on the training data we compute the vector of using Eq. (1) on a leave-one-out basis, i.e., by comparing the spectrum of each training image with the spectra of all other training images. We then compute as:

(2)

where we set for computing the detection accuracy.

Deep learning models compute a large number of feature maps from an input image. In practice, one can compute the spectral signature on any/all feature map(s). However, we found that using deepest feature maps leads to better results for the purpose of OOD detection. This is in agreement with the known fact that deeper layers provide more disentangled manifolds [40, 41]. In this study, we only worked with the very last (i.e., deepest) feature maps. In our network, the number of channels in this feature map was 14, which was the length of the spectral signatures in this work.

Figure 2(b) shows an example of histograms of OODM values for the training data, in-distribution test data, and OOD test data. The histograms show that the proposed OODM easily separates in-distribution from OOD data in this experiment.

In this study, we compared our proposed OOD detection method with a common strategy based on prediction uncertainty [31]. Specifically, we trained our models using dropout (with a rate of 10%) after all convolutional layers. At test time, we drew random dropout masks and computed the average of these segmentation probability maps. We used the entropy of this mean probability map, as an estimated voxel-wise map of prediction uncertainty. To estimate an image-wise uncertainty, as suggested in [27], we used the average of the voxel-wise uncertainty on the predicted foreground. Similar to our approach with OODM explained above, we computed a threshold similar to Eq. (2) on the training set. This threshold was used to determine if a test image was OOD.

Iii-E Evaluation metrics

We quantify segmentation accuracy using DSC, the 95 percentile of the Hausdorff Distance (HD95) and Average Symmetric Surface Distance (ASSD). To assess model calibration, we compute the Expected Calibration Error (ECE) and Maximum Calibration Error (MCE), as proposed in [42]

. For OOD detection experiments, we report accuracy, sensitivity, specificity. We also compute the area under the receiver-operating characteristic curve (AUC) by changing the value of

.

Iv Results and Discussion

Iv-a Feasibility and benefits of training on heterogeneous data

As we mentioned above, unlike previous studies, we propose training a single model to segment different organs in different imaging modalities. To show that this is a viable approach, we trained a model on seven datasets spanning six different organs in MRI and CT images. We then trained seven separate models, one on each of these seven datasets. We show a comparison of the test performance of these two training strategies in Table II.

 

Training method Data DSC HD95 (mm) ASSD (mm) ECE MCE

 

Training a separate model for each dataset CP- younger fetus
CP- older fetus
Heart
Hippocampus
Prostate
Liver-CT
Liver-MRI-SPIR

 

Training a single model for all datasets CP- younger fetus
CP- older fetus
Heart
Hippocampus
Prostate
Liver-CT
Liver-MRI-SPIR

 

TABLE II: Results of an experiment to compare training a single model on several datasets with training dedicated models separately for each dataset. This experiment included seven different datasets representing six different organs in MRI and CT.

The results are very interesting. They show that training a single model for several different datasets can achieve results that are as good as or even better than when dedicated models are trained separately for each dataset. In terms of segmentation accuracy, a joint model trained on heterogeneous data was overall better than models dedicated to a single dataset. When a dedicated model was better than the joint model, the difference was small, typically within 10%. On the other hand, on some datasets the joint model improved the segmentation accuracy by large margins. For example, on Prostate and Liver-MRI-SPIR datasets, the joint model reduced HD95 and ASSD by factors of 1.55-2.64, which is quite substantial. The joint model was also better-calibrated than the dedicated models on 5 out of 7 datasets. Only on the Hippocampus dataset, the dedicated model was noticeably better than the joint model. It is interesting to note that the Hippocampus dataset included 260 images, compared with 15-32 images in each of the other six datasets used in this experiment. This indicates the influence of dataset size on the potential benefits of training on heterogeneous data.

In retrospect, the fact that a single model can automatically recognize the context and accurately segment the organ of interest is surprising. To the best of our knowledge, this capability of CNN-based medical image segmentation methods has not been reported in any previous study. Figure 3 shows a slice of one test image from each of the seven datasets used in this experiment and the segmentation produced by the joint model trained on all seven datasets. The model accurately segments all seven datasets. Note that we did not use any additional modules or extra inputs to help the model recognize the imaging modality or the organ of interest in training or test stages. In other experiments, we increased the number of datasets to 12, and we observed the same patterns as those shown for the experiment with seven datasets in Table II.

Fig. 3: A slice of a test image from each of the seven datasets used in the experiment reported in Table II and the output segmentation of the joint model that was trained on all seven datasets. Incredibly, this single joint model trained on all seven datasets was able to accurately segment different organs in different modalities. Moreover, it performed as well as or even better than seven dedicated models trained to segment each dataset separately.

In order to show an important potential benefit of training on heterogeneous data, we compare it with transfer learning in an experiment involving the three cortical plate datasets (See Table I). As shown in the example images and segmentations in Figure 4, the shape and complexity of cortical plate evolves dramatically before and right after birth. In addition, the sizes of the three datasets are highly unequal. CP- younger fetus dataset includes 27 images with postmenstrual age of weeks, CP- older fetus includes 15 images with age of weeks, and CP- newborn includes 400 images with age of weeks. The question is, given the complexity of this segmentation task and the small size of two of the datasets, what is the best training strategy to achieve high segmentation accuracy on all three datasets?

Fig. 4: Example axial slices of the images and segmentations from the three cortical plate segmentation datasets used in this study. From left to right, the images come from CP- younger fetus, CP- older fetus, and CP- newborn. Postmenstrual age of each subject is displayed above the image.

Given the much smaller sizes of two of the datasets, transfer learning is the method that is recommended by previous studies ([2, 3]). In Table III, we compare the results obtained using different transfer learning trials with the results obtained using training on heterogeneous data, i.e., training a single model on all three datasets. In each of the transfer learning trials, we first trained the model to convergence on one of the datasets. We then fine-tuned it to convergence on another dataset. We then further fine-tuned the model that had been trained on the second dataset on the remaining dataset. Our definition of convergence is the same as defined in Section III-B. Our fine-tuning strategy was “deep fine-tuning” [2]; we reduced the initial learning rate by half and fine-tuned all model layers. We also experimented with shallow fine-tuning as well as keeping the initial learning rate, but the results were inferior.

 

Training/fine-tuning data Test data DSC HD95 ASSD ECE MCE

 

Train on CP- younger fetus CP- younger fetus
      Fine-tune on CP- older fetus CP- older fetus
          Fine-tune on CP- newborn CP- newborn

 

Train on CP- older fetus CP- older fetus
      Fine-tune on CP- younger fetus CP- younger fetus
          Fine-tune on CP- newborn CP- newborn

 

Train on CP- newborn CP- newborn
      Fine-tune on CP- younger fetus CP- younger fetus
          Fine-tune on CP- older fetus CP- older fetus

 

Train a single model for all datasets CP- younger fetus
CP- older fetus
CP- newborn

 

TABLE III: Results of experiments on cortical plate segmentation. We compare three different transfer learning approaches with our proposed method of training on heterogeneous data, i.e., training a single model to segment all three datasets. For each of the three datasets, we have highlighted the best results using bold type.

The results are very interesting. Transfer learning improved the segmentation accuracy in some cases, but in most cases the improvement was very small. Training a joint model on all three datasets, on the other hand, achieved segmentation accuracy results that were on par with or better than any of the transfer learning trials. For the smallest dataset, i.e., CP- older fetus, the joint model achieved the best results in terms of DSC, HD, and ASSD. Furthermore, the joint model had better-calibrated predictions than all of the three transfer learning approaches on all three datasets.

An additional appeal of a joint model that accurately segments all three datasets is its universality. This implies that we will need to maintain only one set of model weights. On the other hand, a model that has been trained on any single one of these datasets will have a poor performance on the other datasets. Therefore, we will need to maintain three separate trained models, one for each dataset. Moreover, for a test image, we will need to know which of the three datasets the image belongs to, in order to use the right model on that image.

Iv-B Detecting OOD test data

In this section, we present the results of our proposed OOD detection method in three different experiments and compare it with the method based on prediction uncertainty.

In the first experiment, we used a mixture of eight different datasets for training. These included CP- younger fetus, CP- older fetus, Prostate, Heart, Liver-CT, Liver-MRI-SPIR, Liver-MRI-DUAL-In and Liver-MRI-DUAL-Out datasets. Then, we applied the proposed OOD detection method on the trained model. We used test images from the same eight dataset as in-distribution data. As OOD data, we used Pancreas, Hippocampus, and Spleen datasets. Histogram of the proposed OODM for this experiment has been shown in Figure 2(b). Table IV shows comparison of our method with the method based on prediction uncertainty. Our method perfectly detected the OOD images, but the method based on prediction uncertainty failed.

Method accuracy sensitivity specificity AUC
Proposed method
Uncertainty-based
TABLE IV: Comparison of the proposed OOD detection method with the method based on prediction uncertainty. In this experiment, the in-distribution data came from CP- younger fetus, CP- older fetus, Prostate, Heart, Liver-CT, Liver-MRI-SPIR, Liver-MRI-DUAL-In and Liver-MRI-DUAL-Out datasets. The OOD data came from Pancreas, Hippocampus, and Spleen datasets.

In the second experiment, we trained a model on the CP- newborn dataset. We then applied it on the test data from the same dataset and on the other two cortical plate datasets. The histograms of OODM values for this experiment have been show in Figure 5. The OODM values for both CP- younger fetus and CP- older fetus fall outside of the distribution of the OODM values for CP- newborn. This model, trained only on the CP- newborn, achieved DSC values of and on the CP- younger fetus and CP- older fetus datasets, respectively. These are very low values, compared with the results shown for these datasets in Table III. Therefore, for this model, images from both CP- younger fetus and CP- older fetus datasets should be considered as OOD. Our proposed method easily distinguished OOD data from in-distribution data. It is interesting to note that OODM values for CP- younger fetus dataset are distributed farther away, compared with those of CP- older fetus dataset. This makes sense because as shown in Figure 4, CP- younger fetus is less similar to CP- newborn than CP- older fetus is. Table V shows comparison of our method with OOD detection based on prediction uncertainty. Compared to our method that perfectly separated in-distribution and OOD data, the method based on uncertainty prediction showed a very low accuracy.

Fig. 5: Histograms of OODM values (computed using Eq. (1)) for an experiment on cortical plate segmentation. This model was trained on CP- newborn dataset. The value of the threshold has been marked with the vertical black line.
Method accuracy sensitivity specificity AUC
Proposed method
Uncertainty-based
TABLE V: Comparison of the proposed OOD detection method with the method based on prediction uncertainty in an experiment on cortical plate segmentation. The model is trained on CP- newborn data. The data from CP- younger fetus and CP- older fetus datasets are used as OOD data.

As the final experiment in OOD detection, we report the results of an experiment with the three liver MRI datasets (See Table I). A slice of one sample image from each of these datasets has been shown in Figure 6. This is a very interesting example because it demonstrates that OOD data are often not easy to distinguish visually. We experimented extensively with these three datasets. We observed that when we trained a model on Liver-MRI-SPIR and Liver-MRI-DUAL-In, it segmented images from Liver-MRI-DUAL-Out with good accuracy (mean DSC= 0.89). Similarly, a model trained on Liver-MRI-SPIR and Liver-MRI-DUAL-Out, achieved a mean DSC of 0.86 on images from Liver-MRI-DUAL-In. Even a model that was trained on Liver-MRI-DUAL-SPIR alone, could segment Liver-MRI-DUAL-In and Liver-MRI-DUAL-Out images accurately. On the other hand, a model trained on Liver-MRI-DUAL-In and/or Liver-MRI-DUAL-Out failed on images from Liver-MRI-SPIR (mean DSC 0.40).

These observations are not intuitive, and they are not at all easy to foretell by visually inspecting these images. This example further highlights the importance of OOD detection in CNN-based medical image segmentation.

Fig. 6: TOP: An axial slice of one image from each of the three liver MRI datasets. BOTTOM: A model trained on Liver-MRI-SPIR performed well on images from the other two datasets. However, a model trained on Liver-MRI-DUAL-In and/or Liver-MRI-DUAL-Out completely failed on images from Liver-MRI-SPIR. Green ✓and red ✗  symbols, respectively, denote success and failure on a dataset at test time.

Figure 7(a) shows the histograms of the proposed OODM values for an experiment with these datasets. In this experiment, Liver-MRI-DUAL-In and Liver-MRI-DUAL-Out were used to train a model. The OODM values were then computed on the test data from the same two datasets as well as the data from Liver-MRI-SPIR, which are OOD for this model. The figure shows that the proposed OODM easily separates in-distribution from OOD data in this experiment. In Table VI, we compare the proposed method with the method based on prediction uncertainty in this experiment. Similar to the two experiments presented above, the uncertainty-based method has a very low accuracy, whereas our proposed method achieves perfect detection accuracy. For completeness, Figure 7(b) shows the OODM histograms for an experiment in which Liver-MRI-SPIR and Liver-MRI-DUAL-In datasets were used for training. The trained model works well on Liver-MRI-DUAL-Out dataset as well. Therefore, all three datasets are in-distribution data. As expected, the OODM values for most images from Liver-MRI-DUAL-Out fall below the threshold , and hence correctly classified as in-distribution.

Fig. 7: (a) Histograms of OODM values for an experiment on liver segmentation in MRI. The in-distribution data in this experiment included Liver-MRI-DUAL-In and Liver-MRI-DUAL-Out datasets, which were used to train the model. The OOD data included Liver-MRI-SPIR dataset, on which the model failed at test time. The value of the threshold has been marked with the vertical black line. The proposed OODM perfectly separated the OOD data from in-distribution data. (b) In this experiment, Liver-MRI-SPIR and Liver-MRI-DUAL-In were used for training. At test time, in addition to these two dataset, the model accurately segmented Liver-MRI-DUAL-Out dataset (DSC= 0.886). As can be seen, the OODM values for Liver-MRI-DUAL-Out are distributed very similar to the OODM values for the training data.
Method accuracy sensitivity specificity AUC
Proposed method
Uncertainty-based
TABLE VI: Comparison of the proposed OOD detection method with the method based on prediction uncertainty in an experiment on liver MRI datasets. In this experiment the model was trained on Liver-MRI-DUAL-In and Liver-MRI-DUAL-Out datasets. The data from Liver-MRI-SPIR dataset are used as OOD.

V Conclusion

The methods proposed in this study represent significant progress towards improving the confidence calibration and OOD detection for CNN-based medical image segmentation models. These are important contributions because they improve the reliability of these models and facilitate their wider adoption in medical and clinical settings.

We showed, for the first time, that standard CNN-based segmentation models can automatically recognize the context and segment the organ of interest in a large pool of heterogeneous datasets. We showed experimentally that such a joint model achieved segmentation accuracy on par with or better than dedicated models trained separately on each dataset. Our experiments also showed that models trained on heterogeneous data usually have much better-calibrated predictions. These are very encouraging results. For example, as we showed in our experiment on cortical plate segmentation, this means one could train a single model to cover images from a wide rage of age groups. Not only such a model can have more accurate and better-calibrated predictions, one would need to maintain a single model that would work on all age groups, without the need to know the age of the subject at test time. Such situations are quite common in medical applications, where the available training data may show large variability in terms of subject age, body size, imaging modality, image quality, etc. While an investigation of all these factors is beyond the scope of a single study, our results show that CNN-based medical image segmentation models have the potential to handle such sources of data heterogeneity easily and effectively.

Our proposed OOD detection method can also be very valuable in practice. Whereas most previous studies have used measures of prediction uncertainty for this purpose, our experiments show that such methods can be inaccurate. To the best of our knowledge, this is the first study to propose a method for OOD detection in medical image segmentation by analyzing CNN features. In three different experiments, our proposed method based on spectral analysis of CNN feature maps accurately detected OOD images. As we showed in our experiment on liver segmentation in MRI, visually identifying OOD data can be quite non-trivial. Therefore, reliable deployment of CNN-based segmentation methods for medical applications requires accurate OOD detection methods to alert the user of the model failure. While this has been a challenging problem because of the massive size and complexity of deep learning models, our proposed method offers an effective solution to this problem.

References