Interpretability or lack thereof can limit the adoption of machine learning methods in decision-critical —e.g., medical or legal— domains. Ensuring interpretability would also contribute to other pertinent criteria such as fairness, privacy, or causality (Doshi-Velez2017Roadmap). Our focus in this paper is on complex self-explaining models where interpretability is built-in architecturally and enforced through regularization. Such models should satisfy three desiderata for interpretability: explicitness, faithfulness, and stability where, for example, stability ensures that similar inputs yield similar explanations. Most post-hoc interpretability frameworks are not stable in this sense as shown in detail in Section 5.4. High modeling capacity is often necessary for competitive performance. For this reason, recent work on interpretability has focused on producing a posteriori
explanations for performance-driven deep learning approaches. The interpretations are derived locally, around each example, on the basis of limited access to the inner workings of the model such as gradients or reverse propagation(Bach2015Pixel-wise; selvaraju2016grad), or through oracle queries to estimate simpler models that capture the local input-output behavior (Ribeiro2016Why; AlvarezMelis2017Causal; Lundberg2017Unified). Known challenges include the definition of locality (e.g., for structured data), identifiability (Li2018Deep) and computational cost (with some of these methods requiring a full-fledged optimization subroutine (yosinski2015understanding)). However, point-wise interpretations generally do not compare explanations obtained for nearby inputs, leading to unstable and often contradicting explanations (AlvarezMelis2018Robustness). A posteriori explanations may be the only option for already-trained models. Otherwise, we would ideally design the models from the start to provide human-interpretable explanations of their predictions. In this work, we build highly complex interpretable models bottom up, maintaining the desirable characteristics of simple linear models in terms of features and coefficients, without limiting performance. For example, to ensure stability (and, therefore, interpretability), coefficients in our model vary slowly around each input, keeping it effectively a linear model, albeit locally. In other words, our model operates as a simple interpretable model locally (allowing for point-wise interpretation) but not globally (which would entail sacrificing capacity). We achieve this with a regularization scheme that ensures our model not only looks like a linear model, but (locally) behaves like one. Our main contributions in this work are:
A rich class of interpretable models where the explanations are intrinsic to the model
Three desiderata for explanations together with an optimization procedure that enforces them
Quantitative metrics to empirically evaluate whether models adhere to these three principles, and showing the advantage of the proposed self-explaining models under these metrics
2 Interpretability: linear and beyond
To motivate our approach, we start with a simple linear regression model and successively generalize it towards the class of self-explaining models. For input features, and associated parameters the linear regression model is given by . This model is arguably interpretable for three specific reasons: i) input features (’s) are clearly anchored with the available observations, e.g., arising from empirical measurements; ii) each parameter provides a quantitative positive/negative contribution of the corresponding feature to the predicted value; and iii) the aggregation of feature specific terms is additive without conflating feature-by-feature interpretation of impact. We progressively generalize the model in the following subsections and discuss how this mechanism of interpretation is preserved.
2.1 Generalized coefficients
We can substantially enrich the linear model while keeping its overall structure if we permit the coefficients themselves to depend on the input . Specifically, we define (offset function omitted) , and choose from a complex model class , realized for example via deep neural networks. Without further constraints, the model is nearly as powerful as—and surely no more interpretable than—any deep neural network. However, in order to maintain interpretability, at least locally, we must ensure that for close inputs and in , and should not differ significantly. More precisely, we can, for example, regularize the model in such a manner that for all in a neighborhood of . In other words, the model acts locally, around each
, as a linear model with a vector of stable coefficients. The individual values act as and are interpretable as coefficients of a linear model with respect to the final prediction, but adapt dynamically to the input, albeit varying slower than . We will discuss specific regularizers so as to keep this interpretation in Section 3.
2.2 Beyond raw features – feature basis
Typical interpretable models tend to consider each variable (one feature or one pixel) as the fundamental unit which explanations consist of. However, pixels are rarely the basic units used in human image understanding; instead, we would rely on strokes and other higher order features. We refer to these more general features as interpretable basis concepts and use them in place of raw inputs in our models. Formally, we consider functions , where is some space of interpretable atoms. Naturally, should be small so as to keep the explanations easily digestible. Alternatives for include: (i) subset aggregates of the input (e.g., with for a boolean mask matrix ), (ii) predefined, pre-grounded feature extractors designed with expert knowledge (e.g., filters for image processing), (iii) prototype based concepts, e.g. for some (Li2018Deep), or learnt representations with specific constraints to ensure grounding (alshedivat2017contextual). Naturally, we can let to recover raw-input explanations if desired. The generalized model is now:
Since each remains a scalar, it can still be interpreted as the degree to which a particular feature is present. In turn, with constraints similar to those discussed above remains interpretable as a local coefficient. Note that the notion of locality must now take into account how the concepts rather than inputs vary since the model is interpreted as being linear in the concepts rather than .
2.3 Further generalization
The final generalization we propose considers how the elements are aggregated. We can achieve a more flexible class of functions by replacing the sum in (1) by a more general aggregation function . Naturally, in order for this function to preserve the desired interpretation of in relation to , it should: i) be permutation invariant, so as to eliminate higher order uninterpretable effects caused by the relative position of the arguments, (ii) isolate the effect of individual ’s in the output (e.g., avoiding multiplicative interactions between them), and (iii) preserve the sign and relative magnitude of the impact of the relevance values . We formalize these intuitive desiderata in the next section. Note that we can naturally extend the framework presented in this section to multivariate functions with range in , by considering , so that is a vector corresponding to the relevance of concept with respect to each of the output dimensions. For classification, however, we are mainly interested in the explanation for the predicted class, i.e., for .
3 Self-explaining models
We now formalize the class of models obtained through subsequent generalization of the simple linear predictor in the previous section. We begin by discussing the properties we wish to impose on in order for it to act as coefficients of a linear model on the basis concepts . The intuitive notion of robustness discussed in Section 2.2 suggests using a condition bounding with for some constant . Note that this resembles, but is not exactly equivalent to, Lipschitz continuity, since it bounds ’s variation with respect to a different—and indirect—measure of change, provided by the geometry induced implicitly by on . Specifically,
We say that a function is difference-bounded by if there exists such that for every .
Imposing such a global condition might be undesirable in practice. The data arising in applications often lies on low dimensional manifolds of irregular shape, so a uniform bound might be too restrictive. Furthermore, we specifically want to be consistent for neighboring inputs. Thus, we seek instead a local notion of stability. Analogous to the local Lipschitz condition, we propose a pointwise, neighborhood-based version of Definition 3.1:
is locally difference bounded by if for every there exist and such that implies .
Note that, in contrast to Definition 3.1, this second notion of stability allows (and ) to depend on , that is, the “Lipschitz” constant can vary throughout the space. With this, we are ready to define the class of functions which form the basis of our approach.
Let and be the input and output spaces. We say that is a self-explaining prediction model if it has the form
is monotone and completely additively separable
For every , satisfies
is locally difference bounded by
is an interpretable representation of
In that case, for a given input , we define the explanation of to be the set of basis concepts and their influence scores.
Besides the linear predictors that provided a starting point in Section 2, well-known families such as generalized linear models and Nearest-neighbor classifiers are contained in this class of functions. However, the true power of models described in Definition 3.3 comes when (and potentially ) are realized by architectures with large modeling capacity, such as deep neural networks. When is realized with a neural network, we refer to as a self-explaining neural network (Senn). If depends on its arguments in a continuous way, can be trained end-to-end with back-propagation. Since our aim is maintaining model richness even in the case where the are chosen to be trivial input feature indicators, we rely predominantly on for modeling capacity, realizing it with larger, higher-capacity architectures. It remains to discuss how the properties (P1)-(P5) in Definition 3.3 are to be enforced. The first two depend entirely on the choice of aggregating function . Besides trivial addition, other options include affine functions where ’s values are constrained to be positive. On the other hand, the last two conditions in Definition 3.3 are application-dependent: what and how many basis concepts are adequate should be informed by the problem and goal at hand. The only condition in Definition 3.3 that warrants further discussion is (P3): the stability of with respect to . For this, let us consider what would look like if the ’s were indeed (constant) parameters. Looking at as a function of , i.e. , let
. Using the chain rule we get, where denotes the Jacobian of h (with respect to ). At a given point , we want to behave as the derivative of with respect to the concept vector around , i.e., we seek . Since this is hard to enforce directly, we can instead plug this ansatz in to obtain a proxy condition:
All three terms in can be computed, and when using differentiable architectures and , we obtain gradients with respect to (3) through automatic differentiation and thus use it as a regularization term in the optimization objective. With this, we obtain a gradient-regularized objective of the form , where the first term is a classification loss and a parameter that trades off performance against stability—and therefore, interpretability— of .
4 Learning interpretable basis concepts
Raw input features are the natural basis for interpretability when the input is low-dimensional and individual features are meaningful. For high-dimensional inputs, raw features (such as individual pixels in images) often lead to noisy explanations that are sensitive to imperceptible artifacts in the data, tend to be hard to analyze coherently and not robust to simple transformations such as constant shifts (kindermans2017unreliability). Furthermore, the lack of robustness of methods that relies on raw inputs is amplified for high-dimensional inputs, as shown in the next section. To avoid some of these shortcomings, we can instead operate on higher level features. In the context of images, we might be interested in the effect of textures or shapes—rather than single pixels—on predictions. For example, in medical image processing higher-level visual aspects such as tissue ruggedness, irregularity or elongation are strong predictors of cancerous tumors, and are among the first aspects that doctors look for when diagnosing, so they are natural “units” of explanation. Ideally, these basis concepts would be informed by expert knowledge, such as the doctor-provided features mentioned above. However, in cases where such prior knowledge is not available, the basis concepts could be learnt instead. Interpretable concept learning is a challenging task in its own right (kim2018TCAV), and as other aspects of interpretability, remains ill-defined. We posit that a reasonable minimal set of desiderata for interpretable concepts is: (i) Fidelity: the representation of in terms of concepts should preserve relevant information, (ii) Diversity: inputs should be representable with few non-overlapping concepts, and (iii) Grounding: concepts should have an immediate human-understandable interpretation. Here, we enforce these conditions upon the concepts learnt by Senn by: (i) training
as an autoencoder, (ii) enforcing diversity through sparsity and (iii) providing interpretation on the concepts by prototyping (e.g., by providing a small set of training examples that maximally activate each concept). Learning ofis done end-to-end in conjunction with the rest of the model. If we denote by the decoder associated with , and the reconstruction of , we use an additional penalty on the objective, yielding the loss:
Achieving (iii), i.e., the grounding of , is more subjective. A simple approach consists of representing each concept by the elements in a sample of data that maximize their value, that is, we can represent concept through the set where is small. Similarly, one could construct (by optimizing ) synthetic inputs that maximally activate each concept (and do not activate others), i.e., . Alternatively, when available, one might want to represent concepts via their learnt weights—e.g., by looking at the filters associated with each concept in a CNN-based . In our experiments, we use the first of these approaches (i.e., using maximally activated prototypes), leaving exploration of the other two for future work.
The notion of interpretability is notorious for eluding easy quantification (Doshi-Velez2017Roadmap). Here, however, the motivation in Section 2 produced a set of desiderata according to which we can validate our models. Throughout this section, we base the evaluation on four main criteria. First and foremost, for all datasets we investigate whether our models perform on par with their non-modular, non interpretable counterparts. After establishing that this is indeed the case, we focus our evaluation on the interpretability of our approach, in terms of three criteria:
Explicitness/Intelligibility: Are the explanations immediate and understandable?
Faithfulness: Are relevance scores indicative of "true" importance?
Stability: How consistent are the explanations for similar/neighboring examples?
Below, we address these criteria one at a time, proposing qualitative assessment of (i) and quantitative metrics for evaluating (ii) and (iii).
5.1 Dataset and Methods
We carry out quantitative evaluation on three classification settings: (i) Mnist digit recognition, (ii) benchmark UCI datasets (lichman2013uci) and (iii) Propublica’s Compas Recidivism Risk Score datasets.111github.com/propublica/compas-analysis/ In addition, we provide some qualitative results on Cifar10 (krizhevsky2009learning) in the supplement (§A.5). The Compas data consists of demographic features labeled with criminal recidivism (“relapse”) risk scores produced by a private company’s proprietary algorithm, currently used in the Criminal Justice System to aid in bail granting decisions. Propublica’s study showing racial-biased scores sparked a flurry of interest in the Compas algorithm both in the media and in the fairness in machine learning community (zafar2017parity; grgic2018beyond). Details on data pre-processing for all datasets are provided in the supplement.
We compare our approach against various interpretability frameworks: three popular “black-box” methods; Lime (Ribeiro2016Why), kernel Shapley values (Shap, (Lundberg2017Unified)) and perturbation-based occlusion sensitivity (Occlusion) (zeiler2014visualizing); and various gradient and saliency based methods: gradientinput (Grad*Input) as proposed by shrikumar2017learning, saliency maps (Saliency) (simonyan2013deep), Integrated Gradients (Int.Grad) (sundararajan2017axiomatic) and ()-Layerwise Relevance Propagation (e-Lrp) (Bach2015Pixel-wise).
5.2 Explicitness/Intelligibility: How understandable are Senn’s explanations?
When taking to be the identity, the explanations provided by our method take the same surface level (i.e, heat maps on inputs) as those of common saliency and gradient-based methods, but differ substantially when using concepts as a unit of explanations (i.e., is learnt). In Figure 2 we contrast these approaches in the context of digit classification interpretability. To highlight the difference, we use only a handful of concepts, forcing the model encode digits into meta-types sharing higher level information. Naturally, it is necessary to describe each concept to understand what it encodes, as we do here through a grid of the most representative prototypes (as discussed in §4), shown here in Fig. 2, right. While pixel-based methods provide more granular information, Senn’s explanation is (by construction) more parsimonious. For both of these digits, Concept 3 had a strong positive influence towards the prediction. Indeed, that concept seems to be associated with diagonal strokes (predominantly occurring in 7’s), which both of these inputs share. However, for the second prediction there is another relevant concept, C4, which is characterized largely by stylized 2’s, a concept that in contrast has negative influence towards the top row’s prediction.
5.3 Faithfulness: Are “relevant” features truly relevant?
Assessing the correctness of estimated feature relevances requires a reference “true” influence to compare against. Since this is rarely available, a common approach to measuring the faithfulness
of relevance scores with respect to the model they are explaining relies on a proxy notion of importance: observing the effect of removing features on the model’s prediction. For example, for a probabilistic classification model, we can obscure or remove features, measure the drop in probability of the predicted class, and compare against the interpreter’s own prediction of relevance(Samek2017Evaluating; arras2017relevant). Here, we further compute the correlations of these probability drops and the relevance scores on various points, and show the aggregate statistics in Figure 3 (left) for Lime, Shap and Senn (without learnt concepts) on various UCI datasets. We note that this evaluation naturally extends to the case where the concepts are learnt (Fig. 3, right). The additive structure of our model allows for removal of features —regardless of their form, i.e., inputs or concepts—simply by setting their coefficients to zero. Indeed, while feature removal is not always meaningful for other predictions models (i.e., one must replace pixels with black or averaged values to simulate removal in a CNN), the definition of our model allows for targeted removal of features, rendering an evaluation based on it more reliable.
5.4 Stability: How coherent are explanations for similar inputs?
As argued throughout this work, a crucial property that interpretability methods should satisfy to generate meaningful explanations is that of robustness with respect to local perturbations of the input. Figure 4
shows that this is not the case for popular interpretability methods; even adding minimal white noise to the input introduces visible changes in the explanations. But to formally quantify this phenomenon, we appeal again to Definition3.2 as we seek a worst-case (adversarial) notion of robustness. Thus, we can quantify the stability of an explanation generation model , by estimating, for a given input and neighborhood size :
where for Senn we have , and for raw-input methods we replace with , turning (5) into an estimation of the Lipschitz constant (in the usual sense) of . We can directly estimate this quantity for Senn since the explanation generation is end-to-end differentiable with respect to concepts, and thus we can rely on direct automatic differentiation and back-propagation to optimize for the maximizing argument , as often done for computing adversarial examples for neural networks (Goodfellow2015Explaining). Computing (5) for post-hoc explanation frameworks is, however, much more challenging, since they are not end-to-end differentiable. Thus, we need to rely on black-box optimization instead of gradient ascent. Furthermore, evaluation of for methods like Lime and Shap is expensive (as it involves model estimation for each query), so we need to do so with a restricted evaluation budget. In our experiments, we rely on Bayesian Optimization (snoek2012practical). The continuous notion of local stability (5) might not be suitable for discrete inputs or settings where adversarial perturbations are overly restrictive (e.g., when the true data manifold has regions of flatness in some dimensions). In such cases, we can instead define a (weaker) sample-based notion of stability. For any in a finite sample , let its -neighborhood within be . Then, we consider an alternative version of (5) with in lieu of
. Unlike the former, its computation is trivial since it involves a finite sample. We first use this evaluation metric to validate the usefulness of the proposed gradient regularization approach for enforcing explanation robustness. The results on theCompas and Breast-Cancer datasets (Fig. 5 A/B), show that there is a natural tradeoff between stability and prediction accuracy through the choice of regularization parameter . Somewhat surprisingly, we often observe an boost in performance brought by the gradient penalty, likely caused by the additional regularization it imposes on the prediction model. We observe a similar pattern on Mnist (Figure 8, in the Appendix). Next, we compare all methods in terms of robustness on various datasets (Fig. 4C), where we observe Senn to consistently and substantially outperform all other methods in this metric. It is interesting to visualize the inputs and corresponding explanations that maximize criterion (5) –or its discrete counterpart, when appropriate– for different methods and datasets, since these succinctly exhibit the issue of lack of robustness that our work seeks to address. We provide many such “adversarial” examples in Appendix A.7. These examples show the drastic effect that minimal perturbations can have on most methods, particularly Lime and Shap. The pattern is clear: most current interpretability approaches are not robust, even when the underlying model they are trying to explain is. The class of models proposed here offers a promising avenue to remedy this shortcoming.
6 Related Work
Interpretability methods for neural networks. Beyond the gradient and perturbation-based methods mentioned here (simonyan2013deep; zeiler2014visualizing; Bach2015Pixel-wise; shrikumar2017learning; sundararajan2017axiomatic), various other methods of similar spirit exist montavon2017methods. These methods have in common that they do not modify existing architectures, instead relying on a-posteriori computations to reverse-engineer importance values or sensitivities of inputs. Our approach differs both in what it considers the units of explanation—general concepts, not necessarily raw inputs—and how it uses them, intrinsically relying on the relevance scores it produces to make predictions, obviating the need for additional computation. More related to our approach is the work of Lei2016Rationalizing and alshedivat2017contextual. The former proposes a neural network architecture for text classification which “justifies” its predictions by selecting relevant tokens in the input text. But this interpretable representation is then operated on by a complex neural network, so the method is transparent as to what aspect of the input it uses for prediction, but not how it uses it. Contextual Explanation Networks (alshedivat2017contextual) are also inspired by the goal of designing a class of models that learns to predict and explain jointly, but differ from our approach in their formulation (through deep graphical models) and realization of the model (through variational autoencoders). Furthermore, our approach departs from that work in that we explicitly enforce robustness with respect to the units of explanation and we formulate concepts as part of the explanation, thus requiring them to be grounded and interpretable. Explanations through concepts and prototypes. Li2018Deep propose an interpretable neural network architecture whose predictions are based on the similarity of the input to a small set of prototypes, which are learnt during training. Our approach can be understood as generalizing this approach beyond similarities to prototypes into more general interpretable concepts, while differing in how these higher-level representation of the inputs are used. More similar in spirit to our approach of explaining by means of learnable interpretable concepts is the work of kim2018TCAV. They propose a technique for learning concept activation vectors representing human-friendly concepts of interest, by relying on a set of human-annotated examples characterizing these. By computing directional derivatives along these vectors, they gauge the sensitivity of predictors with respect to semantic changes in the direction of the concept. Their approach differs from ours in that it explains a (fixed) external classifier and uses a predefined set of concepts, while we learn both of these intrinsically.
7 Discussion and future work
Interpretability and performance currently stand in apparent conflict in machine learning. Here, we make progress towards showing this to be a false dichotomy by drawing inspiration from classic notions of interpretability to inform the design of modern complex architectures, and by explicitly enforcing basic desiderata for interpretability—explicitness, faithfulness and stability—during training of our models. We demonstrate how the fusion of these ideas leads to a class of rich, complex models that are able to produce robust explanations, a key property that we show is missing from various popular interpretability frameworks. There are various possible extensions beyond the model choices discussed here, particularly in terms of interpretable basis concepts. As for applications, the natural next step would be to evaluate interpretable models in more complex domains, such as larger image datasets, speech recognition or natural language processing tasks.
The authors would like to thank the anonymous reviewers and Been Kim for helpful comments. The work was partially supported by an MIT-IBM grant on deep rationalization and by Graduate Fellowships from Hewlett Packard and CONACYT.
Appendix A Appendix
a.1 Data Processing
We use the original Mnist and Cifar10
datasets with standard mean and variance normalization, usingof the training split for validation.
We use standard mean and variance scaling on all datasets and use train, validation and test splits.
We preprocess the data by rescaling the ordinal variableNumber_of_priors to the range . The data contains several inconsistent examples, so we filter out examples whose label differs from a strong () majority of other identical examples.
The architectures used for Senn in each task are summarized below, where CL/FC stand for convolutional and fully-connected layers, respectively, and denotes the number of concepts. Note that in every case we use more complex architectures for the parametrizer than the concept encoder.
In all cases, we train using the Adam optimizer with initial learning rate and, whenever learning , sparsity strength parameter .
a.3 Predictive Performance of Senn
We observed that any reasonable choice of parameters in our model leads to very low test prediction error . In particular, taking (the unregularized model) yields an unconstrained model with an architecture slightly modified from LeNet, for which we obtain a test set accuracy (slightly above typical results for a vanilla LeNet). On the other hand, for the most extreme regularization value used () we obtain an accuracy of
. All other values interpolate between these two extremes. In particular, the actual model used in Figure 2 obtainedaccuracy, just slightly below the unregularized one.
As with previous experiments, our models are able to achieve competitive performance on all Uci datasets for most parameter configurations.
With default parameters, our Senn model achieves an accuracy of on the test set, compared to for a baseline logistic classification model. The relatively low performance of both methods is due to the problem of inconsistent examples mentioned above.
With default parameters, our Senn model achieves an accuracy of on the test set, which is on par for models of that size trained with some regularization method (our method requires no further regularization).
a.4 Implementation and dependency details
We used the implementations of Lime and Shap provided by the authors. Unless otherwise stated, we use default parameter configurations and estimation samples for these two methods. For the rest of the interpretability frameworks, we use the publicly available DeepExplain222github.com/marcoancona/DeepExplain toolbox. In our experiments, we compute for Senn models by minimizing a Lagrangian relaxation of (5
) through backpropagation. For all other methods, we rely instead on Bayesian optimization, via theskopt333scikit-optimize.github.io toolbox, using a budget of function calls for Lime (due to higher compute time) and for all other methods.
a.5 Qualitative results on Cifar10
a.6 Additional Results on Stability
a.7 Adversarial Examples For Interpretability
We now show various examples inputs and their adversarial perturbation (accoring to (5)) on various datasets.