How do Decisions Emerge across Layers in Neural Models? Interpretation with Differentiable Masking

04/30/2020 ∙ by Nicola De Cao, et al. ∙ University of Amsterdam 27

Attribution methods assess the contribution of inputs (e.g., words) to the model prediction. One way to do so is erasure: a subset of inputs is considered irrelevant if it can be removed without affecting the model prediction. Despite its conceptual simplicity, erasure is not commonly used in practice. First, the objective is generally intractable, and approximate search or leave-one-out estimates are typically used instead; both approximations may be inaccurate and remain very expensive with modern deep (e.g., BERT-based) NLP models. Second, the method is susceptible to the hindsight bias: the fact that a token can be dropped does not mean that the model `knows' it can be dropped. The resulting pruning is over-aggressive and does not reflect how the model arrives at the prediction. To deal with these two challenges, we introduce Differentiable Masking. DiffMask relies on learning sparse stochastic gates (i.e., masks) to completely mask-out subsets of the input while maintaining end-to-end differentiability. The decision to include or disregard an input token is made with a simple linear model based on intermediate hidden layers of the analyzed model. First, this makes the approach efficient at test time because we predict rather than search. Second, as with probing classifiers, this reveals what the network `knows' at the corresponding layers. This lets us not only plot attribution heatmaps but also analyze how decisions are formed across network layers. We use DiffMask to study BERT models on sentiment classification and question answering.



There are no comments yet.


page 2

page 9

page 10

page 12

page 19

Code Repositories


Pytorch implementation of DiffMask

view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

(a) Integrated Gradient (Sundararajan et al., 2017).
(b) Restricting the Flow (Schulz et al., 2020)
(c) NLP explainer (Guan et al., 2019).
(d) Our DiffMask.
(e) Erasure exact search optima.
(f) Our DiffMask non-amortized.
Figure 2: Question Answering: Predictions of previous perturbation-based methods, (b) and (c), are misleading as they attribute the prediction mostly to the answer span itself (underlined). Our method (d) reveals that the model pays attention to the question type (e.g., the where token) as well as named entities and predicate ‘practice’. Predictions of the path-based methods (a) are more spread-out. Exact search leads to pathological attributions (e) and the same happens for our tractable but approximate search (f) when no amortization is used.

Deep neural networks (DNN) have become standard tools in NLP demonstrating impressive improvements over traditional approaches on many tasks 

(Goldberg, 2017). Their success is due to their ability to capture complex non-linear relations and induce powerful features. Unfortunately, their power and flexibility come at the expense of interpretability. This lack of interpretability can prevent users from trusting model predictions Kim (2015); Ribeiro et al. (2016), makes it hard to detect model or data deficiencies Gururangan et al. (2018); Kaushik and Lipton (2018) or verify that a model is fair and does not exhibit harmful biases Sun et al. (2019); Holstein et al. (2019).

These challenges have motivated a massive amount of work on interpretability, both in NLP and generally in machine learning; see 

Belinkov and Glass (2019) and Jacovi and Goldberg (2020) for reviews. In this work, we study post hoc interpretability where the goal is to explain the prediction of a trained model and to reveal how the model arrives at the decision. This goal is usually approached with attribution methods (Bach et al., 2015; Shrikumar et al., 2017; Sundararajan et al., 2017), which explain the behavior of a model by assigning relevance to inputs.

One way to perform attribution is to use erasure where a subset of features (i.e. usually input tokens) is considered irrelevant if it can be removed without affecting the model prediction (Li et al., 2016; Feng et al., 2018)

. The advantage of erasure is that it is conceptually simple and optimizes a well-defined objective. This contrasts with most other attribution methods which rely on heuristic rules to define feature salience, for example, attention-based attribution 

(Rocktäschel et al., 2015; Serrano and Smith, 2019; Vashishth et al., 2019) or back-propagation methods (Bach et al., 2015; Shrikumar et al., 2017; Sundararajan et al., 2017). These approaches received much scrutiny in recent years (Nie et al., 2018; Sixt et al., 2019; Jain and Wallace, 2019), as they cannot guarantee that the network is ignoring the low-scored features. They are often motivated as approximations of erasure (Baehrens et al., 2010; Simonyan et al., 2013; Feng et al., 2018) and sometimes even evaluated using erasure (e.g., Serrano and Smith (2019); Jain and Wallace (2019)).

(a) Erasure search.
(b) Schulz et al. (2020).
(c) Sundararajan et al. (2017).
(d) Guan et al. (2019)
(e) Our DiffMask conditioned on embedding layer (left) and hidden states (right).
Figure 3: Input attributions of several methods on a toy task: Given a sequence of digits and a query ( and in this example) of two digits, determine whether there are more than in . The query and input embeddings are concatenated, fed to a feed-forward NN, and then to a single layered GRU. Attributions are normalized to , for visualization.

Despite its conceptual simplicity, subset erasure is not commonly used in practice. First, it is generally intractable, and beam search Feng et al. (2018) or leave-one-out estimates Zintgraf et al. (2017) are typically used instead. These approximations may be inaccurate. For example, leave-one-out can underestimate contribution of features due to saturation Shrikumar et al. (2017). More importantly, even these approximations remain very expensive with modern deep (e.g., BERT-based) models, as they require multiple computation passes through the model. Second, the method is susceptible to the hindsight bias: the fact that a feature can be dropped does not mean that the model ‘knows’ that it can be dropped and that the feature is not used by the model when processing the example. This results in over-aggressive pruning that does not reflect what information the model uses to arrive at the decision. The issue is pronounced in NLP tasks (see Figure 1(e) and Feng et al. (2018)), though it is easier to see on the following artificial example (Figure 2(a)). A model is asked to predict if there are more s than s in the sequence. The erasure attributes the prediction to a single digit, as this reduced example yields the same decision as the original one. However, this does not reveal what the model was actually relying on: in fact it has counted digits and as otherwise it would not have achieved the perfect score on the test set.

We propose a new method, Differentiable Masking (DiffMask), which overcomes the aforementioned limitations and results in attributions that are more informative and help us understand how the model arrives at the prediction. DiffMask relies on learning sparse stochastic gates (i.e. masks), guaranteeing that the information from the masked-out inputs does not get propagated while maintaining end-to-end differentiability without having to resort to REINFORCE Williams (1992); Li et al. (2016). The decision to include or disregard an input token is made with a simple linear model based on intermediate hidden layers of the analyzed model (see Figure 1). First, this amortization circumvents the need for combinatorial search making the approach efficient at test time. Second, as with probing classifiers Adi et al. (2017); Belinkov and Glass (2019), this reveals whether the network ‘knows’ at the corresponding layer what input tokens can be masked.

The amortization lets us not only plot attribution heatmaps, as in Figure 1(d), but also analyze how decisions are formed across network layers. In our artificial example, we can see that in the bottom embedding layer the model cannot discard any tokens, as it does not ‘know’ which digits need to be counted. In the second layer, it ‘knows’ that these are s and s, so the rest gets discarded. On a question answering task (see Figure 8(a)), where we use a -layer model, it takes layers for the model to ‘realize’ that ‘Santa Clara Marriott’ and ‘Stanford University’ are not relevant to the question and discard them. In that way, we go beyond attribution, characterizing instead of how the decision is being incrementally formed by the model.

We also adapt our method to measuring the importance of intermediate states rather than inputs (see Figure 4). This, as we discuss later, lets us analyze which states in every layer store information crucial for making predictions and gives us extra insights about the information flow.


Our contributions are as follows:

  • [topsep=0pt,itemsep=0pt]

  • we analyze limitations of existing attribution-based methods, especially erasure and its approximations;

  • we propose a novel approach, DiffMask, addressing the shortcomings and revealing how a decision is formed across network layers;

  • we use DiffMask to analyse BERT models fine-tuned on sentiment classification and question answering.

2 Related Work

While we motivated our approach through its relation to erasure, an alternative way of looking at our approach is considering it as a perturbation-based method. This recently introduced class of attribution methods Ying et al. (2019); Guan et al. (2019); Schulz et al. (2020); Taghanaki et al. (2019), instead of erasing input, injects noise. These methods can be regarded as continuous relaxations of erasure, though they are typically motivated from the information-theoretic perspective. The previous approaches use continuous gates which may be problematic when the magnitude of input changes or requires making (Gaussian) assumptions about the input distribution. This means that information about the input can still leak to the predictor. These methods are also, similarly to subset erasure, susceptible to hindsight bias. Our method uses mixed discrete-continuous gates, which can completely block the flow, and amortization to address both these issues. We compare to perturbation-based methods in our experiments.

Besides back-propagation and attention-based methods discussed in the introduction, another class of interpretation methods Murdoch and Szlam (2017); Singh et al. (2018); Jin et al. (2020)

builds on prior work in cooperative game theory (e.g., Shapley value of 

Shapley (1953)). These methods are not trivial to apply to new architectures, as new architecture-specific decomposition rules need to be derived. Their hierarchical versions (e.g., Singh et al. (2018); Jin et al. (2020)) also make a strong assumption about the structure of interaction (e.g., forming a tree) which may affect their faithfulness.

There is a large body of literature on analyzing BERT and other Transformed-based models. For example, Tenney et al. (2019) and van Aken et al. (2019) considered probing BERT layers for a range of linguistic tasks, while Hao et al. (2019) analyzed the optimization surface. Rogers et al. (2020) provides a comprehensive overview of recent BERT analysis papers.

3 Method

We are interested in understanding how a model processes an input (e.g., a sentence) to produce an output (e.g., a vector of class probabilities). This is a complex and mostly opaque mapping realized by a stack of parameterized transformations of the input. Inspecting the parameters of the mapping or the intermediate representations themselves is all but obvious. We approach the challenge by answering two questions for every layer along with the mapping from input to output, namely,

what does the model know and where does it store information?

Concretely, for each hidden layer, we first want to understand what parts of the input are necessary to arrive at the prediction. We think of this as probing what that hidden layer ‘knows’ (e.g., in our toy task, see Figure 2(e), the hidden layer knows that the goal is to determine whether there are more s than s as everything else can be disregarded to arrive at the correct prediction). We reveal this information by mapping the states in a hidden layer to a masked version of the input (that is, with some tokens discarded as irrelevant). This masked input is such that, should we feed the model with it, the output would not change. We aim to mask as much as possible, thus revealing from each layer’s perspective the minimum necessary the model must read to preserve the information that layer contributes to the process of composing the original output. This offers a human-readable view of how the prediction is incrementally formed (e.g., in Figure 8(a) and 8(b) our method highlights that low layers predict that some tokens such as determinants or punctuation can be completely ignored while potential answer spans have to be kept; conversely, higher layers can make a more refined prediction to mask spans that do not contain the answer). Second, we want to know where the model stores information that is necessary for the prediction. This gives us insights into its encoding process. Revealing this kind of information requires a minor modification to our framework, in particular, we need only change the objective of the probe. Instead of probing a hidden state for unnecessary inputs, we probe it for its relevance towards the original output. If a state can be masked-out without impact on the output, it most likely does not store information that is important from that layer onward. See Figure 6(b) for an example in sentiment classification, towards the top of the stack, information is eventually stored in a single state.

3.1 Masking inputs

At a conceptual level, we may think of a neural network model as a stack of transformations of an input towards an output. We use to denote this entire computation, where denotes an input (i.e., a sequence of embedded tokens), and denotes the model’s trainable parameters, which are given and fixed for the purpose of interpretation. DiffMask builds on the notion of erasure, whereby tokens that can be dropped without affecting the output are regarded as unnecessary. However, rather than searching through the space of alternative inputs for those that preserve the output, we probe hidden states, at different layers, for which tokens can be masked-out. On the one hand, this addresses erasure’s poor scalability (i.e., the space of masked inputs grows exponentially in sequence length). On the other hand, probing hidden states for what they know overcomes erasure’s hindsight bias.

In DiffMask, for an input (also denoted ), we do a forward pass with the trained model obtaining the output as well as all of the model’s intermediate hidden states . We then probe the model, at different hidden layers, for unnecessary inputs using a shallow prediction model. This interpreter model takes hidden states up to a certain layer and outputs a binary mask indicating which input tokens are necessary and which can be disregarded. To appreciate whether the masked input is sufficient, we re-feed the model with it and compute the output . To avoid making changes to the computation graph of and to avoid changing the distribution of input length, rather than dropping tokens, we ‘mask’ them. Masking, however, as in multiplication by zero, makes a strong assumption about the geometry of the feature space, in particular, it assumes that the zero vector bears no information. Instead, we replace some of the inputs by a learned baseline vector , i.e.,


Concretely, the interpreter model consists of classifiers, the th of which conditions on the stack of hidden states to predict binary ‘votes’ towards keeping or masking input tokens. For a given depth , the interpreter decides to mask out as soon as for some . That is, in order to deem unnecessary, it is sufficient to do so based on any subset of hidden states up until . We realize this by aggregating binary votes via a product


The precise parameterization of is discussed in Appendix A.

Clearly, there is no direct supervision to estimate the parameters of the probe and the baseline , thus we borrow erasure’s conceptual objective, namely, we train the probe to mask-out as many input tokens as possible constrained to keeping . Where

parameterizes a likelihood, we measure changes to the output in terms of Kullback-Leibler divergence.

111Where does not admit a probabilistic interpretation, a distance can be used. As constrained optimization is generally intractable, in Section 3.3 we resort to Lagrangian relaxation and stochastic gradient-based optimization. See Figure 1 for an overview of DiffMask.

Finally, deterministically predicting a binary mask calls for a discontinuous activation, such as a threshold function, for which gradients are not well defined. To enable gradient-based learning, we employ stochastic masks as well as a relaxation to binary variables (Section 

3.3) that admits differentiable sampling while retaining sparsity, in particular, producing true zeros.

3.2 Masking hidden states

To reveal which hidden states store information necessary for realizing the prediction, we modify the probe slightly. Again, for a given depth , we condition on the stack of states up until and including , and again we have a classifier that predicts a mask . This time, however, we use the mask to replace some of the states in by a layer-specific baseline vector , i.e.,


The resulting state is used to re-compute every subsequent hidden state, , as well as the output, which we denote by . To the extent approximates the original output well, we deem the states masked by unnecessary.

Parameter estimation for the probe and baselines is done in a similar way as before, namely, we aim to mask-out as many hidden states as possible constrained to keeping as measured by a divergence of choice. See Figure 4 for an overview of this variant of DiffMask and Section A for the specification of .

Figure 4: Overview of DiffMask to inspect the importance of hidden units. We take hidden states up to layer from a model (top) and feed them to a classifier that predicts a mask . We use this to mask states of the th layer and re-compute the forward pass from that point on (bottom). The classifier is trained to mask as much of as possible without changing the output (minimizing a divergence ).

3.3 Parameter estimation

In this section we describe how to estimate the parameters of the probe as well as the baselines. For ease of notation and without loss of generality, we focus on the variant of DiffMask where we mask inputs (Section 3.1). Inspired by erasure, for each input , we would like to mask-out as much of it as possible without changing the output of the model, that is, finding a mask with as many entries set to as possible and still have . Unlike erasure, rather than combinatorial search, we frame this as a learning problem. That is, we treat as prediction realised by an interpreter network which is shared (i.e., amortized) across data points.222This is reminiscent of how inference networks amortize prediction of local variational factors in a variational auto-encoder. This rather shallow network is trained to satisfy erasure’s requirements, that is, mask as much as possible subject to keeping the prediction unchanged. We can cast this, rather naturally, in the language of constrained optimization and employ a method such as Lagrangian relaxation. In general, however, it is not possible to guarantee equality between and ,333Since is a smooth function a minimal change in its input cannot produce the exact same output. thus we introduce i) a divergence to measure how much the two outputs differ, and ii) a tolerance level within which differences are regarded as acceptable. The choice of depends on the structure of the output of the original model. For instance, for a deterministic regression model, seems a convenient choice, whereas for probabilistic classification, where parameterizes a categorical distribution, cross-entropy or Kullback–Leibler divergence are more appropriate.


A practical way to minimize the number of non-zeros predicted by is minimizing the ‘norm’.444, denoted and defined as , is the number of non-zeros entries in a vector. Contrary to or , is not a homogeneous function and, thus, not a proper norm. Contemporary literature, however, with some abuse of terminology, refers to it as a norm and we do so as well to avoid confusion. Thus, our loss is defined as the number of positions that are not masked:


where is the indicator function. We minimize for all data-points in the dataset subject to a constraint that predictions from masked inputs have to be similar to the original model predictions:


where and . Since non-linear constrained optimisation is generally intractable, we employ Lagrangian relaxation (Boyd et al., 2004) optimizing instead


where is the positive Lagrangian multiplier. Although Lagrangian relaxation allows a practical approach to constrained optimization, this objective is not differentiable with respect to and

, and thus standard stochastic gradient descent methods cannot be employed. That is because i)

is non-differentiable, and ii) to have output binary masks we need a non-differentiable output activation such as the step function.

Stochastic masks

To estimate the parameters of the probe via gradient-based optimization, we cannot have predict discrete masks deterministically. Though proxy (biased) gradients do exist (e.g., deterministic straight-through (Bengio et al., 2013)), they lack theoretical support. A better understood strategy is to give binary variables stochastic treatment and compute the objective in expectation. This requires a simple change to , namely, rather than directly predicting vote

, it parameterizes a Bernoulli distribution from which we then sample a vote in order to compose a mask

. Computing the objective in expectation addresses both sources of non-differentiability, but introduces a difficulty, namely, assessing the loss and the constraint for every configuration of is intractable. The intractability of the expectation forces us to resort to gradient estimation, e.g. via REINFORCE (Williams, 1992)

, which can suffer from high variance as we demonstrate in Section 

4. Instead, we employ a relaxation to binary variables that admits sampling through a differentiable reparameterization while retaining sparsity (Louizos et al., 2018), in particular, producing true zeros.

Sparse relaxation

Instead of sampling binary votes from Bernoulli distributions, we sample votes from Hard Concrete distributions (Louizos et al., 2018), a mixed discrete-continuous distribution over the closed interval . The Hard Concrete distribution, which we review in Appendix B, assigns density to continuous outcomes in the open interval and non-zero mass to exactly and exactly . A particularly appealing property of this distribution is that sampling can be done via a differentiable reparameterization (Rezende et al., 2014; Kingma and Welling, 2014). In this way, the loss in Equation 4 becomes an expectation


whose gradient can be estimated via Monte Carlo sampling without the need for REINFORCE and without introducing biases. We did modify the original Hard Concrete, though only so slightly, in a way that it gives support to samples in the half-open interval , that is, with non-zero mass only at . That is because we need only distinguish from non-zero, and the value is not particularly important.555Only a true is guaranteed to completely mask an input out, while any non-zero value, however small, may leak some amount of information.

Latent rationales

There is a stream of work on learning interpretable models by means of extracting latent rationales (Lei et al., 2016; Bastings et al., 2019). Some of the techniques underlying DiffMask are related to that line of work, but overall we approach very different problems. Lei et al. (2016) use REINFORCE to minimize a downstream loss computed on masked inputs, where the masks are binary and latent. They employ regularization to solve the task while conditioning only on small subsets of the input regarded as a rationale for the prediction. To the same end, Bastings et al. (2019) minimize downstream loss subject to constraints on expected using a variant of the sparse relaxation of Louizos et al. (2018). In sum, they employ stochastic masks to learn an interpretable model which they learn by minimizing a downstream loss subject to constraints on , we employ stochastic masks to interpret an existing model and for that we minimize subject to constraints on that model’s downstream performance.

4 Experiments

The goal of this work is to uncover a faithful interpretation of an existing model, i.e. revealing, as accurately as possible, the process by which the model arrives at the prediction. Human-provided labels will not help us in demonstrating this (e.g., human rationales Camburu et al. (2018); DeYoung et al. (2019)), as humans cannot judge if an interpretation is faithful Jacovi and Goldberg (2020). On the contrary, one is often interested in using such attribution methods to uncover pathologies (or biases) in models or hidden biases in the data. Evaluating faithfulness is challenging, as ground-truth is not known on real tasks and with real models. Our strategy is to i) show the effectiveness of DiffMask in a controlled setting where ground-truth is available (Section 4.1); ii) test the effectiveness of our relaxation for learning discrete masks (on a real model, Section 4.2.1); iii) demonstrate that the method is stable and that accuracy does not degrade due to masking (Section 4.2.2). Once we have established that DiffMask can be trusted, we use it to analyse BERT-based models fine-tuned on sentiment classification (Sections, where we also contrast its attributions with those of other methods, and on question answering (Section 4.3).

4.1 Toy task

To help establish the faithfulness of DiffMask we use it to analyse a model for which the gold-truth attributions are known. Our toy task is defined as follows: given a sequence of digits (i.e., ), and a query about two digits, determine whether there are more than in . We generate sequences of varying length (up to digits long) sampling each element independently: with probability, we draw uniformly or and, with probability, we draw uniformly from the remaining digits. We generate k data-points, keeping of them for validation.666The total number of possible input sequences for this task is . Thus, a model that solves the task cannot simply memorize the training set. Intuitively, solving the task is very easy: a model has to track the occurrences of and , ignoring all the other inputs.


We implement a shallow model. It consists of an embedding layer of dimensionality . Then, the embedded query and input are concatenated and fed to a single-layer feed-forward NN, followed by a single-layer unidirectional GRU (Cho et al., 2014).777We use a feed-forward NN to incorporate the query information, rather than another GRU layer, to ensure that counting cannot happen in the first layer. This helps us define the ground-truth for the method. The classification is done by a linear classifier that acts on the last hidden state of the GRU. Unsurprisingly the model solves the task almost perfectly (accuracy on validation is ).


We designed a model for which the ground-truth can be identified. In particular, input states are designed such that they store embeddings for the query ( and ) and the corresponding digit . After the feed-forward layer, hidden states need no longer store the identity of the digit (i.e., ) but simply whether is , , or other. We verify this by plotting the distribution of hidden states (which we set to dimension with the purpose of having a bottleneck and a clear visualization) in Figure 12 in Appendix D, where we observe linear separation between states of digits in the query and states of digits not in the query. This confirms that the role of the feed-forward layer is to decide which digits to keep, while the GRU must figure out which one occurred the most. In sum, we know the prediction must be attributed uniformly to all of the input positions where is or .

Erasure – *
Sundararajan et al. (2017)
Schulz et al. (2020)
Guan et al. (2019)
Table 1: Toy task: average divergence in nats between the ground-truth attributions and those different methods assigned to hidden states in the validation set. *Erasure produces a delta distribution that does not share support with the ground-truth.

We compare DiffMask to integrated gradient (Sundararajan et al., 2017), as one of the most widely used attribution methods, as well as the perturbation methods by Schulz et al. (2020) and Guan et al. (2019). We also perform erasure by searching exhaustively for masked inputs that yield the same prediction.


We start with an example of input attributions, see Figure 3, which illustrates how DiffMask goes beyond input attribution as typically known.888 To enable comparison across methods, the attributions in this Section are re-normalized to (whereas DiffMask gates before normalization tend to be or ). The attribution provided by erasure (Figure 2(a)) is not informative: the search in this case, and in all other examples in the test set, finds a single digit that is sufficient to maintain the original prediction and discards all the other inputs. The perturbation methods by Schulz et al. (2020) and Guan et al. (2019) (Figure 2(b) and 2(d)) are also over-aggressive in pruning. They assign low attribution to some items in the query even though those had to be considered when making the prediction. Integrated gradient (Figure 2(c)) assigns high importance to the digits that appear in the query. Differently from other methods, our DiffMask reveals input attributions conditioned on different levels of depth. Figure 2(e) shows both input attributions according to the input itself (left) and according to the hidden layer (right). It reveals that at the embedding layer there is no information regarding what part of the input can be erased: attribution is uniform over the input sequence. After the model has observed the query, hidden states predict that masking input digits other than and will not affect the final prediction: attribution is uniform over digits in the query. This reveals the role of the feed-forward layer as a filter for positions relevant to the query. Other methods do not allow for this type of inspection. These observations are consistent across the entire test set. For attribution to hidden states (i.e., the output of the feed-forward layer) we can compare all methods in terms of how much their attributions resemble the ground-truth across the test set. Table 1 shows how the different approaches deviate from the gold-truth in terms of Kullback-Leibler () and Jensen–Shannon () divergences.999 We use where is the ground-truth attribution distribution and the attribution distribution of a method we want to compare with. is an asymmetric divergence where is a symmetric version of it. Both measure how much two distribution differ (i.e., lower the value more similar the two distributions are). Our method, unlike other methods, achieves perfect scores on the test set.

Metrics REINFORCE+ DiffMask
Table 2: Sentiment classification: optimization with DiffMask and REINFORCE (with a moving average baseline for variance reduction) not amortised against erasure exact search. All metrics are computed at token level where optimality is measured at sentence level.
(a) Attention.
(b) Integrated gradient*.
(c) Schulz et al. (2020).
(d) Guan et al. (2019).
(e) Our DiffMask.
Figure 5:

Sentiment classification: comparison between attribution method for hidden layers w.r.t. the predicted label. All plots are normalized per-layer by the largest attribution. Attention heatmap is obtained max pooling over heads and averaging across positions. *By 

Sundararajan et al. (2017).
Figure 6: Average keep probability for every hidden layer of start and end-of-sentence tokens.

4.2 Sentiment Classification

We turn now to a real task and analyze a standard BERTBASE model fine-tuned for sentiment classification on the Stanford Sentiment Treebank  (SST; Socher et al., 2013). It consists of a pre-trained BERTBASE followed by a pooling layer (which simply takes the hidden state of the first token) and a linear classifier. The model is trained with the cross-entropy loss to predict one of the sentiment classes: very negative, negative, neutral, positive, and very positive. We then apply our DiffMask for both input attribution (i.e., as function of hidden states at different depths) and hidden state attribution. We train DiffMask with a KL divergence constraint .101010We do not update BERT parameters when optimizing DiffMask. Memory and computation overhead is negligible.Hyperparameters are reported in Appendix C.1.

4.2.1 Erasure search as learning masks

Before diving into analysis of the sentiment model, we would like to demonstrate we can learn informative subsets through our differentiable relaxations. In order to do this, we need to have access to the ground-truth. We do not have it for our full approach, but we can obtain it when we do not use amortization (i.e. when all Hard Concrete parameters are learned for a specific example rather than predicted from the BERT states). In that case an optimal solution (or a set of equally good optimal solutions) is provided by erasure. We compare DiffMask to REINFORCE (Williams, 1992) with a moving average baseline for variance reduction. Since erasure requires exact search, and it is unfeasible for long sequences, we evaluate here using only sentences up to words ( of the data). In Table 2 we show the superiority of DiffMask to REINFORCE. Both achieved a comparable level of sparsity while our method reaches an optimal solution much more often than REINFORCE (45% of the times vs 16%) and is, on average, closer to an optimal solution (87% vs 77% ).

4.2.2 Maintaining prediction and stability

Now, we get back to the fully amortized DiffMask approach and verify that there is no performance degradation when applying masking. Indeed, the macro score of the model on validation moved from to while masking input tokens and to while masking hidden states. The explanation provided by DiffMask are also stable. Across

independent runs with different seeds, the standard deviation of input attributions is

and for hidden state attributions is (averaged across the validation set).

(a) Masking input with amortization.
(b) Masking hidden states with amortization.
(c) Masking hidden states without amortization.
(d) Masking hidden states without amortization and without baseline.
Figure 7: Sentiment classification: highlighting difference between input and hidden state attributions in (a) and (b), and ablation study on amortization in (b), (c) and (d).
(a) POS tags on input.
(b) Sentiment on input.
(c) Sentiment on hidden.
Figure 8: Sentiment classification: average number of layers that predict to keep input tokens or hidden states on validation set. (a) shows average predictions on input aggregating by part-of-speech tag (POS) where (b) and (c) by token level sentiment annotations.

4.2.3 Comparisons

Our previous experiments were aimed primarily at showing that DiffMask can be trusted. Now, we finally turn to using DiffMask to actually analyze the sentiment model.

While previous techniques (e.g., integrated gradient) do not let us test what a model ‘knows’ in a given layer (i.e. attribution to input conditioned on a layer), they can be used to perform attribution to hidden layers. In Figure 5 we compare our method with recent techniques in that regime. In Figure 13 in Appendix D we show an additional example.

Raw attention (Figure 4(a)) does not seem to highlight any significant patterns in that example except that start and end of sentence tokens ([CLS] and [SEP], respectively) receive more attention than the rest.111111Voita et al. (2019b) and Michel et al. (2019) pointed out that many Transformer heads play no or minor role, so it may be possible to obtain more informative attributions if the ‘useless’ heads are disregarded. All the other methods correctly highlight the last hidden state of the [CLS] token as important. Its importance is due to the top-level classifier using the [CLS] hidden state. The methods by Schulz et al. (2020) and Guan et al. (2019) assign slightly higher importance to hidden states corresponding to ‘highly’ and ‘enjoyable’, whereas it is hard to see any informative patterns in heatmaps provided by integrated gradient. Our method assigns much sharper attribution. In Figure 4(e) it is evident that hidden states associated with punctuation can be completely dropped, while the rest needs to be kept. Importantly, with DiffMask, the zero attribution has a very clear interpretation: when DiffMask masks a hidden state, that state is not used for prediction (i.e in layers higher up in the model).

4.2.4 Analysis

The hidden state attribution we have done so far tells us if a state stores important information. Now, we contrast it to input attribution, which, with DiffMask, shows what the model ‘knows’ at a given layer. The two attributions are shown in Figure 6(a) and 6(b). The situation here seems relatively straightforward. From Figure 6(a) we see that the model, even in the bottom layers, knows that the punctuation and both separators can be dropped from the input. This contrasts with hidden states attribution (Figure 6(b)) which indicates that the separator states (especially [SEP]) are very important. By putting this information together, we can hypothesize that the separator is used to aggregate information from the sentence, relying on self-attention. In fact, this aggregation is still happening in layer ; at the very top layers, states corresponding to all non-separator tokens can be dropped. In Figure 6, we confirm the separators are important across the dataset and not only on this example. In Figure 8 we instead aggregate input attributions according to part-of-speech tags. With this sentiment model, determinants, punctuation, and pronouns can be completely discarded from the input while adjective and nouns should be kept.

4.2.5 Human labels

While we cannot use human labels to evaluate faithfulness of our method, comparing them and DiffMask attribution will tells us whether the sentiment model relies on the same cues as humans. Specifically, we use the SST token level annotation of sentiment. In Figure 8, we show after how many layers on average an input token is dropped, depending on its sentiment label. Figure 8 shows the same for hidden states. This suggests that the model relies more heavily on strongly positive or negative words and, thus, is generally consistent with human judgments.

4.2.6 Ablation

As argued in the introduction and shown on the toy task, many popular methods (e.g., erasure and its approximations) are over-aggressive in discarding inputs and hidden units. Amortization is a fundamental component of DiffMask and is aimed at addressing this issue. In Figure 7 we show how our method behaves when ablating amortization and thus optimizing on a single example instead. Noticeable, our method converges to masking out all hidden states at any layer (Figure 6(c)). This happens as it learns an ad hoc baseline just for that example. When we ablate both amortization and baseline learning (Figure 6(d)), the method struggles to uncover any meaningful patterns. This highlights how both core components of our method are needed in combination with each other.

(a) Gating the input.
(b) Gating the input.
(c) Gating hidden states.
(d) Gating hidden states.
Figure 9: Expectation predicted by DiffMask to keep the inputs in (a) and (b) or hidden states in (c) and (d) on two different questions on the same paragraph. The correct answers is highlighted in bold.
(a) POS gating inputs.
(b) POS gating hidden states.
Figure 10: Question answering: average number of layers that predict to keep input tokens (a) or hidden states (b) aggregating by part-of-speech tag (POS) on validation set.

4.3 Question Answering

We turn now to extractive question answering where we analyse a fine-tuned BERTLARGE model trained on the Stanford Question Answering Dataset v1.1  (SQuAD; Rajpurkar et al., 2016). It consists of a pre-trained BERTLARGE encoder followed by two independent linear classifiers that predict the beginning and the end of the span that contains the answer in the document. The model is fine-tuned to minimize the cross-entropy error for span prediction, while DiffMask minimizes subject to a constraint on .10 Hyperparameters are reported in Appendix C.2.

4.3.1 Comparison to other methods

As we do not have access to the ground-truth, we start by contrasting DiffMask qualitatively to other attribution methods on a few examples. We highlight some common pitfalls that afflict other methods (such as the hindsight bias) and how DiffMask overcomes those. This helps demonstrate our method’s faithfulness to the original model. In addition, we discuss how DiffMask explanations provide deeper insight into how predictions are formed.

Figure 2 shows input attributions by different methods on an example from the validation set. Erasure (Figure 1(e)), as expected, does not provide useful insights, it essentially singles out the answer discarding everything else including the question. This cannot be faithful and is a simple consequence of erasure’s hindsight bias: when only the span that contains the answer is presented as input, the model predicts that very span as the answer, but this does not imply that the model ignores everything else when presented with the complete document as input. The methods of Schulz et al. (2020) and Guan et al. (2019) optimize attributions on single examples and thus also converge to assigning high importance mostly to words that support the current prediction and that indicate the question type. Integrated gradient does not seem to highlight any discernible pattern, which we speculate is largely because a zero baseline is not suitable for word embeddings. Choosing a more adequate baseline is not straightforward and remains an important open issue (Sturmfels et al., 2020). Note that without amortization, DiffMask closely approximates erasure (as demonstrated in Section 4.2.1 for SST), and indeed Figure 1(f) is another example of the hindsight bias which prevents us from gaining insights about the model.

Differently from all other methods, our DiffMask probes the network to understand what it ‘knows’ about the input-output mapping in different layers. In Figure 1(d) we show the expectation of keeping input tokens conditioned on any one of the layers in the model to make such predictions. Our input attributions highlight that the model, in expectation across layers, wants to keep words in the question as well as all potential candidate answers, but that eventually, the most important spans are in the question and the answer itself.

4.3.2 Analysis

Having provided additional evidence that DiffMask explanations are faithful, we proceed to uncovering patterns in how the model processes inputs. We start by asking ourselves, or rather DiffMask, which tokens does the model keep? In Figure 8(a) and 8(b) we visualize the expectations of keeping the input tokens with respect to each of BERT’s layers on two different questions about the same passage. For that example, the model seems to ignore almost all determinants, prepositions, and conjunctions to perform its predictions. To better investigate the role of different parts of speech (POS), we aggregate statistics over the entire validation set in Figure 10. It indeed emerges that those parts of speech are largely ignored by the model, while nouns and proper nouns are often kept. We argue that due to the pre-training objective, BERT could infer well missing parts of the input, especially if they are trivial to infer (e.g., as prepositions or determiners in many cases). In contrast, in Figure 8(a), we can see that takes layers for the model to ‘realize’ that ‘Santa Clara Marriott’ and ‘Stanford University’ are not relevant to the question and discard them.

Similarly to sentiment classification, the [CLS] tokens appear not useful. Conversely, the [SEP] tokens are important (at least according to bottom layers). Notice that, in this task, the [SEP] token is also used as a separator between the question and the passage, and hence indicates where the questions end. However, at th layer, the model is already confident what the possible answers could be, so these tokens are no longer needed.

Unsurprisingly, in both examples, all layers choose to keep the originally predicted answer spans. Across the validation set, this is happening in of cases. Incrementality in processing the data is much more evident here than on the sentiment task. For example, at least layers are needed to decide to drop any named entity. At top layers (e.g., top ), the model can drop almost everything except for the answer, indicating that the model has already converged to the decision. Higher layers also still vote to keep parts of the question (e.g., ‘Where’), presumably because it is fundamental for selecting the answer type. Key named entities in the questions are kept as well, while the question mark is always dropped: it is present in every question so does not carry any information. Our observation that higher layers are more predictive is in line with findings of Kovaleva et al. (2019). They pointed out that final layers of BERT change most and are more task specific.

Where is the information stored?

In Figure 8(c) and 8(d) we visualize the expectations of keeping hidden states across layers predicted by DiffMask on two different questions, but for the same passage. Differently from deciding which input tokens to drop (as in Figure 8(a) and 8(b)), masking hidden states sheds light on which hidden states store important information. As an example, the model seems to use the hidden states aligned with the determinant ‘the’ in proximity of the answer span in Figure 8(c). Although determinants can be completely ignored as inputs, their hidden states are actually used by the model. It is consistent with findings in Voita et al. (2019a) which show that frequent tokens, such as determiners, accumulate contextual information. Statistics aggregated for different part-of-speech tags across the validation set (Figure 9(b)) confirms this intuition.

All layers are voting to drop every state except the ones corresponding to the answer span. These contain information needed at the top layer for classification while all the others can indeed be removed, without affecting the model prediction. In contrast, intermediate layers seem to be still considering different span options, as they are still active on all plausible spans (i.e., names of locations).

5 Conclusion

The recent developments in expressivity and efficacy of complex deep neural networks have come at the expense of interpretability. While systematically erasing inputs to determine how a model reacts leads to a neat interpretation, it comes with many issues such as an exponential computational time complexity and susceptibility to the hindsight bias: if a word can be dropped from the input, it does not necessary implies that it is not used by the model. We have introduced a new post hoc interpretation method which learns to completely remove subsets of inputs or hidden states through masking. We circumvent an intractable search by learning an end-to-end differentiable prediction model. To circumvent the hindsight bias problem, we probe the model’s hidden states at different depths and amortize predictions over the training set.

We validate the faithfulness of DiffMask in a controlled artificial experiment pointing more clearly to some flaws of other attribution methods. DiffMask without amortization was also shown to approximate erasure well in a real task (SST) and to outperform REINFORCE in doing so both in terms of performance and stability across runs. Having established that DiffMask can be trusted, we used it to study BERT-based models on sentiment classification and question answering. Our method sheds light on what different layers ‘know’ about the input and where information about the prediction is stored in different layers.


Authors want to thank Christos Baziotis, Elena Voita, Dieuwke Hupkes, and Naomi Saphra for helpful discussions. This project is supported by SAP Innovation Center Network, ERC Starting Grant BroadSem (678254), the Dutch Organization for Scientific Research (NWO) VIDI 639.022.518, and the European Union’s Horizon 2020 research and innovation programme under grant agreement No 825299 (Gourmet).


Appendix A Parameterization

To keep the probes as simple as possible, we parameterized them as bilinear functions. When masking input tokens, ‘votes’ are computed as where


are trainable parameters. See Appendix B for details about the Hard Concrete distribution including its parameterization. When masking hidden states, we use the same functional form to compute but is replaced by .

Appendix B Binary Concrete

A stretched and rectified Binary Concrete (also known as Hard Concrete) distribution is obtained applying an affine transformation to the Binary Concrete distribution (Maddison et al., 2017; Jang et al., 2017) and rectifying its samples in the interval (see Figure 11). A Binary Concrete is defined over the open interval ( in Figure 10(a)) and it is parameterised by a location parameter and temperature parameter

. The location acts as a logit and it controls the probability mass skewing the distribution towards

in case of negative location and towards in case of positive location. The temperature parameter controls the concentration of the distribution. The Binary Concrete is then stretched with an affine transformation extending its support to with and ( in Figure 10(a)). Finally, we obtain a Hard Concrete distribution rectifying samples in the interval . This corresponds to collapsing the probability mass over the interval to , and the mass over the interval to ( in Figure 10(b)). This induces a distribution over the close interval with non-zero mass at and . Samples are obtained using



is the Sigmoid function

and . We point to the Appendix B of Louizos et al. (2018) for more information about the density of the resulting distribution and its cumulative density function.

Figure 11: Binary Concrete distributions: (a) a Concrete and its stretched version ; (b) a rectified and stretched (Hard) Concrete .

Appendix C Experiments

c.1 Sentiment Classification

For the sentiment classification experiment we downloaded 12 a pre-trained model from the Huggingface implementation13 of Wolf et al. (2019), and we fined-tuned on the SST dataset. We report hyperparameters used for training the model and our DiffMask in Table 3.

Model Value
Type BERTBASE (uncased)
Hidden units
Pre-trained masking standard
Optimizer Adam *
Learning rate

Train epochs

Batch size
DiffMask Value
Optimizer Lookahead RMSprop **
Learning rate
Learning rate
Train epochs
Batch size
Table 3: List of hyperparameters for the sentiment classification experiment. *is Kingma and Ba (2015), **is Tieleman and Hinton (2012); Zhang et al. (2019).

c.2 Question Answering

For the question answering experiment we downloaded121212 an already fine-tuned model from the Huggingface implementation131313 of Wolf et al. (2019) We report hyperparameters used by them for training the original model and the ones used for our DiffMask in Table 4.

Model Value
Type BERTLARGE (uncased)
Hidden units
Pre-trained masking whole-word
Optimizer Adam *
Learning rate
Train epochs
Batch size
DiffMask Value
Optimizer Lookahead RMSprop **
Learning rate
Learning rate
Train epochs
Batch size
Table 4: List of hyperparameters for the question answering experiment. *is Kingma and Ba (2015), **is Tieleman and Hinton (2012); Zhang et al. (2019).

Appendix D Additional plots

Figure 12:

Hidden state values for the two-neuron toy task. Clusters of whether the input digit is equal to the first or second position in the query (

or respectively) or not at all () are completely linear separable.
(a) Attention.
(b) Integrated gradient*.
(c) Schulz et al. (2020).
(d) Guan et al. (2019).
(e) Our DiffMask.
Figure 13: Sentiment classification: comparison between attribution method for hidden layers w.r.t. the predicted label. All plots are normalized per-layer by the largest attribution. Attention heatmap is obtained max pooling over heads and averaging across positions. *By Sundararajan et al. (2017).
(a) Gating the input.
(b) Gating the input.
(c) Gating hidden states.
(d) Gating hidden states.
Figure 14: Expectation predicted by DiffMask to keep the inputs or hidden states on two different questions on the same paragraph. The correct answers is highlighted in bold.