Visual Feature Attribution using Wasserstein GANs

11/24/2017 ∙ by Christian F. Baumgartner, et al. ∙ 0

Attributing the pixels of an input image to a certain category is an important and well-studied problem in computer vision, with applications ranging from weakly supervised localisation to understanding hidden effects in the data. In recent years, approaches based on interpreting a previously trained neural network classifier have become the de facto state-of-the-art and are commonly used on medical as well as natural image datasets. In this paper, we discuss a limitation of these approaches which may lead to only a subset of the category specific features being detected. To address this problem we develop a novel feature attribution technique based on Wasserstein Generative Adversarial Networks (WGAN), which does not suffer from this limitation. We show that our proposed method performs substantially better than the state-of-the-art for visual attribution on a synthetic dataset and on real 3D neuroimaging data from patients with mild cognitive impairment (MCI) and Alzheimer's disease (AD). For AD patients the method produces compellingly realistic disease effect maps which are very close to the observed effects.



There are no comments yet.


page 1

page 4

page 6

page 8

page 12

page 13

page 14

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In this paper we address the problem of visual attribution, which we define as detecting and visualising evidence of a particular category in an image. Pinpointing all evidence of a class is important for a variety of tasks such as weakly supervised localisation or segmentation of structures [43, 45, 67], and better understanding disease effects, and physiological or pathological processes in medical images [69, 18, 12, 19, 13, 28, 56, 31, 32, 65].

Currently, the most frequently used approach to address the visual attribution problem is training a neural network classifier to predict the categories of a set of images and then following one of two strategies: analysing the gradients of the prediction with respect to an input image [28, 5, 56] or analysing the activations of the feature maps for the image [67, 43, 45] to determine which part of the image was responsible for making the associated prediction.

Figure 1: Our proposed method learns a map generating function from unlabelled training data. Given a test image, this function will generate an image-specific visual attribution map which highlights the features unique to that category. The method is of particular interest for creating medical disease effect maps. We show that on neuroimaging data the method predicts effects in very good agreement with the actual observed effects.

Visual attribution based directly on neural network classifiers may, under some circumstances, produce undesired results. It is known that such classifiers base their decisions on certain salient regions rather than the whole object of interest. It was recently shown that during training neural networks minimise the mutual information between input and output layers, thereby compressing the input features [52]. These findings suggest that a classifier may ignore features with low discriminative power if stronger features with redundant information about the target are available. In other words, neural network training may be working in opposition to the goal of visual attribution. As a consequence, if there is evidence for a class at multiple locations in the image (such as multiple lesions in medical images) some locations may not influence the classification result and may thus not be detected. We demonstrate this effect on a synthetic dataset in our experiments.

It would be highly desirable if instead we could visualise evidence of a particular category in a way that captures all category-specific effects in an image. Our main contribution is a novel approach towards solving the visual attribution problem which takes a first step in this direction. In contrast to the majority of recent techniques, the method does not rely on a classifier but rather aims at finding a map that, when added to an input image of one category, will make it indistinguishable from images from a baseline category. To this end we propose a generative model in which the additive map is learned as a function of the images. The method is based on Wasserstein generative adversarial networks (WGAN) [2], which have the desirable property that they minimise an approximation of the Wasserstein distance between the distributions of the generated images and the real ones.

We note that our method does not tackle the classification problem but rather assumes that the category labels of the test images have already been determined (e.g. using a separately trained classifier or by an expert). Furthermore, the method requires a baseline category, which is not the case for many benchmark recognition datasets in vision, but is in fact the case for many practical detection applications, especially in medical image analysis.

We demonstrate the method on synthetic 2D data and on large 3D brain MR data, where we aim to predict subject-specific disease effect maps for Alzheimer’s disease (AD).

1.1 Medical motivation

Identifying disease effects at the subject-specific level is of great interest for various medical applications. In clinically oriented research, identifying subject-specific disease effects would be useful for stratification amongst the patient population and to help disentangling diseases such as AD [26] and Schizophrenia [50], that are believed to be composed of multiple sub-types rather than a single disease. Furthermore, for clinicians, subject-specific maps could be helpful in assessing disease status and grading.

In this paper, we chose to study the disease effects of AD with respect to mild cognitive impairment (MCI), which is characterised by a slight decline in cognitive abilities. Patients with MCI are at increased risk of developing AD, but do not always do. We evaluate our method on one of the largest publicly available neuroimaging datasets acquired by the Alzheimer’s Disease Neuroimaging Initiative (ADNI). We used the MCI population as the baseline category and the AD population as the category of interest. Our choice to use MCI as our baseline is motivated by the fact that the ADNI dataset contains a number of MCI subjects who convert AD with imaging data at both stages of the disease. This allowed us to evaluate the predicted disease effects against real observed effects defined as the differences between images at the different stages. Note that even though using normal controls as the baseline is feasible, it would have been much harder to assess the proposed method due to the small number of control to AD converters in the ADNI dataset.

2 Related work

2.1 Visual attribution

A commonly used approach for weakly supervised localisation or segmentation is to analyse the final feature map of a neural network classifier [43, 45]. The Class Activation Mapping (CAM) method [67] builds on those techniques by reducing the feature maps of the second to last layer using a global average pooling layer, followed by a dense prediction layer. This allows to create class-specific activation maps as a linear combination of the weights in the last layer.

A large amount of works on medical images builds on the CAM technique. Examples include the work of Feng et al. [12] on pulmonary nodule localisation in CT, the work of Ge et al. [18] on skin disease recognition. Other examples are [69], [19]. It is important to note that CAM is restricted in the resolution of its visual attributions by the resolution of the last feature map. Consequently, often post-processing of the predictions is required [12, 13, 45]. In contrast, our proposed method can produce visual attributions at the resolution of the original input images.

Another class of techniques creates saliency maps by backpropagating back to the input image. Examples include Guided Backprop

[55], Excitation Backprop [64], Integrated Gradients [56], meaningful perturbations [13].

Similar techniques have been applied in the domain of medical images. Jamaludin et al. [28] use the backprop-based saliency technique proposed by [53] to pinpoint lumbar degradations, and Baumgartner et al. [5, 6] use a variant of [55] to localise fetal anatomy. Gao and Noble [15] apply a similar approach to localise the fetal heart.

2.2 Statistical disease models

Statistical analysis of medical images for identifying disease effects has been an instrumental tool for various diseases and disorders [58, 48, 9] as well as other non-disease related factors [17, 40, 29, 61, 44, 57]

. The most common approach is to use regression analysis or machine learning tools to generate population average maps, which highlights features that are salient across the population 

[3, 33, 63, 16, 41, 47, 14].

Recently, constructing subject-specific maps has received attention. Maumet et al. took a one-versus-all group analysis approach [39, 38], while Konukoglu and Glocker extracted subject-specific maps with predictive models and Markov Random Field restoration [31, 32].

The common drawback in the previous approaches is the need for registration. In order to compute disease effect maps, images of different subjects need to be non-rigidly aligned on a common template where statistical analysis can be performed. The non-rigid registration process brings additional uncertainty to the subject-specific maps. Our work, addresses this shortcoming and generates subject-specific disease effect maps without requiring registration.

2.3 Image generation using GANs

Generative adversarial images conditioned on an input image have been used in diverse applications such as video frame prediction [37]

, image super-resolution

[34], image-translation across domains using paired [27] and unpaired [68] images, and pixel level domain adaptation [7, 51].

In the context of medical images, GANs have been applied to super-resolution in retinal fundus images [36], for semi-supervised cardiac segmentation [66], synthesising computed tomography images from MR images[42, 62] and intraoperative motion modelling [24]. Although some of the above models use 3D data, the examined volumes are usually relatively small [24], or the networks operate in a patch-wise fashion [42]. It is important to note that in the case of brain MR images of Alzheimer’s disease patients, the diagnostic information is only visible at a high resolution and cannot be determined by considering small local patches only. In this paper, we therefore tackle the challenge of processing large 3D volumes directly.

2.4 Contributions

  1. We demonstrate a limitation in current neural network based visual attribution methods using synthetic data.

  2. We propose a novel visual attribution technique that can detect class specific regions more completely and at a high resolution.

  3. To our knowledge, this is the first application of generative adversarial networks on large structural 3D data.

An implementation of the proposed method is publicly available here:

3 Visual attribution using WGANs

3.1 Problem Formulation

Figure 2: Overview of VA-GAN. During training images are sampled from the categories . Images from are passed to the map generating function . The map generator aims to create additive maps which produce generated images that the critic cannot distinguish from images sampled from . The critic, tries to assign different values to generated and real images. During testing, can be used directly to predict a map in a single forward pass.

Our goal is to estimate a map that highlights the areas in an image which are specific to the class the image belongs to. We formulate the problem for two classes

, a baseline class and a class of interest. The formulation however, easily extends to the case of multiple classes of interest. We denote an image with and the distribution of images coming from class with and images from class with . In the case of medical application, could for example denote the set of images from a population with a certain disease and images of control subjects.

We formulate a problem as estimating a map function that, when added to an image from category , creates an image


which is indistinguishable from the images sampled from . Thereby, the map contains all the features which distinguish the input image from the other category. In the case of medical images, will by definition contain the effects of a disease visible in the images, i.e. a disease effect map.

We model the function

using a convolutional neural network, whose parameters we find using a WGAN.

3.2 Wasserstein GANs

In the GAN paradigm a generator function and a discriminator function (both neural networks) compete with each other in a zero-sum game [20]. Given random noise as input, the generator tries to produce realistic images that fool the discriminator, while the discriminator tries to learn the difference between generated and real images.

Arjovski and Bottou pointed out a limitation in this paradigm which precludes a guarantee that the generated images will necessarily converge to the target distribution [1] (although in practice, with appropriate training methods, many impressive results were achieved [46]

). Wasserstein GANs are a modification to the classic GAN paradigm where the discriminator is replaced by a critic which does not have an activation function in its final layer and which is constrained to be a

-Lipschitz function. WGANs have better optimisation properties and it can be shown that they minimise a meaningful distance between the generated and real distributions.

3.3 Constrained effect maps using WGANs

In this work we build on WGANs to find the optimal map generation function. In contrast to regular WGANs, we have a map generator function , which, during training, takes as input randomly sampled images from category rather than noise. tries to generate maps that, when added to , create images appearing to be from category . By trying to distinguish generated images from real images from category , the critic ensures that the generated maps are constrained to realistic modifications (see Fig. 2 for an overview). In the context of medical images, this means enforcing anatomically realistic modifications to the images.

Building on [2] this leads to the following cost function:


Optimising Eq. 2 directly could lead to changes in the input image that change the image identity. For instance, the brain anatomy of a subject could be changed to a degree where it does not only capture disease related changes but changes the subject identity. We want to encourage the smallest required map that still leads to a realistic . Thus add the following data regularisation term to the cost function:


where is the L1 norm [13].

The final optimisation is then given by


where is the set of 1-Lipschitz functions.

In order to enforce the Lipschitz constraint we use the optimisation method proposed in [22]. As recommended by [22], we weigh the gradient penalty with a factor of 10 throughout all experiments.

3.4 Network architecture

As we will discuss in more detail in Section 4.3, we design our proposed method with large 3D medical imaging data in mind, which often need to be processed at high resolutions in order to retain diagnostic information. Specifically, in our experiments on neuroimaging data, an input volume size of 128x160x112 voxels is used.

With such large images the limiting factor becomes storing the activations of the networks on GPU memory. With this in mind we design the map generator and the critic networks as follows.

3.4.1 Map generator network

The map generator function should be able to form an internal representation of the visual attributes that characterise the categories. In the case of brain images affected by dementia, it should be able to “understand” the systematic changes involved in the disease. Therefore, a relatively powerful network is required to adequately model the function . To this end, we use the 3D U-Net [10] (originally proposed for segmentation), as a starting point. The 3D U-Net has an encoder-decoder structure with a bottle-neck layer in the middle, but additionally introduces skip connections at each resolution level bypassing the bottle-neck. This allows the network to combine high-level semantic information (such as the presence of a structure) with low-level information (such as edges).

In order to reduce GPU memory consumption we reduce the number of feature maps by a factor of 4 in most layers. As in the original 3D U-Net [10] we use batch normalisation for all layers except the final one. The exact architecture is shown in Fig. 2 in the supplementary material.

3.4.2 Critic function

In line with related literature on image generation using GANs [27, 68, 51], we model our critic as a fully convolution network with no dense layers. We loosely base our architecture on the C3D network which achieved impressive results on action recognition tasks in video data by processing them directly in the spatio-temporal 3D space [59]. However, in contrast to that work we only perform 4 pooling steps. After the fourth pooling layer we add another 3x3x3 convolution layer, followed by a 1x1x1 convolution layer which reduces the number of feature maps to one. The final critic prediction is given by a global average pooling operation of that feature map.

It proved important not to use batch normalisation for the critic network. Towards the beginning of training generating statistics of a batch with generated and the real images may not produce reasonable estimates, because the images vary considerably from each other. We surmise that this effect prevents the critic from learning when batch normalisation is used. A similar observation was made in [22]. We also experimented with layer normalisation [4], but did not observe improvements.

The exact architecture we used is shown in Fig. 1 in the supplementary material.

3.5 Training

To optimise our networks, we follow [2, 22] and update the parameters of the critic and map generator networks in an alternating fashion. In contrast to the regular GANs [20], WGANs require a critic which is kept close to optimality through-out training. We therefore perform 5 critic updates for every map generator update. Additionally, for the first 25 iterations and every hundredth iteration, we perform 100 critic updates per generator update.

With the above architectures, the maximum batch size that can be used for a single gradient computation on a Nvidia Titan Xp GPU with 12 GB of memory is 2+2 (real+generated). In order to obtain more reliable gradient estimates we aggregate the gradients for a total of 6 mini-batches before performing a training step.

We used the ADAM optimiser [30] to perform the update steps for all experiments. The optimiser parameters were set to , , and we used a learning rate of . Lastly, we used a weight of for the map regularisation term (see Eq. 4) throughout the paper. Training took approximately 24 hours on an Nvidia Titan Xp.

4 Experiments

We evaluated the proposed method using a synthetically generated dataset and a large number of 3D brain MRI images from the publicly available ADNI dataset.

We compared our proposed visual attribution GAN (VA-GAN) to methods from the literature which have been used for visual attribution both on natural and on medical images. Specifically, we compared against Guided Backpropagation [55], Integrated Gradients [56] and Class Activation Mapping (CAM) [67]. Furthermore, to verify that the WGAN framework is necessary, we also investigated an alternative way of estimating the additive map not based on GANs, which is described in detail in the next section.

All the methods except VA-GAN use classification networks. For simplicity, we used a very similar architecture for these networks as for the critic in VA-GAN, except for two differences: (1) we replaced the last convolution and the global average pooling layer by two dense layers followed by a softmax and (2) we used batch normalisation for all layers, which produced better classification results for the experiments on the ADNI dataset. In addition, for the CAM method we designed the last layer as described by [67]

and omitted the last two max pooling layers, which allowed significantly more accurate visual attribution maps due to the higher resolution of the last feature maps.

Lastly, for the experiments on the 2D synthetic data we simply replaced all 3D operations by 2D operations, but left the architectures otherwise unchanged.

4.1 Classifier-based map estimation

In the VA-GAN approach, we generate an additive map which is constrained by the critic to generate a realistic image from the opposite class. To demonstrate that this approach is necessary we also investigated an alternative method of estimating the additive map without a term enforcing realistic maps.

The alternative approach requires training a classifier and then optimising an additive map image that lowers the prediction as much as possible. That is to say, the image should minimise . This formulation is almost exactly the same as for the WGAN-based approach (see Eq. 1) except that is not a function of .

We need to use a regularisation in determining to avoid trivial solutions, such as imperceptible changes that can fool classifiers [21]. A “well behaved” map can be found by the following minimisation problem:


Here indexes the pixels or voxels of . The L1 term weighted by encourages small maps, while the total variation term weighted by encourages smoothness.

We optimise this cost function using the ADAM optimiser using the default internal parameters given in [30] with a learning rate of and early stopping at 1500 iterations. Furthermore, we set , and in all experiments.

This approach is strongly related to the meaningful perturbation masks technique proposed by [13] in which parts of an image are locally deleted by a mask such that the prediction is minimised. In preliminary experiments we found that on the medical image problem we studied, visual attribution using destructive masks did not lead to the desired results. Deleting the diagnostic part of an image will not produce an image of the opposite class but rather an image with an undetermined diagnosis. This means such a mask may contain information about the location of diagnostic regions but not about specific disease effects, e.g. enlargement or shrinkage. In contrast, by optimising Eq. 5 we attempt to morph the image into the opposite class, such that diagnostic regions can be changed to have the characteristics of another class. Because of the similarity to [13], we refer to this method as additive perturbation maps.

4.2 Synthetic experiments

Data: In order to quantitatively evaluate the performance of the examined visual attribution methods, we generated a synthetic dataset of 10000 112x112 images with two classes, which model a healthy control group (label 0) and a patient group (label 1). The images were split evenly across the two categories. We closely followed the synthetic data generation process described in [32] where disease effects were studied in smaller cohorts of registered images.

Figure 3: Description of synthetic data. We generated noisy observations from ground-truth effect maps. The dataset contained two categories: A baseline category 0 (e.g. healthy images) and category with an effect (e.g. patient images). The images in category 1 contained one of two subtypes, A or B, which is unknown to the algorithms. A: box in the lower right, B: box in the upper left.

The control group (label 0) contained images with random iid Gaussian noise convolved with a Gaussian blurring filter. Examples are shown in Fig. 3. The patient images (label 1) also contained the noise, but additionally exhibited one of two disease effects which was generated from a ground-truth effect map: a square in the centre and a square in the lower right (subtype A), or a square in the centre and a square in the upper left (subtype B). Importantly, both disease subtypes shared the same label. The location of the off-centre squares was randomly offset in each direction by a maximum of 5 pixels. This effect was added to make the problem harder, but had no notable effect on the outcome.

Evaluation: We split the data into a 80-20 training and testing set. Moreover, we used 20% of the training set for monitoring the training. Next, we estimated the disease effect maps for all cases from the synthetic patient class using the examined methods.

In order to assess the visual attribution accuracy quantitatively, we calculated the normalised cross correlation (NCC) between the ground-truth label maps and the predicted disease effect maps. The NCC has the advantage that it is not sensitive to the magnitude of the signals. For CAM we used only the positive values to calculate the NCC, while for the backprop-based techniques we used the absolute value, since those techniques do not necessarily predict the correct sign of the changes.

Results: A number of examples of the estimated disease effect maps are shown in Fig. 4. Guided Backpropagation produced similar results to Integrated Gradients. We therefore omitted it in the visual results due to space considerations but provide quantitative results.

Figure 4: Examples of visual attribution on synthetic data obtained using the compared methods.

For the backprop-based methods we consistently observed two behaviours: 1) They tended to focus exclusively on the central square which was always present and was thus the most predictive set of features. This behaviour is consistent with the network compressing away less predictive features discussed earlier [52]. 2) They tended to focus mostly on the edges of the boxes rather than on the whole object. This may have to do with the fact that edges are more salient than other points and, again, are sufficient to predict the presence or absence of the box.

The CAM method managed to capture both squares most of the times, but by design had limited spatial resolution. Note that due to the lower number of max-pooling layers used for the CAM classifier each pixel in the last feature map had a receptive field of only 39x39 pixels. This could mean that many pixels in that feature map could not simultaneously see both of the squares, which may have contributed to the squares being better discerned. However, we did not investigate this further.

Lastly, our proposed VA-GAN method produced the most localised disease effect maps, finding the entire boxes and following the edges closely. It also managed to consistently identify both disease effects.

Method mean std.
Guided Backprop [55] 0.14 0.04
Integrated Gradients [56] 0.36 0.11
CAM [67] 0.48 0.04
Additive Perturbation 0.06 0.03
VA-GAN 0.94 0.07
Table 1: NCC scores for experiments on synthetic data.

The quantitative NCC results shown in Table 1 are mostly consistent with our qualitative observations, with VA-GAN obtaining significantly higher NCC than the other methods. The additive perturbation technique achieved a low score due to its exclusive focus on edges.

4.3 Experiments on real neuroimaging data

In this section, we investigate the methods’ ability to detect the areas of the brain which are involved in the progression from MCI to AD at a subject-specific level. We trained on images from both categories and then generated disease effect maps only for the AD images.

Data: We selected 5778 3D T1-weighted MR images from 1288 subjects with either an MCI (label 0) or AD (label 1) diagnosis from the ADNI cohort. 2839 of the images were acquired using a 1.5T magnet, the remainder using a 3T magnet. The subjects are scanned at regular intervals as part of the ADNI study and a number of subjects converted from MCI to AD over the years. We did not use these correspondences for training, however, we took advantage of it for evaluation as will be described later. An overview of the data is given in the supplemental materials in Section C.

All images were processed using standard operations available in the FSL toolbox [54] in order to reorient and rigidly register the images to MNI space, crop them and correct for field inhomogeneities. We then skull-stripped the images using the ROBEX algorithm [25]. Lastly, we resampled all images to a resolution of 1.3  and normalised them to a range from -1 to 1. The final volumes had a size of 128x160x112 voxels.

Evaluation: We split the data on a subject level into a training, testing and validation set containing 825, 256 and 207 subjects, respectively. We then trained all of the algorithms with both AD and MCI data as described earlier, and generated disease effect maps for the AD subjects from the test set. The validation set was used to monitor the training.

In order to better understand the quality of the generated disease maps we estimated the actual deformations for a number of subjects as follows. We identified all subjects from the test set who were diagnosed with MCI during the baseline examination but progressed to AD in one of the follow-up scans. We then aligned those images rigidly and subtracted them from each other to obtain an observed disease effect map. We excluded all subjects which were not acquired with the same field strength, since a large amount of the observed effects could be due to differences in image quality. This left 50 subjects which we evaluated more closely. We note that even for the same field strength there are a number of artefacts due to intensity variations and registration. Furthermore, there are likely to be effects not caused by the disease, such as ageing (which will also be captured by our method), such that the observed disease effect maps could be considered a ground-truth.

Nevertheless, we also evaluated NCC between the observed and the predicted disease effect maps in the same manner as for the synthetic data.

Results: Fig. 5 shows disease effect maps obtained for a selection of AD subjects (we again omitted Guided Backprop in the figure). The subjects are ordered by increasing progression of the disease as measured by the ADAS13 cognition exam [49]. It can be seen that VA-GAN’s predictions were in very good agreement with the observed effect maps. As is known from the literature [8, 11] the method indicates atrophy in the hippocampi, and general brain atrophy around the ventricles. Furthermore, it is known that in later stages of the disease other brain areas such as the temporal lobe get affected as well [60]. Those effects were also identified by VA-GAN in the last subject in Fig. 5.

Figure 5: Coronal and sagittal views of generated AD effect maps for three subjects and actual observed effects. Maps are shown as coloured overlay over the input image. The ventricular (arrow A) and hippocampal (arrow B) regions are particularly affected by the disease and are reliably captured by VA-GAN. In later stages also other brain regions such as the temporal lobe (arrow C) are affected. We also report the ADAS13 cognitive exam scores (larger means AD is further progressed) and the ADNI identifier (rid) for each subject.
Figure 6: Close-up of the hippocampus region of a subject before (left) and after developing AD (middle). The right panel shows the generated image. The red (hippocampus) and green (ventricles) contours are in the same location in all three images. It can be observed that the map “reverses” some of the atrophy.

The backprop-based methods and additive perturbations were observed to be very noisy and tended to identify only the hippocampal areas. We believe that this is in agreement with the findings on the synthetic data. The hippocampus is known to be the most predictive region for AD, however, it is also known that many other regions are involved in the disease. It is likely, that classifiers learned to focus only on the most discriminative set of features ignoring the rest. Lastly, it is hard to interpret the results produced by CAM due to the low resolution. However, the images suggest that this method focuses on similar areas as the other methods.

Quantitative results are given in Table 2. VA-GAN obtained the highest correlation scores, however, it is hard to draw conclusions from these figures due to the noisy nature of the observed effect maps as well as the possible non-disease related effects on the observed effect maps, which are taken to be “ground-truth” in the experiments.

Method mean std.
Guided Backprop [55] 0.05 0.03
CAM [67] 0.09 0.07
Integrated Gradients [56] 0.13 0.05
Additive Perturbation 0.11 0.05
VA-GAN 0.27 0.15
Table 2: NCC scores for experiments on neuroimaging data.

We observed that VA-GAN generally produced very realistic deformations. In Fig. 6 a close-up of the MCI, AD, and generated image is shown for a sample subject. It can be seen that our method succeeded in making the generated image more similar to the corresponding MCI image and that the changes were realistic.

5 Limitations and discussion

We have proposed a method for visual feature attribution using Wasserstein GANs. It was shown that, in contrast to backprop-based methods, our technique can capture multiple regions affected by disease, and produces state-of-the-art results for the prediction of disease effect maps in neuroimaging data and on a synthetic dataset.

Currently, the method assumes that the category labels of the test data are known during test-time. In case they are unknown, the method could be easily combined with classifier which produces this information. We only evaluated the method for the case of two labels. More categories could be addressed by training multiple map generators each mapping to a background class (assuming there is one).

In the future, we plan to model other effects such as ageing or the presence or absence of certain genes on the ADNI data, investigate the method on other datasets and apply it to other problems such as weakly-supervised localisation.


We gratefully acknowledge the support of NVIDIA Corporation with the donation of a Titan Xp GPU.


Appendix A Network architectures

In this section we describe the exact network architectures used for the 3D VA-GAN. We present the critic and map generator functions as Python-inspired pseudo code, which we found easier to interpret than a graphical representation. The layer parameters are specified as arguments to the layer functions. Unless otherwise specified all convolutional layers used a stride of 1x1x1 and a rectified linear unit (ReLU) non-linearity.

The architecture of the critic function is shown in Fig. 7. The conv3D_layer function performs a regular 3D convolution without batch normalisation and the global_averagepool3D function performs an averaging over the spatial dimensions of the feature maps.

The architecture for the map generator function is shown in Fig. 8. Here, the conv3D_layer_bn is a 3D convolutional layer with batch normalisation before the nonlinearity. The deconv3D_layer_bn learns an upsampling operation as in the original U-Net and also uses batch normalisation. Lastly, the crop_and_concat_layer implements the skip connections across the bottleneck by stacking the feature maps along the dimension of the channels.

Note that the architectures for the 2D experiments on synthetic data were identical, except all 3D operations were replaced by their 2D equivalents.

def critic(x):
    # inputs
    #    x:      an image from category c=0, or an image from category c=1
    #            plus the additive mask M(x)
    # returns
    #    logits: the critic output for x
    conv1_1 = conv3D_layer(x, num_filters=16, kernel_size=(3,3,3))
    pool1 = maxpool3D_layer(conv1_1)
    conv2_1 = conv3D_layer(pool1, num_filters=32, kernel_size=(3,3,3))
    pool2 = maxpool3D_layer(conv2_1)
    conv3_1 = conv3D_layer(pool2, num_filters=64, kernel_size=(3,3,3))
    conv3_2 = conv3D_layer(conv3_1, num_filters=64, kernel_size=(3,3,3))
    pool3 = maxpool3D_layer(conv3_2)
    conv4_1 = conv3D_layer(pool3, num_filters=128, kernel_size=(3,3,3))
    conv4_2 = conv3D_layer(conv4_1, num_filters=128, kernel_size=(3,3,3))
    pool4 = maxpool3D_layer(conv4_2)
    conv5_1 = conv3D_layer(pool4, num_filters=256, kernel_size=(3,3,3))
    conv5_2 = conv3D_layer(conv5_1, num_filters=256, kernel_size=(3,3,3))
    conv5_3 = conv3D_layer(conv5_2, num_filters=256, kernel_size=(3,3,3))
    conv5_4 = conv3D_layer(conv5_3,
    logits = global_averagepool3D(conv5_4)
    return logits
Figure 7: VA-GAN critic architecture.
def map_generator(x):
    # inputs
    #    x: an image from category c=1
    # returns
    #    M: additive map M(x) such that y = x + M(x) appears to be from c=0
    # Encoder:
    conv1_1 = conv3D_layer_bn(x, num_filters=16, kernel_size=(3,3,3))
    conv1_2 = conv3D_layer_bn(conv1_1, num_filters=16, kernel_size=(3,3,3))
    pool1 = maxpool3D_layer(conv1_2)
    conv2_1 = conv3D_layer_bn(pool1, num_filters=32, kernel_size=(3,3,3))
    conv2_2 = conv3D_layer_bn(conv2_1, num_filters=32, kernel_size=(3,3,3))
    pool2 = maxpool3D_layer(conv2_2)
    conv3_1 = conv3D_layer_bn(pool2, num_filters=64, kernel_size=(3,3,3))
    conv3_2 = conv3D_layer_bn(conv3_1 num_filters=64, kernel_size=(3,3,3))
    pool3 = maxpool3D_layer(conv3_2)
    # Bottleneck:
    conv4_1 = conv3D_layer_bn(pool3, num_filters=n128, kernel_size=(3,3,3))
    conv4_2 = conv3D_layer_bn(conv4_1, num_filters=128, kernel_size=(3,3,3))
    # Decoder:
    upconv3 = deconv3D_layer_bn(conv4_2, kernel_size=(4,4,4), strides=(2,2,2), num_filters=64)
    concat3 = crop_and_concat_layer([upconv3, conv3_2])
    conv5_1 = conv3D_layer_bn(concat3, num_filters=64, kernel_size=(3,3,3))
    conv5_2 = conv3D_layer_bn(conv5_1, num_filters=64, kernel_size=(3,3,3))
    upconv2 = deconv3D_layer_bn(conv5_2, kernel_size=(4,4,4), strides=(2,2,2), num_filters=32)
    concat2 = crop_and_concat_layer([upconv2, conv2_2])
    conv6_1 = conv3D_layer_bn(concat2, num_filters=32, kernel_size=(3,3,3))
    conv6_2 = conv3D_layer_bn(conv6_1, num_filters=32, kernel_size=(3,3,3))
    upconv1 = deconv3D_layer_bn(conv6_2, kernel_size=(4,4,4), strides=(2,2,2), num_filters=16)
    concat1 = crop_and_concat_layer([upconv1, conv1_2])
    conv8_1 = conv3D_layer_bn(concat1, num_filters=16, kernel_size=(3,3,3))
    M = conv3D_layer(conv8_1,
    return M
Figure 8: VA-GAN map generator architecture.

Appendix B Close-up analysis of VA-GAN

In Fig. 9 we present a larger view of all three orthogonal planes for an additional subject. In order to allow for an enlarged view, we only include the results obtained by VA-GAN and the actual observed changes from MCI to AD. As before it can be seen that VA-GAN produced visual attribution maps that very closely approximate the observed deformations. In particular, we note that for this subject VA-GAN correctly predicted a smaller disease effect in the left hippocampus compared to the right hippocampus.

Figure 9: Coronal, sagittal and axial views of the predicted and observed disease effect maps for an additional subject. The location of the planes is indicated by dotted white lines in the right column. In order to allow for an enlarged view, only the predictions obtained by VA-GAN are shown. The ADNI rid and the ADAS13 score for this subject are reported on the left-hand side.

Appendix C Details of MR brain data cohort

The MR brain image data used in preparation of this article were obtained from the Alzheimer’s Disease Neuroimaging Initiative (ADNI) database ( As such, the investigators within the ADNI contributed to the design and implementation of ADNI and/or provided data but did not participate in analysis or writing of this report. A complete listing of ADNI investigators can be found at:

Specifically, we used T1-weighted MR data from the ADNI1, ADNIGO and ADNI2 cohorts which were acquired in with a mixture of 1.5T and 3T scanners. The data consisted of 5770 images, acquired from 1291 subjects. The images for each subject were acquired at separate visits that were spaced in regular intervals from 6 months to one year and usually spanned multiple years. On average each subject was scanned 4.5 times. The cohort consisted of 496 female and 795 male subjects. 2839 of the images were acquired using a 1.5T magnet, the remainder using a 3T magnet. The distribution of the ages at which the images were acquired is shown in Fig. 10. We only considered images with a diagnosis of mild cognitive impairment (MCI) or Alzheimer’s disease (AD).

Figure 10: Histogram of the subject age of all ADNI images used in this work. The mean age was

years, with a standard deviation of


After preprocessing we randomly divided the data into a training, testing and validation set. We performed the split on a subject basis rather than an image basis. The exact split is shown in Table 3. The table furthermore shows the distribution over the diagnoses on a image level, and the number of subjects which have undergone a conversion from MCI to AD in the examined time intervals.

The training data was used for learning the mask generator and critic parameters which minimise the cost function in Eq. 4 of the main article. The validation set was used for monitoring of the training based on the Wasserstein distance and visual examination of generated masks, and for hyperparameter tuning. The test set was used for the final qualitative and quantitative evaluation.

Train Test Validation Total
Num. Imag.
   MCI 2520 755 639 3914
   AD 1199 399 266 1864
   Total 3719 1154 905 5778
Num. Subj.
   Converters 172 51 49 272
   Non-converters 653 208 158 1019
   Total 825 259 207 1291
Table 3: Detailed information on data split into training, testing and validation data.

In case of interest, a list of the exact ADNI subject ID’s used in the study can be found in our public code repository ( in the folder data/subject_rids.txt.

Appendix D Alternative classifier architecture

It was suggested during the reviews that our classifier architecture with two dense layers before the final output is responsible for the poor performance of the backpropagation based saliency map techniques. It was recommended that we investigate the popular class of architectures where the final convolutions are aggregated using a global average pooling step over the spatial dimensions of the activation maps, followed by a single dense layer. Examples of this type of architecture include the works of He at al. [23] and Lin et al. [35]. In our experiments, the class activation mappings (CAM) method [67] was also using this general architecture. In theory this may abstract the data less before the final output and perhaps produce maps that can more easily identify multiple regions in the image.

Figure 11: Saliency maps obtained using simple backpropagation, guided backpropagation and integrated gradients for two different network architectures: (1) the original architecture from the synthetic experiments (Section 4.2) in the main article, (2) an alternative architecture with a global average pooling layer followed by a single dense layer before the final classification output. The white arrows in the second row highlight very faint attributions of the second box.

To investigate this theory we repeated the synthetic experiment (outlined in Section 4.2 of the main article), but replaced the final two dense layers in our synthetic experiments by a global average pooling and a single dense layer. After full convergence of the network from the main article and the alternative architecture, we obtained the saliency maps shown in Fig. 11. In addition to the integrated gradients method [56] already shown in the main article, here we also show the results for normal backprop [53] and guided backprop [55]. It can be observed that indeed, with the alternative architecture, normal and guided backprop manage to correctly attribute some of the pixels of the peripheral box, albeit very faintly (emphasised with white arrows in Fig 11). However, regardless of the architecture the classifier appears to focus only on the pixels of one of the edges, which is only subset of the features characterising this class. Note that the orientation of the attributed edges depends on the random initialisation of the network.

Nevertheless, the feature attribution maps obtained using the backprop-based techniques are not of comparable quality to the maps produced by our proposed VA-GAN method. For emphasis we show the corresponding feature attribution map produced with VA-GAN plus two more samples in Fig. 12.

Figure 12: Visual feature attribution maps obtained using our proposed VA-GAN method. The first sample corresponds to the input image in Fig. 11. The other two images correspond to other random input images.

To conclude, we would like to note that from the point of view of saliency maps, (1) two dense layers or (1) average pooling followed by a dense layer, are conceptually similar. In both cases the final prediction aggregates information from multiple receptive fields covering the whole image. Therefore, it is not surprising that the two networks behave similarly. As outlined in the work of Shwartz-Ziv et al. [52] the optimisation of neural network classifiers results in a trade off between compression of input features and predictive accuracy. In both networks, the final prediction has access to all features in the image and thus has the potential to compress away features that are redundant for classification (such as one of the two boxes).