transformers-interpret
Model explainability that works seamlessly with 🤗 transformers. Explain your transformers model in just 2 lines of code.
view repo
Recent work has shown great promise in explaining neural network behavior. In particular, feature attribution methods explain which features were most important to a model's prediction on a given input. However, for many tasks, simply knowing which features were important to a model's prediction may not provide enough insight to understand model behavior. The interactions between features within the model may better help us understand not only the model, but also why certain features are more important than others. In this work, we present Integrated Hessians (Code available at https://github.com/suinleelab/path_explain): an extension of Integrated Gradients that explains pairwise feature interactions in neural networks. Integrated Hessians overcomes several theoretical limitations of previous methods to explain interactions, and unlike such previous methods is not limited to a specific architecture or class of neural network. We apply Integrated Hessians on a variety of neural networks trained on language data, biological data, astronomy data, and medical data and gain new insight into model behavior in each domain.
READ FULL TEXT VIEW PDFModel explainability that works seamlessly with 🤗 transformers. Explain your transformers model in just 2 lines of code.
A repository for explaining feature attributions and feature interactions in deep neural networks.
Deep neural networks are among the most popular class of machine learning model. They have achieved state-of-the-art performance in problem domains ranging from natural language processing
(Devlin et al., 2018) to image recognition (He et al., 2016). They have even outperformed other non-linear model types on structured tabular data (Shavitt and Segal, 2018). Although neural networks have been traditionally difficult to interpret compared to simpler model classes, gaining a better understanding of their predictions is desirable for a variety of reasons. To the extent that these algorithms are used in automated decisions impacting humans, explanations may be legally required (Selbst and Powles, 2017). When used in high stakes applications, like diagnostic radiology, it is essential to ensure that models are making safe decisions for the right reasons (Geis et al., 2019). During model development, interpretability methods may generally help debug undesirable model behavior (Sundararajan et al., 2017).Feature attribution: There has been a large number of recent approaches to interpret deep neural networks, ranging from methods that aim to distill complex models into more simple models (Tan et al., 2018; Wu et al., 2018; Puri et al., 2017), to methods that aim to identify the most important concepts learned by a network (Kim et al., 2017; Olah et al., 2018, 2017; Fong and Vedaldi, 2018; Erhan et al., 2009; Mahendran and Vedaldi, 2015). One of the best-studied sets of approaches is known as feature attribution methods (Binder et al., 2016; Shrikumar et al., 2017; Lundberg and Lee, 2017; Ribeiro et al., 2016). These approaches explain a model’s prediction by assigning credit to each input feature based on how much it influenced that prediction. Although these approaches help practitioners understand which features are important, they do not explain why certain features are important or how features interact in a model. In order to develop a richer understanding of model behavior, it is therefore desirable to develop methods to explain not only feature attributions but also feature interactions. For example, in Figure 1, we show that word-level interactions can help us distinguish why deeper, more expressive neural networks outperform simpler ones on language tasks.
Feature interaction: There are several existing methods that explain feature interactions in neural networks. For example, Cui et al. (2019) propose a method to explain global interactions in Bayesian Neural Networks (BNN). After jointly training a linear model to learn main effects with a BNN to learn interactions, they detect learned interactions by finding pairs of features that on average (across many samples) have large magnitude terms in the input hessian of the BNN. Tsang et al. (2017)
propose a framework called Neural Interaction Detection, that detects statistical interactions between features by examining the weight matrices of feed-forward neural networks. In the domain of deep learning for genomics,
Greenside et al. (2018)propose an approach called Deep Feature Interaction Maps. This approach detects interactions between a source and target feature by calculating the change in the first-order feature attribution of the source feature when all other features are held constant but the target feature is changed.
Limitations of Prior Approaches: Recent work has shown that attempting to quantitatively or qualitatively compare methods to explain machine learning models can lead to misleading and unreliable results (Tomsett et al., 2019; Buçinca et al., 2020). Instead, we attempt to draw contrast with previous approaches from a theoretical and practical perspective.
While previous approaches have taken important steps towards understanding feature interaction in neural networks, they all suffer from practical limitations. Each of the above methods are only applicable to models trained on certain types of data, or applicable only to certain model architectures. Neural Interaction Detection only applies to feed-forward neural network architectures, and can not be used on networks with convolutions, recurrent units, or self-attention. The approach suggested by Cui et al. (2019) is limited in that it requires the use of Bayesian Neural Networks; it is unclear how to apply the method to standard neural networks. Deep Feature Interaction Maps only work when the input features for a model have a small number of discrete values (such as the genomic data used as input in their paper).
In addition to architecture and data limitations, existing methods to detect interactions do not satisfy the common-sense axioms that have been proposed by feature attribution methods (Sundararajan et al., 2017; Lundberg and Lee, 2017). This failure to ground interactions in existing theory leads these previous approaches being provably unable to find interactions when interactions are present, or more generally finding counter-intuitive interactions (see section 4).
Our Contributions:
(1) We propose an approach to quantifying pairwise feature interactions that can be applied to any neural network architecture; (2) We identify several common-sense axioms that feature-level interactions should satisfy and show that our proposed method satisfies them; (3) We provide a principled way to compute interactions in ReLU-based networks, which are piece-wise linear and have zero second derivatives; (4) We demonstrate the utility of our method by showing the insights it reveals into real production-type models and real datasets.
To derive our feature interaction values, we start by considering Integrated Gradients, proposed by Sundararajan et al. (2017). We represent our model as a function .^{*}^{*}*In the case of multi-output models, such as multi-class classification problems, we assume the function is indexed into the correct output class.. For a function , the Integrated Gradients attribution for the th feature is defined as:
(1) |
where is the sample we would like to explain and is some uninformative baseline value. Although is often a neural network, the only requirement in order to compute attribution values is that be differentiable along the path from to . Our key insight is that the Integrated Gradients value for a differentiable model is itself a differentiable function . This means that we can apply Integrated Gradients to itself in order to explain how much feature impacted the importance of feature :
(2) |
For , we can derive that:
(3) |
In the case of , the formula has an additional first-order term. We leave the derivation and the full formula to the appendix. In this view, we can interpret as the explanation of the importance of feature in terms of the input value of feature .
Sundararajan et al. (2017) showed that, among other theoretical properties, Integrated Gradients satisfies the completeness axiom, which states:
(4) |
Although we leave the derivation to the appendix, we can show the following two equalities, which are immediate consequences of completeness:
(5) | |||
(6) |
We call equation 5 the interaction completeness axiom: the sum of the terms adds up to the difference between the output of at and at the baseline . This axiom lends itself to another natural interpretation of : as the interaction between features and . That is, it represents the contribution that the pair of features and together add to the output . Satisfying interaction completeness is important because it demonstrates a relationship between model output and interaction values. Without this axiom, it is unclear how to interpret the scale of interactions.
Equation 6 provides a way to interpret the self-interaction term : it is the main effect of feature after interactions with all other features have been subtracted away. We note that equation 6 also implies the following, intuitive property about the main effect: if for all , or in the degenerate case where is the only feature, we have . We call this the self completeness axiom. Satisfying self-completeness is important because it provides a guarantee that the main effect of feature equals its feature attribution value if that feature interacts with no other features.
Integrated Gradients satisfies several other common-sense axioms such as sensitivity and linearity. We discuss generalizations of these axioms to interaction values (interaction sensitivity and interaction linearity) and demonstrate the Integrated Hessians satisfies said axioms in the appendix. We do observe one additional property here, which we call interaction symmetry: that for any , we have
. It is straightforward to show that existing neural networks and their activation functions have continuous second partial derivatives, which implies that Integrated Hessians satisfies interaction symmetry.
^{†}^{†}†We discuss the special case of the ReLU activation function in Section 3.In practice we compute the Integrated Hessian values by a discrete sum approximation of the integral, similar to how Integrated Gradients is computed (Sundararajan et al., 2017). The original paper uses the completeness axiom to determine convergence of attribution values. Similarly, we can use the interaction completeness axiom to determine convergence of the discrete sum approximation. We find that anywhere between 50 to 300 discrete steps suffice to approximate the double integral in most cases. We leave the exact approximation formulas and a further discussion of convergence to the appendix, as well as section 3.
Several existing feature attribution methods have pointed out the need to explain relative to a baseline value that represents a lack of information that the explanations are relative to (Sundararajan et al., 2017; Shrikumar et al., 2017; Binder et al., 2016; Lundberg and Lee, 2017). However, more recent work has pointed out that choosing a single baseline value to represent lack of information can be challenging in certain domains (Kindermans et al., 2019; Kapishnikov et al., 2019; Sundararajan and Taly, 2018; Ancona et al., 2017; Fong and Vedaldi, 2017; Sturmfels et al., 2020). As an alternative, Erion et al. (2019) proposed an extension of Integrated Gradients called Expected Gradients, which samples many baseline inputs from the training set. It is defined as:
(7) |
where the expectation is over both for the training distribution and . We can apply Expected Gradients to itself to get Expected Hessians:
(8) |
where the expectation is over , and . In our experiments, we use Integrated Hessians where there is a natural baseline: for language models (Section 5.1) we use the zero-embedding as a baseline, and for gene expression data (Section 5.3
) we use the zero vector. On our tabular data examples (Sections
5.2 and 5.4) we use Expected Hessians. We further discuss the hyper-parameters used to compute interactions in the appendix.One major limitation that has not been discussed in previous approaches to interaction detection in neural networks is related to the fact that many popular neural network architectures use the ReLU activation function, . Neural networks that use ReLU are piecewise linear and have second partial derivatives equal to zero in all places. Previous second-order approaches fail to detect any interaction in ReLU-based networks.
Fortunately, the ReLU activation function has a smooth approximation – the SoftPlus function:
(9) |
SoftPlus more closely approximates the ReLU function as increases, and has well-defined higher-order derivatives. Furthermore, Dombrowski et al. (2019) have proved that both the outputs and first-order feature attributions of a model are minimally perturbed when ReLU activations are replaced by SoftPlus activations in a trained network. Therefore, one important insight of our approach is that Integrated Hessians for a trained ReLU network can be obtained by first replacing the ReLU activations with SoftPlus activations, then calculating the interactions. We note that no re-training is necessary for this approach.
In addition to being twice differentiable and allowing us to calculate interaction values in ReLU networks, replacing ReLU with SoftPlus leads to other desirable behavior for calculating interaction values. We show that smoothing a neural network (decreasing the value of in the SoftPlus activation function) lets us accurately approximate the Integrated Hessians value with fewer gradient calls.
For a one-layer neural network with softplus_{β} non-linearity, softplus_{β}, and
input features, we can bound the number of interpolation points
needed to approximate the Integrated Hessians to a given error tolerance by .Proof of Theorem 1 is contained in the appendix. In addition to the proof for the single-layer case, in the appendix we also show empirical results that many-layered neural networks display the same property. Finally, we demonstrate the intuition behind these results. As we replace ReLU with SoftPlus, the decision surface of the network is smoothed - see Figure 2 and Dombrowski et al. (2019). We can see that the gradients tend to all have more similar direction along the path from reference to foreground sample once the network has been smoothed with SoftPlus replacement.
To understand why feature interactions can be more informative than feature attributions, we can first consider the explanations for the XOR function. This illustration also highlights a weakness of any method for detecting interactions that uses the Hessian without integrating over a path. Our data consists of two binary features, and the function learned by our model is the exclusive OR function – it has a large magnitude output when either one feature or the other is on, but a low magnitude output when both features are either on or off (see Figure 3, left).
We use the point (0,0) as a baseline. The Integrated Gradients feature attributions for the states where both features are off (0,0) and both features are on (1,1) are identical: both features get 0 attribution because (see Figure 3, right upper). When we use Integrated Hessians to get all of the pairwise interactions between features for the two samples, we see that the samples now have different explanations in terms of interactions (see Figure 3, right lower). The first datapoint (0,0), which is identical to the baseline, has all 0 interaction values. The second point (1,1), now has negative interaction between the two features and positive self-interaction. Therefore the interactions are usefully able to distinguish between (0,0), which has an output of 0 because it is identical to the baseline, and (1,1), which has an output of 0 because both features are on, which on their own should increase the model output, but in interaction with each other cancel out the positive effects and drive the model’s output back to the baseline.
This example also illustrates a problem with methods like Cui et al. (2019), which identifies global interactions between features and by measuring the average magnitude of the th element of the Hessian over all samples. Looking at Figure 3, we can see that the input gradients and input hessians have completely flattened (saturated) at all points on the data manifold, and consequently the magnitude of the th element of the Hessian over all samples will be 0. By integrating between the baseline and the samples, Integrated Hessians is capable of correctly detecting the negative interaction between the two features.
In this section, we outline four different use cases of Integrated Hessians on real-world data in order to demonstrate its broad applicability.
In the past decade, neural networks have been the go-to model for language tasks, from convolutional (Kim, 2014) to recurrent (Sundermeyer et al., 2012). More recently, attention mechanisms (Vaswani et al., 2017) and large, pre-trained transformer architectures (Peters et al., 2018; Devlin et al., 2018) have achieved state of the art performance on a wide variety of tasks.
Some previous work has suggested looking at the internal weights of the attention mechanisms in models that use attention (Ghaeini et al., 2018; Lee et al., 2017; Lin et al., 2017; Wang et al., 2016). However, more recent work has suggested that looking at attention weights may not be a reliable way to interpret models with attention layers (Serrano and Smith, 2019; Jain and Wallace, 2019; Brunner et al., 2019). To overcome this, feature attributions have been applied to text classification models to understand which words most impacted the classification (Liu and Avci, 2019; Lai et al., 2019). However, these methods do not explain how words interact with their surrounding context.
We download pre-trained weights for DistilBERT (Sanh et al., 2019) from the HuggingFace Transformers library (Wolf et al., 2019). We fine-tune the model on the Stanford Sentiment Treebank dataset (Socher et al., 2013)
in which the task is to predict whether or not a movie review has positive or negative sentiment. After 3 epochs of fine-tuning, DistilBERT achieves a validation accuracy of 0.9071 (0.9054 TPR / 0.9089 TNR).
^{‡}^{‡}‡This performance does not represent state of the art, nor is sentiment analysis representative of the full complexity of existing language tasks. However, our focus in this paper is on explanation and this task is easy to fine-tune without needing to extensively search over hyper-parameters. We leave further fine-tuning details to the appendix.In Figure 4, we show interactions generated by Integrated Hessians and attributions generated by Integrated Gradients on an example drawn from the validation set. The figure demonstrates that DistilBERT has learned intuitive interactions that would not be revealed from feature attributions alone. For example, a word like “painfully,” which might have a negative connotation on its own, has a large positive interaction with the word “funny” in the phrase “painfully funny.” We include more examples in the appendix.
In Figure 1, we demonstrate how interactions can help us understand one reason why a fine-tuned DistilBERT model outperforms a simpler model: a convolutional neural network (CNN) that gets an accuracy of 0.82 on the validation set. DistilBERT picks up on positive interactions between negation words (“not”) and negative adjectives (“bad”) that a CNN fails to fully capture.
Finally, in Figure 5
, we use interaction values to reveal saturation effects: many negative adjectives describing the same noun interact positively. Although this may seem counter-intuitive at first, it reflects the structure of language. If a phrase has only one negative adjective, it stands out as the word that makes the phrase negative. At some point, however, describing a noun with more and more negative adjectives makes any individual negative adjective less important towards classifying that phrase as negative.
Here we aggregate interactions learned from many samples in a clinical dataset and use the interactions to reveal global patterns. We examine the Cleveland heart disease dataset (Detrano et al., 1989; Das et al., 2009). After preprocessing, the dataset contains 298 patients with 13 associated features, including demographic information like age and gender and clinical measurements such as systolic blood pressure and serum cholesterol. The task is to predict whether or not a patient has coronary artery disease.
We split the data into 238 patients for training (of which 109 have coronary artery disease) and 60 for testing (of which 28 have coronary artery disease). A two-layer neural network with softplus activation achieves a test accuracy of 0.87 (0.82 TPR and 0.91 TNR) and a test AUROC of 0.92. We did not extensively search over hyper-parameters for this task, as our goal is not state of the art performance but rather model interpretation.
In Figure 6, we examine the interactions with a feature describing the number of major coronary arteries with calcium accumulation (0 to 3), as determined by cardiac cinefluoroscopy (Detrano et al., 1986). Previous research has shown that this technique is a reliable way to gauge calcium build-up in major blood vessels, and serves as a strong predictor of coronary artery disease (Detrano et al., 1986; Bartel et al., 1974; Liu et al., 2015). Our model correctly learns that more coronary arteries with evidence of calcification indicate increased risk of disease. Additionally, Integrated Hessians reveals that our model learns a negative interaction between the number of coronary arteries with calcium accumulation and female gender. This supports the well-known phenomenon of under-recognition of heart disease in women – at the same levels of cardiac risk factors, women are less likely to have clinically manifest coronary artery disease (Maas and Appelman, 2010).
In addition to including more details about training procedure and dataset preprocessing, we show additional interactions and attributions in the appendix.
In the domain of anti-cancer drug combination response prediction, plotting Integrated Hessians helps us to glean biological insights into the process we are modeling. We consider one of the largest publicly-available datasets measuring drug combination response in acute myeloid leukemia (Tyner et al., 2018). Each one of 12,362 samples consists of the measured response of a 2-drug pair tested in the cancer cells of a patient. The input features are split between features describing the drug combinations (binary labels of individual drugs and their molecular targets) and features describing the cancerous cells (RNA-seq expression levels for all genes). Following Hao et al. (2018), we first learn an embedding of biological pathways from individual gene expression levels in order to have a more interpretable model. Following Preuer et al. (2017), the drug features and pathway features are then input into a simple feed-forward neural network (see appendix for more details).
Looking at only the first-order attributions, we see that the presence or absence of the drug Venetoclax in the drug combination is the most important feature (see Figure 7, top left). We can also easily see that first-order explanations are inadequate in this case – while the presence of Venetoclax is generally predictive of a more responsive drug combination, the amount of positive response to Venetoclax is predicted to vary across samples.
Integrated Hessians gives us the insight that some of this variability can be attributed to the drug Venetoclax is combined with. We can see that the model has learned a strong negative interaction between Venetoclax and Artemisinin (see Figure 7, top right). Biological interactions are known to occur between anti-cancer drugs, and are an area of great clinical interest to to their potential therapeutic effects. Using additional data not directly available to the model, we can determine the ground truth as to which patients actually had positive and negative biological interactions between Venetoclax and Artemisinin (see appendix for details of calculation). We see that the Integrated Hessians interaction values are significantly more negative in the group with real biological negative interactions ().
Finally, we can gain insight into the variability in the interaction values between Venetoclax and Artemisinin by plotting them against the expression level of a pathway containing cancer genes (see Figure 7, bottom). We see that patients with higher expression of this pathway tend to have a more negative interaction (sub-additive response) than patients with lower expression of this pathway. Integrated Hessians helps us understand the interactions between drugs in our model, as well as what genetic factors influence this interaction.
In this section, we use a physics dataset to confirm that a model has learned global pattern that is visible in the training data. We utilize the HRTU2 dataset, curated by Lyon et al. (2016) and originally gathered by Keith et al. (2010). The task is to predict whether or not a particular signal measured from a radio telescope is a pulsar star or generated from radio frequency interference (e.g. background noise). The features include statistical descriptors of measurements made from the radio telescope. The dataset contains 16,259 examples generated through radio frequency interference and 1,639 examples that are pulsars. We leave a more complete description of the data and its features to the appendix, and refer the reader to Lyon (2016) for a more complete background on pulsar star detection.
We split the data into 14,318 training examples (1,365 are pulsars) and 3,580 testing examples (274 are pulsars). We train a two-layer, softplus neural network and achieve a held out test accuracy of 0.98 (0.86 TPR and 0.99 TNR). In Figure 8
, we examine the interaction between two key features in the dataset: kurtosis of the integrated profile, which we abbreviate as kurtosis (IP), and standard deviation of the dispersion-measure signal-to-noise ratio curve, which we abbreviate as standard deviation (DM-SNR).
The bottom of Figure 8 shows that kurtosis (IP) is a highly predictive feature, while standard deviation (DM-SNR) is less predictive. However, in the range where kurtosis (IP) is roughly between 0 and 2, standard deviation (DM-SNR) helps distinguish between a concentration of negative samples at standard deviation (DM-SNR) 40. We can verify that the model we’ve trained correctly learns this interaction. By plotting the interaction values learned by the model against the value of kurtosis (IP), we can see a peak positive interaction for points in the indicated range and with high standard deviation (DM-SNR). Interaction values show us that the model has successfully learned the expected pattern: that standard deviation (DM-SNR) has the most discriminative power when kurtosis (IP) is in the indicated range.
Although we covered the bulk of the related work in Section 1, we mention some additional related work here. Detecting interactions - when two or more variables have a combined effect not equal to their additive effect - has a long history in statistics (Southwood, 1978), economics (Balli and Sørensen, 2013)
and game theory
(Grabisch and Roubens, 1999). In the context of machine learning, many methods have been proposed to learn interactions from data, e.g. using additive models (Coull et al., 2001; Lou et al., 2013), group-lasso (Lim and Hastie, 2015), or trees (Sorokina et al., 2008). We view these approaches as tangential to our work: they aim to learn new models to detect specific, pairwise interactions; we propose a method to explain interactions in deep networks that have already been trained. Doing so is especially important in domains where neural networks achieve state of the art performance, like natural language.We are unaware of any work to explain interactions in neural networks other than those works mentioned in Section 1 (Cui et al., 2019; Tsang et al., 2017; Greenside et al., 2018). Parallel to our work, Lundberg et al. (2020) propose an extension of feature attribution methods to feature interactions for tree-based models which satisfies similar properties to the ones we propose.
In this work we propose a novel method to explain feature interactions in neural networks. The interaction values we propose have two natural interpretations: (1) as the combined effect of two features to the output of a model, and (2) as the explanation of one feature’s importance in terms of another. Our method provably satisfies common-sense axioms that previous methods do not - and unlike such previous methods, places no requirement on neural network architecture, class or data type.
Additionally, we demonstrate how to glean interactions from neural networks trained with a ReLU activation function which has no second-derivative. In accordance with recent work, we show why replacing the ReLU activation function with the softplus activation function at explanation time is both intuitive and efficient.
Finally, we perform extensive experiments to reveal the utility of our method, from understanding performance gaps between model classes to discovering patterns a model has learned on high-dimensional data. We conclude that although feature attribution methods provide valuable insight into model behavior, such methods by no means end the discussion on interpretability. Rather, they encourage further work in deeper understanding model behavior.
International application of a new probability algorithm for the diagnosis of coronary artery disease
. The American journal of cardiology 64 (5), pp. 304–310. Cited by: §D.1, §5.2.Proceedings of the IEEE International Conference on Computer Vision
, pp. 3429–3437. Cited by: §2.3.Proceedings of the IEEE conference on computer vision and pattern recognition
, pp. 8730–8738. Cited by: §1.Ethics of artificial intelligence in radiology: summary of the joint European and North American multisociety statement
. Radiology 293 (2), pp. 436–440. Cited by: §1.Interactive visualization and manipulation of attention-based neural machine translation
. In Proceedings of the 2017 Conference on Empirical Methods in Natural Language Processing: System Demonstrations, pp. 121–126. Cited by: §5.1.In this section, we expand the derivation of the Integrated Hessians formula, further discuss relevant axioms that interaction values should satisfy, and how we compute Integrated Hessians in practice.
Here we derive the formula for Integrated Hessians from it’s definition: . We start by expanding out using the definition of Integrated Gradients:
(10) |
We consider the function , and we first assume that
(11) | ||||
(12) | ||||
(13) | ||||
(14) |
where we’ve assume that the function satisfies the conditions for the Leibniz Integral Rule (so that integration and differentiation are interchangeable). These conditions require that the derivative of , and its second derivative function are continuous over in the integration region, and that the bounds of integration are constant with respect to . It’s easy to see that the bounds of integration are constant with respect to . It is also straightforward to see that common neural network activation functions - for example, , , or where
is the cumulative distribution function of the normal distribution - have continuous first and second partial derivatives, which implies compositions of these functions have continuous first and second partial derivatives as well. Although this is not the case with the ReLU activation function, we discuss replacing it with softplus in the maintext.
We can proceed by plugging equation 14 into the original definition of :
(15) | ||||
(16) | ||||
(17) |
where all we’ve done is re-arrange terms.
Deriving proceeds similarly:
(18) | ||||
(19) | ||||
(20) |
using the chain rule. After similar re-arrangement, we can arrive at:
(21) |
The derivation for Expected Hessians simply observes that the integrals can be viewed as integrating over the product of two uniform distributions
, e.g.(22) | ||||
(23) |
where represents the data distribution and .
The proof that Integrated Hessians satisfies the axioms interaction completeness and self completeness are straightforward to show, but we include the step-by-step derivation here. First, we note that for any because . Then, by completeness of Integrated Gradients, we have that:
(24) |
Re-arrangement gives us the self completeness axioms:
(25) |
Since Integrated Gradients satisfies completeness, we have:
(26) |
Making the appropriate substitution from equation 24 shows the interaction completeness axiom:
(27) |
Integrated Gradients satisfies an axiom called sensitivity, which can be phrased as follows. Given an input and a baseline , if for all except where and if , then . Specifically, by completeness we know that . Intuitively, this is saying that if only one feature differs between the baseline and the input and changing that feature changes the output, then the amount the output changes should be equal to the importance of that feature.
We can extend this axiom to the interaction case by considering the case when two features differ from the baseline. We call this axiom interaction sensitivity, and can be described as follows. If an input and a baseline are equal everywhere except and , and if , then: and for all . Intuitively, this says that if the only features that differ from the baseline are and , then the difference in the output must be solely attributable to the main effects of and plus the interaction between them. This axiom holds simply by applying interaction completeness and observing that if or .
The implementation invariance axiom, described in the original paper, states the following. For two models and such that , then for all features and all points regardless of how and are implemented. Although it seems trivial, this axiom does not necessarily hold for attribution methods that use the implementation or structure of the network in order to generate attributions. Critically, this axiom also does not hold for the interaction method proposed by Tsang et al. (2017), which looks at the first layer of a feed forward neural network. Two networks may represent exactly the same function but differ greatly in their first layer.
This axiom is trivially seen to hold for Integrated Hessians since it holds for Integrated Gradients. However, this axiom is desirable because without it, it may mean that attributions/interactions are encoding information about unimportant aspects of model structure rather than the actual decision surface of the model.
Integrated Gradients satisfies an axiom called linearity, which can be described as follows. Given two networks and , consider the output of the weighted ensemble of the two networks . Then the attribution of the weighted ensemble equals the weighted sum of attributions for all features and samples . This axiom is desirable because it preserves linearity within a network, and allows easy computation of attributions for network ensembles.
We can generalize linearity to interactions using the interaction linearity axiom: for any and all points . Given that is composition of linear functions , in terms of the parameterized networks and , it itself is a linear function of the networks and therefore Integrated Hessians satisfies interaction linearity.
We say that two features and are symmetric with respect to if swapping them does not change the output of anywhere. That is, ). The original paper shows that Integrated Gradients is symmetry-preserving, that is, if and are symmetric with respect to , and if and for some input and baseline , then . We can make the appropriate generalization to interaction values: if the same conditions as above hold, then for any feature . This axiom holds since, again and are symmetry-preserving. This axiom is desirable because it says that if two features are functionally equivalent to a model, then they must interact the same way with respect to that model.
To compute Integrated Gradients in practice, Sundararajan et al. (2017) introduce the following discrete sum approximation:
(28) |
where is the number of points used to approximate the integral. To compute Integrated Hessians, we introduce a similar discrete sum approximation:
(29) |
Typically, it is easiest to compute this quantity when and the number of samples drawn is thus a perfect square - however, when a non-square number of samples is desired we can generate a number of sample points from the product distribution of two uniform distributions such that the number is the largest perfect square above the desired number of samples, and index the sorted samples appropriately to get the desired number. The above formula omits the first-order term in but it can be computed using the same principle.
Expected Hessians has a similar, if slightly easier form:
(30) |
where is the th sample from the product distribution of two uniform distributions. We find that in general that less than 300 samples are required for any given problem to approximately satisfy interaction completeness. For most problems, a number far less than 300 suffices (e.g. around 50) although this is model and data dependent: larger models and higher-dimensional data generally require more samples than smaller models and lower-dimensional data.
For a one-layer neural network with softplus_{β} non-linearity, softplus_{β}, and input features, we can bound the number of interpolation points needed to approximate the Integrated Hessians to a given error tolerance by .
As pointed out in Sundararajan et al. (2017) and Sturmfels et al. (2020), completeness can be used to assess the convergence of the approximation. We first show that decreasing improves convergence for Integrated Gradients. In order to accurately calculate the Integrated Gradients value for a feature , we want to be able to bound the error between the approximate computation and exact value. The exact value is given as:
(31) |
To simplify notation, we can define the partial derivative that we want to integrate over in the th coordinate as :
(32) |
Since the single layer neural network with softplus activation is monotonic along the path, the error in the approximate integral can be lower bounded by the left Riemann sum :
(33) |
and can likewise be upper-bounded by the right Riemann sum :
(34) |
We can then bound the magnitude of the error between the Riemann sum and the true integral by the difference between the right and left sums:
(35) |
By the mean value theorem, we know that for some and , . Therefore:
(36) |
Rewriting in terms of the original function, we have:
(37) |
We can then consider the gradient vector of :
(38) |
where each coordinate is maximized at the zeros input vector, and takes a maximum value of . We can therefore bound the error in convergence as:
(39) |
Ignoring the dependency on path length and the magnitude of the weights of the neural network, we see that:
(40) |
This demonstrates that the number of interpolation points necessary to achieve a set error rate decreases as the activation function is smoothed (the value of is decreased). While this proof only bounds the error in the approximation of the integral for a single feature, we get the error in completeness by multiplying by an additional factor of features.
We can extend the same proof to the Integrated Hessians values. We first consider the error for estimating off diagonal terms
. The true value we are trying to approximate is given as:(41) |
For the sake of notation, we can say . Assuming that we are integrating from the all-zeros baseline as suggested in Sundararajan et al. (2017), since is monotonic on either interval from the 0 baseline, we can again bound the error in the double integral by the magnitude of the difference in the left and right Riemann sums…
(42) |
(43) |
We can then again use monotonicity over the interval to say that , which gives us:
(44) |
By the mean value theorem, we know that for some , . Substituting gives us:
(45) |
We can then consider the elements of the gradient vector:
(46) |
For the univariate version of each coordinate, we can maximize the function by taking the derivative with respect to and setting it equal to 0.
(47) |
We can see that this equation holds only when = 0, and can solve it by finding the roots of this quadratic equation, which occur when . When we plug that back in, we find the absolute value of the function in that coordinate takes a maximum value of . Therefore, for a given set of fixed weights of the network, we can see that the coordinate-wise maximum magnitude of , and that the number of interpolation points necessary to reach a desired level of error in approximating the double integral decreases as is decreased. Again ignoring the fixed weights and path length, the number of interpolation points necessary is bounded by:
(48) |
For the terms (main effect terms) the error will have another additive factor of . This is because there is an added term to the main effect equal to:
(49) |
When we bound the error in this approximate integral by the difference between the double left sum and double right sum, we get that:
(50) |
Following the exact same steps as in Equation 35 through Equation 39, we can then show the bound on the error of the on-diagonal terms will have an additional term that is . Due to the axiom of interaction completeness, the error bound of the entire convergence can be obtained by adding up all of the individual terms, incurring another factor of in the bound.
∎
In addition to theoretically analyzing the effects of smoothing the activation functions of a single-layer neural network on the convergence of the approximately calculation of Integrated Gradients and Integrated Hessians, we also wanted to empirically analyze the same phenomenon in deeper networks. We assessed this by creating two networks: one with 5 hidden layers of 50 nodes, and a second with 10 hidden layers of 50 nodes. These networks were then randomly initialized using the Xavier Uniform intialization scheme (Glorot and Bengio, 2010). We created 10 samples to explain, each with 100 features drawn at random from the standard normal distribution. To evaluate the convergence of our approximate Integrated Hessians values, we plot the interaction completeness error (the difference between the sum of the Integrated Hessians value and the difference of the function output at a sample and the function output at the zeros baseline). We plot the completeness error as a fraction of the magnitude of the function output. As we decrease the value of , we smooth the activations, and we can see that the number of interpolations required to converge decreases (see Figure 9 and Figure 10). We note that the randomly initialized weights of each network are held constant and the only thing changed is the value of in the activation function.
As mentioned in the main text, we download pre-trained weights for DistilBERT, a pre-trained language model introduced in Sanh et al. (2019), from the HuggingFace Transformers library (Wolf et al., 2019). We fine-tune the model on the Stanford Sentiment Treebank dataset introduced by Socher et al. (2013). We fine-tune for 3 epochs using a batch size of 32 and a learning rate of 0.00003. We use a a max sequence length of 128 tokens, and the Adam algorithm for optimization (Kingma and Ba, 2014). We tokenize using the HuggingFace build-in tokenizer, which does so in an uncased fashion. We did not search for these hyper-parameters - rather, they were the defaults presented for fine-tuning in the HuggingFace repository. We find that they work adequately for our purposes, and didn’t attempt to search through more hyper-parameters.
The convolutional neural network that we compare to in the main text is one we train from scratch on the same dataset. We randomly initialize 32-dimensional embeddings and use a max sequence length of 52. First, we apply dropout to the embeddings with dropout rate 0.5. The network itself is composed of 1D convolutions with 32 filters of size 3 and 32 filters of size 8. Each filter size is applied separately to the embedding layer, after which max pooling with a stride of 2 is applied and then the output of both convolutions is concatenated together and fed through a dropout layer with a dropout rate of 0.5 during training. A hidden layer of size 50 follows the dropout, finally followed by a linear layer generating a scalar prediction that the sigmoid function is applied to.
We train with a batch size of 128 for 2 epochs and use a learning rate of 0.001. We optimize using the Adam algorithm with the default hyper-parameters (Kingma and Ba, 2014). Since this model was not pre-trained on a large language corpus and lacks the expressive power of a deep transformer, it is unable to capture patterns like negation that a fine-tuned DistilBERT does.
In order to generate attributions and interactions, we use Integrated Gradients and Integrated Hessians with the zero-embedding - the embedding produced by the all zeros vector, which normally encodes the padding token - as a baseline. Because embedding layers are not differentiable, we generate attributions and interactions to the word embeddings and then sum over the embedding dimension to get word-level attributions and interactions, as done in
Sundararajan et al. (2017). When computing attributions and interactions, we use 256 background samples. Because DistilBERT uses the GeLU activation function (Ramachandran et al., 2017), which has continuous first and second partial derivatives, there is no need to use the softplus replacement. When we plot interactions, we avoid plotting the main-effect terms in order to better visualize the interactions between words.Here we include additional examples of interactions learned on the sentiment analysis task. First we expand upon the idea of saturation in natural language, displayed in Figure 11. We display interactions learned by a fine-tuned DistilBERT on the following sentences: “a bad movie” (negative with 0.9981 confidence), “a bad, terrible movie” (negative with 0.9983 confidence), “a bad, terrible, awful movie” (negative with 0.9984 confidence) and “a bad, terrible, awful, horrible movie” (negative with 0.9984 confidence). The confidence of the network saturates: a network output only gets so negative before it begins to flatten. However, the number of negative adjectives in the sentence increases. This means a sensible network would spread the same amount of credit (because the attributions sum to the saturated output) across a larger number of negative words, which is exactly what DistilBERT does. However, this means that each word gets less negative attribution than it would if it was on its own. Thus, the negative words have positive interaction effects, which is exactly what we see from the figure.
In Figure 12, we give another example of the full interaction matrix on a sentence from the validation set. In Figure 13, we give an example of how explaining the importance of a particular word can understand whether that word is important because of its main effect or because of its surrounding context. We show additional examples from the validation set in Figures 14, 15, 16, 17, 18. We note that while some interactions make intuitive sense to humans (“better suited” being negative or “good script” being positive), there are many other examples of interactions that are less intuitive. These interactions may indicate that the Stanford Sentiment Treebank dataset does not fully capture the expressive power of language (e.g. it doesn’t have enough samples to fully represent all of the possible interactions in language), or it may indicate that the model has learned higher order effects that cannot be explained by pairwise interactions alone.
As mentioned in the main text, the dataset we use has 13 associated features. The list of features, which we reproduce here, is from (Detrano et al., 1989), the original paper introducing the dataset:
Age of patient (mean: 54.5 years standard deviation: 9.0)
Gender (202 male, 96 female)
Resting systolic blood pressure (131.6 mm Hg 17.7)
Cholesterol (246.9 mg/dl 51.9)
Whether or not a patient’s fasting blood sugar was above 120 mg/dl (44 yes)
Maximum heart rate achieved exercise (149.5 bpm 23.0)
Whether or not a patient has exercise-induced angina (98 yes)
Excercise-induced ST-segment depression (1.05 mm 1.16)
Number of major vessels appearing to contain calcium as revealed by cinefluoroscopy (175 patients with 0, 65 with 1, 38 with 2, 20 with 3)
Type of pain a patient experienced if any (49 experienced typical anginal pain, 84 experienced atypical anginal pain, 23 experienced non-anginal pain and 142 patients experienced no chest pain)
Slope of peak exercise ST segment (21 patients had upsloping segments, 138 had flat segments, 139 had downsloping segments)
Whether or not a patient had thallium defects as revealed by scintigraphy (2 patients with no information available, 18 with fixed defects, 115 with reversible defects and 163 with no defects)
Classification of resting electrocardiogram (146 with normal resting ecg, 148 with an ST-T wave abnormality, and 4 with probable or definite left centricular hypertrophy)
On this task, we use a two layer neural network with 128 and 64 hidden units, respectively, with softplus activation after each layer. We optimize using gradient descent (processing the entire training set in a single batch) with an initial learning rate of 0.1 that decays exponentially with a rate 0.99 after each epoch. We use nesterov momentum with
(Sutskever et al., 2013). After training for 200 epochs, the network achieves a held-out accuracy of 0.8667 with 0.8214 true positive rate and 0.9062 true negative rate. We note that the hyper-parameters chosen here were not carefully tuned on a validation set - they were simply those that seemed converge to a reasonable performance on the training set. Our focus is not state of the art prediction or comparing model performances, but rather interpreting patterns a reasonable model learns.To generate attributions and interactions for this dataset, we use Expected Gradients and Expected Hessians with the training set forming the background distribution. We use 200 samples to compute both attributions and interactions, although we note this number is probably larger than necessary but was easy to compute due to the small size of the dataset.
Figure 19 shows which features were most important towards predicting heart disease aggregated over the entire dataset, as well as the trend of importance values. Interestingly, the model learns some strangely unintuitive trends: if a patient doesn’t experience chest pain, they are more likely to have heart disease than if they experience anginal chest pain! This could indicate problems with the way certain features were encoded, or perhaps dataset bias. Figure 20 demonstrates an interaction learned by the network between maximum heart rate achieved and gender, and Figure 21 demonstrates an interaction between exercise-induced ST-segment depression and number of major vessels appearing to contain calcium.
As mentioned in the main text, our dataset consists of 12,362 samples (available from http://www.vizome.org/). Each sample consists of the measured response of a 2-drug pair tested in the cancer cells of a patient (Tyner et al., 2018). The 2-drug combination was described with both a drug identity indicator and a drug target indicator. For each sample, the drug identity indicator is a vector where each element represents one of the 46 anti-cancer drugs present in the data, and takes a value of if the corresponding drug is not present in the combination and a value of if the corresponding drug is present in the combination. Therefore, for each sample, will have 44 elements equal to 0 and 2 elements equal to 1. This is the most compact possible representation for the 2-drug combinations. The drug target indicator is a vector where each element represents one of the 112 unique molecular targets of the anti-cancer drugs in the dataset. Each entry in this vector is equal to if neither drug targets the given molecule, equal to if one of the drugs in the combination targets the given molecule, and equal to if both drugs target the molecule. The targets were compiled using the information available on DrugBank (Wishart et al., 2018). The ex vivo samples of each patient’s cancer was described using gene expression levels for each gene in the transcriptome, as measured by RNA-seq, . Before training, the data was split into two parts – 80% of samples were used for model training, and an additional 20% were used as a held-out validation set to determine when the model had been trained for a sufficient number of epochs.
The cancerous cells in each sample were described using RNA-seq data – measurements of the expression level of each gene in the sample. We describe here the preprocessing steps used to remove batch effects while preserving biological signal. We first converted raw transcript counts to fragments per kilobase of exon model per million mapped reads (FPKM), a measure that is known to better reflect the molar amount of each transcript in the original sample than raw counts. FPKM accounts for this by normalizing the counts for different genes according to the length of transcript, as well as for the total number of reads included in the sample (Mortazavi et al., 2008). The equation for FPKM is given as:
(51) |
where is the vector containing the number of raw counts for a particular transcript across all samples, is the effective length of that transcript, and represents the total number of counts. After converting raw counts to FPKM, we opt to consider only the protein-coding part of the transcriptome by removing all non-protein-coding transcripts from the dataset. Protein-coding transcripts were determined according to the list provided by the HUGO Gene Nomenclature Committee (https://www.genenames.org/download/statistics-and-files/). In addition to non-protein-coding transcripts, we also removed any transcript that was not observed in of samples. Transcripts are then
transformed and made 0-mean unit variance. Finally, the ComBat tool (a robust empirical Bayes regression implemented as part of the sva R package) was used to correct for batch effects
(Leek and Storey, 2007).To model the data, we combined the successful approaches of Preuer et al. (2017) and Hao et al. (2018). Our network architecture is a simple feed-forward network (Figure 22), as in Preuer et al. (2017), where there were two hidden layers of 500 and 250 nodes respectively, both with Tanh activation. In order to improve performance and interpretability, we followed Hao et al. (2018) in learning a pathway-level embedding of the gene expression data. The RNA-seq data, , was sparsely connected to a layer of nodes, where each node corresponded to a single pathway from KEGG, BioCarta, or Reactome (Kanehisa and others, 2002; Nishimura, 2001; Croft et al., 2014)
. We made this embedding non-linear by following the sparse connections with a Tanh activation function. The non-linear pathway embeddings were then concatenated to the drug identity indicators and the drug target indicators, and these served as inputs to the densely connected layers.We trained the network to optimize a mean squared error loss function, and used the Adam optimizer in PyTorch with default hyperparameters and a learning rate equal to
(Kingma and Ba, 2014). We stopped the training when mean squared error on the held-out validation set failed to improve over 10 epochs, and found that the network reached an optimum at epochs. For the sake of easier calculation and more human-intuitive attribution, we attribute the model’s output to the layer with the pathway embedding and drug inputs, rather than to the raw RNA-seq features and drug inputs (see