Born Identity Network: Multi-way Counterfactual Map Generation to Explain a Classifier's Decision

11/20/2020 ∙ by Kwanseok Oh, et al. ∙ 7

There exists an apparent negative correlation between performance and interpretability of deep learning models. In an effort to reduce this negative correlation, we propose Born Identity Network (BIN), which is a post-hoc approach for producing multi-way counterfactual maps. A counterfactual map transforms an input sample to be classified as a target label, which is similar to how humans process knowledge through counterfactual thinking. Thus, producing a better counterfactual map may be a step towards explanation at the level of human knowledge. For example, a counterfactual map can localize hypothetical abnormalities from a normal brain image that may cause it to be diagnosed with a disease. Specifically, our proposed BIN consists of two core components: Counterfactual Map Generator and Target Attribution Network. The Counterfactual Map Generator is a variation of conditional GAN which can synthesize a counterfactual map conditioned on an arbitrary target label. The Target Attribution Network works in a complementary manner to enforce target label attribution to the synthesized map. We have validated our proposed BIN in qualitative, quantitative analysis on MNIST, 3D Shapes, and ADNI datasets, and show the comprehensibility and fidelity of our method from various ablation studies.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 6

page 7

page 15

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

As deep learning has shown its success in various domains, there has been a growing need for interpretability and explainability in deep learning models. The black-box nature of deep learning models limits their real-world applications in fields, especially, where fairness, accountability, and transparency are essential. Moreover, from the end-user point of view, it is a crucial process that requires a clear understanding and explanation at the level of human knowledge. However, achieving high performance with interpretability is still an unsolved problem in the field of explainable AI (XAI) due to their apparent negative correlation [15] (i.e. interpretable models tend to have a lower performance than black-box models).

Figure 1: We propose an approach for producing counterfactual maps of a) digit images and b) brain images as a step towards counterfactual reasoning, which is a process of producing hypothetical realities given observations. For example, a counterfactual map can localize hypothetical, yet very possible, abnormalities from a brain image of a normal subject that may cause the subject to be diagnosed with a disease.
Figure 2: Schematic overview of Born Identity Network (BIN). There are two major components of BIN: Counterfactual Map Generator and Target Attribution Network. The Counterfactual Map Generator synthesized a counterfactual map conditioned on arbitrary target label, while the Target Attribution Network work towards enforcing target label attributes to the synthesized map.

Reducing this negative correlation in performance and interpretability has been a long-standing goal in the field of XAI. For example, in the early era of XAI [11], researchers have proposed various methods for discovering or identifying the regions that have the most influence on deriving the outcome of the classifier [26, 29, 30, 1, 23, 27, 28, 34]. The main objective of these early era XAI methods is to answer why and how a model has made its decision. However, recent XAI methods induce to answer the question that can offer a more fundamental explanation: “What would have caused the model to make a different decision?”. This sort of explanation is defined at the root of counterfactual reasoning. Counterfactual reasoning can provide an explanation at the level of human knowledge since it can explain a model’s decision in hypothetical situations, which is essentially similar to how humans process knowledge. Thus, the motivation of this paper is to demonstrate a higher-level visual explanation of a deep learning model similar to that of how humans process knowledge, i.e., through the means of a counterfactual map.

A counterfactual map is a map that can transform an input sample that was originally classified as one label to be classified as another. For example, a counterfactual map transforms a digit image to be an image of another number (Figure 1-(a)). This counterfactual map explains what kind of structural changes are required for a digit image to be another number. A real-world application can be of a medical image analysis of the brain (Figure 1-(b)), where a counterfactual map would describe which Region Of Interests (ROIs) may cause a normal subject to be diagnosed with a disease.

To our knowledge, most works on producing counterfactual explanations are generative models. Most notably, these works utilize Generative Adversarial Network (GAN) and its variants 

[12, 6, 31, 8]

to synthesize a counterfactual explanation. Although these works can generate meaningful counterfactual explanations, two fundamental problems limit their application in the real-world. First, generative models are a typical example of models that experience the aforementioned negative correlation between performance and interpretability. To resolve this issue, we propose a method that can produce counterfactual maps from an already trained model. This post-hoc nature of our work not only provides a generalized framework that can be applied to most neural networks but also reduces the negative correlation since it can be applied to a model that already has a high performance (

i.e., we can focus solely on interpretability with the higher performance given beforehand). Second, recent works on counterfactual explanation can only produce a single [8, 12, 6] or dual [13, 32] sided explanation. In other words, they only consider one or two hypothetical scenarios for counterfactual reasoning (e.g., producing maps that can only transform a digit image to be classified as one or two specific numbers). With the help of a target attribution mechanism, our work, to the best of our knowledge, is the first to propose a multi-way counterfactual reasoning (e.g., producing maps that can transform a digit image to be classified as any other number).

To this end, we propose a Born Identity Network (BIN111We have coined this word to emphasize that the produced counterfactual map is a result of its original input, i.e., its identity. Also, the post-hoc nature of our work fixes the weights, i.e., the identity, of the deep learning model throughout the learning process.) that produces a counterfactual map using two components: The Counterfactual Map Generator, and the Target Attribution Network. The Counterfactual Map Generator is a variant of conditional GAN [22] that synthesizes a conditioned map, while the Target Attribution Network works in a complementary manner with Counterfactual Map Generator in enforcing a target counterfactual attribute to the synthesized conditioned-map.

To evaluate our proposed framework, we perform a suite of analyses in various data domains: Handwritten digit, geometric, and medical. First, to qualitatively evaluate our work, we demonstrate and analyze the counterfactual maps for these above datasets. Second, to quantitatively validate the counterfactual map, we further calculate a correlation score between the counterfactual map and its ground truth map. Finally, we examine how each component of BIN works toward creating a counterfactual map with a suite of exhaustive ablation studies. Thus, the main contributions of our study are as follows:

  • We propose BIN, which, to the best of our knowledge, is the first work on producing counterfactual reasoning in multiple hypothetical scenarios.

  • Our work produces a counterfactual map in a post-hoc manner, which can reduce the apparent negative correlation between performance and interpretability. This post-hoc nature makes our proposed network a generalized interpretation framework that can produce a counterfactual map from most pre-trained models.

2 Related Works

In this section, we describe various works proposed for XAI. First, we briefly categorize XAI into a general framework of attribution-based explanation and a more recent framework of counterfactual explanation.

2.1 Attribution-based Explanations

Attribution-based explanation refers to discovering or identifying the regions that have the most influence on deriving the outcome of a model [11]. These approaches can further be separated into a gradient-based and reference-based explanation. First, the gradient-based explanation highlights the activation nodes that most contributed to the model’s decision. For example, Class Activation Map (CAM) [35], and Grad-CAM [26] highlight activation patterns of weights in a specified layer. In a similar manner, DeepTaylor [23], DeepLIFT [27], and Layer-wise Relevance Propagation (LRP) [1]

highlight the gradients with regard to the prediction score. These approaches usually suffer from vanishing gradients due to the ReLU activation, and Integrated Gradients 

[30] resolves this issue through sensitivity analysis. However, a crucial drawback of attribution-based approaches is that they tend to ignore features with relatively low discriminative power or highlight only those with overwhelming feature importance. Second, the reference-based explanation [34, 10, 6, 9] focuses on changes in model output with regards to perturbation in input samples. Various perturbation methods, such as masking [7]

, heuristic 

[6] (e.g., blurring, random noise), region of the distractor image as reference for perturbation [13], or synthesized perturbation [9, 34, 10], has been proposed. One general drawback of these aforementioned attribution-based explanations is that they tend to produce low-resolution and blurred salient map. In contrast, our work can produce a crisp salient map with the same resolution as the input sample using GAN.

2.2 Counterfactual Explanations

Recently, more researchers have focused on counterfactual reasoning as a form of higher-level explanation. Counterfactual explanation refers to analyzing the model output with regards to hypothetical scenarios. For example, a counterfactual explanation could highlight those regions that may (hypothetically) cause a normal subject to be diagnosed with a disease. Visual Feature Attribution GAN (VA-GAN) [4] uses a variant of GANs that synthesizes a counterfactual map that transforms an input sample to be classified as another label. However, VA-GAN can only perform one-way synthesis (e.g., the map that transforms input originally classified as A to be classified as B, but not vice versa), which limits the counterfactual explanation to also become single-sided. ICAM [2] was proposed as an extension to VA-GAN to produce a dual-way counterfactual explanation. However, most real-world applications cannot be explained through single-/dual-sided explanation due to their complexity. Thus, our work proposes an approach for producing the multi-way counterfactual explanation (e.g., counterfactual map that can transform a digit image to be classified as any other number).

Figure 3: A detailed view of the Counterfactual Map Generator. A tiled one-hot target label is concatenated to the skip connection. This enables the generator to regulate the counterfactual map to be conditioned on an arbitrary target condition.

3 Multi-way Counterfactual Map

Here, we formally define the multi-way counterfactual map. First, we define a data set , where denote the data, label pair and denote the unpaired data and label. We define a counterfactual map as a salient map that is able to produce a counterfactual explanation of a classifier . Formally, a single- or dual-way counterfactual map is a map that when added to an input sample , i.e., , the classifier classifies it as another label complement to its original label, e.g., and , and vice versa for dual-way counterfactual map. Here, if we are able to condition this counterfactual map with an attribute (or label), i.e., the multi-way counterfactual map , we can transform an input sample as any other attribute of interest:

(1)

where data and label is unpaired, i.e., .

4 Born Identity Network

The goal of BIN is to induce counterfactual reasoning dependent on the target condition from a pre-trained model. To achieve this goal, we’ve devised BIN with two core modules, i.e., Counterfactual Map Generator and Target Attribution Network, which work in a complementary manner in producing the multi-way counterfactual map. Specifically, the Counterfactual Map Generator synthesizes a conditioned map, while the Target Attribution Network enforces a target attribution to the synthesized map (Figure 2.). In addition, we released our proposed BIN codes on Github222https://github.com/ksoh97/BIN.

4.1 Counterfactual Map Generator

The Counterfactual Map Generator is a variant of Conditional GAN [22] that can synthesize a counterfactual map conditioned on a target label , i.e., . Specifically, it consists of an encoder , a generator , and a discriminator . First, the network design of the encoder and the generator is a variation of U-Net [25] with tiled target label concatenated to the skip connections (Figure 3.). This generator design enables the generation to synthesize target conditioned maps such that multi-way counterfactual reasoning is possible. As a result, the counterfactual map is formulated as following:

(2)
(3)

Finally, following a general adversarial learning scheme, the discriminator discriminates real samples from synthesized samples .

Counterfactual reasoning requires a good balance between the proposed hypothetical and given reality. In the following sections, we define the loss functions that guide the Counterfactual Map Generator to produce a well balanced counterfactual map.

4.1.1 Adversarial Loss

For the adversarial loss functions, we have adopted the Least Square GAN (LSGAN) [21] objective function due to its stability during adversarial training. More specifically, LSGAN objective function contributes to a stable model training by penalizing samples far from the discriminator’s decision boundary. This objective function is an important choice for BIN since the generated counterfactual maps should neither destroy the input sample nor ignore the target attribution, i.e., it should contain a good balance between real and fake samples. To this end, the discriminator and the generator loss is defined as follows, respectively:

(4)
(5)

4.1.2 Cycle Consistency Loss

The cycle consistency loss is used for producing better multi-way counterfactual maps. In the scenario of a one-way counterfactual map generator, real samples and synthesized samples are always sampled from one same specific class. This setting does not require a cycle consistency loss since the real and fake samples always have specified labels. However, since our discriminator only classifies the real or fake samples, it does not have the ability to guide the generator to produce multi-way counterfactual maps. Thus, we add a cycle consistency loss where the forward cycle produces a map with an arbitrary condition, i.e., generated from an unpaired data and label , and the backward cycle produces a map conditioned on the input sample’s label, i.e., generated from paired data and label :

(6)

where .

4.2 Target Attribution Network

The Target Attribution Network works in a complementary manner with the Counterfactual Map Generator in enforcing attribution to a synthesized counterfactual map. Specifically, the objective of the Target Attribution Network is to guide the generator to produce counterfactual maps that transform an input sample to be classified as a target class:

(7)

where denote the target label, , and denote the cross-entropy function.

Conceptually, the role of the Target Attribution Network is similar to that of a discriminator in GANs, but their objective is very different. While a discriminator learns to distinguish between real and fake samples, the Target Attribution Network is pre-trained in classifying the input samples. Thus, the discriminator plays a min-max game with a generator in an effort to produce more realistic samples, while the Target Attribution Network provides a deterministic guidance for the generator to produce class-specific samples.

4.2.1 Counterfactual Map Loss

The counterfactual map loss limits the values of the counterfactual map to grow:

(8)

where and

are weighting constants as the hyperparameter. This loss is crucial in solving two issues in counterfactual map generation. First, when left untethered, the counterfactual map will destroy the identity of the input sample. This problem is related to adversarial attacks 

[33], where a simple perturbation to the input sample will change the model’s decision. Second, from an end-user point of view, only the most important features should be represented as the counterfactual explanation. Thus, by placing a constraint on the magnitude of counterfactual maps, we can address the above issues in one step.

Figure 4: Examples of counterfactual maps for 3D Shapes [5] dataset. The counterfactual map consists of the RGB channel. Two models were trained separately for producing counterfactual maps for orangered transformation and orangecyan transformation.

4.3 Learning

Finally, we define the overall loss function for BIN as follows:

(9)

where is the hyperparameter of the model.

During training, we share and fix the weights of the encoder of Counterfactual Map Generator with the Target Attribution Network to ensure that the attribution is consistent throughout the generative process. However, preliminary experiments on training the encoder from scratch resulted in slightly lower qualitative and quantitative results.

5 Experiments

We conducted various experiments to validate the counterfactual maps generated by our proposed model. First, we conducted a qualitative analysis of dual- and multi-way counterfactual explanation. Second, we reported a quantitative analysis using a correlation measure between generated counterfactual maps and ground-truth maps. Finally, we performed a suite of ablation studies to verify that each component of BIN towards creating a meaningful counterfactual map.

5.1 Datasets

For a comprehensive analysis of our work, we selected three data domains to evaluate our experiments: 3D Shapes dataset [5] is used for comparing dual-way counterfactual explanations; MNIST [20] is used for multi-way counterfactual explanations; ADNI (Alzheimer’s Disease Neuroimaging Initiative) [24] dataset is used as a real-world application in the medical domain.

3D Shapes

The 3D Shapes dataset [5] consists of 480,000 RGB images of 3D geometric shapes with six latent factors (10 floors/wall/object hues, 8 scales, 4 shapes, and 15 orientations). For our experiments, we selected orange, red, and cyan object hues as the target classes. Thus, although we used a total number of 144,000 images (48,000 per object hue), we only used images from two object hues for each experiment (i.e., 96,000 images) since this dataset is intended for dual-way experiments. We randomly divided the dataset into a train, validation, and test set at a ratio of 8:1:1 and applied channel-wise min-max normalization.

Mnist

For MNIST dataset, we use the data split provided by [20] and applied min-max normalization.

Figure 5:

Examples of counterfactual maps for MNIST dataset. The resulting confound image is an addition between an input and its corresponding counterfactual map conditioned on a target label. The counterfactual maps maintain the style of the input samples while transforming them to target labels. The values on the top of each confound map are the model’s softmax activated logit.

Components FID score
0 1 2 3 4 5 6 7 8 9 avg
1.340 1.082 1.034 0.935 0.945 0.933 0.999 0.850 0.961 0.850 0.993
1.338 1.241 1.039 0.955 1.064 0.973 1.022 0.956 0.993 0.935 1.052
1.283 1.156 0.915 0.810 0.959 0.822 0.987 0.888 0.896 0.834 0.955
1.318 1.056 0.906 0.804 0.951 0.824 1.001 0.841 0.904 0.841 0.945
1.338 1.242 1.039 0.956 1.065 0.973 1.022 0.956 0.993 0.935 1.052
1.338 1.242 1.039 0.956 1.065 0.973 1.022 0.956 0.993 0.935 1.052
1.211 0.902 0.892 0.802 0.958 0.812 0.970 0.835 0.912 0.833 0.912
Table 1: Fréchet Inception Distance (FID) scores reported for ablation studies for MNIST dataset.
Adni

ADNI dataset consists of 3D Magnetic Resonance Imaging (MRI) of various subject groups ranging from cognitive normal to Alzheimer’s disease. For this dataset, we have selected a baseline MRI of 425 Cognitive Normal (CN) subjects and 765 Alzheimer’s Disease (AD) subjects in ADNI 1/2/3/Go studies. Furthermore, for longitudinal studies, we have selected 20 CN test subjects that have converted to the AD group in any given time. With the exception of 20 longitudinal test subjects, all subjects were randomly split into train, validation, and test sets at a ratio of 8:1:1. The pre-processing procedure consisted of neck removal (FSL v6.0.1 robustfov), brain extraction (HDBet 

[17]

), linear registration (FSL v6.0.1 FLIRT), zero-mean unit-variance normalization, quantile normalization at 5% and 95%, and down-scaling by

. The resulting pre-processed MRI is a image. We used default parameters from FSL v6.0.1 [18].

5.2 Implementation

We utilized an architecture of Kim et al[19] for 3D Shapes dataset and minor modification of this network for MNIST. SonoNet-16 [3] is used for ADNI dataset as the encoder and the discriminator . The generator has the same network design as the encoder with pooling layers replaced by up-sampling layers. The test accuracy of the pre-trained Target Attribution Network is 99.56% for MNIST, 99.89% for 3D Shapes, and 89.29% for ADNI dataset. Further implementation details are in the Supplementary Chapter 1.

Figure 6: Examples of counterfactual maps for ADNI dataset. Purple box visualizes ventricular regions (top row), green box visualizes cortex regions (second row), and orange box visualizes hippocampus regions (bottom two rows).

5.3 Qualitative Analysis

An N-way counterfactual map is a map that when added to an input sample, the classifier classifies it as a targeted class. The N-way

term refers to the degree of freedom of the targeted class. For example, a dual-way counterfactual map can transform a red cylinder to be classified as an orange or an orange cylinder to be classified as red (Figure

4.). In the following sections, we compared and validated multi-way counterfactual maps in various settings. Specifically, we took a bottom-up analysis approach with experiments on a toy example building up to application in a real-world dataset. First, we analyzed the dual-way counterfactual maps with a toy example of 3D Shapes dataset. Then, we validated our novel multi-way counterfactual maps with MNIST dataset. Finally, we applied our proposed BIN in a real-world dataset and compare it with state-of-the-art XAI frameworks with ADNI dataset.

5.3.1 Validation of Counterfactual Map

First, we analyzed dual-way counterfactual maps using a toy example of 3D Shapes dataset (Figure 4.). The goal of this experiment is to verify that our counterfactual map is able to transform an input image with regard to the target latent factor, i.e., object hue. The first and second row show a counterfactual map for transforming red (R: 255, G: 0, B: 0) to orange (R: 255, G: 150, B: 0), and orange to red. Transforming red to orange is a relatively simpler task since it only requires a transformation in the G channel. As a result, R and B channels of our counterfactual maps are highlighted relatively less than the G channel. The third row transforms orange (R: 255, G: 150, B: 0) to cyan (R: 0, G: 210, B: 255). Despite the contrasting color in all channels, BIN was able to induce the counterfactual map corresponding to the target object hue. In addition, the classification performance on the synthesized image is 99.51%, indicating that the counterfactual map has successfully transformed the input image to be classified as another label.

In most counterfactual maps generated in this experiment, we have observed an interesting phenomenon in which the counterfactual maps are relatively invariant to latent factors that were not targeted for transformation. For example, the scale and shape of the object in the counterfactual map are relatively similar to those of the input image. This indicates that our counterfactual map is somewhat localized to the target latent factor.

5.3.2 Invariant of Counterfactual Map

For this experiment, we generated multi-way counterfactual maps for MNIST dataset (Figure 5.). In plain sight, we can observe that the style of an input image is maintained. For example, with confound images of 1 (second row), which are one of the least realistic confound images, we can observe that the transformation maintains the style of the input image. This is related to the invariant to non-targeted latent factors observed in section 5.3.1. We hypothesize that this invariant is due to the counterfactual map loss (eq. (8)) since it puts constraints on the values of the counterfactual map so that it transforms an input sample with the least amount of energy. From a conceptual view, our generated counterfactual map exhibited a good example of counterfactual reasoning since it can successfully produce hypothetical realities with regards to a given input sample.

5.3.3 Application in Medical Domain

The first part of this section compared BIN with various XAI frameworks (Figure 6.). Since there are no ground truth maps for ADNI dataset, we utilized the Normal target map from longitudinal test subjects. First, we have gathered MRI from subjects that were originally a part of CN group but converted to AD group in any given time. Then, we subtracted the baseline image from the image at the time of AD conversion to create the Normal target map. This Normal target map exhibited a good representation of a ground truth map of disease localization since we can observe which regions are possibly responsible for AD conversion. For attribution-based approaches (Integrated gradients [30], LRP-Z [1], DeepLIFT [27]) and perturbation-based approach (Guided back-propagation [26]), we used the pre-trained Target Attribution Network as the classifier.

One of the biomarkers of Alzheimer’s disease is the atrophy in brain regions, such as the ventricular, cortical thickness, and hippocampus [16]. As observed in the Normal target map (second column in Figure 6

), atrophy is clearly visible in those regions. The counterfactual map generated by attribution-based approaches (Integrated gradients, LRP-Z, DeepLIFT) and perturbation-based approach (Guided back-propagation) does not clearly showed the regions responsible for AD conversion. Also, comparing our proposed BIN with VA-GAN, our work clearly showed a better representation of the Normal target map. In an extension to the above experiment, we have performed an interpolation between AD and CN by conditioning the counterfactual map with interpolated target classes (in the Supplementary Chapter 2).

Method 3D Shapes ADNI
NCC(+) NCC(-) NCC(+) NCC(-)
LRP-Z [1] 0.008 0.086 0.004 0.005
Integrated Gradients [30] 0.006 0.152 0.006 0.004
DeepLIFT [27] 0.007 0.183 0.005 0.003
Guided Backprop [26] 0.183 0.123 0.225 0.198
VA-GAN [4] 0.381 N/A 0.282 N/A
BIN (ours) 0.516 0.465 0.306 0.185
Table 2: Normalized Cross-Correlation (NCC) scores for 3D Shapes and ADNI dataset.

5.4 Quantitative Analysis

In this section, we quantitatively evaluated our proposed BIN and compared it with the outcome of other methods. To quantitatively assess the quality of our generated counterfactual maps, we’ve calculated the Normalized Cross-Correlation (NCC) score between the generated maps and the ground-truth maps. NCC score measures the similarity between two samples in a normalized setting. Thus, NCC can be helpful when two samples have a different magnitude of signals. However, since MNIST dataset does not have ground truth maps, we have performed a different quantitative evaluation in section 5.5. For the ADNI dataset, we used the Normal target map described in section 5.3.3 as the ground truth map. NCC(+) refers to the counterfactual map for transforming red cylinder to orange cylinder for 3D Shapes, and AD to CN for ADNI dataset (vice versa for NCC(-)). Higher NCC scores denote higher similarity (i.e., better performance).

In Table 2, we reported NCC scores for 3D Shapes and ADNI dataset. Attribution-based methods (LRP-Z, Integrated gradients, and DeepLIFT) tend to have lower scores since lower discriminative features can be ignored in domains with high complexity, which is evident in Figure 6. Counterfactual maps for VA-GAN can only transform input from AD to CN, thus NCC(-) scores cannot be calculated. Our proposed BIN had outstanding NCC(+) scores indicating that the addition operation is a stronger suit than the subtraction operation. A further discussion on why the NCC(-) is lower in section 6.

Removed 3D Shapes ADNI
Components NCC(+) NCC(-) NCC(+) NCC(-)
0.271 0.259 0.211 0.088
0.199 0.126 0.065 0.063
0.482 0.336 0.253 0.125
0.238 0.159 0.248 0.158
0.089 0.068 0.088 0.082
All above 0.033 0.064 0.076 0.053
BIN (ours) 0.516 0.465 0.306 0.185
Table 3: Normalized Cross-Correlation (NCC) scores for ablation studies.

5.5 Ablation Studies

In this section, we conducted a suite of ablation studies to assess each component of BIN in creating a counterfactual map. Specifically, we focused on ablation studies on the conditioned generator (denoted as ), the Target Attribution Network loss (), cycle consistency loss (), and the Counterfactual Map loss ().

Fréchet Inception Distance

For a quantitative assessment of multi-way counterfactual map, we used a Fréchet Inception Distance [14] commonly used for assessing the generative performance of GANs. Here, we’ve selected 4,000 test samples for counterfactual map generation (i.e. fake images), and 4,000 test samples as real images. Since a counterfactual map can transform any number to any other number, a single image can generate 10 images (one per number). Thus, the total number of fake samples we’ve generated is 40,000 (i.e., 4,000 fake images per number). To best of our knowledge, our proposed BIN is the first to perform multi-way counterfactual map generation. Thus, we compared our work in an ablation study setting (Table  1). In most of the settings, our proposed BIN performed significantly better.

Normalized Cross Correlation

For the ADNI and 3D Shapes dataset, we calculated NCC scores with the same settings as section 5.4 (Table  3). The reported scores indicate every component of BIN was crucial for producing a meaningful counterfactual map. However, ablating the Target Attribution Network loss showed a significant drop in NCC scores, indicating that it is one of the most crucial components of BIN. The Target Attribution Network guided the generative process to build targeted class attribution, which works in a similar manner to a discriminator in GANs. The cycle consistency loss ensured the counterfactual map to contain information on its identity, i.e., input sample. In a preliminary experiment, we have observed that cycle consistency loss helps generate more crisp counterfactual maps, indicating that it works as a regularizer for BIN. The counterfactual map loss was another important component of BIN since it regulates which region is most important.

6 Discussion

For most experiments, reported NCC(-) scores for almost every compared method were lower than that of the NCC(+). A possible explanation is that the activation functions, such as ReLU, may be imposing a positive bias or negative skewness on the model output. For example, the mean and mode of a counterfactual map for almost every datasets and every method compared in this paper were slightly over zero. This may result in imposing a constraint for subtraction operations which lowers the NCC(-) scores. Possible future work for the community is to verify whether this bias or skewness really exists and propose a way to alleviate this.

7 Conclusion

In this work, we proposed a Born Identity Network (BIN), which is a post-hoc approach for producing multi-way counterfactual map. We demonstrated that our method can be easily applied to various networks. Fidelity of performance is secured by visualizing a causal relationship for prediction. That is, we showed that counterfactual maps can provide end-users with an intuitive explanation of classification outcomes.

8 Acknowledge

This work was supported by the Institute of Information & Communications Technology Planning & Evaluation (IITP) grant funded by the Korea Government (MSIT) (No. 2017-0-01779, A machine learning and statistical inference framework for explainable artificial intelligence). This work was also supported by the Institute of Information & Communications Technology Planning & Evaluation (IITP) grant funded by the Korea Government (MSIT) (No. 2019-0-00079, Department of Artificial Intelligence (Korea University)).

References

  • [1] S. Bach, A. Binder, G. Montavon, F. Klauschen, K. Müller, and W. Samek (2015) On pixel-wise explanations for non-linear classifier decisions by layer-wise relevance propagation. Public Library of Science ONE 10 (7), pp. e0130140. Cited by: §1, §2.1, §5.3.3, Table 2.
  • [2] C. Bass, M. da Silva, C. Sudre, P. Tudosiu, S. Smith, and E. Robinson (2020) ICAM: interpretable classification via disentangled representations and feature attribution mapping. In Advances in Neural Information Processing Systems, Vol. 33. Cited by: §2.2.
  • [3] C. F. Baumgartner, K. Kamnitsas, J. Matthew, T. P. Fletcher, S. Smith, L. M. Koch, B. Kainz, and D. Rueckert (2017) SonoNet: real-time detection and localisation of fetal standard scan planes in freehand ultrasound. IEEE Transactions on Medical Imaging 36 (11), pp. 2204–2215. Cited by: Chapter 1: Network Architecture, §5.2.
  • [4] C. F. Baumgartner, L. M. Koch, K. Can Tezcan, J. Xi Ang, and E. Konukoglu (2018) Visual feature attribution using wasserstein gans. In

    Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition

    ,
    pp. 8309–8319. Cited by: §2.2, Table 2.
  • [5] C. Burgess and H. Kim (2018) 3D shapes dataset. Note: https://github.com/deepmind/3dshapes-dataset/ Cited by: Figure 4, §5.1, §5.1.
  • [6] C. Chang, E. Creager, A. Goldenberg, and D. Duvenaud (2018) Explaining image classifiers by counterfactual generation. arXiv preprint arXiv:1807.08024. Cited by: §1, §2.1.
  • [7] P. Dabkowski and Y. Gal (2017) Real time image saliency for black box classifiers. In Advances in Neural Information Processing Systems, pp. 6967–6976. Cited by: §2.1.
  • [8] S. Dash and A. Sharma (2020) Counterfactual generation and fairness evaluation using adversarially learned inference. arXiv preprint arXiv:2009.08270. Cited by: §1.
  • [9] A. Dhurandhar, P. Chen, R. Luss, C. Tu, P. Ting, K. Shanmugam, and P. Das (2018) Explanations based on the missing: towards contrastive explanations with pertinent negatives. In Advances in Neural Information Processing Systems, pp. 592–603. Cited by: §2.1.
  • [10] R. C. Fong and A. Vedaldi (2017) Interpretable explanations of black boxes by meaningful perturbation. In Proceedings of the IEEE International Conference on Computer Vision, pp. 3429–3437. Cited by: §2.1.
  • [11] L. H. Gilpin, D. Bau, B. Z. Yuan, A. Bajwa, M. Specter, and L. Kagal (2018) Explaining explanations: an overview of interpretability of machine learning. In

    Proceedings of the International Conference on Data Science and Advanced Analytics

    ,
    pp. 80–89. Cited by: §1, §2.1.
  • [12] Y. Goyal, A. Feder, U. Shalit, and B. Kim (2019) Explaining classifiers with causal concept effect (cace). arXiv preprint arXiv:1907.07165. Cited by: §1.
  • [13] Y. Goyal, Z. Wu, J. Ernst, D. Batra, D. Parikh, and S. Lee (2019) Counterfactual visual explanations. arXiv preprint arXiv:1904.07451. Cited by: §1, §2.1.
  • [14] M. Heusel, H. Ramsauer, T. Unterthiner, B. Nessler, and S. Hochreiter (2017) Gans trained by a two time-scale update rule converge to a local nash equilibrium. In Advances in neural information processing systems, pp. 6626–6637. Cited by: §5.5.
  • [15] S. Hooker, D. Erhan, P. Kindermans, and B. Kim (2019) A benchmark for interpretability methods in deep neural networks. In Advances in Neural Information Processing Systems, pp. 9737–9748. Cited by: §1.
  • [16] K. Iqbal, M. Flory, S. Khatoon, H. Soininen, T. Pirttila, M. Lehtovirta, I. Alafuzoff, K. Blennow, N. Andreasen, E. Vanmechelen, et al. (2005) Subgroups of alzheimer’s disease based on cerebrospinal fluid molecular markers. Annals of Neurology: Official Journal of the American Neurological Association and the Child Neurology Society 58 (5), pp. 748–757. Cited by: §5.3.3.
  • [17] F. Isensee, M. Schell, I. Pflueger, G. Brugnara, D. Bonekamp, U. Neuberger, A. Wick, H. Schlemmer, S. Heiland, W. Wick, M. Bendszus, K. H. Maier-Hein, and P. Kickingereder (2019) Automated brain extraction of multisequence mri using artificial neural networks. Human Brain Mapping 40 (17), pp. 4952–4964. External Links: Document, Link, https://onlinelibrary.wiley.com/doi/pdf/10.1002/hbm.24750 Cited by: §5.1.
  • [18] M. Jenkinson, C. F. Beckmann, T. E.J. Behrens, M. W. Woolrich, and S. M. Smith (2012) FSL. NeuroImage 62 (2), pp. 782 – 790. Note: 20 YEARS OF fMRI External Links: ISSN 1053-8119, Document, Link Cited by: §5.1.
  • [19] H. Kim and A. Mnih (2018) Disentangling by factorising. arXiv preprint arXiv:1802.05983. Cited by: Chapter 1: Network Architecture, §5.2.
  • [20] Y. LeCun (1998) The mnist database of handwritten digits. http://yann. lecun. com/exdb/mnist/. Cited by: §5.1, §5.1.
  • [21] X. Mao, Q. Li, H. Xie, R. Y. Lau, Z. Wang, and S. Paul Smolley (2017) Least squares generative adversarial networks. In Proceedings of the IEEE International Conference on Computer Vision, pp. 2794–2802. Cited by: §4.1.1.
  • [22] M. Mirza and S. Osindero (2014) Conditional generative adversarial nets. arXiv preprint arXiv:1411.1784. Cited by: §1, §4.1.
  • [23] G. Montavon, S. Lapuschkin, A. Binder, W. Samek, and K. Müller (2017) Explaining nonlinear classification decisions with deep taylor decomposition. Pattern Recognition 65, pp. 211–222. Cited by: §1, §2.1.
  • [24] S. G. Mueller, M. W. Weiner, L. J. Thal, R. C. Petersen, C. Jack, W. Jagust, J. Q. Trojanowski, A. W. Toga, and L. Beckett (2005) The alzheimer’s disease neuroimaging initiative. Neuroimaging Clinics of North America 15 (4), pp. 869 – 877. Note: Alzheimer’s Disease: 100 Years of Progress External Links: ISSN 1052-5149, Document, Link Cited by: §5.1.
  • [25] O. Ronneberger, P. Fischer, and T. Brox (2015) U-net: convolutional networks for biomedical image segmentation. In Proceedings of the International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 234–241. Cited by: §4.1.
  • [26] R. R. Selvaraju, M. Cogswell, A. Das, R. Vedantam, D. Parikh, and D. Batra (2017) Grad-cam: visual explanations from deep networks via gradient-based localization. In Proceedings of the IEEE International Conference on Computer Vision, pp. 618–626. Cited by: §1, §2.1, §5.3.3, Table 2.
  • [27] A. Shrikumar, P. Greenside, and A. Kundaje (2017) Learning important features through propagating activation differences. arXiv preprint arXiv:1704.02685. Cited by: §1, §2.1, §5.3.3, Table 2.
  • [28] K. Simonyan, A. Vedaldi, and A. Zisserman (2013) Deep inside convolutional networks: visualising image classification models and saliency maps. arXiv preprint arXiv:1312.6034. Cited by: §1.
  • [29] D. Smilkov, N. Thorat, B. Kim, F. Viégas, and M. Wattenberg (2017) Smoothgrad: removing noise by adding noise. arXiv preprint arXiv:1706.03825. Cited by: §1.
  • [30] M. Sundararajan, A. Taly, and Q. Yan (2017) Axiomatic attribution for deep networks. arXiv preprint arXiv:1703.01365. Cited by: §1, §2.1, §5.3.3, Table 2.
  • [31] A. Van Looveren and J. Klaise (2019) Interpretable counterfactual explanations guided by prototypes. arXiv preprint arXiv:1907.02584. Cited by: §1.
  • [32] P. Wang and N. Vasconcelos (2020-06) SCOUT: self-aware discriminant counterfactual explanations. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, Cited by: §1.
  • [33] K. Xu, S. Liu, P. Zhao, P. Chen, H. Zhang, Q. Fan, D. Erdogmus, Y. Wang, and X. Lin (2018) Structured adversarial attack: towards general implementation and better interpretability. arXiv preprint arXiv:1808.01664. Cited by: §4.2.1.
  • [34] M. D. Zeiler and R. Fergus (2014) Visualizing and understanding convolutional networks. In Proceedings of the European Conference on Computer Vision, pp. 818–833. Cited by: §1, §2.1.
  • [35] B. Zhou, A. Khosla, A. Lapedriza, A. Oliva, and A. Torralba (2016)

    Learning deep features for discriminative localization

    .
    In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2921–2929. Cited by: §2.1.

Chapter 1: Network Architecture

In each Table S4, S8, and S12,

denotes the skip-connection layer that passed through the convolution layer of 3 kernel size, 1 stride after concatenation with the tiled target label. In the case of feature maps used in

, we used the same number of feature maps of the counterfactual map generator to be concatenated.

We performed -D convolutional operation in the order of Conv.BatchNorm.Activation. In 3D Shapes experiment, the encoder network structure covered by Kim et al[19] was utilized, and in ADNI experiment, the encoder network of SonoNet-16 [3] was used identically.

1-1. 3D Shapes

Operation Feature Maps Batch Norm. Kernels Strides Padding Activation
Input
(ENC1) 2D Conv. 32 (4 4) (2 2) ReLU
(ENC2) 2D Conv. 32 (4 4) (2 2) ReLU
(ENC3) 2D Conv. 64 (4 4) (2 2) ReLU
(ENC4) 2D Conv. 64 (4 4) (2 2) ReLU
Table S1: 3D Shapes encoder network.
Operation Feature Maps Batch Norm. Dropout Activation
Input
Fully Connected 256 ReLU
Fully Connected 2 Softmax
Table S2: 3D Shapes Classifier network.
Operation Feature Maps Batch Norm. Kernels Strides Padding Activation
Input
2D Conv. 32 (4 4) (2 2) Leaky ReLU
2D Conv. 32 (4 4) (2 2) Leaky ReLU
2D Conv. 64 (4 4) (2 2) Leaky ReLU
2D Conv. 64 (4 4) (2 2) Leaky ReLU
Fully Connected 1 Linear
Table S3: 3D Shapes discriminator network.
Operation Feature Maps Batch Norm. Kernels Strides Padding Activation
Input
(DEC3) Upsampling (2 2)
Concatenate and (DEC3) along the channel axis
2D Conv. 64 (3 3) (1 1) ReLU
(DEC2) Upsampling (2 2)
Concatenate and (DEC2) along the channel axis
2D Conv. 32 (3 3) (1 1) ReLU
(DEC1) Upsampling (2 2)
Concatenate and (DEC1) along the channel axis
2D Conv. 32 (3 3) (1 1) ReLU
2D Deconv 1 (4 4) (2 2) Tanh
Table S4: 3D Shapes counterfactual map generator network.

1-2. Mnist

Operation Feature Maps Batch Norm. Kernels Strides Padding Activation
Input
2D Conv. 32 (3 3) (1 1) ReLU
(ENC1) 2D Conv. 32 (4 4) (2 2) ReLU
2D Conv. 64 (3 3) (1 1) ReLU
(ENC2) 2D Conv. 64 (4 4) (2 2) ReLU
2D Conv. 128 (3 3) (1 1) ReLU
(ENC3) 2D Conv. 128 (4 4) (2 2) ReLU
2D Conv. 256 (3 3) (1 1) ReLU
(ENC4) 2D Conv. 256 (4 4) (2 2) ReLU
Table S5: MNIST encoder network.
Operation Feature Maps Batch Norm. Dropout Activation
Input
Fully Connected 128 0.5 ReLU
Fully Connected 10 0.25 Softmax
Table S6: MNIST Classifier network.
Operation Feature Maps Batch Norm. Kernels Strides Padding Activation
Input
2D Conv. 32 (3 3) (1 1) Leaky ReLU
2D Conv. 32 (4 4) (2 2) Leaky ReLU
2D Conv. 64 (3 3) (1 1) Leaky ReLU
2D Conv. 64 (4 4) (2 2) Leaky ReLU
2D Conv. 128 (3 3) (1 1) Leaky ReLU
2D Conv. 128 (4 4) (2 2) Leaky ReLU
2D Conv. 256 (3 3) (1 1) Leaky ReLU
2D Conv. 256 (4 4) (2 2) Leaky ReLU
Fully Connected 1 Linear
Table S7: MNIST discriminator network.
Operation Feature Maps Batch Norm. Kernels Strides Padding Activation
Input
Upsampling (2 2)
(DEC3) 2D Conv. 128 (3 3) (1 1) ReLU
Concatenate and (DEC3) along the channel axis
2D Conv. 128 (3 3) (1 1) ReLU
Upsampling (2 2)
(DEC2) 2D Conv. 64 (2 2) (1 1) ReLU
Concatenate and (DEC2) along the channel axis
2D Conv. 64 (3 3) (1 1) ReLU
Upsampling (2 2)
(DEC1) 2D Conv. 32 (3 3) (1 1) ReLU
Concatenate and (DEC1) along the channel axis
2D Conv. 32 (3 3) (1 1) ReLU
2D Deconv 1 (4 4) (2 2) Tanh
Table S8: MNIST counterfactual map generator network.

1-3. Adni

Operation Feature Maps Batch Norm. Kernels Strides Padding Activation
Input
3D Conv. 16 (3 3 3) (1 1 1) ReLU
(ENC1) 3D Conv. 16 (3 3 3) (1 1 1) ReLU
Max pooling (2 2 2) (2 2 2 )
3D Conv. 32 (3 3 3) (1 1 1) ReLU
(ENC2) 3D Conv. 32 (3 3 3) (1 1 1) ReLU
Max pooling (2 2 2) (2 2 2)
3D Conv. 64 (3 3 3) (1 1 1) ReLU
3D Conv. 64 (3 3 3) (1 1 1) ReLU
(ENC3) 3D Conv. 64 (3 3 3) (1 1 1) ReLU
Max pooling (2 2 2) (2 2 2)
3D Conv. 128 (3 3 3) (1 1 1) ReLU
3D Conv. 128 (3 3 3) (1 1 1) ReLU
(ENC4) 3D Conv. 128 (3 3 3) (1 1 1) ReLU
Max pooling (2 2 2) (2 2 2)
3D Conv. 128 (3 3 3) (1 1 1) ReLU
3D Conv. 128 (3 3 3) (1 1 1) ReLU
3D Conv. 128 (3 3 3) (1 1 1) ReLU
Table S9: ADNI encoder network.
Operation Feature Maps Batch Norm. Dropout Activation
Input
Fully Connected 256 0.25 ReLU
Fully Connected 32 0.25 ReLU
Fully Connected 2 0.25 Softmax
Table S10: ADNI Classifier network.
Operation Feature Maps Batch Norm. Kernels Strides Padding Activation
Input
3D Conv. 16 (3 3 3) (1 1 1) Leaky ReLU
3D Conv. 16 (3 3 3) (1 1 1) Leaky ReLU
Max pooling (2 2 2) (2 2 2)
3D Conv. 32 (3 3 3) (1 1 1) Leaky ReLU
3D Conv. 32 (3 3 3) (1 1 1) Leaky ReLU
Max pooling (2 2 2) (2 2 2)
3D Conv. 64 (3 3 3) (1 1 1) Leaky ReLU
3D Conv. 64 (3 3 3) (1 1 1) Leaky ReLU
3D Conv. 64 (3 3 3) (1 1 1) Leaky ReLU
Max pooling (2 2 2) (2 2 2)
3D Conv. 128 (3 3 3) (1 1 1) Leaky ReLU
3D Conv. 128 (3 3 3) (1 1 1) Leaky ReLU
3D Conv. 128 (3 3 3) (1 1 1) Leaky ReLU
Max pooling (2 2 2) (2 2 2)
3D Conv. 128 (3 3 3) (1 1 1) Leaky ReLU
3D Conv. 128 (3 3 3) (1 1 1) Leaky ReLU
3D Conv. 128 (3 3 3) (1 1 1) Leaky ReLU
Fully Connected 1 Linear
Table S11: ADNI discriminator network.
Operation Feature Maps Batch Norm. Kernels Strides Padding Activation
Input
(DEC4) Upsampling (2 2 2)
Concatenate and (DEC4) along the channel axis
3D Conv. 128 (3 3 3) (1 1 1) ReLU
3D Conv. 128 (3 3 3) (1 1 1) ReLU
3D Conv. 128 (3 3 3) (1 1 1) ReLU
(DEC3) Upsampling (2 2 2)
Concatenate and (DEC3) along the channel axis
3D Conv. 64 (3 3 3) (1 1 1) ReLU
3D Conv. 64 (3 3 3) (1 1 1) ReLU
3D Conv. 64 (3 3 3) (1 1 1) ReLU
(DEC2) Upsampling (2 2 2)
Concatenate and (DEC2) along the channel axis
3D Deconv 32 (1 2 1) (1 1 1) ReLU
3D Conv. 32 (3 3 3) (1 1 1) ReLU
3D Conv. 32 (3 3 3) (1 1 1) ReLU
(DEC1) Upsampling (2 2 2)
Concatenate and (DEC1) along the channel axis
3D Conv. 16 (3 3 3) (1 1 1) ReLU
3D Conv. 16 (3 3 3) (1 1 1) ReLU
3D Conv. 1 (3 3 3) (1 1 1) Tanh
Table S12: ADNI counterfactual map generator network.

1-4. Hyperparameters

Optimizer Adam
Epochs 50
Batch Size 128
Learning Rate 0.0005
Exponential Decay Rate 0.98
One-sided Label Smoothing 0.1
Weight Constants
Table S13: 3D Shapes model hyperparameters.
Optimizer Adam
Epochs 100
Batch Size 256
Learning Rate 0.001
Exponential Decay Rate 0.99
One-sided Label Smoothing 0.1
Weight Constants
Table S14: MNIST model hyperparameters.
Optimizer Adam
Epochs 150
Batch Size 3
Learning Rate 0.001
Exponential Decay Rate 0.99
One-sided Label Smoothing 0.1
Weight Constants
Table S15: ADNI model hyperparameters.

Chapter 2: Brain Interpolation

Figure S1:

Example of counterfactual map conditioned on interpolated target labels. The purple box visualizes the ventricular regions. The one-hot vector at the bottom of each counterfactual map is the target labels. The scores at the bottom of AD subject and NC subject is the model’s softmax activated logit.