Training Machine Learning Models by Regularizing their Explanations

09/29/2018 ∙ by Andrew Slavin Ross, et al. ∙ 0

Neural networks are among the most accurate supervised learning methods in use today. However, their opacity makes them difficult to trust in critical applications, especially when conditions in training may differ from those in practice. Recent efforts to develop explanations for neural networks and machine learning models more generally have produced tools to shed light on the implicit rules behind predictions. These tools can help us identify when models are right for the wrong reasons. However, they do not always scale to explaining predictions for entire datasets, are not always at the right level of abstraction, and most importantly cannot correct the problems they reveal. In this thesis, we explore the possibility of training machine learning models (with a particular focus on neural networks) using explanations themselves. We consider approaches where models are penalized not only for making incorrect predictions but also for providing explanations that are either inconsistent with domain knowledge or overly complex. These methods let us train models which can not only provide more interpretable rationales for their predictions but also generalize better when training data is confounded or meaningfully different from test data (even adversarially so).



There are no comments yet.


page 17

page 21

page 23

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Contributions

The major contributions of this thesis are as follows:

  • It presents a framework for encoding domain knowledge about a classification problem as local penalties on the gradient of the model’s decision surface, which can be incorporated into the loss function of any differentiable model (e.g. a neural network). Applying this framework in both supervised and unsupervised formulations, it trains models that generalize to test data from different distributions, which would otherwise be unobtainable by traditional optimization methods. (Chapter

    Training Machine Learning Models by Regularizing their Explanations)

  • It applies a special case of this framework (where explanations are regularized to be simple) to the problem of defending against adversarial examples. It demonstrates increased robustness of regularized models to white- and black-box attacks, at a level comparable or better than adversarial training. It also demonstrates both increased transferability and interpretability of adversarial examples created to fool regularized models, which we evaluate in a human subject experiment. (Chapter Training Machine Learning Models by Regularizing their Explanations)

  • It considers cases where we can meaningfully change what models learn by regularizing more general types of explanations. We review literature and suggest directions for explanation regularization, using sparse gradients, input Hessians, decision trees, nearest neighbors, and even abstract concepts that emerge or that we encourage to emerge in deep neural networks. It concludes by outlining an interface for interpretable machine teaching. (Chapter

    Training Machine Learning Models by Regularizing their Explanations)

2 Introduction

High-dimensional real-world datasets are often full of ambiguities. When we train classifiers on such data, it is frequently possible to achieve high accuracy using classifiers with qualitatively different decision boundaries. To narrow down our choices and encourage robustness, we usually employ regularization techniques (e.g. encouraging sparsity or small parameter values). We also structure our models to ensure domain-specific invariances (e.g. using convolutional neural nets when we would like the model to be invariant to spatial transformations). However, these solutions do not address situations in which our training dataset contains subtle confounds or differs qualitatively from our test dataset. In these cases, our model may fail to generalize no matter how well it is tuned.

Such generalization gaps are of particular concern for uninterpretable models such as neural networks, especially in sensitive domains. For example, Caruana et al. (2015) describe a model intended to prioritize care for patients with pneumonia. The model was trained to predict hospital readmission risk using a dataset containing attributes of patients hospitalized at least once for pneumonia. Counterintuitively, the model learned that the presence of asthma was a negative predictor of readmission, when in reality pneumonia patients with asthma are at a greater medical risk. This model would have presented a grave safety risk if used in production. This problem occurred because the outcomes in the dataset reflected not just the severity of patients’ diseases but the quality of care they initially received, which was higher for patients with asthma.

This case and others like it have motivated recent work in interpretable machine learning, where algorithms provide explanations for domain experts to inspect for correctness before trusting model predictions. However, there has been limited work in optimizing models to find not just the right prediction but also the right explanation. Toward this end, this work makes the following contributions:

  • We confirm empirically on several datasets that input gradient explanations match state of the art sample-based explanations (e.g. LIME, Ribeiro (2016)).

  • Given annotations about incorrect explanations for particular inputs, we efficiently optimize the classifier to learn alternate explanations (to be right for better reasons).

  • When annotations are not available, we sequentially discover classifiers with similar accuracies but qualitatively different decision boundaries for domain experts to inspect for validity.

2.1 Related Work

We first define several important terms in interpretable machine learning. All classifiers have implicit decision rules for converting an input into a decision, though these rules may be opaque. A model is interpretable if it provides explanations for its predictions in a form humans can understand; an explanation provides reliable information about the model’s implicit decision rules for a given prediction. In contrast, we say a machine learning model is accurate if most of its predictions are correct, but only right for the right reasons if the implicit rules it has learned generalize well and conform to domain experts’ knowledge about the problem.

Explanations can take many forms (Keil, 2006) and evaluating the quality of explanations or the interpretability of a model is difficult (Lipton, 2016; Doshi-Velez and Kim, 2017). However, within the machine learning community recently there has been convergence (Lundberg and Lee, 2016) around local counterfactual explanations, where we show how perturbing an input in various ways will affect the model’s prediction . This approach to explanations can be domain- and model-specific (e.g. “annotator rationales” used to explain text classifications by Li et al. (2016); Lei et al. (2016); Zhang et al. (2016)). Alternatively, explanations can be model-agnostic and relatively domain-general, as exemplified by LIME (Local Interpretable Model-agnostic Explanations, Ribeiro et al. (2016); Singh et al. (2016)) which trains and presents local sparse models of how predictions change when inputs are perturbed.

The per-example perturbing and fitting process used in models such as LIME can be computationally prohibitive, especially if we seek to explain an entire dataset during each training iteration. If the underlying model is differentiable, one alternative is to use input gradients as local explanations (Baehrens et al. (2010) provides a particularly good introduction; see also Selvaraju et al. (2016); Simonyan et al. (2013); Li et al. (2015); Hechtlinger (2016)

). The idea is simple: the gradients of the model’s output probabilities with respect to its inputs literally describe the model’s decision boundary (see Figure 

1). They are similar in spirit to the local linear explanations of LIME but much faster to compute.

Input gradient explanations are not perfect for all use-cases—for points far from the decision boundary, they can be uniformatively small and do not always capture the idea of salience (see discussion and alternatives proposed by Shrikumar et al. (2016); Bach et al. (2015); Montavon et al. (2017); Sundararajan et al. (2017); Fong and Vedaldi (2017)). However, they are exactly what is required for constraining the decision boundary. In the past, Drucker and Le Cun (1992) showed that applying penalties to input gradient magnitudes can improve generalization; to our knowledge, our application of input gradients to constrain explanations and find alternate explanations is novel.

Figure 1:

Input gradients lie normal to the model’s decision boundary. Examples above are for simple, 2D, two- and three-class datasets, with input gradients taken with respect to a two hidden layer multilayer perceptron with ReLU activations. Probability input gradients are sharpest near decision boundaries, while log probability input gradients are more consistent within decision regions. The sum of log probability gradients contains information about the full model.

More broadly, none of the works above on interpretable machine learning attempt to optimize explanations for correctness. For SVMs and specific text classification architectures, there exists work on incorporating human input into decision boundaries in the form of annotator rationales (Zaidan et al., 2007; Donahue and Grauman, 2011; Zhang et al., 2016). Unlike our approach, these works are either tailored to specific domains or do not fully close the loop between generating explanations and constraining them.

2.2 Background: Input Gradient Explanations

Consider a differentiable model parametrized by with inputs

and probability vector outputs

corresponding to one-hot labels . Its input gradient is given by or which is a vector normal to the model’s decision boundary at and thus serves as a first-order description of the model’s behavior near . The gradient has the same shape as each vector ; large-magnitude values of the input gradient indicate elements of that would affect if changed. We can visualize explanations by highlighting portions of in locations with high input gradient magnitudes.

3 Our Approach

We wish to develop a method to train models that are right for the right reasons. If explanations faithfully describe a model’s underlying behavior, then constraining its explanations to match domain knowledge should cause its underlying behavior to more closely match that knowledge too. We first describe how input gradient-based explanations lend themselves to efficient optimization for correct explanations in the presence of domain knowledge, and then describe how they can be used to efficiently search for qualitatively different decision boundaries when such knowledge is not available.

3.1 Loss Functions that Constrain Explanations

When constraining input gradient explanations, there are two basic options: we can either constrain them to be large in relevant areas or small in irrelevant areas. However, because input gradients for relevant inputs in many models should be small far from the decision boundary, and because we do not know in advance how large they should be, we opt to shrink irrelevant gradients instead.

Formally, we define an annotation matrix , which are binary masks indicating whether dimension should be irrelevant for predicting observation . We would like to be near at these locations. To that end, we optimize a loss function of the form

which contains familiar cross entropy and regularization terms along with a new regularization term that discourages the input gradient from being large in regions marked by . This term has a regularization parameter which should be set such that the “right answers” and “right reasons” terms have similar orders of magnitude; see Appendix 6 for more details. Note that this loss penalizes the gradient of the log probability, which performed best in practice, though in many visualizations we show , which is the gradient of the predicted probability itself. Summing across classes led to slightly more stable results than using the predicted class log probability , perhaps due to discontinuities near the decision boundary (though both methods were comparable). We did not explore regularizing input gradients of specific class probabilities, though this would be a natural extension.

Because this loss function is differentiable with respect to , we can easily optimize it with gradient-based optimization methods. We do not need annotations (nonzero ) for every input in , and in the case , the explanation term has no effect on the loss. At the other extreme, when is a matrix of all 1s, it encourages the model to have small gradients with respect to its inputs; this can improve generalization on its own (Drucker and Le Cun, 1992). Between those extremes, it biases our model against particular implicit rules.

This penalization approach enjoys several desirable properties. Alternatives that specify a single for all examples presuppose a coherent notion of global feature importance, but when decision boundaries are nonlinear many features are only relevant in the context of specific examples. Alternatives that simulate perturbations to entries known to be irrelevant (or to determine relevance as in Ribeiro et al. (2016)) require defining domain-specific perturbation logic; our approach does not. Alternatives that apply hard constraints or completely remove elements identified by miss the fact that the entries in may be imprecise even if they are human-provided. Thus, we opt to preserve potentially misleading features but softly penalize their use.

3.2 Find-Another-Explanation: Discovering Many Possible Rules without Annotations

Although we can obtain the annotations via experts as in Zaidan et al. (2007), we may not always have this extra information or know the “right reasons.” In these cases, we propose an approach that iteratively adapts to discover multiple models accurate for qualitatively different reasons; a domain expert could then examine them to determine which is the right for the best reasons. Specifically, we generate a “spectrum” of models with different decision boundaries by iteratively training models, explaining , then training the next model to differ from previous iterations:

where the function returns a binary mask indicating which gradient components have a magnitude ratio (their magnitude divided by the largest component magnitude) of at least and where we abbreviated the input gradients of the entire training set at as . In other words, we regularize input gradients where they were largest in magnitude previously. If, after repeated iterations, accuracy decreases or explanations stop changing (or only change after significantly increasing ), then we may have spanned the space of possible models.444Though one can design simple pathological cases where we do not discover all models with this method; we explore an alternative version in Appendix 8 that addresses some of these cases. All of the resulting models will be accurate, but for different reasons; although we do not know which reasons are best, we can present them to a domain expert for inspection and selection. We can also prioritize labeling or reviewing examples about which the ensemble disagrees. Finally, the size of the ensemble provides a rough measure of dataset redundancy.

4 Empirical Evaluation

We demonstrate explanation generation, explanation constraints, and the find-another-explanation method on a toy color dataset and three real-world datasets. In all cases, we used a multilayer perceptron with two hidden layers of size 50 and 30, ReLU nonlinearities with a softmax output, and a penalty on . We trained the network using Adam (Kingma and Ba, 2014) with a batch size of 256 and Autograd (Mclaurin et al., 2017). For most experiments, we used an explanation L2 penalty of , which gave our “right answers” and “right reasons” loss terms similar magnitudes. More details about cross-validation are included in Appendix 6. For the cutoff value described in Section 3.2 and used for display, we often chose 0.67, which tended to preserve 2-5% of gradient components (the average number of qualifying elements tended to fall exponentially with ). Code for all experiments is available at

4.1 Toy Color Dataset

We created a toy dataset of RGB images with four possible colors. Images fell into two classes with two independent decision rules a model could implicitly learn: whether their four corner pixels were all the same color, and whether their top-middle three pixels were all different colors. Images in class 1 satisfied both conditions and images in class 2 satisfied neither. Because only corner and top-row pixels are relevant, we expect any faithful explanation of an accurate model to highlight them.

Figure 2: Gradient vs. LIME explanations of nine perceptron predictions on the Toy Color dataset. For gradients, we plot dots above pixels identified by (the top 33% largest-magnitude input gradients), and for LIME, we select the top 6 features (up to 3 can reside in the same RGB pixel). Both methods suggest that the model learns the corner rule.

In Figure 2, we see both LIME and input gradients identify the same relevant pixels, which suggests that (1) both methods are effective at explaining model predictions, and (2) the model has learned the corner rather than the top-middle rule, which it did consistently across random restarts.

Figure 3: Implicit rule transitions as we increase and the number of nonzero rows of . Pairs of points represent the fraction of large-magnitude () gradient components in the corners and top-middle for 1000 test examples, which almost always add to 1 (indicating the model is most sensitive to these elements alone, even during transitions). Note there is a wide regime where the model learns a hybrid of both rules.
Figure 4: Rule discovery using find-another-explanation method with 0.67 cutoff and for and for . Note how the first two iterations produce explanations corresponding to the two rules in the dataset while the third produces very noisy explanations (with low accuracies).

However, if we train our model with a nonzero (specifically, setting for corners across examples ), we were able to cause it to use the other rule. Figure 3 shows how the model transitions between rules as we vary and the number of examples penalized by . This result demonstrates that the model can be made to learn multiple rules despite only one being commonly reached via standard gradient-based optimization methods. However, it depends on knowing a good setting for , which in this case would still require annotating on the order of examples, or 5% of our dataset (although always including examples with annotations in Adam minibatches let us consistently switch rules with only 50 examples, or 0.2% of the dataset).

Finally, Figure 4 shows we can use the find-another-explanation technique from Sec. 3.2 to discover the other rule without being given . Because only two rules lead to high accuracy on the test set, the model performs no better than random guessing when prevented from using either one (although we have to increase the penalty high enough that this accuracy number may be misleading - the essential point is that after the first iteration, explanations stop changing). Lastly, though not directly relevant to the discussion on interpretability and explanation, we demonstrate the potential of explanations to reduce the amount of data required for training in Appendix 7.

4.2 Real-world Datasets

To demonstrate real-world, cross-domain applicability, we test our approach on variants of three familiar machine learning text, image, and tabular datasets:

  • 20 Newsgroups: As in Ribeiro et al. (2016), we test input gradients on the alt.atheism vs. soc.religion.christian subset of the 20 Newsgroups dataset Lichman (2013). We used the same two-hidden layer network architecture with a TF-IDF vectorizer with 5000 components, which gave us a 94% accurate model for .

  • Iris-Cancer: We concatenated all examples in classes 1 and 2 from the Iris dataset with the the first 50 examples from each class in the Breast Cancer Wisconsin dataset (Lichman, 2013) to create a composite dataset . Despite the dataset’s small size, our network still obtains an average test accuracy of 92% across 350 random -

    training-test splits. However, when we modify our test set to remove the 4 Iris components, average test accuracy falls to 81% with higher variance, suggesting the model learns to depend on Iris features and suffers without them. We verify that our explanations reveal this dependency and that regularizing them avoids it.

  • Decoy MNIST: On the baseline MNST dataset (LeCun et al., 2010), our network obtains 98% train and 96% test accuracy. However, in Decoy MNIST, images have gray swatches in randomly chosen corners whose shades are functions of their digits in training (in particular, ) but are random in test. On this dataset, our model has a higher 99.6% train accuracy but a much lower 55% test accuracy, indicating that the decoy rule misleads it. We verify that both gradient and LIME explanations let users detect this issue and that explanation regularization lets us overcome it.

Figure 5: Words identified by LIME vs. gradients on an example from the atheism vs. Christianity subset of 20 Newsgroups. More examples are available at Words are blue if they support soc.religion.christian and orange if they support alt.atheism, with opacity equal to the ratio of the magnitude of the word’s weight to the largest magnitude weight. LIME generates sparser explanations but the weights and signs of terms identified by both methods match closely. Note that both methods reveal some aspects of the model that are intuitive (“church” and “service” are associated with Christianity), some aspects that are not (“13” is associated with Christianity, “edu” with atheism), and some that are debatable (“freedom” is associated with atheism, “friends” with Christianity).
Figure 6: Input gradient explanations for Decoy MNIST vs. LIME, using the LIME image library Ribeiro (2016). In this example, the model incorrectly predicts 3 rather than 7 because of the decoy swatch.
Figure 7: Iris-Cancer features identified by input gradients vs. LIME, with Iris features highlighted in red. Input gradient explanations are more faithful to the model. Note that most gradients change sign when switching between and , and that the magnitudes of input gradients are different across examples, which provides information about examples’ proximity to the decision boundary.

Input gradients are consistent with sample-based methods such as LIME, and faster. On 20 Newsgroups (Figure 5), input gradients are less sparse but identify all of the same words in the document with similar weights. Note that input gradients also identify words outside the document that would affect the prediction if added.

On Decoy MNIST (Figure 6), both LIME and input gradients reveal that the model predicts 3 rather than 7 due to the color swatch in the corner. Because of their fine-grained resolution, input gradients sometimes better capture counterfactual behavior, where extending or adding lines outside of the digit to either reinforce it or transform it into another digit would change the predicted probability (see also Figure 10). LIME, on the other hand, better captures the fact that the main portion of the digit is salient (because its super-pixel perturbations add and remove larger chunks of the digit).

On Iris-Cancer (Figure 7), input gradients actually outperform LIME. We know from the accuracy difference that Iris features are important to the model’s prediction, but LIME only identifies a single important feature, which is from the Breast Cancer dataset (even when we vary its perturbation strategy). This example, which is tabular and contains continuously valued rather categorical features, may represent a pathological case for LIME, which operates best when it can selectively mask a small number of meaningful chunks of its inputs to generate perturbed samples. For truly continuous inputs, it should not be surprising that explanations based on gradients perform best.

There are a few other advantages input gradients have over sample-based perturbation methods. On 20 Newsgroups, we noticed that for very long documents, explanations generated by the sample-based method LIME are often overly sparse, and there are many words identified as significant by input gradients that LIME ignores. This may be because the number of features LIME selects must be passed in as a parameter beforehand, and it may also be because LIME only samples a fixed number of times. For sufficiently long documents, it is unlikely that sample-based approaches will mask every word even once, meaning that the output becomes increasingly nondeterministic—an undesirable quality for explanations. To resolve this issue, one could increase the number of samples, but that would increase the computational cost since the model must be evalutated at least once per sample to fit a local surrogate. Input gradients, on the other hand, only require on the order of one model evaluation total to generate an explanation of similar quality (generating gradients is similar in complexity to predicting probabilities), and furthermore, this complexity is based on the vector length, not the document length. This issue (underscored by Table 1) highlights some inherent scalability advantages input gradients enjoy over sample-based perturbation methods.

LIME Gradients Dimension of
Iris-Cancer 0.03s 0.000019s 34
Toy Colors 1.03s 0.000013s 75
Decoy MNIST 1.54s 0.000045s 784
20 Newsgroups 2.59s 0.000520s 5000
Table 1: Gradient vs. LIME runtimes per explanation. Note that each method uses a different version of LIME; Iris-Cancer and Toy Colors use lime_tabular

with continuous and quartile-discrete perturbation methods, respectively, Decoy MNIST uses

lime_image, and 20 Newsgroups uses lime_text. Code was executed on a laptop and input gradient calculations were not optimized for performance, so runtimes are only meant to provide a sense of scale.
Figure 8: Overcoming confounds using explanation constraints on Iris-Cancer (over 350 random train-test splits). By default (), input gradients tend to be large in Iris dimensions, which results in lower accuracy when Iris is removed from the test set. Models trained with in Iris dimensions (full ) have almost exactly the same test accuracy with and without Iris.
Figure 9: Training with explanation constraints on Decoy MNIST. Accuracy is low () on the swatch color-randomized test set unless the model is trained with in swatches (full ). In that case, test accuracy matches the same architecture’s performance on the standard MNIST dataset (baseline).

Given annotations, input gradient regularization finds solutions consistent with domain knowledge. Another key advantage of using an explanation method more closely related to our model is that we can then incorporate explanations into our training process, which are most useful when the model faces ambiguities in how to classify inputs. We deliberately constructed the Decoy MNIST and Iris-Cancer datasets to have this kind of ambiguity, where a rule that works in training will not generalize to test. When we train our network on these confounded datasets, their test accuracy is better than random guessing, in part because the decoy rules are not simple and the primary rules not complex, but their performance is still significantly worse than on a baseline test set with no decoy rules. By penalizing explanations we know to be incorrect using the loss function defined in Section 3.1, we are able to recover that baseline test accuracy, which we demonstrate in Figures 8 and 9.

Figure 10:

Find-another-explanation results on Iris-Cancer (top; errorbars show standard deviations across 50 trials), 20 Newsgroups (middle; blue supports Christianity and orange supports atheism, word opacity set to magnitude ratio), and Decoy MNIST (bottom, for three values of

with scatter opacity set to magnitude ratio cubed). Real-world datasets are often highly redundant and allow for diverse models with similar accuracies. On Iris-Cancer and Decoy MNIST, both explanations and accuracy results indicate we overcome confounds after 1-2 iterations without any prior knowledge about them encoded in .

When annotations are unavailable, our find-another-explanation method discovers diverse classifiers. As we saw with the Toy Color dataset, even if almost every row of is 0, we can still benefit from explanation regularization (meaning practitioners can gradually incorporate these penalties into their existing models without much upfront investment). However, annotation is never free, and in some cases we either do not know the right explanation or cannot easily encode it. Additionally, we may be interested in exploring the structure of our model and dataset in a less supervised fashion. On real-world datasets, which are usually overdetermined, we can use find-another-explanation to discover s in shallower local minima that we would normally never explore. Given enough models right for different reasons, hopefully at least one is right for the right reasons.

Figure 10 shows find-another-explanation results for our three real-world datasets, with example explanations at each iteration above and model train and test accuracy below. For Iris-Cancer, we find that the initial iteration of the model heavily relies on the Iris features and has high train but low test accuracy, while subsequent iterations have lower train but higher test accuracy (with smaller gradients in Iris components). In other words, we spontaneously obtain a more generalizable model without a predefined alerting us that the first four features are misleading.

Find-another-explanation also overcomes confounds on Decoy MNIST, needing only one iteration to recover baseline accuracy. Bumping too high (to the point where its term is a few orders of magnitude larger than the cross-entropy) results in more erratic behavior. Interestingly, in a process remniscent of distillation (Papernot et al., 2016c), the gradients themselves become more evenly and intuitively distributed at later iterations. In many cases they indicate that the probabilities of certain digits increase when we brighten pixels along or extend their distinctive strokes, and that they decrease if we fill in unrelated dark areas, which seems desirable. However, by the last iteration, we start to revert to using decoy swatches in some cases.

On 20 Newsgroups, the words most associated with alt.atheism and
soc.religion.christian change between iterations but remain mostly intuitive in their associations. Train accuracy mostly remains high while test accuracy is unstable.

For all of these examples, accuracy remains high even as decision boundaries shift significantly. This may be because real-world data tends to contain significant redundancies.

4.3 Limitations

Input gradients provide faithful information about a model’s rationale for a prediction but trade interpretability for efficiency. In particular, when input features are not individually meaningful to users (e.g. for individual pixels or word2vec components), input gradients may be difficult to interpret and may be difficult to specify. Additionally, because they can be 0 far from the decision boundary, they do not capture the idea of salience as well as other methods (Zeiler and Fergus, 2014; Sundararajan et al., 2017; Montavon et al., 2017; Bach et al., 2015; Shrikumar et al., 2016). However, they are necessarily faithful to the model and easy to incorporate into its loss function. Input gradients are first-order linear approximations of the model; we might call them first-order explanations.

5 Discussion

In this chapter, we showed that:

  • On training sets that contain confounds which would fool any model trained just to make correct predictions, we can use gradient-based explanation regularization to learn models that still generalize to test. These results imply that gradient regularization actually changes why our model makes predictions.

  • When we lack expert annotations, we can still use our method in an unsupervised manner to discover models that make predictions for different reasons. This “find-another-explanation” technique allowed us to overcome confounds on Decoy MNIST and Iris-Cancer, and even quantify the ambiguity present in the Toy Color dataset.

  • Input gradients are consistent with sample-based methods such as LIME but faster to compute and sometimes more faithful to the model, especially for continuous inputs.

Our consistent results on several diverse datasets show that input gradients merit further investigation as building blocks for optimizable explanations; there exist many options for further advancements such as weighted annotations , different penalty norms, and more general specifications of whether features should be positively or negatively predictive of specific classes for specific inputs.

Finally, our “right for the right reasons” approach may be of use in solving related problems, e.g. in integrating causal inference with deep neural networks or maintaining robustness to adversarial examples (which we discuss in Chapter Training Machine Learning Models by Regularizing their Explanations). Building on our find-another-explanation results, another promising direction is to let humans in the loop interactively guide models towards correct explanations. Overall, we feel that developing methods of ensuring that models are right for better reasons is essential to overcoming the inherent obstacles to generalization posed by ambiguities in real-world datasets.

6 Cross-Validation

Most regularization parameters are selected to maximize accuracy on a validation set. However, when your training and validation sets share the same misleading confounds, validation accuracy may not be a good proxy for test accuracy. Instead, we recommend increasing the explanation regularization strength until the cross-entropy and “right reasons” terms have roughly equal magnitudes (which corresponds to the region of highest test accuracy below). Intuitively, balancing the terms in this way should push our optimization away from cross-entropy minima that violate the explanation constraints specified in and towards ones that correspond to “better reasons.” Increasing too much makes the cross-entropy term negligible. In that case, our model performs no better than random guessing.

Figure 11: Cross-validating . The regime of highest accuracy (highlighted) is also where the initial cross-entropy and loss terms have similar magnitudes. Exact equality is not required; being an order of magnitude off does not significantly affect accuracy.

7 Learning with Less Data

It is natural to ask whether explanations can reduce data requirements. Here we explore that question on the Toy Color dataset using four variants of (with chosen to match loss terms at each ).

Figure 12: Explanation regularization can reduce data requirements.

We find that when is set to the Pro-Rule 1 mask, which penalizes all pixels except the corners, we reach 95% accuracy with fewer than 100 examples (as compared to , where we need almost 10000). Penalizing the top-middle pixels (Anti-Rule 2) or all pixels except the top-middle (Pro-Rule 2) also consistently improves accuracy relative to data. Penalizing the corners (Anti-Rule 1), however, reduces accuracy until we reach a threshold . This may be because the corner pixels can match in 4 ways, while the top-middle pixels can differ in ways, suggesting that Rule 2 could be inherently harder to learn from data and positional explanations alone.

8 Simultaneous Find-Another-Explanation

In Section 3.2, we introduced a method of training classifiers to make predictions for different reasons by sequentially augmenting to penalize more features. However, as our ensemble grows, can saturate to , and subsequent models will be trained with uniform gradient regularization. While these models may have desirable properties (which we explore in the following chapter), they will not be diverse.

As a simple example, consider a 2D dataset with one class confined to the first quadrant and the other confined to the third. In theory, we have a full degree of decision freedom; it should be possible to learn two perfect and fully orthogonal boundaries (one horizontal, one vertical). However, when we train our first MLP, it learns a diagonal surface; both features have large gradients everywhere, so immediately. To resolve this, we propose a simultaneous training procedure:



refers to our single-model loss function, and for our similarity measure we use the squared cosine similarity

, where we add to the denominator for numerical stability. Squaring the cosine similarity ensures our penalty is positive, is minimized by orthogonal boundaries, and is soft for nearly orthogonal boundaries. We show in Figure 13 that this lets us obtain the two desired models.

Figure 13: Toy 2D problem with one degree of decision boundary freedom. Across random restarts (left two plots), we tend to learn a boundary in which both features are significant, which prevents sequential find-another-explanation from producing diverse models. If we jointly train two models with a penalty on the cosine similarity of their gradients (right plot), they end up with orthogonal boundaries.

9 Introduction

In the previous chapter, we used input gradient penalties to encourage neural networks to make predictions for specific reasons. We demonstrated this on “decoy” datasets deliberately designed to deceive models making decisions for different reasons. This philosophy of testing – that we should measure generalization by testing on data from a different distribution than we trained on – can be taken to its extreme by testing models in an adversarial setting, where neural networks have known vulnerabilities (Szegedy et al., 2013)

. In this chapter, we consider whether a domain knowledge-agnostic application of explanation regularization (a uniform L2 penalty on input gradients, similar in spirit to Ridge regression on the model’s local linear approximations) could help defend against adversarial examples.

Adversarial examples pose serious obstacles for the adoption of neural networks in settings which are security-sensitive or have legal ramifications (Kang and Kang, 2017). Although many techniques for generating these examples (which we call “attacks”) require access to model parameters, Papernot et al. (2017) have shown that it is possible and even practical to attack black-box models in the real world, in large part because of transferability; examples generated to fool one model tend to fool all models trained on the same dataset. Particularly for images, these adversarial examples can be constructed to fool models across a variety of scales and perspectives (Athalye and Sutskever, 2017)

, which poses a problem for the adoption of deep learning models in systems like self-driving cars.

Although there has recently been a great deal of research in adversarial defenses, many of these methods have struggled to achieve robustness to transferred adversarial examples (Tramèr et al., 2017b). Some of the most effective defenses simply detect and reject them rather than making predictions (Xu et al., 2017). The most common, “brute force” solution is adversarial training, where we include a mixture of normal and adversarially-generated examples in the training set (Kurakin et al., 2016b). However, Tramèr et al. (2017a) show that the robustness adversarial training provides can be circumvented by randomizing or transferring perturbations from other models (though ensembling helps).

As we noted in Chapter Training Machine Learning Models by Regularizing their Explanations, domain experts are also often concerned that DNN predictions are uninterpretable. The lack of interpretability is particularly problematic in domains where algorithmic bias is often a factor (Angwin et al., 2016) or in medical contexts where safety risks can arise when there is mismatch between how a model is trained and used (Caruana et al., 2015)

. For computer vision models (the primary target of adversarial attacks), the most common class of explanation is the saliency map, either at the level of raw pixels, grid chunks, or superpixels

(Ribeiro et al., 2016).

The local linear approximation provided by raw input gradients (Baehrens et al., 2010) is sometimes used for pixel-level saliency maps (Simonyan et al., 2013). However, computer vision practitioners tend not to examine raw input gradients because they are noisy and difficult to interpret. This issue has spurred the development of techniques like integrated gradients (Sundararajan et al., 2017) and SmoothGrad (Smilkov et al., 2017) that generate smoother, more interpretable saliency maps from noisy gradients. The rationale behind these techniques is that, while the local behavior of the model may be noisy, examining the gradients over larger length scales in input space provides a better intution about the model’s behavior.

However, raw input gradients are exactly what many attacks use to generate adversarial examples. Explanation techniques which smooth out gradients in background pixels may be inappropriately hiding the fact that the model is quite sensitive to them. We consider that perhaps the need for these smoothing techniques in the first place is indicative of a problem with our models, related to their adversarial vulnerability and capacity to overfit. Perhaps it is fundamentally hard for adversarially vulnerable models to be interpretable.

On the other hand, perhaps it is hard for interpretable models to be adversarially vulnerable. Our hypothesis is that by training a model to have smooth input gradients with fewer extreme values, it will not only be more interpretable but also more resistant to adversarial examples. In the experiments that follow we confirm this hypothesis using uniform gradient regularization, which optimizes the model to have smooth input gradients with respect to its predictions during training. Using this technique, we demonstrate robustness to adversarial examples across multiple model architectures and datasets, and in particular demonstrate robustness to transferred adversarial examples: gradient-regularized models maintain significantly higher accuracy on examples generated to fool other models than baselines. Furthermore, both qualitatively and in human subject experiments, we find that adversarial examples generated to fool gradient-regularized models are, in a particular sense, more “interpretable”: they fool humans as well.

10 Background

In this section, we will (re)introduce notation, and give a brief overview of the baseline attacks and defenses against which we will test and compare our methods. The methods we will analyze again apply to all differentiable classification models , which are functions parameterized by that return predictions given inputs . These predictions indicate the probabilities that each of inputs in dimensions belong to each of class labels. To train these models, we try to find sets of parameters that minimize the total information distance between the predictions and the true labels (also

, one-hot encoded) on a training set:


which we will sometimes write as

with giving the sum of the cross entropies between the predictions and the labels.

10.1 Attacks

10.1.1 Fast Gradient Sign Method (FGSM)

Goodfellow et al. (2014) introduced this first method of generating adversarial examples by perturbing inputs in a manner that increases the local linear approximation of the loss function:


If is small, these adversarial examples are indistinguishable from normal examples to a human, but the network performs significantly worse on them.

Kurakin et al. (2016a) noted that one can iteratively perform this attack with a small to induce misclassifications with a smaller total perturbation (by following the nonlinear loss function in a series of small linear steps rather than one large linear step).

10.1.2 Targeted Gradient Sign Method (TGSM)

A simple modification of the Fast Gradient Sign Method is the Targeted Gradient Sign Method, introduced by Kurakin et al. (2016a). In this attack, we attempt to decrease a modified version of the loss function that encourages the model to misclassify examples in a specific way:


where encodes an alternate set of labels we would like the model to predict instead. In the digit classification experiments below, we often picked targets by incrementing the labels by 1 (modulo 10), which we will refer to as . The TGSM can also be performed iteratively.

10.1.3 Jacobian-based Saliency Map Approach (JSMA)

The final attack we consider, the Jacobian-based Saliency Map Approach (JSMA), also takes an adversarial target vector . It iteratively searches for pixels or pairs of pixels in to change such that the probability of the target label is increased and the probability of all other labels are decreased. This method is notable for producing examples that have only been changed in several dimensions, which can be hard for humans to detect. For a full description of the attack, we refer the reader to Papernot et al. (2016b).

10.2 Defenses

As baseline defenses, we consider defensive distillation and adversarial training. To simplify comparison, we omit defenses

(Xu et al., 2017; Nayebi and Ganguli, 2017) that are not fully architecture-agnostic or which work by detecting and rejecting adversarial examples.

10.2.1 Distillation

Distillation, originally introduced by Ba and Caruana (2014), was first examined as a potential defense by Papernot et al. (2016c)

. The main idea is that we train the model twice, initially using the one-hot ground truth labels but ultimately using the initial model’s softmax probability outputs, which contain additional information about the problem. Since the normal softmax function tends to converge very quickly to one-hot-ness, we divide all of the logit network outputs (which we will call

instead of the probabilities ) by a temperature (during training but not evaluation):


where we use to denote a network ending in a softmax with temperature . Note that as approaches , the predictions converge to . The full process can be expressed as


Distillation is usually used to help small networks achieve the same accuracy as larger DNNs, but in a defensive context, we use the same model twice. It has been shown to be an effective defense against white-box FGSM attacks, but Carlini and Wagner (2016) have shown that it is not robust to all kinds of attacks. We will see that the precise way it defends against certain attacks is qualitatively different than gradient regularization, and that it can actually make the models more vulnerable to attacks than an undefended model.

10.2.2 Adversarial Training

In adversarial training (Kurakin et al., 2016b), we increase robustness by injecting adversarial examples into the training procedure. We follow the method implemented in Papernot et al. (2016a), where we augment the network to run the FGSM on the training batches and compute the model’s loss function as the average of its loss on normal and adversarial examples without allowing gradients to propogate so as to weaken the FGSM attack (which would also make the method second-order). We compute FGSM perturbations with respect to predicted rather than true labels to prevent “label leaking,” where our model learns to classify adversarial examples more accurately than regular examples.

11 Gradient Regularization

We defined our “right for the right reasons” objective in Chapter Training Machine Learning Models by Regularizing their Explanations using an L2 penalty on the gradient of the model’s predictions across classes with respect to input features marked irrelevant by domain experts. We encoded their domain knowledge using an annotation matrix . If we set , however, and consider only the log-probabilities of the predicted classes, we recover what Drucker and Le Cun (1992)

introduced as “double backpropagation”, which trains neural networks by minimizing not just the “energy” of the network but the rate of change of that energy with respect to the input features. In their formulation the energy is a quadratic loss, but we can reformulate it almost equivalently using the cross-entropy:


whose objective we can write a bit more concisely as


is again a hyperparameter specifying the penalty strength. The intuitive objective of this function is to ensure that if any input changes slightly, the divergence between the predictions and the labels will not change significantly (though including this term does not guarantee Lipschitz continuity everywhere). Double backpropagation was mentioned as a potential adversarial defense in the same paper which introduced defensive distillation

(Papernot et al., 2016c), but at publish time, its effectiveness in this respect had not yet been analyzed in the literature – though Gu and Rigazio (2014) previously and Hein and Andriushchenko (2017); Czarnecki et al. (2017) concurrently consider related objectives, and Raghunathan et al. (2018) derive and minimze an upper bound on adversarial vulnerability based on the maximum gradient norm in a ball around each training input. These works also provide stronger theoretical explanations for why input gradient regularization is effective, though they do not analyze its relationship to model interpretability. In this work, we interpret gradient regularization as a quadratic penalty on our model’s saliency map.

12 Experiments

12.0.1 Datasets and Models

We evaluated the robustness of distillation, adversarial training, and gradient regularization to the FGSM, TGSM, and JSMA on MNIST (LeCun et al., 2010), Street-View House Numbers (SVHN) (Netzer et al., 2011), and notMNIST Butalov (2011)

. On all datasets, we test a simple convolutional neural network with 5x5x32 and 5x5x64 convolutional layers followed by 2x2 max pooling and a 1024-unit fully connected layer, with batch-normalization after all convolutions and both batch-normalization and dropout on the fully-connected layer. All models were implemented in Tensorflow and trained using Adam

(Kingma and Ba, 2014) with and for 15000 minibatches of size of 256. For SVHN, we prepare training and validation set as described in Sermanet et al. (2012), converting the images to grayscale following Grundland and Dodgson (2007) and applying both global and local contrast normalization.

12.0.2 Attacks and Defenses
Figure 14: Accuracy of all CNNs on FGSM examples generated to fool undefended models, defensively distilled, adversarially trained, and gradient regularized models (from left to right) on MNIST, SVHN, and notMNIST (from top to bottom). Gradient-regularized models are the most resistant to other models’ adversarial examples at high , while all models are fooled by gradient-regularized model examples. On MNIST and notMNIST, distilled model examples are usually identical to non-adversarial examples (due to gradient underflow), so they fail to fool any of the other models.
Figure 15: Applying both gradient regularization and adversarial training (“both defenses”) allows us to obtain maximal robustness to white-box and normal black-box attacks on SVHN (with a very slight label-leaking effect on the FGSM, perhaps due to the inclusion of the term). However, no models are able to maintain robustness to black-box attacks using gradient regularization.

For adversarial training and JSMA example generation, we used the Cleverhans adversarial example library (Papernot et al., 2016a). For distillation, we used a softmax temperature of , and for adversarial training, we trained with FGSM perturbations at , averaging normal and adversarial losses. For gradient regularized models, we use double backpropagation, which provided the best robustness, and train over a spread of values. We choose the with the highest accuracy against validation black-box FGSM examples but which is still at least 97% as accurate on normal validation examples (though accuracy on normal examples tended not to be significantly different). Code for all models and experiments has been open-sourced777

12.0.3 Evaluation Metrics

For the FGSM and TGSM, we test all models against adversarial examples generated for each model and report accuracy. Testing this way allows us to simultaneously measure white- and black-box robustness.

On the JSMA and iterated TGSM, we found that measuring accuracy was no longer a good evaluation metric, since for our gradient-regularized models, the generated adversarial examples often resembled their targets more than their original labels. To investigate this, we performed a human subject experiment to evaluate the legitimacy of adversarial example misclassifications.

12.1 Accuracy Evaluations (FGSM and TGSM)

12.1.1 FGSM Robustness

Figure 14 shows the results of our defenses’ robustness to the FGSM on MNIST, SVHN, and notMNIST for our CNN at a variety of perturbation strengths . Consistently across datasets, we find that gradient-regularized models exhibit strong robustness to black-box transferred FGSM attacks (examples produced by attacking other models). Although adversarial training sometimes performs slightly better at , the value we used in training, gradient regularization generally surpasses it at higher (see the green curves in the leftmost plots).

The story with white-box attacks is more interesting. Gradient-regularized models are generally more robust to than undefended models (visually, the green curves in the rightmost plots fall more slowly than the blue curves in the leftmost plots). However, accuracy still eventually falls for them, and it does so faster than for adversarial training. Even though their robustness to white-box attacks seems lower, though, the examples produced by those white-box attacks actually fool all other models equally well. This effect is particularly pronounced on SVHN. In this respect, gradient regularization may hold promise not just as a defense but as an attack, if examples generated to fool them are inherently more transferable.

Models trained with defensive distillation in general perform no better and often worse than undefended models. Remarkably, except on SVHN, attacks against distilled models actually fail to fool all models. Closer inspection of distilled model gradients and examples themselves reveals that this occurs because distilled FGSM gradients vanish – so the examples are not perturbed at all. As soon as we obtain a nonzero perturbation from a different model, distillation’s appearance of robustness vanishes as well.

Although adversarial training and gradient regularization seem comparable in terms of accuracy, they work for different reasons and can be applied in concert to increase robustness, which we show in Figure 15. In Figure 16 we also show that, on normal and adversarially trained black-box FGSM attacks, models trained with these two defenses are fooled by different sets of adversarial examples. We provide intuition for why this might be the case in Figure 17.

Figure 16: Venn diagrams showing overlap in which MNIST FGSM examples, generated for normal, adversarially trained, and gradient regularized models, fool all three. Undefended models tend to be fooled by examples from all models, while the sets of adversarially trained model FGSM examples that fool the two defended models are closer to disjoint. Gradient-regularized model FGSM examples fool all models. These results suggest that ensembling different forms of defense may be effective in defending against black box attacks (unless those black box attacks use a gradient-regularized proxy).
Figure 17: Conceptual illustration of the difference between gradient regularization and gradient masking. In (idealized) gradient masking, input gradients are completely uninformative, so following them doesn’t affect either the masked model’s predictions or those of any other model. In gradient regularization, gradients actually become more informative, so following them will ultimately fool all models. However, because gradients are also smaller, perturbations need to be larger to flip predictions. Unregularized, unmasked models are somewhere in between. We see quantitative support for this interpretation in Figure 16, as well as qualitative evidence in Figure 22.
Figure 18: CNN accuracy on TGSM examples generated to fool the four models on three datasets (see Figure 14 for more explanation). Gradient-regularized models again exhibit robustness to other models’ adversarial examples. Distilled model adversarial perturbations fool other models again since their input gradients no longer underflow.
12.1.2 TGSM Robustness

Against the TGSM attack (Figure 18), defensively distilled model gradients no longer vanish, and accordingly these models start to show the same vulnerability to adversarial attacks as others. Gradient-regularized models still exhibit the same robustness even at large perturbations , and again, examples generated to fool them fool other models equally well.

Figure 19: Distributions of (L2 norm) magnitudes of FGSM input gradients (top), TGSM input gradients (middle), and predicted log probabilities across all classes (bottom) for each defense. Note the logarithmic scales. Gradient-regularized models tend to assign non-predicted classes higher probabilities, and the L2 norms of the input gradients of their FGSM and TGSM loss function terms have similar orders of magnitude. Distilled models (evaluated at ) assign extremely small probabilities to all but the predicted class, and their TGSM gradients explode while their FGSM gradients vanish (we set a minimum value of to prevent underflow). Normal and adversarially trained models lie somewhere in the middle.

One way to better understand the differences between gradient-regularized, normal, and distilled models is to examine the log probabilities they output and the norms of their loss function input gradients, whose distributions we show in Figure 19 for MNIST. We can see that the different defenses have very different statistics. Probabilities of non-predicted classes tend to be small but remain nonzero for gradient-regularized models, while they vanish on defensively distilled models evaluated at (despite distillation’s stated purpose of discouraging certainty). Perhaps because , defensively distilled models’ non-predicted log probability input gradients are the largest by many orders of magnitude, while gradient-regularized models’ remain controlled, with much smaller means and variances. The other models lie between these two extremes. While we do not have a strong theoretical argument about what input gradient magnitudes should be, we believe it makes intuitive sense that having less variable, well-behaved, and non-vanishing input gradients should be associated with robustness to attacks that consist of small perturbations in input space.

Figure 20: Results of applying the JSMA to MNIST 0 and 1 images with maximum distortion parameter for a distilled model (left) and a gradient-regularized model (right). Examples in each row start out as the highlighted digit but are modified until the model predicts the digit corresponding to their column or the maximum distortion is reached.

12.2 Human Subject Study (JSMA and Iterated TGSM)

12.2.1 Need for a Study

Accuracy scores against the JSMA can be misleading, since without a maximum distortion constraint it necessarily runs until the model predicts the target. Even with such a constraint, the perturbations it creates sometimes alter the examples so much that they no longer resemble their original labels, and in some cases bear a greater resemblance to their targets. Figure 20 shows JSMA examples on MNIST for gradient-regularized and distilled models which attempt to convert 0s and 1s into every other digit. Although all of the perturbations “succeed” in changing the model’s prediction, in the gradient-regularized case, many of the JSMA examples strongly resemble their targets.

The same issues occur for other attack methods, particularly the iterated TGSM, for which we show confusion matrices for different models and datasets in Figure 21. For the gradient-regularized models, these psuedo-adversarial examples quickly become almost prototypical examples of their targets, which is not reflected in accuracies with respect to the original labels.

Figure 21: Partial confusion matrices showing results of applying the iterated TGSM for 15 iterations at . Each row is generated from the same example but modified to make the model to predict every other class. TGSM examples generated for gradient-regularized models (right) resemble their targets more than their original labels and may provide insight into what the model has learned. Animated versions of these examples can be seen at

To test these intuitions more rigorously, we ran a small pilot study with 11 subjects to measure whether they found examples generated by these methods to be more or less plausible instances of their targets.

12.2.2 Study Protocol

The pilot study consisted of a quantitative and qualitative portion. In the quantitative portion, subjects were shown 30 images of MNIST JSMA or SVHN iterated TGSM examples. Each of the 30 images corresponded to one original digit (from 0 to 9) and one model (distilled, gradient-regularized, or undefended). Note that for this experiment, we used

gradient regularization, ran the TGSM for just 10 steps, and trained models for 4 epochs at a learning rate of 0.001. This procedure was sufficient to produce examples with explanations similar to the longer training procedure used in our earlier experiments, and actually increased the robustness of the undefended models (adversarial accuracy tends to fall with training iteration). Images were chosen uniformly at random from a larger set of 45 examples that corresponded to the first 5 images of the original digit in the test set transformed using the JSMA or iterated TGSM to each of the other 9 digits (we ensured that all models misclassified all examples as their target). Subjects were not given the original label, but were asked to input what they considered the most and second-most plausible predictions for the image that they thought a reasonable classifier would make (entering N/A if they thought no label was a plausible choice). In the qualitative portion that came afterwards, users were shown three 10x10 confusion matrices for the different defenses on MNIST (Figure

20 shows the first two rows) and were asked to write comments about the differences between the examples. Afterwards, there was a short group discussion. This study was performed in compliance with the institution’s IRB.

Model human
normal 2.0% 26.0% 40.0% 63.3%
distilled 0.0% 23.5% 1.7% 25.4%
grad. reg. 16.4% 41.8% 46.3% 81.5%
Table 2: Quantitative feedback from the human subject experiment. “human fooled” columns record what percentage of examples were classified by humans as most plausibly their adversarial targets, and “mistake reasonable” records how often humans either rated the target plausible or marked the image unrecognizable as any label (N/A).
12.2.3 Study Results

Table 2 shows quantitative results from the human subject experiment. Overall, subjects found gradient-regularized model adversarial examples most convincing. On SVHN and especially MNIST, humans were most likely to think that gradient-regularized (rather than distilled or normal) adversarial examples were best classified as their target rather than their original digit. Additionally, when they did not consider the target the most plausible label, they were most likely to consider gradient-regularized model mispredictions “reasonable” (which we define in Table 2), and more likely to consider distilled model mispredictions unreasonable. p-values for the differences between normal and gradient regularized unreasonable error rates were 0.07 for MNIST and 0.08 for SVHN.

In the qualitative portion of the study (comparing MNIST JSMA examples), all of the written responses described significant differences between the insensitive model’s JSMA examples and those of the other two methods. Many of the examples for the gradient-regularized model were described as “actually fairly convincing,” and that the normal and distilled models “seem to be most easily fooled by adding spurious noise.” Few commentators indicated any differences between the normal and distilled examples, with several saying that “there doesn’t seem to be [a] stark difference” or that they “couldn’t describe the difference” between them. In the group discussion one subject remarked on how the perturbations to the gradient-regularized model felt “more intentional”, and others commented on how certain transitions between digits led to very plausible fakes while others seemed inherently harder. Although the study was small, both its quantitative and qualitative results support the claim that gradient regularization, at least for the two CNNs on MNIST and SVHN, is a credible defense against the JSMA and the iterated TGSM, and that distillation is not.

12.3 Connections to Interpretability

Figure 22: Input gradients that provide a local linear approximation of normal models (top), distilled models at (second from top), adversarially trained models (middle), and models trained with and gradient regularization (bottom two). Whitening black pixels or darkening white pixels makes the model more certain of its prediction. In general, regularized model gradients appear smoother and make more intuitive sense as local linear approximations.

Finally, we present a qualitative evaluation suggesting a connection between adversarial robustness and interpretability. In the literature on explanations, input gradients are frequently used as explanations (Baehrens et al., 2010), but sometimes they are noisy and not interpretable on their own. In those cases, smoothing techniques have been developed (Smilkov et al., 2017; Shrikumar et al., 2016; Sundararajan et al., 2017) to generate more interpretable explanations, but we have already argued that these techniques may obscure information about the model’s sensitivity to background features.

We hypothesized that if the models had more interpretable input gradients without the need for smoothing, then perhaps their adversarial examples, which are generated directly from their input gradients, would be more interpretable as well. That is, the adversarial example would be more obviously transformative away from the original class label and towards another. The results of the user study show that our gradient-regularized models have this property; here we ask if the gradients are more interpretable as explanations.

In Figure 22 we visualize input gradients across models and datasets, and while we cannot make any quantitative claims, there does appear to be a qualitative difference in the interpretability of the input gradients between the gradient-regularized models (which were relatively robust to adversarial examples) and the normal and distilled models (which were vulnerable to them). Adversarially trained models seem to exhibit slightly more interpretable gradients, but not nearly to the same degree as gradient-regularized models. When we repeatedly apply input gradient-based perturbations using the iterated TGSM (Figure 21), this difference in interpretability between models is greatly magnified, and the results for gradient-regularized models seem to provide insight into what the model has learned. When gradients become interpretable, adversarial images start resembling feature visualizations Olah et al. (2017); in other words, they become explanations.

13 Discussion

In this chapter, we showed that:

  • Gradient regularization slightly outperforms adversarial training (the SOTA) as a defense against black-box transferred FGSM examples from undefended models.

  • Gradient regularization significantly increases robustness to white-box attacks, though not quite as much as adversarial training.

  • Adversarial examples generated to fool gradient-regularized models are more “universal;” they are more effective at fooling all models than examples from unregularized models.

  • Adversarial examples generated to fool gradient-regularized models are more interpretable to humans, and examples generated from iterative attacks quickly come to legitimately resemble their targets. This is not true for distillation or adversarial training.

The conclusion that we would like to reach is that gradient-regularized models are right for better reasons. Although they are not completely robust to attacks, their correct predictions and their mistakes are both easier to understand. To fully test this assertion, we would need to run a larger and more rigorous human subject evaluation that also tests adversarial training and other attacks beyond the JSMA, FGSM, and TGSM.

Connecting what we have done back to the general idea of explanation regularization, we saw in Equation 7 that we could interpret our defense as a quadratic penalty on our CNN’s saliency map. Imposing this penalty had both quantitative and qualitative effects; our gradients became smaller but also smoother with fewer high-frequency artifacts. Since gradient saliency maps are just normals to the model’s decision surface, these changes suggest a qualitative difference in the “reasons” behind our model’s predictions. Many techniques for generating smooth, simple saliency maps for CNNs not based on raw gradients have been shown to vary under meaningless transformations of the model Kindermans et al. (2017) or, more damningly, to remain invariant under extremely meaningful ones (Adebayo et al., 2018) – which suggests that many of these methods either oversimplify or aren’t faithful to the models they are explaining. Our approach in this chapter was, rather than simplifying our explanations of fixed models, to optimize our models to have simpler explanations. Their increased robustness can be thought of as a useful side effect.

Although the problem of adversarial robustness in deep neural networks is still very much an open one, these results may suggest a deeper connection between it and interpretability. No matter what method proves most effective in the general case, we suspect that any progress towards ensuring either interpretability or adversarial robustness in deep neural networks will likely represent progress towards both.

14 Alternative Input Gradient Penalties

Before we leave input gradients behind altogether, it is worth considering what else we can do with them besides simple L2 regularization.

14.1 L1 Regularization

In Chapter Training Machine Learning Models by Regularizing their Explanations, we saw that penalizing the L2 norm of our model’s input gradients encouraged gradient interpretability and prediction robustness to adversarial examples, and drew an analogy to Ridge regression. One natural question to ask is how penalizing the L1 norm instead would compare, which we could understand as a form of local linear LASSO.

For a discussion of this question with application to sepsis treatment, we refer the reader to Ross et al. (2017a), which includes a case-study showing how L1 gradient regularization can help us obtain mortality risk models that are locally sparse and more consistent with clinical knowledge.

On image datasets (where input features are not individually meaningful), we do find that L1 gradient regularization is effective in defending against adversarial examples, perhaps more so than L2 regularization. To that end, in Figure 23 we present results for VGG-16 models on CIFAR-10, which bode favorably for L1 regularization against both white- and black-box attacks. However, although the gradients of these models change qualitatively compared to normal models, they are not significantly sparser than gradients of models trained with L2 gradient regularization. These results suggest that sparsity with respect to input features may not be a fully achievable or desirable objective for complex image classification tasks.

Figure 23: Left: Accuracy loss on CIFAR-10 FGSM examples (px) for VGG models trained with varying levels of L1 gradient regularization. Diagonals measure white-box vulnerability and off-diagonals measure transferability. Right: L1 vs. L2 gradient regularization on VGG as a defense against white-box FGSM examples, 2px perturbation. The value of is multiplied by 100 for the L1 regularized network to equalize penalty magnitudes (since we do not take the square root of the L2 penalty). Compared to L2, L1 gradient regularized models tend to be more robust to attacks like the FGSM, and their adversarial examples tend to be less transferable.

14.2 Higher-Order Derivatives

Bishop (1993) introduced the idea of limiting the curvature of the function learned by a neural network by imposing an L2 penalty on the network’s second input derivatives. They note, however, that evaluating these second derivatives increases the computational complexity of training by a factor of

, the number of input dimensions. This scaling behavior poses major practical problems for datasets like ImageNet, whose inputs are over 150,000-dimensional.

Rifai et al. (2011)

develop a scalable workaround by estimating the Frobenius norm of the input Hessian as

for , which converges to the true value as

. They then train autoencoders whose exact gradient and approximate Hessian norms are both L2-penalized, and find that the unsupervised representations they learn are more useful for downstream classification tasks.

Czarnecki et al. (2017) also regularize using estimates of higher-order derivatives.

Hessian regularization may be desirable for adversarial robustness and interpretability as well. The results in Figure 24 suggest that exact Hessian regularization for an MLP on a simple 2D problem encourages the model to learn flatter and wider decision boundaries than gradient regularization, which could be useful for interpretability and robustness. Hessian regularization also appears to behave more sensically even when the penalty term is much larger than the cross entropy. By contrast, in this regime, gradient regularization starts pathologically seeking areas of the input space (usually near the edges of the training distribution) where it can set gradients to 0.

Figure 24:

Gradient regularization (left) vs. Hessian regularization (right). Purple line indicates the true decision boundary; other lines indicate level sets of a 10-hidden unit MLP’s predicted log-odds from -5 to 5 by increments of 2.5, with the model’s decision boundary in green. Hessian regularization can make decision boundaries wider and flatter without triggering pathological cases.

15 Heftier Surrogates

While input gradient-based methods are appealing because of their close relationship to the shape and curvature of differentiable models’ decision surfaces, they are limited by their locality and humans’ inability to express abstract desiderata in terms of input features. This second limitation in particular prevents us from optimizing for the kind of simplicity or diversity humans find intuitive. Therefore, in the next sections we explore ways of training models using more complex forms of explanation.

One common way of explaining complicated models like neural networks is by distilling them into surrogate models; decision trees are a particularly popular choice Craven and Shavlik (1996). However, these decision trees must sometimes be quite deep in order to accurately explain the associated networks, which defeats the purpose of making predictions interpretable. To address this problem, Wu et al. (2017) optimize the underlying neural networks to be accurately approximatable by shallow decision trees. Performing such an optimization is difficult because the process of distilling a network into a decision tree cannot be expressed analytically, much less differentiated. However, they approximate it by training a second neural network to predict the depth of the decision tree that would result from the first neural network’s parameters. They then use this learned function as a differentiable surrogate of the true approximating decision tree depth. Crucially, they find a depth regime where their networks can outperform decision trees while remaining explainable by them. Although they only try to minimize the approximating decision tree depth, in principle one could train the second network to estimate other characteristics of the decision tree related to simplicity or consistency with domain knowledge (and optimize the main network accordingly).

16 Examples and Exemplars

Another popular way of explaining predictions is with inputs themselves. k-Nearest Neighbors (kNN) algorithms are easy to understand since one can simply present the neighbors, and techniques have recently been proposed to perform kNN using distance metrics derived from pretrained neural networks

(Papernot and McDaniel, 2018). More general methods involve sparse graph flows between labeled and unlabeled inputs (Rustamov and Klosowski, 2017) or optimization to find small sets of prototypical inputs that can be used for cluster characterization or classification (Kim et al., 2014), even within neural networks (Li et al., 2017). There has also been recent work on determining which points would most affect a prediction if removed from the training set (Koh and Liang, 2017)

. These approaches have both advantages and disadvantages. Justifying predictions based on input similarity and difference can seem quite natural, though it can also be confusing or misleading when the metric used to quantify distance between points does not correspond to human intuition. Influence functions shed light on model sensitivities that are otherwise very hard to detect, but they are also very sensitive to outliers, leading to sometimes inscrutable explanations.

However, it seems straightforward at least in principle to implement example-based explanation regularization. For example, we could train neural networks with annotations indicating that certain pairs of examples should be similar or dissimilar, and penalize the model when their intermediate representations are relatively distant or close (which might require altering minibatch sampling to keep paired examples together if annotations are sparse). Although influence functions may be too computationally expensive to incorporate into the loss functions of large networks, it seems useful in principle to specify that certain examples should be particularly representative or influential in deciding how to classify others.

17 Emergent Abstractions

Stepping back, the level of abstraction at which we communicate the reason behind a decision significantly affects its utility, as Keil (2006) notes:

Explanations… suffer if presented at the wrong level of detail. Thus, if asked why John got on the train from New Haven to New York, a good explanation might be that he had tickets for a Broadway show. An accurate but poor explanation at too low a level might say that he got on the train because he moved his right foot from the platform to the train and then followed with his left foot. An accurate but poor explanation at too high a level might say that he got on the train because he believed that the train would take him to New York from New Haven.

The explanations we have considered so far have been in terms of input features, entire inputs, or simple surrogates. However, sometimes humans seek to know the reasons behind predictions at levels of abstraction these forms cannot capture. If we really want to create interpretable interfaces for training and explaining machine learning models, humans and models will need to speak a common language that permits abstraction.

This may seem like a daunting task, but there has been important recent progress in interpreting neural networks in terms of abstractions that emerge during training. Bau et al. (2017) introduce a densely labeled image dataset. They train convolutional neural networks on a top-level classification task, but also include lower-level sublabels that indicate other features in the image. They measure the extent to which different intermediate nodes in their top-level label classifiers serve as exclusive “detectors” for particular sublabels, and compare the extent to which different networks learn different numbers of exclusive detectors. They also categorize their sublabels and look at differences in which kinds of sublabels each network learns to detect (and when these detectors emerge during training).

Kim et al. (2017)

provide a method of testing networks’ sensitivity to concepts as defined by user-provided sets of examples. Concretely, they train a simple linear classifer at each layer to distinguish between examples in the concept set and a negative set. They reinterpret the weights of this linear classifier as a “concept activation vector,” and take directional derivatives of the class logits with respect to these concept activations. Repeated across the full dataset for many different concepts, this procedure outputs a set of concept sensitivity weights for each prediction, which can be used for explanation or even image retrieval.

The previous two methods require manual human selection of images corresponding to concepts, and they do not guarantee meaningful correspondence between these concepts and what the network has learned. Feature visualization (Olah et al., 2017)

takes a different approach and attempts to understand what the network has learned on its own terms. In particular, it tries to explain what (groups of) neuron(s) learn by optimizing images to maximize (or minimize) their activations. It can also optimize sets of images to jointly maximize activations while encouraging diversity. This process can be useful for obtaining an intuitive sense of (some of) what the model has learned, especially if the neurons being explained are class logits. However, it also leads to an information overload, since modern networks contain millions of neurons and an effectively infinite number of ways to group them. To that end,

Olah et al. (2018) use non-negative matrix factorization (NMF) to learn a small number of groups of neurons whose feature visualizations best summarize the entire set. Feature visualizations of neuron groups obtained by NMF tend to correspond more cleanly to human-interpretable concepts, though again there is no guarantee this will occur. Olah et al. (2018) also suggest that incorporating human feedback into this process could lead to a method to train models to make decisions “for the right reasons.”

The above cases either take humans concepts and try to map them to network representations or take network “concepts” and try to visualize them so humans can map them to their own concepts. But they do not actually try to align network representations with human concepts. However, there has been significant recent interest in training models to learn disentangled representations (Chen et al., 2016; Higgins et al., 2016; Siddharth et al., 2017). Disentangled representations are often described as separating out latent factors that concisely characterize important aspects of the inputs but which cannot be easily expressed in terms of their component features. Generally, disentangled representations tend to be much easier to relate to human-intuitive concepts than what models learn when only trained to minimize reconstruction or prediction error.

Figure 25: Accuracies of a normal model and two models trained using find-another-explanation in a disentangled latent space (right) a toy image dataset that confounds background color and square size in training (left) but decouples them in test. Performing find-another-explanation in a latent space allows us to learn models that make predictions for conceptually different reasons, which is reflected in their complementary accuracies on each version of the test set.

These advances in bridging human and neural representations could have major payoffs in terms of interpreting models or optimizing them to make predictions for specific reasons. Suppose we are interested in testing a classifier’s sensitivity to an abstract concept entangled with our input data. If we have an autoencoder whose representation of the input disentangles the concept into a small set of latent factors, then for a specific input, we can encode it, decode it, and pass the decoded input through the classifier, taking the gradient of the network’s output with respect to the latent factors associated with the concept. If we fix the autoencoder weights but not the classifier weights, we can use this differentiable concept sensitivity score to apply our “right for the right reasons” technique from Chapter Training Machine Learning Models by Regularizing their Explanations to encourage the classifier to be sensitive or insensitive to the concept.

We present a preliminary proof of concept of this idea in Figure 25. In this experiment, we construct a toy dataset of images of white squares with four true latent factors of variation: the size of the square, its x and y position, and the background color of the image. In training, background color and square size are confounded; images either have dark backgrounds and small squares or light backgrounds and large squares (and either one can be used to predict the label). However, we create two versions of the test set where these latent factors are decoupled (and only one predicts the label). This is analogous to the parable in our introduction with squares representing tanks and background colors representing light. When we train a one-hidden layer MLP normally, it learns to implicitly use both factors, and obtains suboptimal accuracies of about 75% on each test set. To circumvent this issue, we first train a convolutional autoencoder that disentangles square size from background color (which we do with supervision here, but in principle this can be unsupervised) and then prepend the autoencoder to our MLP with fixed weights. We then simultaneously train two instantiations of this network with the find-another-explanation penalty we introduced in Section 8. These two networks learn to perform nearly perfectly on one test set and do no better than random guessing on the other, which suggests they are making predictions for different conceptual reasons. Obtaining these networks would have been very difficult using only gradient penalties in the input space.

18 Interpretability Interfaces

Olah et al. (2018) describe a space of “interpretability interfaces” and introduce a formal grammar for expressing explanations of neural networks (and a systematic way of exploring designs). They visualize this design space in a grid of relationships between different “substrates” of the design, which include groups of neurons, dataset examples, and model parameters – the latter of which presents an opportunity “to consider interfaces for taking action in neural networks.” If human-defined concepts, disentangled representations, or other forms of explanation are included as additional substrates, one can start to imagine a very general framework for expressing priors or constraints on relationships between them. These would be equivalent to optimizing models to make predictions for specific reasons.

Figure 26: Schematic diagram of an interpretability interface.

How would humans actually express these kinds of objectives? One interface worth emulating could be that introduced by recent but popular libraries for weak supervision (Ratner et al., 2017) or probabilistic soft logic (Bach et al., 2017), which is related to the well-studied topic of fuzzy logic, a method noted for its compatibility with human reasoning (Zadeh, 1997). In these frameworks, users can specify “soft” logical rules for labeling datasets or constraining relationships between atoms (or substrates) of a system. Though users can sometimes specify that certain rules are inviolable or highly-weighted, in general these systems assume that rules are not always correct and attempt to infer weights for each. While these inference problems are nontrivial, and in general there may be complex, structured interactions between rules that are difficult to capture, the interface it exposes to users is expressive and potentially worth emulating in an interpretability interface. For example, we could imagine writing soft rules relating:

  • dataset examples to each other (e.g. these examples should be conceptually similar with respect to a task)

  • dataset examples to concepts (e.g. these are examples of a concept)

  • features to concepts (e.g. this set of features is related to this concept, this other set is not; in this specific case, these features contribute positively)

  • concepts to predictions (e.g. the presence of this concept makes this prediction more or less likely, except when this other concept is present)

These rules could be “compiled” into additional energy terms in the model’s loss function, possibly with thresholding if we expect them to be incorrect some percentage of the time (though rules defined for specific examples may be more reliable). We present a schematic diagram of how a system like this might work in Figure 26.

Such a system would strongly depend on being able to define rules in terms of abstract concepts, but such rules might not be enforcible until the model has a differentiable, stable representations of them. However, one could imagine pre-learning static, disentangled concept representations that could be related back to input features. If 1:1 mappings between human concepts and latent representations do not emerge naturally, even allowing for hierarchical relationships (Esmaeili et al., 2018), steps could be taken to optimize model representations to better match human understanding (e.g. using partial supervision) or to help humans better understand model representations (e.g. using feature visualization). This process of reaching user-model intersubjectivity might require multiple stages of identification and refinement, but seems possible in principle. And perhaps arriving at a shared conceptual framework for understanding a problem is where the work of teaching and learning ought to lie, regardless of whether the teachers and learners are human.

19 Discussion

In this chapter, we discussed a number of strategies for explanation regularization beyond the methods we used in the previous chapters. We described simple extensions of gradient-based methods (imposing L1 and Hessian penalties), strategies in terms of interpretable surrogates (regularizing distilled decision trees, nearest neighbors, and exemplars), and strategies in terms of concepts (concept activation vectors, disentangled representations, and feature visualization). We then combined many of these strategies into a design for an “interpretability interface” that could be used to simultaneously improve neural network interpretability and incorporate domain knowledge.

One limitation of this discussion is that we only considered classification models and traditional ways of explaining their predictions. However, there is a much larger literature on alternative forms of explanation and prediction like intuitive theories (Gerstenberg and Tenenbaum, 2017) or causal inference (Pearl, 2010) that is highly relevant, especially if we want to apply these techniques to problems like sequential decisionmaking. We started this thesis by making a point that was “easiest to express with a story;” even with arbitrarily human-friendly compositional abstraction (Schulz et al., 2017), flat sets of concepts may never be sufficient in cases where users think in terms of narratives (Abell, 2004).

However, despite these limitations, we think the works we have outlined in this chapter have started to map a rich design space for interpreting and training machine learning models with more than just es and s.