Organ segmentation is an important processing step in medical image analysis, e.g., for image guided interventions, radiotherapy, or improved radiological diagnostics. A plethora of single/multi-organ segmentation methods including machine/deep learning approaches has been introduced in the literature for different medical imaging modalities, e.g., magnetic resonance imaging, and positron emission tomography (PET).
More recently, deep learning based medical image segmentation approaches have gained great popularity [yuan2017automatic, litjens2017survey, baumgartner2017exploration, milletari2017hough, wang2017deepigeos, bentaieb2016topology]
. Several deep convolutional segmentation models in the form of encoder-decoder networks have been proposed for both medical and non-medical images to learn features and classify/segment images simultaneously in an end-to-end manner e.g., 2D U-Net[ronneberger2015u], 3D U-Net [cciccek20163d], 3D V-Net [milletari2016v], 2D SegNet [badrinarayanan2017segnet]. These models with/without modifications have been widely applied to both binary and multi-class medical image segmentation problems.
When performing segmentation especially using deep networks, one has to cope with two types of imbalance issues:
a) Input imbalance or inter-class-imbalance during training, i.e., much fewer foreground pixels/voxels relative to the large number of background voxels in binary segmentation, and smaller objects/classes in a multi-class segmentation relative to other larger objects/classes and the background. Therefore, classes with more observations (i.e., voxels) overshadow the minority classes.
b) Output imbalance. During inference, it is unavoidable to have false positives and false negatives. False positives are the background voxels (or other objects in the case of multi-class) that are wrongly labeled as the target object. False negatives refer to the voxels of a target object that are erroneously labeled as background or, in the case of multi-organ segmentation, mislabelled as another organ. Clearly, eliminating both false positives and false negatives is the ultimate ideal. However, in practical systems, one increases as the other decreases. For certain applications, reducing the false positive (FP) rate is more important than reducing the false negative (FN) rate or vice versa. The following are example cases where FP should be penalized more: to handle missing organs and to prevent a model from segmenting normal active regions in PET image segmentation (i.e., relatively high intensity regions compared to background which are considered as neither lesion nor organ). PET organ segmentation is useful when a corresponding computed tomography (CT) image is not available to help with detecting organs. Even if a corresponding CT image is available, usually PET and CT need to be registered. In contrast, for some other applications false negatives should be penalized more, e.g., in ultrasound image segmentation where the boundaries of organs are not very clear, target regions might be under-segmented or in magnetic resonance imaging (MRI) segmentation, small spaces of unsegmented regions within a segmented area might be produced. However, conventional loss functions lack a systematic way of controlling the trade-off between false positive and false negative rates.
A key step when training deep networks on imbalanced data is to properly formulate a loss function. U-Net (both 2D and 3D) and 2D SegNet minimize cross entropy loss to mimic ground truth segmentation masks for an input image while 3D V-Net applies a Dice based loss function.
Cross entropy is commonly used as a loss function in deep learning. Although it can potentially control output imbalance i.e., false positives and false negatives, it has sub-optimal performance when segmenting highly input class-imbalanced images [badrinarayanan2017segnet]. There are several ways of handling input imbalance in general classification tasks, e.g., random over/under sampling, synthetic minority over-sampling technique (SMOTE) [chawla2002smote]. Similar to SMOTE, the threshold calibration method, introduced by Pozzolo et al. [dal2015calibrating], operates at the data-level, i.e., it requires the data to be undersampled first. However, this cannot be used when the input is an image and we deal with classifying pixels/voxels (i.e., segmentation). To be specific, it would be meaningless to undersample an image by removing only some of its majority class (e.g., background) pixels/voxels in the case of using full-volumes. Although in patch based approaches, patches can be selected in way to handle the imbalance during training, they do not encode full contextual information and the choice of patch size is not straightforward. Therefore, several different techniques such as weighted cross entropy [sudre2017generalised], median frequency balancing as used in 2D SegNet [badrinarayanan2017segnet], the Dice optimization function as used in the 3D V-Net method [milletari2016v], and a focal loss function [lin2017focal] have been proposed.
Among all methods introduced for tackling the input-imbalance problem, the Dice based loss function has shown better performance for binary-class segmentation problems [sudre2017generalised]. However, the ability of the Dice loss function to control the trade-off between false positives and false negatives (i.e., output imbalance) has not been explored in previous works. Controlling the trade-off is not a trivial issue for some types of medical images and it is not easily handled by a classical Dice optimization function.
Table I, lists previous works that used different loss functions to cope with input/output imbalance. As reported in the table, none of the current loss functions are able to explicitly handle both input and output imbalance. Some other works attempted to enhance output imbalance in the segmented images using post processing techniques, e.g., Hu et al., [hu2017automatic] applied an energy based refinement step to improve the CNN segmentation results. Similarly, Gibson et al., [gibson2017towards] applied a threshold based refinement step to cope with false positives produced by their convolutional neural network (CNN) based organ segmentation model. Yang et al. [yang2017automatic] also applied a post processing step to reduce both false positives and false negatives in segmented images. In this paper, we leverage both the cross entropy and the Dice optimization functions to define a new loss function that handles both of the aforementioned input and output imbalance types by using global spatial information driven by Dice term and explicitly and gradually enforcing the trade-off between FNs and FPs by cross entropy term.
In this paper, we make the following contributions: a) We introduce a curriculum learning based loss function to handle input and output imbalance (in algorithmic-level) in segmentation problems. b) Our proposed loss improves previous deep models namely 3D U-Net, 3D V-Net, and our extended version of 2D SegNet i.e., 3D SegNet in both training and testing accuracy for single and multi-organ segmentation from different modalities. c) The proposed loss function, by controlling the trade-off between false positives and negatives, is able to handle missing organs i.e., by penalizing the false positives more. d) We extend 3D U-Net and 3D V-Net from binary to multiclass segmentation models. e) We introduce the first deep volumetric multi-organ semantic model which simultaneously segments and classifies multiple organs from whole body 3D PET scans.
Given a medical image volume, the goal is to predict the class of each voxel by assigning an activation value to each voxel x. We adopt a deep learning technique to learn a prediction model , where denotes the model parameters and is activation value for organ/class .
Cross Entropy Loss Function. For multi-class problems, the cross entropy loss can be computed as where
is the predicted probability mass function (PMF), which assigns a probability/activation value to each class for each voxel, and
is the one-hot encoded target (or ground truth) PMF, where the indexiterates over the number of organs and over the number of the samples (i.e., voxels). can be computed as a sum of several binary cross entropy terms, which for some multi-class problems, as in this paper, makes it possible to have control over false positives/negatives. In the case of binary classification, can be rewritten as . The term penalizes false positives as it is zero when the prediction is correct. The binary formulation can also be extended and used for multi-class problems as where . Therefore, the output is an average of multiple binary cross entropies.
Dice Optimization Function. The Dice function is a widely used metric for evaluating image segmentation accuracy, which can be written in forms of or . It can also be rewritten as a weighted function to generalize into multi-class problems [sudre2017generalised]. However, when it is used as an optimization/loss function, it is not possible to control the penalization of either FPs or FNs separately or their trade-off in the above formulations. In the binary case, the generalized/weighted Dice loss function [sudre2017generalised] is written as
where and are the reference foreground segmentation with voxel values and predicted segmentation with voxel values . However, similar to the original Dice, in this formulation it is not possible to explicitly control the trade-off between FPs and FNs. Moreover, the GDL formulation requires the whole volume to produce meaningful weights (i.e., ), but in most cases, because of limited GPU memory, the segmentation should be performed on sub-volumes. It is also possible to use weighted version of Dice also known as F score [hashemi2018asymmetric]
to control the trade-off between precision and recall. However, in case of using Dice (F1 or its weighted version F
) or GDL with sigmoid activation function in output layer of the network to model probabilities, the derivative of the loss in the back-propagation with respect to a specific weightin layer looks like:
is the derivative of the sigmoid activation function. When a neuron has a value close to 0 or 1, the gradient of the sigmoid is very small. As a result, the gradient of the whole cost function with respect towill become very small. Such a saturated neuron will change its weights very slowly. Note that in equation above Dice (F1) can be replaced by F or GDL. However, in case of using cross entropy the gradient is computed as
Here gradient is not affected by anymore, so the gradient only depends on the neuron’s output, the target and the neuron’s input
. This avoids learning slow-down and helps with the vanishing gradient problem from which deep neural networks suffer.
Combo Loss. To leverage the Dice function that handles the input class-imbalance problem, i.e., segmenting a small foreground from a large context/background, while at the same time controlling the trade-off between and and enforcing a smooth training using cross entropy as discussed above, we introduce our loss as a weighted sum of two terms: A Dice loss and a modified cross entropy to encode curriculum learning, and is written as:
where controls the amount of Dice term contribution in the loss function , and controls the level of model penalization for false positives/negatives: when is set to a value smaller than 0.5, are penalized more than as the term is weighted more heavily, and vice versa. In our implementation, to prevent division by zero, we perform add-one smoothing (a specific instance of the additive/Laplace/Lidstone smoothing) [russell2016artificial], i.e. we add unity constant to both the denominator and numerator of the Dice term. Although the proposed loss seems to be simply combining two different loss functions, we deliberately chose the binary version of the cross entropy to enable us to explicitly enforce a the intended trade-off between false positives and negatives using the parameter (equation 1) and, at the same time, keep the model parameters out of bad local minima via the global spatial information provided by Dice term.
After sigmoid normalization over all the channels (i.e., classes) in the last layer, the Combo loss function is computed using the flattened volumes (one-hot multi-label encoding for both the predicted and ground truth volumes containing several objects) of size where , , , and refer to width, height, depth, and number of channelsclasses. This strategy makes it simple to generalize to multi-class segmentation hence directly controlling FPs and FNs over entire volume.
Model Parameter Optimization. To optimize the model parameters
to minimize the loss, we use error back propagation, which relies on the chain rule. We calculate the gradient ofwith respect to , i.e., ,
Then we calculate how the changes in the model parameters in the last layer of the deep architecture affect the predicted , and so on.
|2D SegNet [badrinarayanan2017segnet]||No||Cross entropy||Specific||16,375,169|
|3D U-Net [cciccek20163d]||Yes||Cross entropy||Regular||12,226,243|
|3D V-Net [milletari2016v]||Yes||Dice||Regular||84,938,241|
Deep Model Architecture. We use the deep architecture shown in Fig. 1. This architecture departs from existing architectures like 3D U-Net, 3D U-Net, and 2D SegNet as listed in Table II. We adopt this simple network to show that the improvement in results is not attributed to some elaborate architecture and to validate our hypothesis that, even with a simple shallower architecture as long as a proper loss function is used, it is possible outperform more complex architectures e.g., networks with skip connections [ronneberger2015u, cciccek20163d, milletari2016v] or specific up-sampling [badrinarayanan2017segnet].
Training. For multi-organ segmentation from whole body PET images, as the volumes are too large to fit into memory, we extract random sub-volumes from each whole body scan to train a model. Each sub-volume could include voxels belonging to organs, with
indicating a sub-volume including only background. However, for binary segmentation i.e., 3D ultrasound and MRI datasets, we train using the entire volumes. On test data (only for PET), we apply a volumetric sliding window (with stride), i.e., a volumetric field of viewis partitioned into smaller sub-volumes , where the size of is the same as that of the training sub-volumes. Along any of the dimensions, the stride would be at least 1 voxel and at most the size of the sub-volume in that dimension. Larger strides speed up the computation at the expense of coarser spatial predictions. Let be a subvolume with activation , be the set of subvolumes that include , be the set of corresponding activation values. is the set of indicator variables whose value is if the activation is larger than t, and 0 otherwise, where t is a threshold value. Then, the the label assigned to voxel is given by: . In other words, a single voxel may reside within multiple overlapping subvolumes; if the activation of any these subvolumes is larger than threshold , then is assigned , and otherwise.
Iii Implementation details
a) PET multi-organ segmentation: For training the PET multi-organ segmentation network, from each training image, we extract 100 randomly positioned -voxel sub-volumes per organ (5 organs in total: brain, heart, left kidney, right kidney, and bladder) and another 100 for negative background sub-volumes. Therefore, we train all the models with number of training volumes. In test, the striding size was set to . PET volumes size varied from to . We train and test all the models using two Titan-X GPUs in parallel each with batch-size 1.
b) Ultrasound echocardiography and prostate MRI segmnetation: We train and test all the models on these datasets with whole-volume images (i.e., not sub-volumes) of size and for ultarsound and MRI datasets, respectively using two M5000 GPUs in parallel each with batch-size 2. As explored by Masters et al., [masters2018revisiting], small mini-batch sizes can provide more up-to-date gradient calculations, which results in more stable and reliable training while reducing over-fitting more compared to larger batch sizes.
As MRI and ultrasound images are taken from a part of the body, the number of slices per volume are relatively less compared to PET whole body volumes. To prevent sliding window for both training and testing and fitting whole MRI and ultrasound volumes into memory, we slightly resampled MRI and Ultrasound images without losing much information causes by resampling. However, PET volumes should be highly resized in order to be fitted into memory which results in considerable accuracy drop. Therefore, we did not resample PET images.
For all datasets, we initialize our models and competing methods using the method introduced by and Bengio [glorot2010understanding] and train them with ADADELTA [zeiler2012adadelta], with learning rate of 1, , , andioffe2015batch]
. It also allows us to use higher learning rate. Similar to how hyperparameters values are selected in deep models, e.g., learning rate and pooling window size, the optimal values forand were also found by grid search to optimize results on the validation set (i.e., one round of cross-validation). We found that the equal contribution (i.e., ) of Dice and cross-entropy terms gives the best results. However, we found that for the PET data, models need to be penalized more for false positives (i.e., ) and for MRI and ultrasound data models need to be penalized more for false negatives (i.e., for MRI and for ultrasound images). For the last layer of the proposed method, we applied the sigmoid activation function as it allowed us to compute the loss over only foreground objects (i.e., there is no extra channel for the background class, as softmax function requires) and then normalize the output into the range [0-1]. To obtain the segmentation masks we use threshold of 0.5.
All the models have been trained for a fixed number of epochs and we report the results for the best epoch based on the validation set. Note that for the competing methods we set the hyper-parameters as proposed by the authors of these methods. For fairness and to elucidate the direct effect of the proposed Combo loss, when we replace the original loss functions of the competing methods with Combo loss (TablesIII and V), we do not change the original network hyper-parameters.
For evaluation, we use three different datasets: a) 58 whole body PET scans of resolution . We randomly pick 10 whole body volumes for testing and train with the 48 remaining volumes. We normalize the intensity range of our training and testing volumes using the min-max method based on min and max intensity values of the whole training set. Next, in both training and testing, each single sub-volume is also normalized to using its min and max before feeding it into network. b) 958 MRI prostate scans of different resolution which were resampled to voxel size of . We randomly picked 258 volumes for testing, and train with remaining 700 volumes. c) hlUltrasound echocardiography images of resolution , used for left ventricular myocardial segmentation, were split into 430 train and 20 test. The datasets were collected internally and from The Cancer Imaging Archive (TCIA) QIN-HEADNECK and ProstateX datasets [QINheadNeck2015, QINheadNeck2016, ProstateX2017, litjens2014computer, clark2013cancer]. Samples of the three datasets are shown in Fig. 2.
Our evaluation is divided into 2 parts. First, in subsection V-A, we compare, both qualitatively and quantitatively, the performance of all the competing methods to the proposed method on the test data, for multi-organ segmentation from PET scans. We test different modification/variants of the proposed loss with the proposed architecture, i.e., cross entropy optimization (, weighted cross entropy (), Dice optimization (), Dice + cross entropy optimization (), and the proposed loss (). DCE refers to simply integrating Dice and traditional cross entropy losses, whereas, Combo refers to combining the weighted version of cross entropy with Dice. Second, in subsection V-B, we perform similar experiments to subsection V-A for single organ segmentation from two more different modalities, i.e., MRI and ultrasound scans.
V-a Performance of the proposed vs. competing methods on multi-organ PET segmentation
|a||3D U-Net [cciccek20163d]|
|b||3D V-Net [milletari2016v]|
|d||Ahmadvand et al. [ahmadvand2016tumor]|
|PET (coronal)||GT||Ahmadvand et al. [ahmadvand2016tumor]||3D SegNet||3D U-Net [cciccek20163d]||3D V-Net [milletari2016v]|
We treat the multi-class case as binary, i.e., the one-hot multi-label encoding for the both the predicted and ground truth volumes containing several objects are flattened and the Combo loss is computed. In this case, similar to binary segmentation, balancing the false positives and negatives improves segmentation. As reported in Table III, the proposed architecture with proposed loss () outperforms all competing methods with , , in Jaccard, Dice and FPR, respectively. Comparing rows of section a in Table III, we note that: Modified 3D U-Net improves with our proposed loss (Combo) relatively by 23.6%, 14.5%, 56%, and 30% in Jaccard, Dice, FPR, and FNR, respectively. Comparing rows of section b, we note that: 3D V-Net improves with our proposed loss relatively by 5.8%, 4.5%, and 28%, in Jaccard, Dice, and FPR, respectively. Section c shows that 3D SegNet improves with our proposed loss by relatively 34.1%, 23.2%, 44%, and 12.5% in Jaccard, Dice, FPR, and FNR, respectively. Comparing vs. in section e of Table III shows that WCE helps. Comparing vs. shows that the proposed Combo loss improves the results. Although the results and formulation of Dice + original cross entropy (i.e., DCE) and Combo loss are close, it is important to note that, in the Combo loss formulation, we weight the two terms of the original cross entropy so we can enforce the intended trade-off between FP and FN.
As shown in Figure III, although 3D U-Net, 3D V-Net, and the extended version (3D) of SegNet are able to locate the normal activities (bright areas in the image because of absorbing radio-tracer. The look very similar to abnormalities) and segment them, two issues are visible: a) misclassification of organs: the competing methods were not successful in distinguishing the organs from each other, as sometimes the brain (red) has been labeled as bladder (black); b) the competing methods tend to produce false positives i.e., wrongly labeling some background voxels as an organ (or one organ as another) or missing an organ (false negative). As shown in the figure, still produces false positives, but no misclassification of organs. shows clearer segmentations, however, as we penalize the false positives more with the proposed loss we obtain much clearer outputs (last columns: ). The performance of the proposed method was evaluated for each specific organ and reported in Table IV.
|GT||3D SegNet||3D SegNet_Combo||3D U-Net [cciccek20163d]||3D U-Net_Combo||3D V-Net [milletari2016v]||3D V-Net_Combo|
Over all the organs, Dice scores for the proposed method (proposed architecture + Combo loss) ranges from to . We show the worst, an in-between and the best results in terms of Dice score in Fig. 5. Although the left case in the figure seems to be the worst result in terms of Dice score, it is a difficult case with several missing organs. However, the proposed method has been able to handle multiple missing organs to a high extent. Note that some organs c an be physically absent from a patient body, as in renal agenesis or radical (complete) nephrectomy, but in PET scans, there might be more ”missing” organs (similar to the left case in Fig. 5) simply because of lack of radiotracer uptake in these organs thus they do not appear in PET. Although, in training, Dice score improvement compared to 3D V-Net is small, as shown in Figure 4, in test, proposed loss helped 3D V-Net in terms of reducing organ misclassification and false positives. Looking at both Table III and Fig. 4, 3D U-Net and 3D SegNet achieved higher performance when incorporating the proposed loss.
V-B Performance of the proposed vs. competing methods on single organ segmentation from MRI and ultrasound
For MRI and ultrasound datasets, we observed that all the methods are more prone to false negatives than false positives, so we weigh more the false negative term of the proposed loss (i.e., increase to 0.9). As reported in Table V, similar to results in Section V-A, the Combo loss function improved 3D U-Net and 3D V-Net by 4.6% and 1.13% in Dice and 43.8% and 16.7%in FNR, respectively, for MRI prostate segmentation. Similarly, 3D U-Net and 3D V-Net results were improved by 8.23% and 3.4% in Dice and 33.3% and 16.7%in FNR, respectively, for ultrasound left ventricular myocardial segmentation.
|3D U-Net [cciccek20163d]|
|3D V-Net [milletari2016v]|
|3D U-Net [cciccek20163d]|
|3D V-Net [milletari2016v]|
As can also be seen in Table V
, the proposed loss also helps reduce the variance of the segmentation results.
We also compared the proposed loss function with the recently introduced Focal loss function [lin2017focal]. Our integrative loss function outperformed Focal loss after both were applied to different networks (Table V). We applied Focal loss to the best performing competing method for each dataset i.e., 3D V-Net for MRI and 3D U-Net for ultrasound dataset. For Focal loss, we tested several different values for and , but as suggested by the authors we obtained better results with and . Note that there is no correspondence between the alpha used in the Focal loss paper (the weight assigned to the rare class) and the one we use in Combo loss equation (the weight that controls the contribution of Dice and cross entropy terms). For the MRI dataset, the proposed Combo loss outperformed Focal loss by and in Dice and FNR, respectively, when both were used in 3D V-Net. For the ultrasound dataset, Combo loss outperformed Focal loss by in Dice. In Figure 6, we plot both Dice and Hausdorff distance (HD) of the Combo loss vs. competing methods. As shown in the figure, the proposed method outperforms the competing methods in terms of Dice score. Comparing both Dice and Hausdorff distance values of the competing methods, after applying Combo loss (i.e., U_C, and V_C) in Figure 6
, the range of the values are smaller, i.e., less outliers compared to when they use original loss (i.e., U and V).
Among the competing methods, U-Net applies cross entropy loss while V-Net leverages Dice loss. To show the direct contribution of the Combo loss, we replace the original loss functions in U-Net and V-Net with Combo (Table V). As reported in Table V, after replacing cross entropy loss of U-Net with Combo loss, the Dice scores improve from 0.87 to 0.91 and 0.85 to 0.92 for MRI and ultrasound datasets, respectively. Similarly, when replacing the Dice loss function of V-Net with the proposed Combo loss, the segmentation results improve from 0.88 to 0.89 and 0.84 to 0.87 for MRI and ultrasound datasets, respectively.
Parameter controls the contribution of Dice and cross entropy terms while parameter in the second term, i.e., cross entropy, controls the trade-off between false positives and negatives. As a key contribution of the paper is providing the means to explicitly control output balance, i.e., false positives and negatives, we tested several different values for parameter beta to see how the final results are affected by and we fix parameter that controls the trade-off between Dice and cross entropy to 0.5. In Figure 7, we show the different Dice and HD results obtained from different values, which control false positives and false negatives. As expected, we note that the final segmentations are affected by the choice of parameter beta and the best results in terms of higher Dice and lower Hausdorff distance were obtained for and for ultrasound and MRI datasets, respectively. As HD is sensitive to outliers, there are sometimes relatively large values in the HD results (i.e., second column in the figure)
In this paper, we proposed a curriculum learning based loss function to handle input/class-imbalance and output imbalance (i.e., enforcing the trade-off between false positives and false negatives). Note that enforcing a desired trade-off between false positives and false negatives can be seen in Tables III and V). Noting the change in FPR and FNR values of 3D U-Net, 3D V-Net, and 3D Seg-Net when they apply Combo loss, we see that FPR or FNR is severely decreased when the models are penalized for FP or FN, respectively (for PET data i.e., Table III, the Combo loss penalizes FP and for MRI and ultrasound data i.e., Table V, it penalizes FN). The proposed loss function resulted in improved performance in both multi- and single-organ segmentation from different modalities. The proposed loss function also improved the existing methods in terms of achieving higher Dice and lower false positive and false negative rates. In this work, we applied the proposed loss function to a multi-organ segmentation problem, but it can simply be leveraged for other segmentation tasks as well. The key advantage of the proposed Combo loss is that it enforces a desired trade-off between the false positives and negatives (which results in cutting out post-processing) and avoids getting stuck in bad local minima as it leverages Dice term. The Combo loss converges considerably faster than cross entropy loss during training. Similar to Focal loss, our Combo loss also has two parameters that need to be set. In this work, we used cross-validation to set the hyperparameters (including and of our proposed loss). Future work can explore using Bayesian approaches [snoek2012practical, murugan2017hyperparameters].
We thank NVIDIA for GPU donation.