Explaining Models by Propagating Shapley Values of Local Components

11/27/2019 ∙ by Hugh Chen, et al. ∙ 0

In healthcare, making the best possible predictions with complex models (e.g., neural networks, ensembles/stacks of different models) can impact patient welfare. In order to make these complex models explainable, we present DeepSHAP for mixed model types, a framework for layer wise propagation of Shapley values that builds upon DeepLIFT (an existing approach for explaining neural networks). We show that in addition to being able to explain neural networks, this new framework naturally enables attributions for stacks of mixed models (e.g., neural network feature extractor into a tree model) as well as attributions of the loss. Finally, we theoretically justify a method for obtaining attributions with respect to a background distribution (under a Shapley value framework).

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

page 5

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

Introduction

Neural networks and ensembles of models are currently used across many domains. For these complex models, explanations accounting for how features relate to predictions is often desirable and at times mandatory [goodman2017european]. In medicine, explainable AI (XAI) is important for scientific discovery, transparency, and much more [holzinger2017we]. One popular class of XAI methods is per-sample feature attributions (i.e., values for each feature for a given prediction).

In this paper, we focus on SHAP values [lundberg2017unified] – Shapley values [shapley1953value] with a conditional expectation of the model prediction as the set function. Shapley values are the only additive feature attribution method that satisfies the desirable properties of local accuracy, missingness, and consistency. In order to approximate SHAP values for neural networks, we fix a problem in the original formulation of DeepSHAP [lundberg2017unified] where previously it used

as the reference and theoretically justify a new method to create explanations relative to background distributions. Furthermore, we extend it to explain stacks of mixed model types as well as loss functions rather than margin outputs.

Popular model agnostic explanation methods that also aim to obtain SHAP values are KernelSHAP [lundberg2017unified] and IME [vstrumbelj2014explaining]

. The downside of most model agnostic methods are that they are sampling based and consequently high variance or slow.

Alternatively, local feature attributions targeted to deep networks has been addressed in numerous works: Occlusion [zeiler2014visualizing], Saliency Maps [simonyan2013deep], Layer-Wise Relevance Propagation [bach2015pixel], DeepLIFT, Integrated Gradients (IG) [sundararajan2017axiomatic], and Generalized Integrated Gradients (GIG) [gig].

Of these methods, the ones that have connections to the Shapley Values are IG and GIG. IG integrates gradients along a path between a baseline and the sample being explained. This explanation approaches the Aumann-Shapley value. GIG is a generalization of IG to explain losses and mixed model types – a feature DeepSHAP also aims to provide. IG and GIG have two downsides: 1.) integrating along a path can be expensive or imprecise and 2.) the Aumann-Shapley values fundamentally differ to the SHAP values we aim to approximate. Finally, DASP [ancona2019explaining]

is an approach that approximates SHAP values for deep networks. This approach works by replacing point activations at all layers by probability distributions and requires many more model evaluations than DeepSHAP. Because DASP aims to obtain the same SHAP values as in DeepSHAP it is possible to use DASP as a part of the DeepSHAP framework.

Approach

Propagating SHAP values

Figure 1: Visualization of models for understanding DeepLIFT’s connection to SHAP values. In the figure is a non-linear function and is a non-differentiable tree model.

DeepSHAP builds upon DeepLIFT; in this section we aim to better understand how DeepLIFT’s rules connect to SHAP values. This has been briefly touched upon in [deeplift] and [lundberg2017unified], but here we explicitly define the relationship.

DeepSHAP is a method that explains a sample (foreground sample), by setting features to be “missing”. Missing features are set to corresponding values in a baseline sample (background sample). Note that DeepSHAP generally uses a background distribution, however focusing on a single background sample is sufficient because we can rewrite the SHAP values as an average over attributions with respect to a single background sample at a time (see next section for more details). In this section, we define a foreground sample to have features x

and neuron values

(obtained by a forward pass) and a background sample to have or . Finally we define to be attribution values.

If our model is fully linear as in Figure 1a, we can get the exact SHAP values for an input by summing the attributions along all possible paths between that input and the model’s output . Therefore, we can focus on a particular path (in blue). Furthermore, the path’s contribution to is exactly the product of the weights along the path and the difference in : , because we can rewrite the layers of linear equations in 1a as a single linear equation. Note that we can derive the attribution for

in terms of the attribution of intermediary nodes (as in the chain rule):

(1)

Next, we move on to reinterpreting the two variants of DeepLIFT: the Rescale rule and the RevealCancel rule. First, a gradient based interpretation of the Rescale rule has been discussed in [ancona2018towards]. Here, we explicitly tie this interpretation to the SHAP values we hope to obtain.

For clarity, we consider the example in Figure 1b. First, the attribution value for is because SHAP values maintain local accuracy (sum of attributions equals ) and is a function with a single input. Then, under the Rescale rule, (note the resemblance to Equation (1)). Under this formulation it is easy to see that the Rescale rule first computes the exact SHAP value for and then propagates it back linearly. In other words, the the non-linear and linear functions are treated as separate functions. Passing back nonlinear attributions linearly is clearly an approximation, but confers two benefits: 1.) fast computation on order of a backward pass and 2.) a guarantee of local accuracy.

Next, we describe how the RevealCancel rule (originally formulated to bring DeepLIFT closer to SHAP values) connects to SHAP values in the context of Figure 1c. RevealCancel partitions into positive and negative components based on if (where ), in essence forming nodes and . This rule computes the exact SHAP attributions for and and then propagates the resultant SHAP values linearly. Specifically:

Under this formulation, we can see that in contrast to the Rescale rule that explains a linearity and nonlinearity by exactly explaining the nonlinearity and backpropagating, the RevealCancel rule exactly explains the nonlinearity and a partition of the inputs to the linearity as a single function prior to backpropagating. The RevealCancel rule incurs a higher computational cost in order to get a an estimate of

that is ideally closer to the SHAP values.

This reframing naturally motivates explanations for stacks of mixed model types. In particular, for Figure 1d, we can take advantage of fast, exact methods for obtaining SHAP values for tree models to obtain using Independent Tree SHAP [treeshap]. Then, we can propagate these attributions to get using either the Rescale or RevealCancel rule. This argument extends to explaining losses rather than output margins as well.

Although we consider specific examples here, the linear propagation described above will generalize to arbitrary networks if SHAP values can be computed or approximated for individual components.

SHAP values with a background distribution

Note that many methods (Integrated Gradients, Occlusion) recommend the utilization of a single background/reference sample. In fact, DeepSHAP as previously described in [lundberg2017unified]

created attributions with respect to a single reference equal to the expected value of the inputs. However, in order to obtain SHAP values for a given background distribution, we prove that the correct approach is as follows: obtain SHAP values for each baseline in your background distribution and average over the resultant attributions. Although similar methodologies have been used heuristically

[deeplift, erion2019learning], we provide a theoretical justification in Theorem 1 in the context of SHAP values.

Theorem 1.

The average over single reference SHAP values approaches the true SHAP values for a given distribution.

Proof.

Define to be the data distribution, to be the set of all features, and to be the model being explained. Additionally, define to return a sample where the features in are taken from and the remaining features from . Define to be all combinations of the set and to be all permutations of . Starting with the definition of SHAP values for a single feature:

where the second step depends on an interventional conditional expectation [janzing2019feature] which is very close to Random Baseline Shapley in [sundararajan2019many]). ∎

Experiments

Background distributions avoid bias

Figure 2: Using a single baseline leads to bias in explanations.

In this section, we utilize the popular CIFAR10 dataset [krizhevsky2009learning] to demonstrate that single references lead to bias in explanations. We train a CNN that achieves 75.56% test accuracy and evaluate it using either a zero baseline as in DeepLIFT or with a random set of 1000 baselines as in DeepSHAP.

In Figure 2, we can see that for these images drawn from the CIFAR10 training set, DeepLIFT has a clear bias that results in low attributions for darker regions of the image. For DeepSHAP, having multiple references drawn from a background distribution solves this problem and we see attributions in sensical dark regions in the image.

Explaining mortality prediction

In this section, we validate DeepSHAP’s explanations for an MLP with 82.56% test accuracy predicting 15 year mortality. The dataset has 79 features for 14,407 individuals released by [treeshap] based on NHANES I Epidemiologic Followup Study [cox1997plan].

Figure 3: Summary plot of DeepSHAP attribution values. Each point is the local feature attribution value, colored by feature value. For brevity, we only show the top 6 features.

In Figure 3, we plot a summary of DeepSHAP (with 1000 random background samples) attributions for all NHANES training samples (n=) and notice a few trends. First, Age is predictably the most important and old age contributes to a positive mortality prediction (positive SHAP values). Second, the Sex feature validates a well-known difference in mortality [gjoncca1999male]. Finally, the trends linking high systolic BP, low serum albumin, high sedimentation rate, and high hematocrit to mortality have been independently discovered [port2000systolic, goldwasser1997serumalbumin, paul2012hematocrit, go2016sedimentation].

Figure 4: Explaining an individual’s mortality prediction for different backgrounds distributions.

Next, we show the benefits of being able to specify a background distribution. In Figure 4a, we see that explaining an individual’s mortality prediction with respect to a general population emphasizes that the individual’s age and gender are driving a high mortality prediction. However, in practice doctors are unlikely to compare a 67-year old male to a general population that includes much younger individuals. In Figure 4b, being able to specify a background distribution allows us to compare our individual against a more relevant distribution of males over 60. In this case, gender and age are naturally no longer important, and the individual actually may not have cause for concern.

Interpreting a stack of mixed model types

Figure 5: Ablation test for explaining an LSTM feature extractor fed into an XGB model. All methods used background of 20 samples obtained via kmeans. [a.] Convergence of methods for a single explanation. [b.] Model performance versus # features kept for DeepSHAP (rescale), IME Explainer (4000 samples), KernelSHAP (2000 samples) and a baseline (Random) (AUC in the legend).

Stacks, and more generally ensembles, of models are increasingly popular for performant predictions [bao2009stacking, gunecs2017stacked, zhai2018development]. In this section, our aim is to evaluate the efficacy of DeepSHAP for a neural network feature extractor fed into a tree model. For this experiment, we use the Rescale rule for simplicity and Independent TreeSHAP to explain the tree model [treeshap]. The dataset is a simulated one called Corrgroups60. Features have tight correlation between groups of features ( is feature ), where , if , and otherwise. The label is generated linearly as where and if and otherwise.

We evaluate DeepSHAP with an ablation metric called keep absolute (mask) [treeshap]

. The metric works in the following manner: 1) Obtain the feature attributions for all test samples 2) Mask all features (by mean imputation) 3) Introduce one feature at a time (unmask) from largest absolute attribution value to smallest for each sample and measure

. The should initially increase rapidly, because we introduce the “most important” features first.

We compare against two sampling-based methods (a natural alternative for explaining mixed model stacks) that provide SHAP values in expectation: KernelSHAP and IME explainer. In Figure 5b, DeepSHAP (rescale) has no variability and requires a fixed number of model evaluations. IME Explainer and KernelSHAP, benefit from having more samples (and therefore more model evaluations). For the final comparison, we check the variability of the tenth largest attribution (absolute value) of the sampling based methods to determine “convergence” across different numbers of samples. Then, we use the number of samples at the point of “convergence” for the next figure.

In Figure 5c, we can see that DeepSHAP has a slightly higher performance than model agnostic methods. Promisingly, all methods demonstrate initial steepness in their performance; this indicates that the most important features had higher attribution values. We hypothesize that KernelSHAP and IME Explainer’s lower performance is due in part to noise in their estimates. This highlights an important point: model agnostic methods often have sampling variability that makes determining convergence difficult. For a fixed background distribution, DeepSHAP does not suffer from this variability and generally requires fewer model evaluations.

Improving the RevealCancel rule

Figure 6: Comparison of new RevealCancel rule for estimating SHAP values on a toy example. The axes correspond to mean absolute difference from the SHAP values (computed exactly). Green means RevealCancel wins and red means it loses.

DeepLIFT’s RevealCancel rule’s connection to the SHAP values is touched upon in [deeplift]. Our SHAP value framework explicitly defines this connection. In this section, we propose a simple improvement to the RevealCancel rule. In DeepLIFT’s RevealCancel rule the threshold is set to (for splitting and ). Our proposed rule RevealCancel sets the threshold to the mean value of across . Intuitively, splitting by the mean better separates nodes, resulting in a better approximation than splitting by zero.

We experimentally validate RevealCancel in Figure 6, explaining a simple function: . We fix the background to zero:

and draw 100 foreground samples from a discrete uniform distribution:

.

In Figure 6a, we show that RevealCancel offers a large improvement for approximating SHAP values over the Rescale rule and a modest one over the original RevealCancel rule (at no additional asymptotic computational cost).

Conclusion

In this paper, we improve the original DeepSHAP formulation [lundberg2017unified] in several ways: we 1.) provide a new theoretically justified way to provide attributions with a background distribution 2.) extend DeepSHAP to explain stacks of mixed model types 3.) present improvements of the RevealCancel rule.

Future work includes more quantitative validation on different data sets and comparison to more interpretability methods. In addition, we primarily used Rescale rule for many of these evaluations, but more empirical evaluations of RevealCancel are also important.

References