To lessen the need for large-scale annotated datasets, researchers have recently explored weaker forms of supervision [kervadec2020bounding, valvano2021learning], consisting of weak annotations that are easier and faster to collect. Unfortunately, weak labels provide lower quality training signals, making it necessary to introduce regularisation to prevent model overfitting. Examples of regularisation are: forcing the model to produce similar predictions for similar inputs [ouali2020overview, valvano2019temporal], or using prior knowledge about object shape [kervadec2019constrained, zhou2019prior], intensity [review2016incorporating], and position [kayhan2020translation].
Data-driven shape priors learned by Generative Adversarial Networks (GAN) are popular regularisers[yi2019generative], exploiting unpaired masks’ availability to improve training. Recently, GANs have been used in weakly supervised learning, showing that they can provide training signals for the unlabelled pixels of an image [zhang2020accl]. Moreover, multi-scale GANs also provide information on multi-scale relationships among pixels [valvano2021learning], and can be easily paired with attention mechanisms [valvano2021learning, zhang2019self] to focus on the specific objects and boost performance. However, GANs can be difficult to optimise, and they require a set of compatible masks for training. Annotated on images from a different data source, these masks must contain annotations for the exact same classes used to train the segmentor. Moreover, the structures to segment must be similar across datasets to limit the risk of covariate shift. For example, there are no guarantees that optimising a multi-scale GAN using masks from a paediatric dataset will not introduce biases in a weakly supervised segmentor meant to segment elderly images.
Thus, multi-scale GANs are not always a feasible option. In these cases, it would be helpful to introduce multi-scale relationships without relying on unpaired masks. Herein, we show that it is possible to do so without performance loss. Our contributions are: i) we present a novel self-supervised method to introduce multi-scale shape consistency without relying on unpaired masks for training; ii) we train a shape-aware segmentor coupling multi-scale predictions and attention mechanisms through a mask-free self-supervised objective; and iii) we show comparable performance gains to that of GANs, but without need for unpaired masks. We summarise our idea in Fig. 1.
2 Related Work
Weakly-supervised Learning for Image Segmentation. Recent research has explored weak annotations to supervise models, including: bounding boxes [kervadec2020bounding], image-level labels [patel2021weakly], point clouds [qu2020weakly], and scribbles [lin2016scribblesup, can2018learning, dorent2020scribble, valvano2021learning]. Although it is possible to extend the proposed approach to other types of weak annotations, herein, we focus on scribbles, which have shown to be convenient to collect in medical imaging, especially when annotating nested structures [can2018learning].
A standard way to improve segmentation with scribbles is to post-process model predictions using Conditional Random Fields (CRFs) [lin2016scribblesup, can2018learning]. Recent work avoids the post-processing step and the need of tuning the CRF parameters by including learning constraints during training. For example [belharbi2020deep] uses a max-min uncertainty regulariser to limit the segmentor flexibility, while other approaches regularise training using global statistics, such as the size of the target region [zhou2019prior, kervadec2019constrained, kervadec2020bounding] or topological priors [kervadec2020bounding]. Although they increase model performance, the applicability of these constraints is limited to specific assumptions about the objects and usually requires prior knowledge about the structure to segment. As a result, these methods face difficulty when dealing with pathology or uncommon anatomical variants. On the contrary, we do not make strong assumptions: we use a general self-supervised regularisation loss, optimising the segmentor to maintain multi-scale structural consistency in the predicted masks.
Multi-scale Consistency and Attention. Multi-scale consistency is not new to medical image segmentation. For example, deep supervision uses undersampled ground-truth segmentations to supervise a segmentor at multiple resolution levels [dou20173d]. Unfortunately, differently from these methods, we cannot afford to undersample the available ground-truth annotations because scribbles, which have thin structures, would risk to disappear at lower scales.
Other methods introduce the shape prior training GAN discriminators with a set of compatible segmentation masks [zhang2020accl, valvano2021learning]. Instead, we remove the need of full masks for training, and we impose multi-scale consistent predictions through an architectural bias localised inside of attention gates within the segmentor.
Attention has been widely adopted in deep learning[Jetley2018] as it suppresses the irrelevant or ambiguous information in the feature maps. Recently, attention was also successfully used in image segmentation [oktay2018attention, schlemper2019attention, sinha2020multi]. While standard approaches do not explicitly constrain the learned attention maps, Valvano et al. [valvano2021learning] have recently shown that conditioning the attention maps to be semantic increases model performance. In particular, they condition the attention maps through an adversarial mask discriminator, which requires a set of unpaired masks to work. Herein, we replace the mask discriminator with a more straightforward and general self-supervised consistency objective, obtaining attention maps coherent with the segmentor predictions at multiple scales.
Self-supervised Learning for Medical Image Segmentation. Self-supervised learning studies how to create supervisory signals from data using pretext tasks: i.e. easy surrogate objectives aimed at reducing human intervention requirements. Several pretext tasks have been proposed in the literature, including image in/out-painting [zhou2019models], superpixel segmentation [ouyang2020self], coordinate prediction [bai2019self], context restoration [chen2019self] and contrastive learning [chaitanya2020contrastive]. After a self-supervised training phase, these models need a second-stage fine-tuning on the segmentation task. Unfortunately, choosing a proper pretext task is not trivial, and pre-trained features may not generalise well if unrelated to the final objective [zamir2018taskonomy]. Hence, our method is more similar to those using self-supervision to regularise training, such as using transformation consistency [xie2020pgl] and feature prediction [valvano2019temporal].
3 Proposed Approach
Notation. We use capital Greek letters to denote functions , and italic lowercase letters for scalars s. Bold lowercase define two-dimensional images , with
natural numbers denoting image height and width. Lastly, we denote tensorsusing uppercase letters, with .
Method Overview. We assume to have access to pairs of images and their weak annotations (in our case, are scribbles), which we denote with the tuples . We present a segmentor incorporating a multi-scale prior learned in a self-supervised manner. We introduce the shape-prior through a specialised attention gate residing at several abstraction levels of the segmentor. These gates produce segmentation masks as an auxiliary task, allowing them to construct semantic attention maps used to suppress background activations in the extracted features. As our model predicts and refines the segmentation at multiple scales, we refer to these attention modules as Pyramid Attention Gates (PyAG).
Model Architecture and Training. The segmentor is a modified UNet [ronneberger2015u] with batch normalisation [ioffe2015batch]. Encoder and decoder of the UNet are interconnected through skip connections, which propagate features across convolutional blocks at multiple depth levels d. We leave the encoder as in the original framework while we modify the decoder at each level, as illustrated in Fig. 2. In particular, we first process the extracted features with two convolutional layers, as in the standard UNet. Next, we refine them with the introduced PyAG module, represented in a light yellow background on the right side of Fig. 2
. Each PyAG module consists of: classifier, background extraction, and multiplicative gating operation. As a classifier, we use a convolutional layer withc filters having size , with c the number of segmentation classes including the background, and k the number of input channels. Obtained an input feature map at depth d, the classifier predicts a multi-channel score map that we pass through a softmax. The resulting tensor assigns a probabilistic value between 0 and 1 to each spatial location. We make this tensor a lower-resolution version of the predicted segmentation mask using the self-supervised consistency constraint:
where d is the depth level, i is an index denoting the class, is the prediction at depth d, and is the final prediction of the model.111Here we assume that the predicted is a mask, not a scribble. Intuitively, our hypothesis derives from the observation that unlabelled pixels in the image have an intrinsic uncertainty: thus, the segmentor will look for clues in the image (e.g. anatomical edges and colours) to solve the segmentation task. Since eq. 3 does not limit model flexibility on the unlabelled pixels, we empirically confirm our hypothesis. Notice that, different from [valvano2021learning], we condition with rather than a multi-scale discriminator.
To prevent affecting the final prediction, we propagate the self-supervised training gradients only through the attention gates and the segmentor encoder, as we graphically show in Fig. 2, left. We further constrain the segmentor to reuse the extracted information by suppressing the activations in the spatial locations of the feature map which can be associated with the background (Fig. 2, right). This multiplicative gating operation can be formally defined as:
where is the background channel of the predicted mask at the depth level d. The extracted features are finally upsampled to the new resolution level and processed by the next convolutional block.
To supervise the model with scribbles, we use the Partial Cross-Entropy (PCE) loss [tang2018normalized] on the final prediction . By multiplying the cross-entropy with a labelled pixel identifier , the PCE avoids loss contribution on the unlabelled pixels. The role of the masking function is to return 1 for annotated pixels, 0 otherwise. Mathematically, we formulate the weakly-supervised loss as:
with the ground truth scribble annotation.
Considering both weakly-supervised and self-supervised objectives, the overall cost function becomes: , where a is a scaling factor that balances training between the two costs. Similar to [valvano2021learning], we find beneficial to use a dynamic value for a, which maintains a fixed ratio between supervised and regularisation cost. In particular, we set , where is meant to give more importance to the supervised objective. We minimise using Adam optimiser [kingma2014adam] with a learning rate of 0.0001, and a batch size of 12.
ACDC [bernard2018deep] has cardiac MRIs of 100 patients. There are manual segmentations for right ventricle (RV), left ventricle (LV) and myocardium (MYO) at the end-systolic and diastolic cardiac phases. We resample images to the average resolution of 1.51
, and crop/pad them topixels. We normalise data by removing the patient-specific median and dividing by its interquartile range.
CHAOS [chaos] contains abdominal images from 20 different patients, with manual segmentation of liver, kidneys, and spleen. We test our method on the available T1 in-phase images. We resample images to 1.89 resolution, normalise them in between -1 and 1, and then crop them to pixel size.
LVSC [suinesiaputra2014collaborative] has cardiac MRIs of 100 subjects, with manual segmentations of left ventricular myocardium (MYO). We resample images to the average resolution of 1.45 and crop/pad them to pixels. We normalise data by removing the patient-specific median and dividing by its interquartile range.
PPSS [luo2013pedestrian] has (non-medical) RGB images of pedestrians with occlusions. Images were obtained from 171 different surveillance videos and cameras. There are manual segmentations for six pedestrian parts: face, hair, arms, legs, upper clothes, and shoes. We resample all the images to the same spatial resolution of the segmentation masks: ; then we normalise images in range.
Scribbles. The above datasets provide fully-annotated masks. To test the advantages of our approach in weakly-supervised learning, we use the manual scribble annotations provided for ACDC in [valvano2021learning]. For the remaining datasets, we follow the guidelines provided by Valvano et al. [valvano2021learning] to emulate synthetic scribbles using binary erosion operations or random walks inside the segmentation masks.
Setup. We divide ACDC, LVSC, and CHAOS data into groups of 70%, 15% and 15% of patients for train, validation, and test set, respectively. In PPSS, we follow recommendations in [luo2013pedestrian], using images from the first 100 cameras to train (90%) and validate (10%) our model, the remaining 71 cameras for testing it.
4.2 Evaluation Protocol
We compare segmentation performance of our method, termed UNetPyAG, to:
UNet: Trained on scribbles using the loss [tang2018normalized].
UNetComp.: UNet segmentor whose training is regularised with the Compactness loss proposed by [liu2020shape], which models a generic shape compactness prior and prevents the appearance of scattered false positives/negatives in the generated masks. The compactness prior is mathematically defined as: , where is the perimeter length and is the area of the generated mask. As for our method, we dynamically rescale this regularisation term to be 10 times smaller than the supervised cost (Sec. 3).
UNetCRF: Lastly, we consider post-processing the previous UNet predictions through CRF to better capture the object boundaries [chen2017deeplab].222CRF models the pairwise potentials between pixels using weighted Gaussians, weighting with values and , and parametrising the distributions with the factors . For ACDC and LVSC, we use the cardiac segmentation parameters in [can2018learning]: . For CHAOS, we manually tune . Finally, for PPSS, we tuned them to: .
While our method does not need a set of unpaired masks for training, we also compare with methods which learn the shape prior from masks:
UNetAAG [valvano2021learning]: The method upon which we build our model by replacing the multi-scale GAN with self-supervision. The subscript AAG stands for Adversarial Attention Gates, which couple adversarial signals and attention.
DCGAN: We consider a standard GAN, learning the shape prior from unpaired masks. This model is the same as UNetAAG, but without attention gates and multi-scale connections between segmentor and discriminator.
ACCL [zhang2020accl]: It trains with scribbles using a PatchGAN discriminator to provide adversarial signals, and with the [tang2018normalized] on the annotated pixels.
We perform 3-fold cross-validation and measure segmentation quality using Dice and IoU scores, and the Hausdorff Distance. We use Wilcoxon test () to show if improvements w.r.t. the second best model are statistically significant.
As shown, our method is the best one when we compare it to other approaches that do not require extra masks for training (Fig. 4, top). In particular, a simple UNet has unsatisfying performance, but regularisation considerably helps. Adding the compactness loss aids more with compact shapes, such as those in ACDC, CHAOS and PPSS, while it can be harmful when dealing with non-compact masks, such as that of the myocardium (doughnut-shape) in LVSC.
Post-processing the segmentor predictions with CRF can lead to performance increase when object boundaries are well defined. On the contrary, we could not make the performance increase on CHAOS data, where using CRF made segmentation worse with all the metrics.
On LVSC, the introduced multi-scale shape consistency prior tends to make the model a bit less conservative on the most apical and basal slices of the cardiac MRI. Unfortunately, whenever there is a predicted mask but the manual segmentation is empty, the Hausdorff distance peaks. In fact, by definition, the distance assumes the maximum possible value (i.e. the image dimension) whenever one of the masks is empty, which makes the performance distribution on the test samples broader (see Hausdorff distance box plots for LVSC, Fig. 4, top).
On CHAOS, Dice and IoU are more skewed for methods not using unpaired masks for training (Fig.4, top row). This happens because CHAOS is a small dataset, and optimising models using only scribble supervision is challenging. On the contrary, the extra knowledge of unpaired masks may help (bottom row).
Finally, we compare our method with approaches using unpaired masks for training (Fig. 4, bottom). We find competitive performance on all datasets. While, in some cases, the UNetAAG performs slightly better than UNetPyAG, we emphasise that our approach can work also without unpaired masks.
We introduced a novel self-supervised learning strategy for semantic segmentation. Our approach consists of predicting masks at multiple resolution levels and enforcing multi-scale segmentation consistency. We use these multi-scale predictions as part of attention gating operations, restricting the model to re-use the extracted information on the object shape and position. Our method performs considerably better than other scribble-supervised approaches while having comparable performance to approaches requiring additional unpaired masks to regularise their training. Hoping to inspire future research, we release the code used for the experiments at https://vios-s.github.io/multiscale-pyag.
This work was partially supported by the Alan Turing Institute (EPSRC grant EP/N510129/1). S.A. Tsaftaris acknowledges the support of Canon Medical and the Royal Academy of Engineering and the Research Chairs and Senior Research Fellowships scheme (grant RCSRF1819\8\25).