Stochastic Segmentation Networks
In image segmentation, there is often more than one plausible solution for a given input. In medical imaging, for example, experts will often disagree about the exact location of object boundaries. Estimating this inherent uncertainty and predicting multiple plausible hypotheses is of great interest in many applications, yet this ability is lacking in most current deep learning methods. In this paper, we introduce stochastic segmentation networks (SSNs), an efficient probabilistic method for modelling aleatoric uncertainty with any image segmentation network architecture. In contrast to approaches that produce pixel-wise estimates, SSNs model joint distributions over entire label maps and thus can generate multiple spatially coherent hypotheses for a single image. By using a low-rank multivariate normal distribution over the logit space to model the probability of the label map given the image, we obtain a spatially consistent probability distribution that can be efficiently computed by a neural network without any changes to the underlying architecture. We tested our method on the segmentation of real-world medical data, including lung nodules in 2D CT and brain tumours in 3D multimodal MRI scans. SSNs outperform state-of-the-art for modelling correlated uncertainty in ambiguous images while being much simpler, more flexible, and more efficient.READ FULL TEXT VIEW PDF
Stochastic Segmentation Networks
Tutorial reimplementation of Monteiro et al. (2020) on a toy problem.
The task of semantic image segmentation is a highly structured prediction problem where the output label maps should capture the spatial consistency of the objects to be segmented. While casting image segmentation as a dense pixel-wise classification task is at the heart of most machine learning approachescriminisi2012decision; everingham2015pascal; long2015fully, this paradigm largely ignores the underlying spatial structure. Methods will often rely on inductive biases to capture structure as opposed to modelling it directly. While this approach may yield reasonable, single deterministic predictions, it is insufficient to model the underlying distribution over multiple plausible outputs. In image segmentation, there is often more than one plausible solution for a given input. The exact location of object boundaries is often ambiguous, and ideally, the model should be able to capture this inherent uncertainty.
Uncertainty can be decomposed into aleatoric, which is inherent to the observations, and epistemic uncertainty, which relates to the ambiguity about the model’s parameters and can be explained away with more data (kendall2017uncertainties)
, e.g., a noisy regression problem with many data points has low epistemic but high aleatoric uncertainty. In segmentation, aleatoric uncertainty is both spatially correlated and heteroscedastic, since an image can have both regions with higher and lower uncertainty. The ideal model should represent the joint probability distribution of the labels at every pixel given the image, enabling sampling multiple plausible label maps.
Because aleatoric uncertainty cannot be reduced by acquiring more data, modelling it explicitly is crucial in risk-sensitive applications. In medical imaging, the images are often noisy, and the boundaries between tissue types may not be well defined, which leads to disagreement even between experts. The ability to automatically generate multiple plausible hypotheses to choose from is of high value in applications such as radiotherapy, where trade-offs have to be made about which anatomical regions to include for invasive treatment. Additionally, providing confidence intervals alongside tumour boundaries would allow uncertainty to be taken into account when making critical decisions.
Fully convolutional neural networks (FCNNs) are the state-of-the-art for semantic segmentation(lecun1998gradient; long2015fully; ronneberger2015u; chen2018)
. In principle, FCNNs are probabilistic models, since their output is a set of independent categorical distributions per pixel, parameterised by a softmax layer. Because these distributions are independent given the last layer’s activations, sampling from this model would result in spatially incoherent segmentations (grainy label noise in the uncertain regions). We argue that any method that only produces independent pixel-wise uncertainty estimates is unable to generate spatially coherent label maps, and thus incapable of fully capturing the structured uncertainty.
Recent work extends FCNNs to model the joint distribution over labels given the image, allowing for multiple plausible segmentations (kohl2018probabilistic; baumgartner2019phiseg)
. These methods have rigid, hierarchical, memory-intensive architectures, loss functions with manually tuned hyper-parameters, and require one forward pass per sample. In this paper, we introduce stochastic segmentation networks (SSNs), a lightweight and flexible alternative that efficiently captures correlations between pixels by modelling the logit map as a low-rank multivariate normal distribution. In contrast with previous approaches, our method is less complex, achieves higher predictive performance and can generate multiple samples from a single forward pass. In addition, it can be used with any existing architecture, and its efficiency makes it applicable to high-dimensional problems such as 3D imaging.
In data constrained scenarios, Bayesian methods are useful for quantifying epistemic uncertainty for previously unseen examples. Seminal works by mackay1992bayesian and neal1993probabilistic
inspired inference methods in Bayesian deep learning such as Markov chain Monte-Carlo(welling2011; ma2015) and variational inference methods (pmlr-v37-blundell15; pmlr-v48-gal16). These methods focus on estimating the posterior over the weights of a neural network which allows for estimating epistemic uncertainty independently of the task. Ensemble lakshminarayanan2017 and multi-head lee2015m; lee2016stochastic; rupprecht2017learning methods follow a frequentist approach to modelling the weight distributions. In the case of label disagreement or noise, defined as aleatoric uncertainty, the issue is not the lack of data. Still, both uncertainties are complementary. In classification, there is work on estimating aleatoric uncertainty by predicting Dirichlet distributions (malinin2018; malinin2019; sensoy2018) as well as post-training calibration of the predicted class probabilities guo2017calibration; kull2019. In segmentation, attempts at quantifying aleatoric uncertainty on a pixel-wise level (kendall2017uncertainties; tanno2017bayesian; WANG201934; jungo2020analyzing) ignore the joint distribution over labels.
Historically, probabilistic graphical models (PGMs) such as conditional random fields (CRFs) (blake2011markov; krahenbuhl2011efficient) have been used to explicitly model the joint probability distribution over labels. However, the inference was mostly limited to predicting the maximum a posteriori (MAP) estimate. Although there is work on obtaining the M-best diverse solutions for a given input image (batra2012; Kirillov_2015_ICCV), these models are restricted to a fixed number of solutions and have computationally expensive inference. Work on combining PGMs and FCNNs to enforce label dependencies as a post-processing step (arnab2018; chen2018; kamnitsas2017efficient) or even within a single model (zheng2015conditional) suffers from the same limitations as classic PGMs when quantifying aleatoric uncertainty.
Recently, kohl2018probabilistic and baumgartner2019phiseg have built on conditional variational auto-encoders (kingma2013auto; sohn2015) to extend FCNNs for modelling spatially correlated aleatoric uncertainty. hu2019supervised
extend this framework by regressing the uncertainty maps in a supervised manner. These methods encode the image into one or more uncorrelated multivariate normal latent variables and rely on the decoder to translate the added uncorrelated stochasticity into meaningful spatial variation. Like variational auto-encoders, these models have the flexibility to transform the latent distributions into arbitrarily complex distributions with correlations between pixels. However, the placement of the latent variables within the network means that one forward pass is required for every new sample. Furthermore, this flexibility comes at the cost of having to use a cumbersome variational inference framework which makes use of a training-only posterior network and manually tuned hyper-parameters weighing the Kullback–Leibler divergence regularisation term of the loss. These overly expressive distributions might not justify their cost, a more constrained distribution could suffice and allow the usage of simpler inference methods.
We start by analysing the independence assumptions made to obtain the cross-entropy loss typically used in image segmentation. Consider a standard segmentation problem in which an image, , with channels and pixels, maps to a one-hot label map of the same size, , with classes: and for . In a classic CNN, the probability of one label, , is the output of a softmax layer taking as input the logit, . Before any independence assumptions, the MAP estimate for the negative log-likelihood can be written as:
where is the probability of the logit map given the image under a model with parameters . To obtain the standard cross-entropy loss, we assume that the logit map is given by a deterministic function, , which means can be written as:
Due to this deterministic function, given the image and model, the logits, for , are trivially independent of each other, i.e., given the image and model, no new information can be gained about a single logit by observing its neighbours. Secondly, we must assume that the labels, for , are independent of each other when given their respective logit:
This is a two-part assumption: first, it assumes that labels, , are independent of each other when given the full logit map, , and second, it assumes that each label, , only depends on its respective logit, , i.e., no new information can be gained about a label by observing the true values of its neighbours. Incorporating the assumptions of equations 2 and 3 into equation 1, and substituting by a categorical distribution parameterised by the softmax transform of , we arrive at the familiar form for the cross-entropy:
Whereas in image-level classification these independence assumptions may be valid, in segmentation the labels at each pixel are clearly correlated, which should be taken into account.
In this work, we propose using weaker independence assumptions by using a more expressive distribution over logits. Specifically, we use a multivariate normal distribution whose parameters are the output of a neural network , where and . The size of the full covariance matrix scales with the square of the number of pixels times the number of classes making it infeasible to compute for anything but very small images. For this reason, we use a low-rank parameterisation of the covariance matrix of the form:
where the covariance factor, , is a matrix of size , where is a hyper-parameter defining the rank of the parameterisation, and is a diagonal matrix whose diagonal has elements. This low-rank parameterisation ensures that the three components describing the distribution: the mean, covariance factor, and covariance diagonal can be efficiently computed by a neural network.
By plugging this distribution into equation 1, we no longer assume that logits, , are independent of each other. However, the integral also becomes intractable because of the softmax transform on the normal distribution. For this reason, we approximate the integral using Monte-Carlo integration:
where can be solved as in equation 4. For inference, with a single forward pass, we can sample from the distribution multiple times to obtain logit maps, which can be transformed into a probability or label maps. To obtain the most likely logit sample, we use the mean of the distribution.
Figure 1 shows the probabilistic graphical models of a classic neural network, a CRF and the proposed model. While the neural network does not model dependencies between output labels, the CRF explicitly models these dependencies at the cost of an expensive inference procedure. In contrast, by implicitly modelling label dependencies in the logit space and then making independence assumptions, we can capture label dependencies while keeping the efficient inference of a neural network.
Consider a dataset on a one-dimensional 21-pixel line with one image for which there are two equiprobable label maps. For both label maps, the first third of the line is labelled 1 (on), and the last third is labelled 0 (off). However, the middle third is off for the first label map, and on for the second label map (see visual examples on the far right of Figure 2). In this setting, the labels of the middle third are uncertain but not independent. Since there is only one input, it is a constant and hence can be disregarded for further modelling. Thus, the goal of the problem becomes to find a generative model for the distribution of the two label maps.
A deterministic model would correctly learn the mean of the distribution but would yield implausible predictions. The first and last thirds would be correct, but the middle third would be arbitrarily fixed. For example, if the label maps were not equiprobable, the model would always generate the most probable one. Next, we consider two stochastic models where the distribution over logits is a multivariate normal distribution: one with a diagonal covariance matrix and one with a low-rank covariance matrix (). We train these models with gradient descent and the loss function in equation 7 using 200 Monte-Carlo samples for 10000 iterations. The results are shown in Figure 2. We observe that the diagonal model is able to learn the mean of the distribution and even which pixels have higher uncertainty. However, it cannot learn the structure of the noise and thus produces samples with uncorrelated noise. In contrast, the low-rank model is able to learn the correct noise structure and produce samples matching the desired distribution, yielding a higher log-likelihood, -0.93, when compared to the diagonal model, -4.87.
Caveat: Under our model, we can deduce that the true generative model is as follows: the mean is zero for the middle third, for the first third, and for the last the third. The covariance matrix is for all entries regarding self and cross covariances of pixels in the middle third and zero elsewhere. This area of infinite covariance caused numerical stability issues since the covariance quickly grew to infinity producing overflow errors. Furthermore, we found that the covariance grew much faster than the mean causing the model to get stuck in suboptimal local minima. To address these issues, we pre-train the mean first and use early stopping to obtain the last model before an overflow error occurs. In the real data used in this paper, the only area with infinite covariance is the air in the background of brain scans. We addressed the issue by masking out the background.
To compare with previous work, we evaluated our model on the LIDC-IDRI dataset (armato2011lung) using the task defined by kohl2018probabilistic. The dataset consists of 1018 3D thorax CT scans where four radiologists have annotated multiple lung nodules in each scan. The dataset was annotated by 12 radiologists, and it is not possible to match an annotation to an expert. Thus, the four sets of annotations are not self-consistent in “style” across images. Regardless, this type of data is ideal for validating models which seek to capture the inherent uncertainty in the data — evident from the disagreement between experts. kohl2018probabilistic preprocessed the data by extracting 2D slices centred around the annotated nodules. When at least one expert has segmented a nodule, a slice of the image and four expert segmentations were extracted. Empty segmentations were introduced when there were less than four annotations for a slice. This process yielded a dataset of 15096 slices each having four segmentations.
We compared with three baseline models: a deterministic U-Net (ronneberger2015u), a probabilistic U-Net (kohl2018probabilistic) and the PHiSeg model (baumgartner2019phiseg) (the best performing variant reported). We used the pre-processed data provided by kohl2018probabilistic and the code, configurations, and hyper-parameters provided by baumgartner2019phiseg. We implemented our algorithm on top of the provided deterministic U-Net with for the low-rank model, and, for comparison, we tested a model with a diagonal covariance matrix. By using the same backbone, code and hyper-parameters, we ensured a fair comparison with previous work.
We measured the predictive performance using the Dice Similarity Coefficient, , where is true positives, is false negatives, and is false positives. Even if all four radiologists annotated a nodule, disagreements about its borders combined with the 3D to 2D preprocessing introduce several empty annotations (on average 1.6/4 = 40.4%). A non-empty prediction on an empty annotation results in a zero towards the average , heavily penalising it. Therefore, we also report defined as the computed only where the ground-truth annotations are not empty. Pixel-wise metrics for uncertainty quantification and calibration are not appropriate for spatially structured prediction such as segmentation. Hence, we used sample diversity to quantify the amount of uncertainty and the distance between the expert and predicted distributions to quantify uncertainty calibration. Given the ground-truth distribution defined by the four expert segmentations, , and the predicted distribution, , we measure the distance between the two using the generalised energy distance (szekely2013energy; kohl2018probabilistic):
where , if both segmentations are empty . We define sample diversity as . Note how both these metrics are bounded between zero and one.
To measure how models deal with increasing uncertainty in the labels, we trained each model using only one and all four annotations per image. We divided the data into train, validation and test sets (60/20/20%), and trained all models for 500k iterations with the same configuration described in baumgartner2019phiseg. For the proposed loss function, we used 20 Monte-Carlo samples. We computed and sample diversity using 100 random samples. The prediction for the probabilistic baselines was obtained by averaging the probability maps of these samples (baumgartner2019phiseg). For the proposed model, we used the mean of the logit map distribution. We computed the between the prediction and the four ground-truths before averaging over sets of annotations and slices.
|deterministic U-Net||set 0||37.5 0.4||50.3 0.4||0.698 0.009||0.000 0.000|
|probabilistic U-Net||38.4 0.4||57.2 0.4||0.516 0.007||0.290 0.004|
|PHiSeg||39.1 0.4||51.3 0.5||0.456 0.008||0.215 0.003|
|proposed (diagonal)||37.1 0.4||51.2 0.4||0.734 0.009||0.001 0.000|
|proposed (low-rank)||40.7 0.4||58.6 0.4||0.365 0.005||0.399 0.004|
|deterministic U-Net||all||35.9 0.4||43.5 0.5||0.607 0.009||0.000 0.000|
|probabilistic U-Net||sets||39.0 0.4||50.6 0.5||0.252 0.004||0.469 0.003|
|PHiSeg||33.8 0.4||40.3 0.5||0.224 0.004||0.496 0.003|
|proposed (diagonal)||37.0 0.4||46.2 0.5||0.622 0.009||0.007 0.001|
|proposed (low-rank)||43.6 0.4||68.5 0.3||0.225 0.002||0.609 0.002|
Table 1 shows the results for the five models and Figure 3 shows qualitative results for the proposed low-rank model trained on four sets of annotations. In terms of predictive performance, the proposed low-rank model outperformed the baselines for both settings. Of note, our model is the only method which benefits from the additional annotations yielding improved predictive performance. For uncertainty calibration, our model yielded the lowest except for the PHiSeg model with four annotations where their performance was comparable. In both settings, our model obtained the highest sample diversity. For reference, the diversity between experts is 0.399 0.002. Lastly, the diagonal model nearly collapsed to a deterministic model, yielding very little sample diversity.
|Deepmedic||88.2 1.3||60.5 2.9||72.1 2.3||67.3 3.5||0.886 0.043||0.000 0.000|
|low-rank 30 mm||88.0 1.3||59.3 3.1||71.7 2.3||68.7 3.5||0.635 0.029||0.312 0.014|
|low-rank 60 mm||88.7 1.3||59.6 3.0||72.4 2.2||69.2 3.5||0.689 0.031||0.217 0.012|
We also applied our method to the BraTS 2017 dataset (menze2014multimodal; bakas2017advancing; bakas2018identifying). This dataset consists of 285 3D multimodal MRI images (four channels: T1, T1ce, T2 and Flair) where one radiologist has segmented four classes: background, non-enhancing/necrotic tumour core (NET), oedema (OD) and enhancing tumour core (ET). We implemented the proposed method on top of an implementation of DeepMedic (kamnitsas2017efficient; kamnitsas2016deepmedic), a network specifically developed for brain segmentation. We use for the low-rank model and omit the diagonal only model since it converged to a deterministic model. The images have a resolution of mm and a size of voxels, making them too large to train on whole images. We trained the baseline and proposed models on image patches of 110 mm (1mm
= 1 voxel), which, since no padding was used, result in label map patches of 30 mm. To test the effect of including longer distance dependencies between voxels, we also trained the proposed model on image patches of 140 mm which result in label map patches of 60 mm. Note that, increasing the patch size of the baseline does not change its behaviour since the model is fully convolutional and its receptive field is 81 mm (which is larger than 60 mm).
We split the data into training, validation and test sets (60/10/30%) and trained according to the procedure described in the appendix. During inference, we stitched together the patches of the mean, covariance factor and diagonal to build a distribution over the entire image from which we can sample. Due to the fully convolutional nature of the model, after it is trained, the patch size used for inference has no impact on the final result. We measured the of the three lesion classes and the whole tumour (WT), consisting of all lesion combined. We measured sample diversity and using only 20 samples due to the quadratic dependency on the number of samples and the large image size.
Table 2 shows the quantitative results for the deterministic and stochastic models. The stochastic models had no loss in performance when compared to the deterministic model. Comparing the two stochastic models, we observe that the added spatial context did not increase performance or yield a better-calibrated distribution. Regardless, the amount of needed spatial context is application dependent. Figure 4 shows qualitative results for six cases for the stochastic 30 mm model. We observe entire structures in the segmentation appear and disappear between samples in regions of high uncertainty (e.g., row 4). Furthermore, mistakes made by the deterministic model or the stochastic model are corrected in at least one of the samples (e.g., row 2). Lastly, the high uncertainty in lesion borders makes them shrink and expand consistently between samples (e.g., row 1).
Figure 5 shows per case sample distributions of the average lesion class
(100 samples). As expected, for most cases, the majority of samples are worse than the mean prediction. However, on average, 26.0% samples are better than the mean prediction (dashed line). When looking at the best samples, the average (over the dataset) 95% quantile of the average classwas 70.3% when compared with the deterministic model average class
of 66.8%. This gain is not uniformly distributed as it tends to be higher for cases with low performance and decrease as the performance increases. In addition to being able to sample repeatedly after inference, another advantage of outputting a full distribution is the ability to manipulate samples post-inference. Since the covariance matrix has entries which are separable per class, by scaling only the part of the matrix relating to a given class, we are able to manipulate samples to increase or reduce the presence of that class. This can be used to correct possible mistakes or adjust borders, as shown in Figure6. Similarly, we can trade sample diversity for quality by scaling the temperature of the entire distribution.
This paper introduces an efficient approach for modelling spatially correlated aleatoric uncertainty in segmentation. We have shown that our method outperforms the baselines while being much simpler, improves predictive performance with added uncertainty, and the samples it generates can be better than those of a deterministic approach. The simplicity of the method enables it to be easily implemented over any existing neural network architecture, which enabled its use in a 3D application, something which had previously not been attempted. The ability to generate multiple plausible hypotheses post-inference is of value in human-in-the-loop scenarios, such as radiology, where a human could manipulate the segmentation semi-automatically according to the model’s uncertainty. Furthermore, even in fully-autonomous systems such as autonomous vehicles being able to reason about spatially correlated uncertainty is essential. For example, uncertainty about whether a region is a pedestrian or not should be correlated over all pixels in the region.
Proper uncertainty quantification is crucial to increase trust and interpretability in deep learning systems, which is of particular importance in healthcare applications. Reliable uncertainty estimates could help inform clinical decision making, and importantly, provide clinicians with feedback on when to ignore automatically derived measurements. Moreover, uncertainty estimates could be propagated to downstream clinical tasks such as radiotherapy planning, e.g., the amount of radiation delivered to each anatomical region. In medicine, the notion of a second opinion is well established and an essential part of scrutinising the decision process. The ability to generate and manipulate multiple plausible hypotheses could be of great benefit in semi-automatic settings, such as machine aided image segmentation, and help minimise the risk of missing important modes of the target distribution. A complementary prediction might be contradictory yet still very informative.
This research has received funding from the European Research Council (ERC) under the European Union’s Horizon 2020 research and innovation programme (grant agreement No 757173, project MIRA, ERC-2017-STG). NP is supported by a Microsoft Research PhD Scholarship. DC and NP are also supported by the EPSRC Centre for Doctoral Training in High Performance Embedded and Distributed Systems (HiPEDS, grant ref EP/L016796/1). LLF is funded through the EPSRC (EP/P023509/1).
We split the data into training, validation and test sets (60/10/30%) and trained the models for 1200 epochs. At each epoch, we randomly sampled 50 images and extracted 20 patches from each image. We randomly sampled the patches centred around a lesion or background voxel with equal probability. We used the RMSProp optimiser(tieleman2012lecture) with momentum 0.6 and a learning rate of 1e-3 which we halved at the following epochs: 440, 640, 800, 900, 980, 1050. For augmentation, we used random elastic deformations, right-angle rotations, flips and linear intensity transformations. We used a batch size of 10, except for the 60 mm model where we used a batch size of 4 due to GPU memory constraints.
To calculate the distance between two label maps we used . To calculate the in a multi-class setting, we averaged over the of the individual classes, excluding the background class. If both label maps are empty . The (which is equivalent to the F1-score) reported in our work is lower than the results reported in PHiSeg baumgartner2019phiseg. The authors used a convention where the
is 1.0 if both the predicted and ground-truth slices are empty. We argue that this choice skews results since an algorithm that always predicts an empty label map would achieve an averageequal to the fraction of empty slices in the dataset, e.g. if the dataset has 40% of empty slices the average is also 40%. In contrast, we used the standard definition of , where these cases are undefined and thus excluded from the calculation of the average . To calculate uncertainty maps, we used the marginal entropy of the categorical distributions predicted for each voxel :
Figure A1 compares sampling from the independent categorical distributions of a deterministic model with sampling from the proposed model. Notice the grainy label noise for the deterministic model. Figures A2 - A5 show additional random samples for the stochastic model for multiple test cases.