Stop Throwing Away Discriminators! Re-using Adversaries for Test-Time Training

08/26/2021 ∙ by Gabriele Valvano, et al. ∙ 0

Thanks to their ability to learn data distributions without requiring paired data, Generative Adversarial Networks (GANs) have become an integral part of many computer vision methods, including those developed for medical image segmentation. These methods jointly train a segmentor and an adversarial mask discriminator, which provides a data-driven shape prior. At inference, the discriminator is discarded, and only the segmentor is used to predict label maps on test images. But should we discard the discriminator? Here, we argue that the life cycle of adversarial discriminators should not end after training. On the contrary, training stable GANs produces powerful shape priors that we can use to correct segmentor mistakes at inference. To achieve this, we develop stable mask discriminators that do not overfit or catastrophically forget. At test time, we fine-tune the segmentor on each individual test instance until it satisfies the learned shape prior. Our method is simple to implement and increases model performance. Moreover, it opens new directions for re-using mask discriminators at inference. We release the code used for the experiments at https://vios-s.github.io/adversarial-test-time-training.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Semi- and weakly-supervised learning are emerging paradigms for image segmentation

[review2019notsosup, review2020embracing], often involving adversarial training [goodfellow2014generative] when annotations are sparse or missing. Adversarial training involves two simultaneously trained networks: one focusing on an image generation task, and the other learning to tell apart generated images from real ones. In semantic segmentation, it is standard practice to condition the generator, also termed segmentor, on an input image and optimise it to output realistic and accurate segmentation masks. After training, the discriminator is discarded and the segmentor used for inference.

Unfortunately, segmentors may under-perform and make errors whenever the test data fall outside the training data distribution (e.g., because acquired with a different scanner or belonging to a different population study). Here we propose a simple mechanism to detect and correct such errors in an end-to-end fashion, re-using components already developed during training.

We embrace an emerging paradigm [sun2020test, wang2020fully, karani2021test, he2021autoencoder] where a model is fine-tuned on individual test instances without requiring access to other data nor labels. We propose strategies that permit recycling an adversarial mask discriminator during inference, thus introducing a data-driven shape prior to correct predictions. Motivated by recent findings of Asano et al. [asano2019critical], reporting that we can effectively train the early layers’ weights of a CNN with just one image, we propose to tune them on a per-testing instance to minimise an adversarial loss. Lastly, contrary to standard post-processing operations, our method can potentially learn from a continuous stream of data [sun2020test]. Our contributions are: 1) to the best of our knowledge, this is the first attempt to use adversarial mask discriminators to detect and correct segmentation mistakes during inference; 2) we define specific assumptions (and show how to satisfy them) to make the discriminators useful after training; 3) we report performance increase on several medical datasets.

2 Related Work

Learning from Test Samples. In our work, we use a discriminator to tune a segmentor on the individual test images until it predicts realistic masks. The idea of fine-tuning a model on the test samples has recently been introduced by Sun et al. [sun2020test] with the name of Test-time Training (TTT). TTT optimises a model by jointly minimising a supervised and an auxiliary self-supervised loss on a training set, such as detecting the rotation angle of an input image. At inference, TTT fine-tunes the model to minimise the auxiliary loss on the individual test instances, thus adapting to potential distribution shifts. Although the model was successful for classification, the authors admit that designing a well-suited auxiliary task is non-trivial. For example, predicting a rotation angle may be less effective for medical image segmentation, where images have different acquisition geometries. Moreover, Sun et al. only test their model “simulating” domain shifts with hand-crafted image corruptions (e.g., noise and blurring) without investigating if TTT can improve segmentation performance.

Following this seminal work, Wang et al. [wang2020fully] suggested tuning an adaptor network to minimise the test prediction entropy. Unfortunately, CNNs usually make low-entropy overly-confident predictions [guo2017calibration], and entropy minimisation could be sub-optimal for segmentation. More crucially, Wang et al. rely on having access to the entire test-set to do the fine-tuning.

Karani et al. [karani2021test]

recently proposed Test-time Adaptable Neural Networks to extend TTT for image segmentation using a pre-trained mask denoising autoencoder (DAE). At inference, they compute a reconstruction error between the mask generated by a segmentor and its auto-encoded version predicted by the DAE. This error constitutes a test-time loss used to fine-tune a small adaptor CNN in front of the segmentor. Once tuned, the adaptor maps the individual test images onto a normalised space which overcomes domain shifts problems for the segmentor. A limitation of this approach is the need to train the mask DAE separately. On the contrary, GANs learn the shape prior and optimise the segmentor in an

end-to-end

fashion. Moreover, tuning the model with a convolutional encoder (the discriminator) rather than an autoencoder has advantages in terms of occupied memory and is faster at inference. Herein, we show that improving performance using a discriminator is also possible and, at the same time, we open a new research direction toward learning re-usable discriminators.

Shape Priors in Deep Learning for Medical Image Segmentation.

Incorporating prior knowledge about organ shapes is not uncommon in medical imaging [nosrati2016incorporating]. Several methods introduced shape priors to regularise the training of a segmentor using penalties [kervadec2019constrained, clough2019topological], autoencoders [oktay2017anatomically, dalca2018anatomical], atlases [dalca2019unsupervised], and adversarial learning [yi2019generative, valvano2021learning]. Others included shape priors for post-processing, fixing prediction mistakes [painchaud2019cardiac, larrazabal2020post]. GANs have become a popular way of introducing shape priors for image segmentation [yi2019generative], with the advantage of: i) learning the prior directly from data; ii) having a simple model that works well for semi- and weakly-supervised learning; and iii) learning the prior while also training the segmentor, instead of in two separate steps (as happens for autoencoders).

Re-using Adversarial Discriminators.

Re-using pre-trained discriminators was proposed to obtain features extractors for transfer learning

[radford2015unsupervised, donahue2016adversarial, mao2019discriminator], or anomaly detectors [zenati2018adversarially, ngo2019fence]. To the best of our knowledge, their (re-)use to detect segmentor mistakes during inference remains unexplored. We are also not aware of previous use of discriminators for test-time tuning of a segmentor.

Figure 1: We re-use GAN discriminators to correct segmentation predictions at inference. The key to our success is training stable and re-usable discriminators, as we detail in Section 3.1. At inference, we tune a small convolutional block on each test image , independently, until the predicted mask satisfies the adversarial shape prior. We only need a single test sample to do the fine-tuning.

3 Method

As summarised in Fig. 1, we consider two stages: i) standard adversarial training; and ii) at inference, image-specific tuning of a small adaptor CNN () in front of the trained segmentor. In the first stage, we optimise a segmentor to minimise a supervised cost on the annotated data and an adversarial cost on a set of unpaired images. Meanwhile, we train the discriminator to distinguish real from predicted masks. At inference, for each test-image, we tune the adaptor using the (unsupervised) adversarial loss, and improve performance. For and we use architectures that proved to be effective in segmentation tasks [ronneberger2015u, karani2021test], while we leave exploring alternative architectures as future work.

Obtaining discriminators re-usable at inference is not trivial and requires specific solutions to overcome crucial challenges. These solutions, with our optimisation strategy and model design, are one major contribution of this work.

In the following, we will use italic lowercase letters to denote scalars s, and bold lowercase for 2D images , where are the height and width of the image, respectively. Lastly, we adopt capital Greek letters for functions .

3.1 Re-usable Discriminators: Challenges and Proposed Solutions

Challenge 1. To obtain a re-usable discriminator , we must prevent it from overfitting and catastrophically forget, or its predictions on the masks generated during inference will not be reliable. Generally speaking, this is a challenging task because: GANs can easily memorise data if trained for too long [nagarajan2018theoretical].111Memorisation can also happen just in the discriminator. In fact, contrarily to the segmentors, we do not use any additional supervised cost to regularise the discriminator training. We show how to detect memorisation from the losses in the Supplemental. Moreover, the discriminator may forget how unrealistic segmentation masks look like after the segmentor training has converged [shrivastava2017learning]. Although may work well at training in these cases, it would not generalise to the test data, as we explain below.

If properly trained, a segmentor predicts realistic segmentation masks in the latest stages of training. Thus, in standard GANs, we stop training while optimising to tell apart real from more and more real-looking masks. At convergence, this becomes similar to training the discriminator using only real images and labelling them as real half the times, as fake the other half. At this point, gradients become uninformative, and the discriminator collapses to one of the following cases: i) it always predicts its equilibrium point (which in vanilla GANs is the number 0.5, equidistant from the labels real: 1, fake: 0) but it can still detect unrealistic images; ii) it predicts the equilibrium point independently of the input image, forgetting what fake samples look like [shrivastava2017learning, kim2018memorization]; or iii)

it memorises the real masks (which, differently from the generated ones, appear unchanged since the beginning of training) and it always classifies them as

real, while classifying any other input as fake. It is crucial to prevent the behaviours ii) and iii) to have a re-usable discriminator. For this reason, we use:

  • Fake anchors: we ensure to expose the discriminator to unrealistic masks (labelled as fake) until the end of training. In particular, we train using real masks , predicted masks , and corrupted masks . We obtain by randomly swapping squared patches within the image222We use patches having size equal to 10% of the image size. and adding binary noise to the real masks, as this proved to be a fast and effective strategy to learn robust shape priors in autoencoders [karani2021test]. While, towards the end of the training, the discriminator may not distinguish from the real-looking , the exposure to will prevent forgetting how unrealistic masks look like, providing informative gradients until we stop training.

Challenge 2. An additional challenge is to train stable

discriminators, which do not change much in the latest training epochs. In other words, we want small oscillations in the discriminator loss. This is necessary because we typically stop training using early stopping criteria on the segmentor loss. Therefore, we want to promote the optimisation of Lipschitz smooth discriminators, avoiding suddenly big gradient updates (thus leaving

mostly unchanged between the last few training epochs). To this end, we suggest using:

  • Smoothness Constraints: we increase discriminator smoothness [chu2020smoothness] through Spectral normalisation [miyato2018spectral] and activations.

  • Discriminator data augmentation: consisting of random roto-translations, and Instance Noise [sonderby2016amortised, muller2019does], to map similar inputs to the same prediction label. We translate images up to 10% of image pixels on both vertical and horizontal axes, and we rotate them between

    . We generate noise using a Normal distribution with zero mean and 0.1 standard deviation.

3.2 Architectures and Training Objectives for and

We use a UNet [ronneberger2015u] segmentor with batch normalisation [ioffe2015batch]. Given an input image , the segmentor predicts a multi-channel label map . For the annotated images, we minimise the supervised weighted cross-entropy loss:

(1)

where is a class index, c the number of classes, and a class scaling factor used to address the class imbalance problem. The value considers both the total number of pixels and the number of pixels with label i.

As discriminator , we use a convolutional encoder, processing the predicted masks with a series of 5 convolutional layers. Layers use a number of

filters following the series: 32, 64, 128, 256, 512. After the first two layers, we downsample the features maps using a stride of 2. We increase smoothness according to Section 

3.1. Finally, a fully-connected layer integrates the extracted features and predicts a scalar linear output, used to compute the adversarial losses [mao2018effectiveness]:

(2)

where and are the labels for fake and real images, respectively. As in standard adversarial semi-supervised training, we alternately minimise eq. 1 on a batch of annotated images and eq. 2 on a batch of unpaired images and unpaired masks. We use Adam optimiser [kingma2014adam], learning rate: , and batch size: 12. Training proceeds until the segmentation loss stops decreasing on a validation set.

3.3 Adversarial Test-Time Training: Adapting

At inference, we do not fine-tune the whole segmentor but only adapt a few convolutional layers at its input. These layers are, according to [asano2019critical], the most suited for one-shot learning. By keeping the deeper layers of unchanged, we also limit the segmentor flexibility and let it adapt only to changes at lower abstraction levels, ultimately preventing trivial solutions. Thus, we include a shallow convolutional residual block (adaptor ) in front of the segmentor, that we tune on the individual test images by minimising for iterations. The adaptor is the same as in [karani2021test] and has 3 convolutional layers with 16 kernels and activation , being

an input tensor and

s a trainable scaling parameter, initialised as 0 and optimised at inference. After tuning , the input to the segmentor is an augmented version of which can be more easily classified. We show qualitative examples in Fig. 2 and in the Supplemental.

4 Experiments

Figure 2: Effect of Test-time Training (TTT). Re-using a discriminator at inference, we optimise a small input adaptor until the predicted mask becomes realistic. We report additional examples in the Supplemental.
Figure 3: Dice (↑), IoU (↑) and Hausdorff distance (↓) obtained before and after tuning the segmentor on the individual test instances. Arrows show metric improvement directions. Under each violin plot, we also report the average performance (standard deviation as subscript). We always improve the metrics, also in the worst-case scenarios (bottom of the distribution tails for Dice and IoU, upper tails for Hausdorff distance). Asterisks show statistical significance.
Data

We consider two cardiac MRI datasets, described below.
ACDC [bernard2018deep] has multi-scanner images from 100 patients with manual annotations for right and left ventricle, and for left myocardium. We resample data to the average resolution: 1.51

, and crop/pad them to

pixels. We standardise data using the patient-specific median and interquartile range.
LVSC [suinesiaputra2014collaborative] contains cardiac MRIs of 100 subjects, obtained with different scanners and imaging parameters. There are manual annotations for the left myocardium. We resample images to the average resolution of 1.45, and then crop or pad them to pixel size. We normalise images as in ACDC.

Setup and Evaluation.

We divide datasets by patients, using groups of 40% for training, 20% for validation, and 40% for the test set, respectively. Out of the 40% training patients, we consider annotations for one fourth of the training subjects in ACDC and LVSC (10 patients). We treat the remaining data as unpaired and use them for adversarial training (eq. 2). Notice that the small training sets cannot fully represent the entire data distribution, leading to segmentation errors at inference. We will investigate dealing with larger distribution shifts (e.g. different scanners, etc.) in the future. We analyse performance increases obtained from adversarial Test-time Training. Inspired by [karani2021test], we also compare to using a DAE to drive the adaptation (DAEs learn the shape prior separately, we do it while training ). We do 3-fold cross-validation and measure performance comparing the predicted segmentation masks and the ground truth labels contained in the test set. We use Dice and IoU scores, and the Hausdorff distance. We assess statistical significance with the non-parametric Wilcoxon test ().

Figure 4:

Test-time Training using an adversarial shape prior vs a prior learned by a pre-trained DAE. In the plots, “Baseline” refers to standard inference of a GAN (i.e. without TTT). Bar plots show average and 95% confidence interval. Both methods lead to similar improvements (

).

4.1 Results and Discussion

We show a qualitative example of test-time adaptation in Fig. 2. In Fig. 3, we represent segmentation performance with violin plots before and after Test-time Training. These plots show the whole distribution of performance values for images in the test set. We observe performance improvements on all metrics and datasets. Importantly, worst-case scenarios (bottom tails of violin plots, for Dice and IoU; top tails for Hausdorff distance) considerably improve, reflecting the desired tendency to correct unrealistic segmentation masks that do not satisfy the learned adversarial shape prior. Qualitatively, we observed that the model removes scattered false positives and closes holes in the segmentation masks (see Fig. 1 in the Supplementary material).

In Fig. 4, we compare the performance of our method vs using a shape prior separately learned by a DAE, inspired by [karani2021test]. Our method achieves similar performance gains to a DAE (no statistically significant differences found), but it has the advantage of not requiring a separate pre-training step.

UNet   GAN   GAN + #1   GAN + #1 + #2   GAN + #1 + #2 + #3
70.113 70.012 70.911 71.210 72.410
Table 1: Ablation Study. We compare the performance of a UNet; a standard GAN; the GAN after adding: smoothness constraints (#1), the proposed regularisation technique: fake anchors (#2), and Test-Time Training (#3). Results are average (standard deviation as subscript) Dice scores on the ACDC test set.

Lastly, we perform an ablation study to analyse the effect regularising the model with smoothness constraints and fake anchors. As illustrated in Table 1, the techniques improves training and makes the adversarial shape prior stronger. As a result: i) the adversarial training leads to a better segmentor; and ii) the re-usable discriminator further increases model performance.

Computational Aspects

The memory required to store the weights of the model is 90 MB. At inference, our method needs forward and backward passes to correct a segmentation. This is slower than standard inference, where each image requires one forward pass. We find that improves segmentation, but high values (e.g. ) overfit the segmentor leading to worse performance. As a compromise, we use (with small temporal overhead: 10s/patient on a TITAN Xp GPU). This fixed iteration strategy is also used by previous work [sun2020test, karani2021test, he2021autoencoder], but using image-specific optimal would be useful and potentially increase performance. We leave automated strategies to set as future work.

Limitations

We find that does not penalise wrong predictions that appear realistic but do not correspond to the input image. In fact, the discriminator only evaluates the predicted mask without considering the segmentor input. We highlight that this is also a limitation of [karani2021test] and of all methods learning the shape prior only using unpaired masks. We expect that including also the image-related information would improve Test-time Training.

5 Conclusion

We demonstrated that by satisfying simple assumptions, it is possible to re-use adversarial discriminators during inference. In particular, we re-used a mask discriminator to detect and then correct segmentation mistakes made by a segmentor. The proposed method is simple and can be potentially applied to any GAN, increasing its test-time performance on the most challenging images.

More broadly, the possibility of re-using adversarial discriminators to correct generator errors may open opportunities even outside image segmentation. Given their flexibility and the ability to learn data-driven losses, GANs have been widely adopted in medical imaging, from domain adaptation to image synthesis tasks [yi2019generative]. With improved architectures and regularisation techniques [kurach2019large, chu2020smoothness], we believe adversarial networks will be even more popular in the future. In this context, training stable and re-usable discriminators opens opportunities for an all-round use of the GAN components.

5.0.1 Acknowledgments

This work was partially supported by the Alan Turing Institute (EPSRC grant EP/N510129/1). S.A. Tsaftaris acknowledges the support of Canon Medical and the Royal Academy of Engineering and the Research Chairs and Senior Research Fellowships scheme (grant RCSRF1819\8\25).

References