Diffusion Models for Implicit Image Segmentation Ensembles

12/06/2021
by   Julia Wolleb, et al.
Universität Basel
0

Diffusion models have shown impressive performance for generative modelling of images. In this paper, we present a novel semantic segmentation method based on diffusion models. By modifying the training and sampling scheme, we show that diffusion models can perform lesion segmentation of medical images. To generate an image specific segmentation, we train the model on the ground truth segmentation, and use the image as a prior during training and in every step during the sampling process. With the given stochastic sampling process, we can generate a distribution of segmentation masks. This property allows us to compute pixel-wise uncertainty maps of the segmentation, and allows an implicit ensemble of segmentations that increases the segmentation performance. We evaluate our method on the BRATS2020 dataset for brain tumor segmentation. Compared to state-of-the-art segmentation models, our approach yields good segmentation results and, additionally, meaningful uncertainty maps.

READ FULL TEXT VIEW PDF

page 2

page 4

page 7

11/25/2016

Semantic Segmentation using Adversarial Networks

Adversarial training has been shown to produce state of the art results ...
12/06/2021

Label-Efficient Semantic Segmentation with Diffusion Models

Denoising diffusion probabilistic models have recently received much res...
02/10/2021

Argmax Flows and Multinomial Diffusion: Towards Non-Autoregressive Language Models

The field of language modelling has been largely dominated by autoregres...
06/06/2021

Using GANs to Augment Data for Cloud Image Segmentation Task

While cloud/sky image segmentation has extensive real-world applications...
12/01/2021

SegDiff: Image Segmentation with Diffusion Probabilistic Models

Diffusion Probabilistic Methods are employed for state-of-the-art image ...
12/17/2020

Quantifying the Unknown: Impact of Segmentation Uncertainty on Image-Based Simulations

Image-based simulation, the use of 3D images to calculate physical quant...
05/22/2021

Orthogonal Ensemble Networks for Biomedical Image Segmentation

Despite the astonishing performance of deep-learning based approaches fo...

1 Introduction

Semantic segmentation is an important and well explored area in medical image analysis [biomedseg]

. The automated segmentation of lesions in medical images with machine learning has shown good performances

[nnunet] and is ready for clinical application to support diagnosis [trial]. In medical applications, it is of high interest to measure the uncertainty of a given prediction, especially when used for further treatments like radiation therapy. In this work, we focus on the BRATS2020 brain tumor segmentation challenge [brats1, brats2, brats3]. This dataset provides four different MR sequences for each patient (namely T1-weighted, T2-weighted, FLAIR and T1-weighted with contrast enhancement), as well as the pixel-wise ground truth segmentation. An exemplary image is shown in Figure LABEL:fig:brats.

fig:brats

Figure 1: Exemplary image of the BRATS2020 dataset.

We propose a novel segmentation method based on a Denoising Diffusion Probabilistic Model (DDPM) [DDPM], which can provide uncertainty maps of the produced segmentation mask. An overview of the workflow is shown in Figure LABEL:fig:overview.

We train a DDPM on the segmentation masks and add the original brain MR image as an image prior to induce the anatomical information. As sampling with DDPMs has a stochastic element in each sampling step, we can generate many different segmentation masks for the same input image and the same pretrained model. This ensemble of segmentations allows us to compute the pixel-wise variance maps, which quantifies the uncertainty of the generated segmentation. Moreover, the ensembling of the segmentations in a mean map boosts the segmentation performance.


fig:overview

Figure 2: General idea of the implicit generation of segmentation ensembles with diffusion models.

1.1 Related Work

In medical image segmentation, a common method is the application of a U-Net [unet] to predict the segmentation mask for every input image. This approach was successfully applied for many different tasks [heart, MS, lung]. The state of the art is given by nnU-Nets [nnunet]

, where the best architecture and hyperparameters are automatically chosen for every specific dataset.


Uncertainty quantification is of high interest in deep learning research

[uncertainty]

. Bayesian neural networks provide an important tool for uncertainty estimation

[bayesianseg, validity, galgar], and were also applied on various medical tasks [medbayes1, medbayes2, medbayes3], including brain tumor segmentation [bratsbayes1, bratsbayes2, bratsbayes3].
During the last year, DDPMs have gained a lot of attention due to their astonishing performance in image generation [beatgans]. Fake images are generated by sampling from Gaussian noise. This sampling scheme follows a Markovian process, and therefore sampling from the same noisy image does not yield in the same output image. A different sampling scheme was introduced by Denoising Diffusion Implicit Models (DDIM)[ddim]

, where sampling is deterministic and can be done by skipping multiple steps. Moreover, meaningful interpolation between images can be achieved. DDPM was further improved by

[improving] and [beatgans]

, where changes in the loss objective, architecture improvements, and classifier guidance during sampling improved the output image quality.


While some new work applies diffusion models on tasks such as image-to-image translation

[unitddpm], style transfer [ilvr], or inpainting tasks [palette], so far there is only very little work about semantic segmentation. Recently, one approach to perform semantic segmentation with a diffusion model was proposed by [diffseg]. A DDPM is trained to reconstruct the image that should be segmented. Then, an MLP for classification is applied on the features of the model, which yields in a segmentation mask for the original image. In contrast to this method, we train a DDPM directly to generate the segmentation mask.
Simultaneously and independent from us, [segdiff] developed an image segmentation method similar to ours. However, they use a separate encoder for the image and the segmentation. Training a larger model may be difficult for medical image analysis due to possible large input images such as 3D data. Our method uses only one encoder to encode the image information and the segmentation mask.

2 Method

The goal is to train a DDPM to generate segmentation masks. We follow the idea and implementation proposed in [improving]. The core idea of diffusion models is that for many timesteps , noise is added to an image . This results in a series of noisy images , where the noise level is steadily increased from (no noise) to (maximum noise). The model follows the architecture of a U-Net and predicts from for any step . During training, we know the ground truth for , and the model is trained with an MSE loss. During sampling, we start from noise , sample for steps, until we get a fake image .
The complete derivations of the formulas below can be found in [DDPM, improving]. The main components of diffusion models are the forward noising process and the reverse denoising process . Following [DDPM], the forward noising process for a given image at step is given by

(1)

where is a variance schedule that is learned by the model, as proposed in [improving]. The idea is that in every step, a small amount of Gaussian noise is added to the image. Doing this for steps, we can write

(2)

with and . With the reparametrization trick, we can directly write as a function of :

(3)

The reverse process is learned by the model parameters and is given by

(4)

As shown in [DDPM], we can then predict from with

(5)

We can see in Equation 5 that sampling has a random component . The proposed sampling scheme is a Markovian process. Note that is the U-Net we train, with input . The noise scheme that will be substracted from during sampling according to Equation 5 has to be learned by the model. This U-Net is trained with the loss objectives given in [improving].
We now modify this idea in order to use diffusion models for semantic segmentation. A visualization of the workflow is given in Figure LABEL:fig:example for the task of brain tumor segmentation.

fig:example

Figure 3: The training and sampling procedure of our method. In every step , the anatomical information is induced by concatenating the brain MR images to the noisy segmentation mask .

Let be the given brain MR image of dimension , where denotes the number of channels, and denote the image height and image width. The ground truth segmentation of the tumor for the input image is denoted as , and is of dimension . We train a DDPM for the generation of segmentation masks. In the classical DDPM approach, would be the only input we need for training, which results in an arbitrary segmentation mask during sampling. In contrast to that, the goal in our proposed method is not to generate any segmentation mask, but we want a meaningful segmentation mask for a given image . To achieve this, we add additional channels to the input: We induce the anatomical information present in by adding it as an image prior to . We do this by concatenating and , and define . Consequently, has dimension .
During the noising process , we only add noise to the ground truth segmentation :

(6)

and we define . Equation 5 is then altered to

(7)

and results in a slightly denoised with dimension . During sampling, we follow the procedure presented in Algorithm 7. As this sampling scheme follows a Markovian process, sampling from the same noisy image does not result in the same output image . Sampling Procedure , the original brain MRI , the predicted segmentation mask sample

3 Dataset and Training Details

We evaluate our method on the BRATS2020 dataset. As described in Section 1, images of four different MR sequences are provided for each patient, which are stacked to 4 channels. We slice the 3D MR scans in axial slices. Since tumors rarely occur on the upper or lower part of the brain, we exclude the lowest 80 slices and the uppermost 110 slices. For intensity normalization, we cut the top and bottom percentile of the pixel intensities. We crop the images to a size of (4, 224, 224). The provided ground truth labels contain four classes, which are background, GD-enhancing tumor, the peritumoral edema, and the necrotic and non-enhancing tumor core. We merge the three different tumor classes into one class and therefore define the segmentation problem as a pixel-wise binary classification. Our training set includes 16,298 images originating from 332 patients, and the test set comprises 1,082 images originating from 37 patients.
The hyperparameters for our DDPM models are described in the appendix of [improving]. We choose a linear noise schedule for steps. We train the model with the hybrid loss objective, with a learning rate of for the Adam optimizer, and a batch size of 16. The number of channels in the first layer is chosen as 128, and we use one attention head at resolution 16. We train the model for 40,000 iterations.

4 Results and Discussion

During evaluation, we take an image from the test set, follow Algorithm 7 and produce a segmentation mask. This mask is thresholded at

to obtain a binary segmentation. We compare our performance against nnU-Net on the test set with respect to the Dice score, the Jaccard index, and the 95 percentile Hausdorff Distance (HD95). We achieve good results with respect to all metrics, as can be seen in Table

LABEL:tab:example. As baseline, we report the segmentation scores for the nnU-Net.
As described in Section 2, sampling with DDPM follows a stochastic process. Therefore, sampling twice for the same brain MRI does not result in the same segmentation mask prediction .
Exploiting this property, we can implicitly generate an ensemble of segmentation masks without having to train a new model. Using the same model and the same input MR image , we can implicitly generate an ensemble of infinity many samples. This ensemble can then be used to boost the segmentation performance. For every image of the test set, we sample 5 different segmentation masks. This implicitly defines an ensemble by averaging over the 5 masks and thresholding it at 0.5. We report the results for this ensemble in Table LABEL:tab:example. We see that already an ensemble of 5 increases the performance of our approach.
In 55 out of 1028 cases, our method produces an empty segmentation mask, which results in a Dice of zero. This lowers the average Dice score a lot, and HD95 cannot be computed. If we disregard those 55 cases, we report the HD95 score, and the average Dice and Jaccard index is reported in square brackets in in Table LABEL:tab:example. The nnU-Net produced empty masks for 20 images of the test set.

tab:example Method Dice HD95 Jaccard Ours, for 1 sampling run 0.837 [0.880] 7.311 0.759 [0.799] Ours, ensemble of 5 runs 0.850 [0.897] 6.253 0.781 [0.824] nnU-Net 0.891 [0.917] 4.741 0.831 [0.859]

Table 1: Segmentation scores of our method and nnU-Net on different metrics.

For visualization of the uncertainty maps, we select four exemplary images , , , and from the test set. By sampling for 100 runs for each of the images, we can generate uncertainty maps by computing the pixel-wise variance of the 100 samples. In Figure 4 we present one channel of the original brain MRI (the T1-weighted MR sequence with contrast enhancement), the ground truth segmentation, two different sampled segmentation masks, as well as the mean and variance map for 100 sampling runs. We can clearly identify the areas where the model was uncertain. Moreover, by thresholding the mean map at , we can produce the ensembled segmentation mask.

In Table LABEL:tab:ens, we report the Dice score and HD95 for this ensemble mask, as well as the average Dice and HD95 for the 100 samples. We see that the ensemble can boost the performance in most cases.

tab:ens

Table 2: Segmentation scores for the examples , , , and presented in Figure 4.
Example avg. Dice ensemble Dice avg. HD95 ensemble HD95
0.962 0.975 2.39 1.00
0.942 0.974 3.81 1.77
0.837 0.857 24.60 18.90
0.916 0.937 7.96 9.06

Original Image

Ground Truth

Sample1

Sample 2

Mean Map

Variance Map

Figure 4: Examples of the produced mean and variance maps for 100 sampling runs.

In Figure LABEL:fig:plot, we plot the number of samples in the ensemble against the Dice score for the four examples to . We can see that already an ensemble of five samples increases the performance, and then the curve flattens.

fig:plot

Figure 5: Performance of the ensemble with respect to the number of samples for the examples , , , and presented in Figure 4.

5 Conclusion

With a modification to DDPMs, we created a model for biomedical image segmentation. Using the stochastic sampling process, our method allows implicit ensembling of different segmentation masks for the same input brain MR image, without having to train a new model. We could show that ensembling those segmentation masks increases the performance of the model with respect to different segmentation scores. Moreover, we can generate meaningful uncertainty maps by computing the variance of the different segmentation masks. This is of great interest in clinical applications, when we want to measure the uncertainty of the decision of the model.
Future work will include comparison to uncertainty maps produced by Bayesian neural networks. We will also investigate the segmentation of the different tumor classes provided by the BRATS2020 challenge.

This research was supported by the Novartis FreeNovation initiative and the Uniscientia Foundation (project # 147-2018).

References