Convolutional Neural Network (CNN) represents one class of black-box machine learning algorithms. Despite its popularity in many applied fields, such as medical and legal settings, the lack of causal interpretability often spurs “trust issues” and criticism from a broader research community and front-line users. Typical frameworks employed in such settings are often required to be causal and interpretable. Consider the following scenario: a CNN algorithm is deployed to predict cardiac arrest using electrocardiogram (ECG) in the outpatient of geriatric medicine. A false negative result means a cardiac arrest event is missed by the algorithm, this could lead to catastrophic outcomes for the patient. Naturally, under this circumstance, the first reaction is to ask: why does the algorithm fail to classify certain cases? Scenarios as such are commonly seen across different fields. However, an satisfying answer to this “why” question has rarely been adequately provided. Instead, researchers often focus on striving to enhance the algorithm performance with a goal of increasing the prediction accuracy. Hence, insights from the misclassified results are often overlooked or understudied. Investigating the causal interpretability of misclassifications can first, give a direct answer to the previously mentioned “why” question; second, estimate the attributable causal effects to the input data and the model architecture; finally, provide an opportunity to modify the potential causes based on evidence, increase prediction accuracy.
2 Related Work
Only in the recent years has black-box model interpretability gained increasing attention. Traditional interpretability frameworks can be classified into two categories: (1) Inherently interpretable models that generate explanations during the training process, such as attention networks, disentangled representation learning; (2) Post-hoc models that first generate an abstract of training concept and map the abstract onto a more interpretable domains, such as saliency maps, example-based explanations. Such frameworks only draw attentions onto the feature importance at its face value and its contribution to predictive accuracy, which are correlation-based considerations. However, causal interpretation concerns questions such as how does altering input features or a component of the algorithm architectures change the prediction results.
Existing causal inference frameworks can generally be classified into model-based causal inference and sample-based causal inference. The former answers questions such as “What is the impact of the j-th filter from i-th layer?”, whereas the latter focuses on explaining the causal mechanism through generated counterfactual images. [3, 5] Some other efforts include treatment effect framework, global explanation framework, and algorithm fairness metrics, etc. [2, 7, 6]
Unlike the previously mentioned methods, our causal framework takes on a “backward” approach by investigating the potential causes of misclassification in six modern CNN architectures: VGG16, ResNet50, GoogLeNet, DenseNet161, MobileNet_V2, Inception_V3; and strives to answer the following three questions: (1) Is there a generalizable (invariant) pattern of misclassifications across all six architectures? (2) Can we deduce testable causal hypothesis from the observed patterns? (3) By precisely modifying the identified causes, will the number of original misclassifications be reduced without introducing new misclassifications? The input of this study is image the output is the class label of the corresponding image. More details about the data set will be described in the method section.
Obtain Misclassified Samples We first train the image data on VGG16, ResNet50, GoogLeNet, DenseNet161, MobileNet_V2, Inception_V3; then classify the test set using the trained networks. Here, we briefly describe the unique features of each selected architecture. VGG16 comparing to the earlier architectures uses smaller (
) convolutional filters but more layers; ResNet50 includes the residual module to enable even deeper network without information loss; GoogLetNet uses efficient “inception” module and removes the fully-connected layer to save memory; DenseNet161 connects each layer to every other layer in the forward path to increase training efficiency; MobileNet_V2 uses thin “bottleneck” layers for the input and output of the residual module; finally Inception_V3 incorporates many previously mentioned ideas into one model. The input layers for each model are modified to be compatible with CIFAR-10 images. All training is performed using Pytorch Lightning
A consistent set of hyperparameters are used to train each model. The goal is to achieve greater than 90% accuracy for the CIFAR-10 dataset. The following set of hyperparameters were used for training each model:
The training and validation curves are included in Appendix Figure 7.
After testing, we then record all the misclassified images including their correct classes and misclassified classes. Two descriptive statistics are calculated here: the class-wise missclassification rateand the conditional misclassification rate , where is the correct image class, is the misclassified class. Thus we have
where is the number of images in class , is the total number of images in class that are misclassified into other classes, is the number of images from class that are misclassified as class . In other words,
is the marginal probability that an image from classgets misclassified, is the conditional probability that, conditioning on the image from class is misclassified, the probability this image is classified into class .
Extract Misclassification Patterns Using the conditional probabilities , a 10 by 10 confusion heatmap is constructed to capture the misclassification patterns for every CNN architecture. We first qualitatively assess the types and distributions of the misclassifications; then summarize the observed patterns within each heatmap and compare across all architectures. We create intuitive misclassification network with each node representing a class and directed edges coming from the correct class to the misclassified class. Weighted in-degrees for each node are calculated as the following:
Where the in-degree of class is the sum of number of other classes misclassified into weighted by the conditional probability . The in-degrees of each node are also used in comparing misclassification patterns across all network architectures.
Causal Hypothesis To generate the causal hypothesis, we gather evidence from the following three aspects: 1) the misclassification results from these six selected architectures; 2) the common misclassification structures from all the heatmaps; 3) the distributions of the scores of misclassified classes and summary statistics , , and . The testable causal hypothesis will focus on the following two potential causes of misclassifications: the innate inter-class distances and the non-essential information interference. The former concerns the morphological similarities across misclassified classes that could also confuse human-performed classification. The latter describes the phenomenon that the pixels outside of the target objects interfering and dominating the misclassifcations of the target objects. The first cause may potentially make up the majority of the irreducible error of misclassification and could be harder to improve upon in comparison to the second one. Besides differentiating the two causes qualitatively, this study also attempts to test the difference statistically across all selected architectures.
Based on the premises above, causal questions we posit here are: 1) Is the innate morphological similarity consistently causing misclassifications across all selected architectures? 2) Is the non-essential information interference causing misclassifications across all selected architectures? 3) Can we qualitatively and quantitatively differentiate the two causes of misclassifications? 4) Can we correct non-essential information interference to reduce misclassifcations?
Hypothesis Testing The first two questions concern the following two aspects: is there a specific cause of misclassification and is this cause model dependent? Two answer these two questions, we will employ human domain knowledge of each class and incorporate evidence from the confusion heatmaps and misclassification networks. Therefore, the two causes are categorized qualitatively. We simply apply the same categorization across all six models to argue the model dependence.
The answer to questions 1) and 2) lead to the first half of the answer to question 3). However to the qualitatively differentiate the two causes, we first qualitatively categorize the two types of misclassifications, then compute the difference of the scores between correct class and misclassified class. Concerning the relatively small sample size, we assume the score difference for each misclassification category follows a t-distribution, with the following density function:
is the number of degrees of freedom,is the gamma function,
is the random variable - score difference. Letand
denote the sample sizes of the two types of misclassifications, and assume the two sets of scores have similar variance, then we may construct the t-statistics:
are the standard deviations of the two score distributions, the degrees of freedom here is.
To answer question 4), we first need to show the existence of non-essential information interference in the image to be classified. Then “surgically” remove the interference, have the model to reclassify the modified image, then record the classification results to observe any decrease in misclassifications. This step is equivalent to the “” in causal inference, where we modify the cause and observed the potential outcomes under both conditions before and after “” is implemented. To achieve this set of goals, we first extract the saliency maps of the misclassified images caused by non-essential information interference. Based on our hypothesis, the saliency maps shall indicate peripheral regions outside of the target objects that are driving the model classifications. We then modify the original image by first selecting the top 5% pixels on the saliency map, generate bounding boxes around the selected pixels with width and height as and , set the pixels within the bonding boxes in the original image to 0, and reclassify the modified image. The driving goal here is to make the minimum image alteration to achieve the largest reduction in misclassification. The top percentages of the pixels on the saliency map, the width and height of the bounding boxes are potentially hyperparameters that can be tuned to optimize the performance of this process, however, this study will only take the first step to explore the efficacy of this idea, the detailed tuning step will be included in our future work.
Compare to Baseline Methods We compare our training and test results of CIFAR-10 from all included networks with the state-of-the-art training and test results from the machine learning community. Since causal inference is a quite new topic in machine learning and is quickly evolving, we are not able to identify comparable baseline methods in this regard. However, we will be doing the rolling literature review, if there are comparable studies emerging, we will include the additional comparison in our final paper.
4 Dataset and Features
This study uses the CIFAR-10 dataset, which contains 60,000 32x32 color images from 10 classes (6,000 images per class). 10,000 out of the total images are held out as a test set with 1,000 test images per class. The 10 classes are planes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. Figure 1 includes the original sample images from the dataset. Each image is channel-wise normalized as the following:
Where denotes the three color channels of the image. is indexing the training image. and are the mean and standard deviation of each channel. and are the pixel value at of channel , before and after normalization, respectively. Figure 2 provides an example of before and after normalized.
5.1 Classification Performance Comparisons
Papers on the state-of-the-art architectures do not conventionally report class-wise misclassification rate, as it is rarely the goal of interest, hence we will not be able to compare our self-trained class-wise misclassification rate to the baseline models. However, the architectures under investigation are comparable to state-of-the-art performance accuracy on the CIFAR-10 dataset. At the time of writing, the top performing architecture on CIFAR-10 (without extra training data) is PyramidNet, which achieved 98.5% accuracy on the test set. By comparison, the top performing model in this study, Inception-v3, achieved 93.3% on the test set. Figure 3 summarizes the misclassification rates by each class and by each trained architecture. We conducted a chi-squared heterogeneity test to examine the misclassifcation distributions across six models. The value is close to 1, which indicates the class-wise misclassification rates is homogeneous across all architectures are homogeneous.
5.2 Misclassification Patterns
As indicated in the homogeneity test, all six architectures share similar misclassification patterns. Figure 4 presents the conditional misclassification rates by class for all six networks. An obvious pattern across all six models is the cat dog misclassifications, where dogs are most likely misclassified as cats and cats are likely misclassified as dogs. Another pattern that is similar to the cat dog misclassifications is the car truck misclassifications. Noting these two patterns are symmetrical. In other words, two classes are “trading” misclassified images. The third pattern is the plane ship misclassifications. The above mentioned symmetry is only observed in ResNet50, GoogLeNet, and DenseNet161. In VGG16, MobileNet_V2, and Inception_V2, such symmetry does not persist. However, all these three networks have ships more likely to be classified as plane than the other way around.
To visualize the symmetry and calculate the in-degrees (), we constructed a misclassification network (Figure 5) for each model with 0.3 as the threshold for the misclassification rate. From Figure 5, we observe two universal patterns across all six models: 1) the symmetrical patterns in the cat dog and car truck misclassifications in all six models; 2) the asymmetrical pattern of frog cat miscalssifcation, as all six models we observe frog is more likely to be categorized as cat than vice versa. The third pattern shown in the network is the inconssitent ship plane misclassification as mentioned above. The rest of the non-universal patterns are all asymmetric. Hence we may conclude with caution that symmetric patterns are not model dependent, whereas asymmettic patterns do not have such guarantees.
5.3 Two Causal Hypothesis and Testing
Based on the evidence presented previously, we make the following two hypothesis: 1) the two symmetric patterns, cat dog, car truck, are potentially caused by innate morphological similarities of the two classes; 2) the asymmetric pattern, frog cat and ship plane, allows us to hypothesize that the misclassifications between these two classes are potentially caused by non-essential information interference.
To explore the two causes of misclassifications, we have compared the distributions of the score ratios (correct class / misclassified class) of the two types of misclassifications (morphology vs interference). The t-tests result are not significant, as the score ratio distributions are similar for these two categories. We include histograms for the score ratio distributions inAppendix Figure 8.
We thereby conclude that although the two types of causes are easily differentiated qualitatively, the metrics chosen in this study cannot differentiate the two causes of misclassifications quantitatively.
5.4 Cause Modification and Reclassification
Due to the time constraint, we are only able to conduct a cause modification experiment using VGG16. In Figure 6, we provide two demo examples of this experiment. As stated previously, we hypothesize that the ship plane misclassification is caused by non-essential information interference. Hence, the two examples are both from the ship images that are initially classified as planes. To verify this hypothesis of interference causes misclassification, we first generate the saliency map for the two demo images . The saliency maps in the middle column indicate the image classification is heavily based on the background of the image instead of the target object, which is the ship in these two examples. Hence, we implement the method proposed in the Method section to construct bonding boxes around the top 5% of the pixels of the saliency map. Top 5% of these pixels that belong to the target object are spared from this operation. Then we set the pixels within the bonding boxes to 0 on the original images, and rerun the VGG16 network to reclassify the images. We are able to obtain the correct classification based on this operation. The idea here is similar to image segmentation and background removal, however we do not need perfect segmentation to achieve the correct classification. Furthermore, segmentation requires nontrivial computation, which is time consuming and computationally expensive. The driving goal here is to implement minimum modification to obtain the correct classification.
This study systematically categorizes the causes of misclassifications into morphological similarity and non-essential information interference. The two categories of misclassifications are observed across all six selected architectures. Although it is intuitive to qualitatively differentiate the two causes, the metrics chosen in this study is not able to provide quantitative differentiation. Interestingly, morphological similarity caused misclassifications tend to be symmetric and not model dependent, where as non-essential inference caused misclassifications tnd to be inconsistent regarding model dependency.
The saliency map is able to verify the non-essential information interference caused misclassifications. Directly modifying the cause shows success in decreasing the misclassification rate. The size of the anchor box and the top percentage of the saliency map pixels are considered as hyper parameters. The goal is to make the smallest change to the original image and maximize the reduction in misclassification.
The future direction for this study includes choosing an appropriate metric to quantitatively differentiate the misclassifications due to different causes, explore the generalizability of the proposed cause-modification method to further reduce the non-essential information interference caused misclassifications, investigate to what extend the morphological similarity caused misclassification is reducible, and explore the potential interaction between the two causes.
We sincerely thank CS231N teaching staff, especially Yi Wen who has provided constructive feedback to our project. The project has been conducted amid COVID-19 Pandemic and Nation-wide Black Lives Matter Movements, we send our condolences to the victims of these crisis and thank the Stanford community to uphold tremendous solidarity in such difficult times.
The two author contributed equally to the project. Authors do not declare any conflict of interest.
-  Harris et al. Understanding and enhancing mixed sample data augmentation. arXiv:2002.12047, 2020.
-  Kim et al. Learning interpretable models with causal guarantees. arXiv:1901.08576, 2019.
-  Mothilal et al. Explaining machine learning classifiers through diverse counterfactual explanations. arXiv:1905.07697, 2019.
-  Moraffah et al. Causal interpretability for machine learning - problems, methods and evaluation. ACM SIGKDD, 22(1), 2020.
-  Narendra et al. Explaining deep learning models using causal inference. arXiv:1811.04376, 2018.
-  Ribeiro et al. Model-agnostic interpretability of machine learning. arXiv:1606.05386, 2016.
Zhang et al.
Fairness in decision-making — the causal explanation formula.
AAAI Publications, Thirty-Second AAAI Conference on Artificial Intelligence, 2018.
-  Zhao et al. Causal interpretations of black-box models. Journal of Business and Economic Statistics, 2019.
-  WA Falcon. Pytorch lightning. GitHub. Note: https://github. com/williamFalcon/pytorch-lightning Cited by, 3, 2019.
-  Narine Kokhlikyan, Vivek Miglani, Miguel Martin, Edward Wang, Jonathan Reynolds, Alexander Melnikov, Natalia Lunova, and Orion Reblitz-Richardson. Pytorch captum, 2019.