An important aim of the medical deep learning field is to develop image processing algorithms that can be deployed in clinical settings. These tools need to be robust to the full range of potential inputs they might receive in a clinical context. We can expect this data to be much more diverse than the typically clean, sanitised datasets on which these algorithms are developed and tested. Even if attempts are made to train on messier datasets with images exhibiting artefacts and other issues, we would expect that eventually the tool will be presented with data it has not seen during training. Typically, deep learning networks perform well when operating in-distribution, but performance can degrade unpredictably and substantially when operating on data out-of-distribution (OOD).
One approach to this problem is to develop networks that provide both predictions and a measure of their uncertainty, enabling decisions to be referred to humans when they are presented with difficult or OOD data samples. Bayesian Neural Networks (BNN) that learn a distribution of weights are capable of this; one popular aproach is to appromixate a BNN using dropout-based variational inference[gal2016dropout]. Another common approach is the deployment of an ensemble of neural networks [lakshminarayanan2017simple]. A comprehensive evaluation of uncertainty methods for classification found that the quality of uncertainty measures degraded as the size of the distributional shift increased [ovadia2019can]. Some work has evaluated these methods in the context of image segmentation but is typically confined to evaluating the utility of the uncertainty methods in-distribution [jungo2018uncertainty, nair2020exploring]; work that does evaluate the uncertainty in an OOD setting typically investigates only small dataset shifts such as increased noise [haas2021uncertainty] or lower quality scans [mcclure2019knowing].
A second approach is to filter our anomalous data before it is fed to the task-specific network by using a generative model that can quantify the probability that a data sample is drawn from the distribution that the task-specific model was trained on. Of the generative approaches, autoregressive methods are attractive for two reasons: firstly, they allow for the computation of exact likelihoods, and secondly, a class of architectures that can be used for autoregressive prediction termed transformers[vaswani2017attention] are proving highly effective general-purpose architectures, achieving state-of-the-art performance across a range of tasks in language [devlin2018bert, brown2020language] and, increasingly, vision [dosovitskiy2020image]
. The high dimensionality of medical images makes it computationally infeasible to use transformers to model the sequence of raw pixel values, and a recent body of work has instead used transformers to model the compressed discrete latent space of an image obtained from a vector-quantised network such as a VQ-VAE or VQ-GAN[oord2017neural, razavi2019generating, esser2021taming]. This approach has provided state-of-the-art unsupervised pathology segmentation for 2D medical images [pinaya2021unsupervised]. Prior work has also shown that this vector-quantised class of auto-encoding methods can be used to substantially compress 3D medical images [tudosiu2020neuromorphologicaly]
indicating the VQ-GAN + transformer approach might be applied to fully 3D anomaly detection but, to our knowledge, there is no published work attempting this.
In this work, we make two principal contributions, focusing on the problem of segmentation of haemorrhagic lesions in head CT data. Firstly, we evaluate segmentation uncertainty methods on a range of OOD inputs and demonstrate they can catastrophically fail, producing confidently wrong predictions. Secondly, we use transformers to perform image-wide OOD detection on 3D images. We find this can effectively flag OOD data that segmentation networks fail to perform well on, demonstrating their viability as a filter in clinical settings where robust and fully-automated segmentation pipelines are needed.
In this work, we focus on the challenge of segmenting Intracerebral Haemorrhages (ICH) in head CT data. The following sections summarise the datasets used, the trained segmentation networks, and the approach to training the VQ-GANs and transformers.
We use three main datasets in this work; two head CT datasets (one used for training, and an independent one for model evaluation), and a non-head CT dataset.
The CROMIS dataset is a set of 687 head CT scans used for training all the networks in this paper. The data consists of CTs containing ICH, acquired across multiple sites as part of a trial [cromistrial, wilson2018cerebral]. Ground-truth haemorrhage segmentation masks are available for 221 scans in the dataset.
The KCH dataset was used for algorithm validation. It consists of 47 clinical scans selected for the presence of ICH, all with ground-truth masks provided by an experienced neuroradiologist. This dataset was used to represent in-distribution test data; it was further used to produce a set of corrupted scans to test our algorithms in the near-OOD setting. A range of corruptions were applied to each scan in the dataset, designed to emulate a number of scenarios such as imaging artefact, image header errors, and errors in the preprocessing pipeline that is typically applied before data is input into a network. The corruptions included: addition of Gaussian noise, inversion through each of the three image planes, skull-stripping, setting the image background to values not equal to 0, global scaling of all image intensities by a fixed factor, and the deletion of a set of slices (or chunks) of the image. Any spatial manipulations of the image were also applied to the labels. This totalled 15 corruptions applied to each image, creating a corrupted dataset of 705 images. Examples of the corruptions applied are included in Appendix A.
The Medical Decathlon dataset was used to test our algorithms in the far-OOD setting. It comprises of a number of medical images covering a variety of organs and imaging modalities, none of which are head CT. We selected 22 images from the test set of each of the ten classes (or as many as were available in the test set if less than 22). A more detailed description of this dataset can be found in [simpson2019large].
Data processing was harmonised between all datasets as much as possible. All CT head images were registered to MNI using an affine transformation, resampled to 1 isotropic, tightly cropped to a grid, intensities clamped between and then rescaled to the range . For the images in the Decathlon dataset, all were resampled to be 1
isotropic and either cropped or zero-padded depending on size to produce agrid. All CT images had their intensities clamped between and then rescaled to lie in the range , all non-CT images were rescaled based on their minimum and maximum values to lie in the range .
2.2 Segmentation networks
We tested three uncertainty methods commonly employed in the literature. The first is intended as a simple baseline and uses the softmax of the network’s output as a per-pixel probability map. The second is an ensemble of neural networks, identical in architecture but each trained on a different subset of the data [lakshminarayanan2017simple]. Based on recommendations from [ovadia2019can] we chose =5. Finally, we use an approximation of a BNN obtained through dropout-based variational inference, training each dropout layer with a dropout probability of and using 5 passes during inference to approximate the posterior [gal2016dropout].
All networks used the same UNet backbone based on [falk2019u] and implemented in Project MONAI111https://monai.io/ as the ‘BasicUnet’ class, with (32, 32, 64, 128, 256) features in the 5 encoding layers, instance normalisation, and LeakyReLU activations. We trained the networks using a batch size of 3, augmenting with affine and elastic transformations and sampling patches of size . The Dice loss was used except for the baseline network, which was trained using cross-entropy loss as Dice is known to provide poorly calibrated, overly-confident predictions [mehrtash2020confidence]. All networks were optimised using the AMSGrad variant of Adam [reddi2019convergence] with a learning rate of .
For each network we sought to assign a single uncertainty value to each predicted lesion. Firstly, we produced a per-voxel uncertainty map for each method. For the baseline method this was the per-voxel softmax of the network output. For the remaining networks we used the entropy between the predictions at each voxel, as described in [nair2020exploring], subtracting the entropy from 1 to produce a measure where larger values reflect higher certainty. We produce per-lesion certainty by taking the average of the per-voxel measures across each lesion, where each separate lesion is taken as each fully connected component from the majority vote prediction of each network.
2.3 VQ-GAN + Transformer networks
Our approach to outlier detection uses a VQ-GAN to compress the information content of each 3D volume into a discrete latent representation, and a transformer to learn the probability density of these representations.
The VQ-GAN [esser2021taming] contains an encoder which takes input and produces a latent representation where is the dimension of the latent embedding vector. The representation is quantised by finding its nearest neighbour, as measured by an norm, in a codebook of -dimensional vectors and replacing the representation with the nearest neighbour’s codebook index, . A decoder uses the quantised latent space to reconstruct the input, . To encourage the network to learn a rich codebook, a discriminator
is used to try to differentiate between real and reconstructed images. Our implementation’s encoder contains four levels, each consisting of a convolution with stride=2 and a residual layer, each followed by ReLU layers. This produces a latent spacesmaller along each dimension, so an input with size is compressed to a latent size of elements. The codebook has , each with dimension . The decoder also contained four levels, each consisting of a residual layer followed by a transposed convolution with stride=2. The codebook was updated using the exponential moving average as described in [oord2017neural]. The VQ-GAN paper combined a mean-squared error loss and a perceptual loss [zhang2018unreasonable] for the reconstruction loss - we used both these and an additional spectral loss [dhariwal2020jukebox]. Given state-of-the-art anomaly detection results have been reported in 2D using a simpler VQ-VAE with MSE loss [pinaya2021unsupervised], we also performed an ablation study to understand how the additional components of the VQ-GAN contributed to performance. Models were trained using Adam with a learning rate=1.65 and a batch-size of 96 on a Nvidia DGX A100.
After training the VQ-GAN, we can estimate the probability density of the training data using a transformer. Each 3D discrete representation obtained from the trained VQ-GAN is flattened into a 1D sequence, and the data-likelihood is represented as the product of conditional probabilities, , with the transformer learning the distribution of by being trained to maximise the log-likelihood of the training data. In addition to estimating the whole image likelihood , we produced spatial likelihood maps by reshaping each from the 1D sequence back into the 3D shape of the latent representation and upsampling to produce a spatial likelihood map of the same dimension as the input image. The transformer’s attention mechanism has a quadratic memory dependence on sequence length that makes it difficult to train on large sequences, so we made use of the more efficient Performer architecture [choromanski2020rethinking] which uses a linearised approximation of the attention matrices to allow for training on longer sequences. We used a 22 layer Performer with 8 attention heads and a latent representation of size 256. The model was trained using the cross-entropy loss using a learning rate of and a batch size of 240 on a Nvidia DGX A100.
Both models were trained on the full CROMIS dataset. It should be noted that this dataset contains pathological images containing haemorrhages; the definition of in-distribution here is not healthy scans but rather scans that are similar to the segmentation network’s training set, and the aim is to estimate whether a new input is similar enough to the segmentation network’s training set, so that it will be segmented accurately.
3 Experiments and Results
3.1 Segmentation uncertainty
We firstly examine the performance of segmentation algorithms in the far-OOD case where images are of a different organ and/or modality than the intended target for segmentation. In this case, any detection can be considered a false-positive (FP). We calculated per-lesion confidence scores for each detection and compared them to the per-lesion scores of every true-positive (TP) detection on the normal head CT dataset. fig:seg_uncertainty_decathlon shows the distribution of lesion confidence scores for these two datasets overlap regardless of the segmentation uncertainty method used, meaning it is not possible to separate FP detections made on far-OOD data from TP detections on in-distribution data using any of these lesion confidence scores alone. This motivates the need for explicit OOD detection models.
Secondly, we look at the segmentation network’s uncertainty performance for near-OOD data: corrupted head CT scans. As these scans contain haemorrhages both TPs and FPs exist. We defined a TP as a predicted mask with at least 50% overlap with a ground truth mask, and computed the AUC obtained when using the per-lesion certainty scores to classify detections, see tab:seg_uncertainty_yee. The networks are able to classify lesions relatively well for certain types of corruption, including noise, image flipping, and the removal of ‘chunks’ from the data. However, they perform poorly for images with modified background values or those that have been skull-stripped. For scaled images all three methods had an AUC, showing they often assigned higher confidence values to FP detections than TP detections.
3.2 Transformers for OOD detection
We examined the ability of transformers to filter out OOD inputs based on the whole-image log-likelihood. fig:likelihood_coarse_classes shows the distribution of log-likelihood values for far-OOD, near-OOD and in-distribution classes (plots showing each sub-class can be found in fig:likelihood_fine_classes). tab:likelihood_auc show the ability of the log-likelihood to distinguish OOD classes from normal head CT data. Performance is perfect for the far-OOD case. In the case of near-OOD data, classes on which the segmentation uncertainty performed poorly are distinguished well: images with adapted backgrounds, skull-stripping, and global intensity scaling are all distinguished with an AUC=1. Subtler corruptions, namely noise with and L-R flips, are not distinguished well. These are classes for which the segmentation network uncertainty measures perform well, suggesting these corruptions are more in-distribution and explaining why they have been assigned likelihoods more similar to the normal head CT data. This result also suggests transformers and segmentation networks with uncertainty may be used in tandem, with highly OOD images being filtered out by the transformer and the segmentation network providing meaningful uncertainty estimates on images that are only slightly OOD. fig:cromis_outliers shows some qualitative results on real data: the CROMIS volumes with the lowest and highest log-likelihood values as assigned by the transformer. fig:mse_plots shows that reconstruction MSE alone is unable to separate out in-distribution and OOD data, indicating the transformer component is essential for OOD detection.
tab:ablation_study reports an ablation study for these results. The results demonstrate the VQ-GAN outperforms the VQ-VAE; but that both the perceptual loss and adversarial loss are needed together to improve performance. Our inclusion of an additional spectral loss provides a modest improvement in performance over the standard VQ-GAN losses. Examination of spatial likelihood maps in fig:spatial_likelihood_demo show good localisation of image corruptions. These maps might be used to explain what part of the image has caused the system to flag it as low-likelihood. A common failure mode reported for generative approaches to anomaly detection is the assignment of higher log-likelihoods to very OOD samples [nalisnick2018deep]; this did not occur in our experiments, which suggests that transformer-based anomaly detection methods could be more robust to this type of failure.
|Non Head CT||Head MR||-7288 (134)||1.00|
|Colon CT||-10809 (789)||1.00|
|Hepatic CT||-10712 (763)||1.00|
|Hippocampal MR||-7465 (20)||1.00|
|Liver CT||-11116 (658)||1.00|
|Lung CT||-9957 (289)||1.00|
|Pancreas CT||-10798 (791)||1.00|
|Prostate MR||-9140 (134)||1.00|
|Spleen CT||-10895 (382)||1.00|
|Cardiac MR||-9661 (318)||1.00|
|Corrupted Head CT||Noise||-5796 (253)||0.49|
|BG value=0.3||-9022 (89)||1.00|
|BG value=0.6||-8803 (100)||1.00|
|BG value=1.0||-9979 (127)||1.00|
|Flip L-R||-5850 (253)||0.55|
|Flip A-P||-7435 (205)||1.00|
|Flip I-S||-9036 (165)||1.00|
|Chunk top||-6382 (214)||0.96|
|Chunk middle||-7784 (179)||1.00|
|Skull stripped||-7226 (125)||1.00|
|Scaling 10%||-7436 (119)||1.00|
|Scaling 1%||-7205 (25)||1.00|
|Normal Head CT||-5803 (256)||-|
We further sought to characterise the relationship between an image’s likelihood and the performance of the segmentation network; we show results for the best-performing dropout network in fig:likelihood_vs_seg_fps and include results for other networks, which are similar, in fig:likelihood_vs_seg_fps_extra. While we would not necessarily expect a relationship between examples the transformer assigns a low likelihood to and examples the networks perform poorly on, our results indicate there is a strong one and that poor segmentations could be effectively filtered out using a log-likelihood threshold of -7000. Furthermore, the results show that corrupted CT-scans that were assigned similar likelihoods to the in-distribution data tend to be handled better by the segmentation network, supporting our claim that these indeed are more subtle corruptions.
In a clinical setting, deployed systems must be robust to the full range of data they may be presented with. Our results show that commonly used methods for obtaining uncertainty from segmentation networks fail when OOD, making confidently wrong predictions in far-OOD and near-OOD settings. We demonstrate that with a VQ-GAN for image compression and a transformer for density estimation we can successfully detect both far- and near-OOD data, and we find the images flagged as OOD by this approach are precisely the ones that segmentation networks struggle with. These results suggest our approach can be used to filter out OOD inputs that a segmentation network is likely to fail on, ensuring it is only run on in-distribution data they are likely to perform well on. Future work will focus on further validating these findings on real-world clinical datasets.
MG, PW, WS, SO, PN and MJC are supported by a grant from the Wellcome Trust (WT213038/Z/18/Z). MJC and SO are also supported by the Wellcome/EPSRC Centre for Medical Engineering (WT203148/Z/16/Z), and the InnovateUK-funded London AI centre for Value-based Healthcare. PN is also supported by NIHR UCLH Biomedical Research Centre. YM is supported by a grant from the Medical Research Council (MR/T005351/1). The models in this work were trained on NVIDIA Cambridge-1, the UK’s largest supercomputer, aimed at accelerating digital biology.