Towards Transparent and Explainable Attention Models

04/29/2020 ∙ by Akash Kumar Mohankumar, et al. ∙ adobe Indian Institute Of Technology, Madras 0

Recent studies on interpretability of attention distributions have led to notions of faithful and plausible explanations for a model's predictions. Attention distributions can be considered a faithful explanation if a higher attention weight implies a greater impact on the model's prediction. They can be considered a plausible explanation if they provide a human-understandable justification for the model's predictions. In this work, we first explain why current attention mechanisms in LSTM based encoders can neither provide a faithful nor a plausible explanation of the model's predictions. We observe that in LSTM based encoders the hidden representations at different time-steps are very similar to each other (high conicity) and attention weights in these situations do not carry much meaning because even a random permutation of the attention weights does not affect the model's predictions. Based on experiments on a wide variety of tasks and datasets, we observe attention distributions often attribute the model's predictions to unimportant words such as punctuation and fail to offer a plausible explanation for the predictions. To make attention mechanisms more faithful and plausible, we propose a modified LSTM cell with a diversity-driven training objective that ensures that the hidden representations learned at different time steps are diverse. We show that the resulting attention distributions offer more transparency as they (i) provide a more precise importance ranking of the hidden states (ii) are better indicative of words important for the model's predictions (iii) correlate better with gradient-based attribution methods. Human evaluations indicate that the attention distributions learned by our model offer a plausible explanation of the model's predictions. Our code has been made publicly available at https://github.com/akashkm99/Interpretable-Attention

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

 

Question 1: What is the best way to improve my spoken English soon ?
Question 2: How can I improve my English speaking ability ?
Is paraphrase (Actual & Predicted): Yes
Attention Distribution
Vanilla LSTM How can I improve my
English speaking ability ?
Diversity LSTM How can I improve my
English speaking ability ?

 

Passage: Sandra went to the garden . Daniel went to the garden.
Question: Where is Sandra?
Answer (Actual & Predicted): garden
Attention Distribution:
Vanilla LSTM Sandra went to the garden .
Daniel went to the garden
Diversity LSTM Sandra went to the garden .
Daniel went to the garden

 

Table 1: Samples of Attention distributions from Vanilla and Diversity LSTM models on the Quora Question Paraphrase (QQP) & Babi 1 datasets.

.

Attention mechanisms Bahdanau et al. (2014); Vaswani et al. (2017)

play a very important role in neural network-based models for various Natural Language Processing (NLP) tasks. They not only improve the performance of the model but are also often used to provide insights into the working of a model. Recently, there is a growing debate on whether attention mechanisms can offer

transparency to a model or not. For example, Serrano and Smith (2019) and Jain and Wallace (2019) show that high attention weights need not necessarily correspond to a higher impact on the model’s predictions and hence they do not provide a faithful explanation for the model’s predictions. On the other hand, Wiegreffe and Pinter (2019) argues that there is still a possibility that attention distributions may provide a plausible explanation for the predictions. In other words, they might provide a plausible reconstruction of the model’s decision making which can be understood by a human even if it is not faithful to how the model works.

In this work, we begin by analyzing why attention distributions may not faithfully explain the model’s predictions. We argue that when the input representations over which an attention distribution is being computed are very similar to each other, the attention weights are not very meaningful. Since the input representations are very similar, even random permutations of the attention weights could lead to similar final context vectors. As a result, the output predictions will not change much even if the attention weights are permuted. We show that this is indeed the case for LSTM based models where the hidden states occupy a narrow cone in the latent space (

i.e., the hidden representations are very close to each other). We further observe that for a wide variety of datasets, attention distributions in these models do not even provide a good plausible explanation as they pay significantly high attention to unimportant tokens such as punctuations. This is perhaps due to hidden states capturing a summary of the entire context instead of being specific to their corresponding words.

Based on these observations, we aim to build more transparent and explainable models where the attention distributions provide faithful and plausible explanations for its predictions. One intuitive way of making the attention distribution more faithful is by ensuring that the hidden representations over which the distribution is being computed are very diverse. Therefore, a random permutation of the attention weights will lead to very different context vectors. To do so, we propose an orthogonalization technique which ensures that the hidden states are farther away from each other in their spatial dimensions. We then propose a more flexible model trained with an additional objective that promotes diversity in the hidden states. Through a series of experiments using datasets spanning tasks, we show that our model is more transparent while achieving comparable performance to models containing vanilla LSTM based encoders. Specifically, we show that in our proposed models, attention weights (i) provide useful importance ranking of hidden states (ii) are better indicative of words that are important for the model’s prediction (iii) correlate better with gradient-based feature importance methods and (iv) are sensitive to random permutations (as should indeed be the case).

We further observe that attention weights in our models, in addition to adding transparency to the model, are also more explainable i.e. more human-understandable. In Table 1, we show samples of attention distributions from a Vanilla LSTM and our proposed Diversity LSTM model. We observe that in our models, unimportant tokens such as punctuation marks receive very little attention whereas important words belonging to relevant part-of-speech tags receive greater attention (for example, adjectives in the case of sentiment classification). Human evaluation on the attention from our model shows that humans prefer the attention weights in our Diversity LSTM as providing better explanations than Vanilla LSTM in , , , of the samples in Yelp, SNLI, Quora Question Paraphrase and Babi 1 datasets respectively.

2 Tasks, Dataset and Models

Our first goal is to understand why existing attention mechanisms with LSTM based encoders fail to provide faithful or plausible explanations for the model’s predictions. We experiment on a variety of datasets spanning different tasks; here, we introduce these datasets and tasks and provide a brief recap of the standard LSTM+attention model used for these tasks. We consider the tasks of Binary Text classification, Natural Language Inference, Paraphrase Detection, and Question Answering. We use a total of 12 datasets, most of them being the same as the ones used in Jain and Wallace (2019)

. We divide Text classification into Sentiment Analysis and Other Text classification for convenience.

Sentiment Analysis: We use the Stanford Sentiment Treebank (SST) Socher et al. (2013), IMDB Movie Reviews Maas et al. (2011), Yelp and Amazon for sentiment analysis. All these datasets use binary target variable (positive /negative).

Other Text Classification: We use the Twitter ADR Nikfarjam et al. (2015) dataset with 8K tweets where the task is to detect if a tweet describes an adverse drug reaction or not. We use a subset of the 20 Newsgroups dataset Jain and Wallace (2019)

to classify news articles into baseball vs hockey sports categories. From MIMIC ICD9

Johnson et al. (2016), we use datasets: Anemia, to determine the type of Anemia (Chronic vs Acute) a patient is diagnosed with and Diabetes, to predict whether a patient is diagnosed with Diabetes or not.

Natural Language Inference: We consider the SNLI dataset Bowman et al. (2015) for recognizing textual entailment within sentence pairs. The SNLI dataset has three possible classification labels, viz entailment, contradiction and neutral.

Paraphrase Detection: We utilize the Quora Question Paraphrase (QQP) dataset (part of the GLUE benchmark Wang et al. (2018)) with pairs of questions labeled as paraphrased or not. We split the training set into training and validation; and use the original dev set as our test set.

Question Answering: We made use of all three QA tasks from the bAbI dataset Weston et al. (2015). The tasks consist of answering questions that would require one, two or three supporting statements from the context. The answers are a span in the context. We then use the CNN News Articles dataset Hermann et al. (2015) consisting of 90k articles with an average of three questions per article along with their corresponding answers.

2.1 LSTM Model with Attention

Of the above tasks, the text classification tasks require making predictions from a single input sequence (of words) whereas the remaining tasks use pairs of sequences as input. For tasks containing two input sequences, we encode both the sequences and by passing their word embedding through a LSTM encoder Hochreiter and Schmidhuber (1997),

where represents the word embedding for the word . We attend to the intermediate representations of , using the last hidden state as the query, using the attention mechanism Bahdanau et al. (2014),

where and are learnable parameters. Finally, we use the attended context vector to make a prediction .

For tasks with a single input sequence, we use a single LSTM to encode the sequence, followed by an attention mechanism (without query) and a final output projection layer.

3 Analyzing Attention Mechanisms

Here, we first investigate the question - Why Attention distributions may not provide a faithful explanation for the model’s predictions? We later examine whether Attention distributions can provide a plausible explanation for the model’s predictions, not necessarily faithful.

3.1 Similarity Measures

We begin with defining similarity measures in a vector space for ease of analysis. We measure the similarity between a set of vectors using the conicity measure Chandrahas et al. (2018); Sai et al. (2019) by first computing a vector ’s ‘alignment to mean’ (ATM),

Conicity is defined as the mean of ATM for all vectors :

A high value of conicity indicates that all the vectors are closely aligned with their mean i.e they lie in a narrow cone centered at origin.

3.2 Attention Mechanisms

As mentioned earlier, attention mechanisms learn a weighting distribution over hidden states using a scoring function such as Bahdanau et al. (2014) to obtain an attended context vector .

The attended context vector is a convex combination of the hidden states which means it will lie within the cone spanned by the hidden states. When the hidden states are highly similar to each other (high conicity), even diverse sets of attention distributions would produce very similar attended context vector as they will always lie within a narrow cone. This could result in outputs with very little difference. In other words, when there is a higher conicity in hidden states, the model could produce the same prediction for several diverse sets of attention weights. In such cases, one cannot reliably say that high attention weights on certain input components led the model to its prediction. Later on, in section 5.3, we show that when using vanilla LSTM encoders where there is higher conicity in hidden states, even when we randomly permute the attention weights, the model output does not change much.

Figure 1: Left: high conicity of hidden states results in similar attended context vectors. Right: low conicity of hidden states results in very different context vectors

3.3 Conicity of LSTMs Hidden States

We now analyze if the hidden states learned by an LSTM encoder do actually have high conicity. In Table 2

, we report the average conicity of hidden states learned by an LSTM encoder for various tasks and datasets. For reference, we also compute the average conicity obtained by vectors that are uniformly distributed with respect to direction (isotropic) in the same hidden space. We observe that across all the datasets the hidden states are consistently aligned with each other with conicity values ranging between

to . In contrast, when there was no dependence between the vectors, the conicity values were much lower with the vectors even being almost orthogonal to its mean in several cases ( in Diabetes and Anemia datasets). The existence of high conicity in the learned hidden states of an LSTM encoder is one of the potential reasons why the attention weights in these models are not always faithful to its predictions (as even random permutations of the attention weights will result in similar context vectors, ).

3.4 Attention by POS Tags

We now examine whether attention distributions can provide a plausible explanation for the model’s predictions even if it is not faithful. Intuitively, a plausible explanation should ignore unimportant tokens such as punctuation marks and focus on words relevant for the specific task. To examine this, we categorize words in the input sentence by its universal part-of-speech (POS) tag Petrov et al. (2011) and cumulate attention given to each POS tag over the entire test set. Surprisingly, we find that in several datasets, a significant amount of attention is given to punctuations. On the Yelp, Amazon and QQP datasets, attention mechanisms pay 28.6%, 34.0% and 23.0% of its total attention to punctuations. Notably, punctuations only constitute 11.0%, 10.5% and 11.6% of the total tokens in the respective datasets signifying that learned attention distributions pay substantially greater attention to punctuations than even an uniform distribution. This raises questions on the extent to which attention distributions provide plausible explanations as they attribute model’s predictions to tokens that are linguistically insignificant to the context.

One of the potential reasons why the attention distributions are misaligned is that the hidden states might capture a summary of the entire context instead of being specific to their corresponding words as suggested by the high conicity. We later show that attention distributions in our models with low conicity value tend to ignore punctuation marks.

4 Orthogonal and Diversity LSTM

Based on our previous argument that high conicity of hidden states affect the transparency and explainability of attention models, we propose strategies to obtain reduced similarity in hidden states.

Figure 2: Orthogonal LSTM: Hidden state at a timestep is orthogonal to the mean of previous hidden states

4.1 Orthogonalization

Here, we explicitly ensure low conicity exists between hidden states of an LSTM encoder by orthogonalizing the hidden state at time with the mean of previous states as illustrated in Figure 2. We use the following set of update equations:

(1)
(2)

where , , , and are the input and hidden dimensions respectively. The key difference from a vanilla LSTM is in the last equations where we subtract the hidden state vector’s component along the mean of the previous states.

4.2 Diversity Driven Training

The above model imposes a hard orthogonality constraint between the hidden states and the previous states’ mean. We also propose a more flexible approach where the model is jointly trained to maximize the log-likelihood of the training data and minimize the conicity of hidden states,

where is the ground truth class, and are the input sentences, contains all the hidden states of the LSTM, is a collection of the model parameters and

represents the model’s output probability.

is a hyperparameter that controls the weight given to diversity in hidden states during training.

Dataset LSTM Diversity LSTM Orthogonal LSTM Random MLP
Accuracy Conicity Accuracy Conicity Accuracy Conicity Conicity Accuracy
Binary Classification
SST 81.79 0.68 79.95 0.20 80.05 0.28 0.25 80.05
IMDB 89.49 0.69 88.54 0.08 88.71 0.18 0.08 88.29
Yelp 95.60 0.53 95.40 0.06 96.00 0.18 0.14 92.85
Amazon 93.73 0.50 92.90 0.05 93.04 0.16 0.13 87.88
Anemia 88.54 0.46 90.09 0.09 90.17 0.12 0.02 88.27
Diabetes 92.31 0.61 91.99 0.08 87.05 0.12 0.02 85.39
20News 93.55 0.77 91.03 0.15 92.15 0.23 0.13 87.68
Tweets 87.02 0.77 87.04 0.24 83.20 0.27 0.24 80.60
Natural Language Inference
SNLI 78.23 0.56 76.96 0.12 76.46 0.27 0.27 75.35
Paraphrase Detection
QQP 78.74 0.59 78.40 0.04 78.61 0.33 0.30 77.78
Question Answering
bAbI 1 99.10 0.56 100.00 0.07 99.90 0.22 0.19 42.00
bAbI 2 40.10 0.48 40.20 0.05 56.10 0.21 0.12 33.20
bAbI 3 47.70 0.43 50.90 0.10 51.20 0.12 0.07 31.60
CNN 63.07 0.45 58.19 0.06 54.30 0.07 0.04 37.40
Table 2:

Accuracy and conicity of Vanilla, Diversity and Orthogonal LSTM across different datasets. Accuracy of a Multilayered Perceptron (MLP) model and conicity of vectors uniformly distributed with respect to direction is also reported for reference.

Figure 3: Box plots of fraction of hidden representations removed for a decision flip. Dataset and models are mentioned at the top and bottom of figures. Blue and Yellow indicate the attention and random ranking.
Figure 4: Comparison of Median output difference on randomly permuting the attention weights in the vanilla, Diversity and Orthogonal LSTM models. The Dataset names are mentioned at the top of each figure. Colors indicate the different models as shown legend.

5 Analysis of the model

We now analyse the proposed models by performing experiments using the tasks and datasets described earlier. Through these experiments we establish that (i) the proposed models perform comparably to vanilla LSTMs (Sec. 5.2) (ii) the attention distributions in the proposed models provide a faithful explanation for the model’s predictions (Secs. 5.3 to 5.5) and (iii) the attention distributions are more explainable and align better with a human’s interpretation of the model’s prediction (Secs. 5.6, 5.7). Throughout this section we will compare the following three models:

1. Vanilla LSTM: The model described in section 2.1 which uses the vanilla LSTM.
2. Diversity LSTM: The model described in section 2.1 with the vanilla LSTM but trained with the diversity objective described in section 4.2.
3. Orthogonal LSTM: The model described in section 2.1 except that the vanilla LSTM is replaced by the orthogonal LSTM described in section 4.1.

5.1 Implementation Details

For all datasets except bAbi, we either use pre-trained Glove Pennington et al. (2014) or fastText Mikolov et al. (2018) word embeddings with 300 dimensions. For the bAbi dataset, we learn 50 dimensional word embeddings from scratch during training. We use a 1-layered LSTM as the encoder with hidden size of 128 for bAbi and 256 for the other datasets. For the diversity weight , we use a value of 0.1 for SNLI, 0.2 for CNN, and 0.5 for the remaining datasets. We use Adam optimizer with a learning rate of 0.001 and select the best model based on accuracy on the validation split. All the subsequent analysis are performed on the test split.

5.2 Empirical evaluation

Our main goal is to show that our proposed models provide more faithful and plausible explanations for their predictions. However, before we go there we need to show that the predictive performance of our models is comparable to that of a vanilla LSTM model and significantly better than non-contextual models. In other words, we show that we do not compromise on performance to gain transparency and explainability. We report the performance of our model on the tasks and datasets described in section 2. In Table 2, we report the accuracy and conicity values of vanilla, Diversity and Orthogonal LSTMs on different tasks. We observe that the performance of Diversity LSTM is comparable to that of vanilla LSTM with accuracy values within -7.7% to +6.7% (relative) of the vanilla model’s accuracy. However, there is a substantial decrease in the conicity values with a drop between 70.6% to 93.2% when compared to the vanilla model’s conicity. Similarly, for the Orthogonal LSTM, the predictive performance is mostly comparable except for an increase in accuracy by 39.9% on bAbI 2 and a drop of -13.91% on CNN. Similar to the Diversity LSTM, the conicity values are much lower than in the vanilla model. We also report the performance of a non-contextual model: Multilayer Perceptron (MLP) + attention in the same table. We observe that both Diversity LSTM and Orthogonal LSTM perform significantly better than the MLP model, especially in difficult tasks such as Question Answering with an average relative increase in accuracy of 73.73%. Having established that the performance of Diversity and Orthogonal LSTMs is comparable to the vanilla LSTM and significantly better than a Multilayer Perceptron model, we now show that these two models give more faithful explanations for its predictions.

5.3 Importance of Hidden Representation

We examine whether attention weights provide a useful importance ranking of hidden representations. We use the intermediate representation erasure by Serrano and Smith (2019) to evaluate an importance ranking over hidden representations. Specifically, we erase the hidden representations in the descending order of the importance (highest to lowest) until the model’s decision changes. In Figure 3, we report the box plots of the fraction of hidden representations erased for a decision flip when following the ranking provided by attention weights. For reference, we also show the same plots when a random ranking is followed. In several datasets, we observe that a large fraction of the representations have to be erased to obtain a decision flip in the vanilla LSTM model, similar to the observation by Serrano and Smith (2019). This suggests that the hidden representations in the lower end of the attention ranking do play a significant role in the vanilla LSTM model’s decision-making process. Hence the usefulness of attention ranking in such models is questionable. In contrast, there is a much quicker decision flip in our Diversity and Orthogonal LSTM models. Thus, in our proposed models, the top elements of the attention ranking are able to concisely describe the model’s decisions. This suggests that our attention weights provide a faithful explanation of the model’s performance (as higher attention implies higher importance).

In tasks such as paraphrase detection, the model is naturally required to carefully go through the entire sentence to make a decision and thereby resulting in delayed decision flips. In the QA task, the attention ranking in the vanilla LSTM model itself achieves a quick decision flip. On further inspection, we found that this is because these models tend to attend onto answer words which are usually a span in the input passage. So, when the representations corresponding to the answer words are erased, the model can no longer accurately predict the answer resulting in a decision flip.

Following the work by Jain and Wallace (2019), we randomly permute the attention weights and observe the difference in the model’s output. In Figure 4, we plot the median of Total Variation Distance (TVD) between the output distribution before and after the permutation for different values of maximum attention in the vanilla, Diversity and Orthogonal LSTM models. We observe that randomly permuting the attention weights in the Diversity and Orthogonal LSTM model results in significantly different outputs. However, there is little change in the vanilla LSTM model’s output for several datasets suggesting that the attention weights are not so meaningful. The sensitivity of our attention weights to random permutations again suggests that they provide a more faithful explanation for the model’s predictions whereas similar outputs raises several questions about the reliability of attention weights in the vanilla LSTM model.

5.4 Comparison with Rationales

Dataset Vanilla LSTM Diversity LSTM
Rationale
Attention
Rationale
Length
Rationale
Attention
Rationale
Length
SST 0.348 0.240 0.624 0.175
IMDB 0.472 0.217 0.761 0.169
Yelp 0.438 0.173 0.574 0.160
Amazon 0.346 0.162 0.396 0.240
Anemia 0.611 0.192 0.739 0.237
Diabetes 0.742 0.458 0.825 0.354
20News 0.627 0.215 0.884 0.173
Tweets 0.284 0.225 0.764 0.306
Table 3: Mean Attention given to the generated rationales with their mean lengths (in fraction)
Pearson Correlation JS Divergence
Dataset
Gradients
(Mean Std.)
Integrated Gradients
(Mean Std.)
Gradients
(Mean Std.)
Integrated Gradients
(Mean Std.)
Vanilla Diversity Vanilla Diversity Vanilla Diversity Vanilla Diversity
Text Classification
SST 0.71 0.21 0.83 0.19 0.62 0.24 0.79 0.22 0.10 0.04 0.08 0.05 0.12 0.05 0.09 0.05
IMDB 0.80 0.07 0.89 0.04 0.68 0.09 0.78 0.07 0.09 0.02 0.09 0.01 0.13 0.02 0.13 0.02
Yelp 0.55 0.16 0.79 0.12 0.40 0.19 0.79 0.14 0.15 0.04 0.13 0.04 0.19 0.05 0.19 0.05
Amazon 0.43 0.19 0.77 0.14 0.43 0.19 0.77 0.14 0.17 0.04 0.12 0.04 0.21 0.06 0.12 0.04
Anemia 0.63 0.12 0.72 0.10 0.43 0.15 0.66 0.11 0.20 0.04 0.19 0.03 0.34 0.05 0.23 0.04
Diabetes 0.65 0.15 0.76 0.13 0.55 0.14 0.69 0.18 0.26 0.05 0.20 0.04 0.36 0.04 0.24 0.06
20News 0.72 0.28 0.96 0.08 0.65 0.32 0.67 0.11 0.15 0.07 0.06 0.04 0.21 0.06 0.07 0.05
Tweets 0.65 0.24 0.80 0.21 0.56 0.25 0.74 0.22 0.08 0.03 0.12 0.07 0.08 0.04 0.15 0.06
Natural Language Inference
SNLI 0.58 0.33 0.51 0.35 0.38 0.40 0.26 0.39 0.11 0.07 0.10 0.06 0.16 0.09 0.13 0.06
Paraphrase Detection
QQP 0.19 0.34 0.58 0.31 -0.06 0.34 0.21 0.36 0.15 0.08 0.10 0.05 0.19 0.10 0.15 0.06
Question Answering
Babi 1 0.56 0.34 0.91 0.10 0.33 0.37 0.91 0.10 0.33 0.12 0.21 0.08 0.43 0.13 0.24 0.08
Babi 2 0.16 0.23 0.70 0.13 0.05 0.22 0.75 0.10 0.53 0.09 0.23 0.06 0.58 0.09 0.19 0.05
Babi 3 0.39 0.24 0.67 0.19 -0.01 0.08 0.47 0.25 0.46 0.08 0.37 0.07 0.64 0.05 0.41 0.08
CNN 0.58 0.25 0.75 0.20 0.45 0.28 0.66 0.23 0.22 0.07 0.17 0.08 0.30 0.10 0.21 0.10
Table 4:

Mean and standard deviation of Pearson correlation and Jensen–Shannon divergence between Attention weights and Gradients/Integrated Gradients in Vanilla and Diversity LSTM models

For tasks with a single input sentence, we analyze how much attention is given to words in the sentence that are important for the prediction. Specifically, we select a minimum subset of words in the input sentence with which the model can accurately make predictions. We then compute the total attention that is paid to these words. These set of words, also known as rationales, are obtained from an extractive rationale generator Lei et al. (2016) that is trained using the REINFORCE algorithm Sutton et al. (1999) to maximize the following reward:

where is the ground truth class, is the extracted rationale, represents the length of the rationale, represents the classification model’s output probability, is a hyperparameter that penalizes long rationales. With a fixed , we trained generators to extract rationales from the vanilla and Diversity LSTM models. We observed that the accuracy of predictions made from the extracted rationales was within 5% of the accuracy made from the entire sentences. In Table 3, we report the mean length (in fraction) of the rationales and the mean attention given to them in the vanilla and Diversity LSTM models. In general, we observe that the Diversity LSTM model provides much higher attention to rationales which are even often shorter than the vanilla LSTM model’s rationales. On average, the Diversity LSTM model provides 53.52 % (relative) more attention to rationales than the vanilla LSTM across the 8 Text classification datasets. Thus, the attention weights in the Diversity LSTM are able to better indicate words that are important for making predictions.

5.5 Comparison with attribution methods

We now examine how well our attention weights agree with attribution methods such as gradients and integrated gradients Sundararajan et al. (2017). For every input word, we compute these attributions and normalize them to obtain a distribution over the input words. We then compute the Pearson correlation and JS divergence between the attribution distribution and the attention distribution. We note that Kendall as used by Jain and Wallace (2019) often results in misleading correlations because the ranking at the tail end of the distributions contributes to a significant noise. In Table 4, we report the mean and standard deviation of these Pearson correlations and JS divergence in the vanilla and Diversity LSTMs across different datasets. We observe that attention weights in Diversity LSTM better agree with gradients with an average (relative) 64.84% increase in Pearson correlation and an average (relative) 17.18% decrease in JS divergence over the vanilla LSTM across the datasets. Similar trends follow for Integrated Gradients.

5.6 Analysis by POS tags

Figure 5 shows the distribution of attention given to different POS tags across different datasets. We observe that the attention given to punctuation marks is significantly reduced from 28.6%, 34.0% and 23.0% in the vanilla LSTM to 3.1%, 13.8% and 3.4% in the Diversity LSTM on the Yelp, Amazon and QQP datasets respectively. In the sentiment classification task, Diversity LSTM pays greater attention to the adjectives, which usually play a crucial role in deciding the polarity of a sentence. Across the four sentiment analysis datasets, Diversity LSTM gives an average of 49.27 % (relative) more attention to adjectives than the vanilla LSTM. Similarly, for the other text classification tasks where nouns play an important role, we observe higher attention to nouns.

Figure 5: Distribution of cumulative attention given to different part-of-speech tags in the test dataset. Blue and Orange indicate the vanilla and Diversity LSTMs.
Dataset Overall Completness Correctness
Vanilla/Divers. Vanilla/Divers. Vanilla/Divers.
Yelp 27.7% / 72.3% 35.1% / 64.9% 10.5% / 89.5%
SNLI 37.8% / 62.2% 32.3% / 67.7% 38.9% / 61.1%
QQP 11.6% / 88.4% 11.8% / 88.2% 7.9% / 92.1%
bAbI 1 1.0% / 99.0% 4.2% / 95.8% 1.0% / 99.0%
Table 5: Percentage preference given to Vanilla vs Diversity model by human annotators based on 3 criteria

5.7 Human Evaluations

We conducted human evaluations to compare the extent to which attention distributions from the vanilla and Diversity LSTMs provide plausible explanations. We randomly sampled 200 data points each from the test sets of Yelp, SNLI, QQP, and bAbI1. Annotators were shown the input sentence, the attention heatmaps, and predictions made by the vanilla and Diversity LSTMs and were asked to choose the attention heatmap that better explained the model’s prediction on 3 criteria 1) Overall - which heatmap is better in explaining the prediction overall 2) Completeness - which heatmap highlights all the words necessary for the prediction. 3) Correctness - which heatmap only highlights the important words and not unnecessary words. Annotators were given the choice to skip a sample in case they were unable to make a clear decision. A total of 15 in-house annotators participated in the human evaluation study. The annotators were Computer Science graduates competent in English. We had annotators for each sample and the final decision was taken based on majority voting. In Table 5, we report the percentage preference given to the vanilla and Diversity LSTM models on the Yelp, SNLI, QQP, and bAbI 1 datasets; the attention distributions from Diversity LSTM significantly outperforms the attention from vanilla LSTM across all the datasets and criteria.

6 Related work

Our work in many ways can be seen as a continuation to the recent studies Serrano and Smith (2019); Jain and Wallace (2019); Wiegreffe and Pinter (2019) on the subject of interpretability of attention. Several other works Shao et al. (2019); Martins and Astudillo (2016); Malaviya et al. (2018); Niculae and Blondel (2017); Maruf et al. (2019); Peters et al. (2018) focus on improving the interpretability of attention distributions by inducing sparsity. However, the extent to which sparse attention distributions actually offer faithful and plausible explanations haven’t been studied in detail. Few works Bao et al. (2018) map attention distributions to human annotated rationales. Our work on the other hand does not require any additional supervision. Work by Guo et al. (2019) focus on developing interpretable LSTMs specifically for multivariate time series analysis. Several other works Clark et al. (2019); Vig and Belinkov (2019); Tenney et al. (2019); Michel et al. (2019); Jawahar et al. (2019); Tsai et al. (2019) analyze attention distributions and attention heads learned by transformer language models. The idea of orthogonalizing representations in an LSTM have been used by Nema et al. (2017)

but they use a different diversity model in the context of improving performance of Natural Language Generation models

7 Conclusion & Future work

In this work, we have analyzed why existing attention distributions can neither provide a faithful nor a plausible explanation for the model’s predictions. We showed that hidden representations learned by LSTM encoders tend to be highly similar across different timesteps, thereby affecting the interpretability of attention weights. We proposed two techniques to effectively overcome this shortcoming and showed that attention distributions in the resulting models provide more faithful and plausible explanations. As future work, we would like to extend our analysis and proposed techniques to more complex models and downstream tasks.

Acknowledgements

We would like to thank Department of Computer Science and Engineering, IIT Madras and Robert Bosch Center for Data Sciences and Artificial Intelligence, IIT Madras (RBC-DSAI) for providing us sufficient resources. We acknowledge Google for supporting Preksha Nema’s contribution through their Google India Ph.D. fellowship program. We also express our gratitude to the annotators who participated in human evaluations.

References

  • D. Bahdanau, K. Cho, and Y. Bengio (2014) Neural machine translation by jointly learning to align and translate. CoRR abs/1409.0473. Cited by: §1, §2.1, §3.2.
  • Y. Bao, S. Chang, M. Yu, and R. Barzilay (2018) Deriving machine attention from human rationales. In EMNLP, Cited by: §6.
  • S. R. Bowman, G. Angeli, C. Potts, and C. D. Manning (2015) A large annotated corpus for learning natural language inference. In EMNLP, Cited by: §2.
  • Chandrahas, A. Sharma, and P. P. Talukdar (2018)

    Towards understanding the geometry of knowledge graph embeddings

    .
    In ACL, Cited by: §3.1.
  • K. Clark, U. Khandelwal, O. Levy, and C. D. Manning (2019) What does bert look at? an analysis of bert’s attention. ArXiv abs/1906.04341. Cited by: §6.
  • T. Guo, T. Lin, and N. Antulov-Fantulin (2019) Exploring interpretable lstm neural networks over multi-variable data. In ICML, Cited by: §6.
  • K. M. Hermann, T. Kociský, E. Grefenstette, L. Espeholt, W. Kay, M. Suleyman, and P. Blunsom (2015) Teaching machines to read and comprehend. In NIPS, Cited by: §2.
  • S. Hochreiter and J. Schmidhuber (1997) Long short-term memory. Neural Computation 9, pp. 1735–1780. Cited by: §2.1.
  • S. Jain and B. C. Wallace (2019) Attention is not explanation. In NAACL-HLT, Cited by: §1, §2, §2, §5.3, §5.5, §6.
  • G. Jawahar, B. Sagot, and D. Seddah (2019) What does bert learn about the structure of language?. In ACL, Cited by: §6.
  • A. E. W. Johnson, T. J. Pollard, L. Shen, L. H. Lehman, M. Feng, M. M. Ghassemi, B. Moody, P. Szolovits, L. A. Celi, and R. G. Mark (2016) MIMIC-iii, a freely accessible critical care database. In Scientific data, Cited by: §2.
  • T. Lei, R. Barzilay, and T. S. Jaakkola (2016) Rationalizing neural predictions. In EMNLP, Cited by: §5.4.
  • A. L. Maas, R. E. Daly, P. T. Pham, D. Huang, A. Y. Ng, and C. Potts (2011) Learning word vectors for sentiment analysis. In ACL, Cited by: §2.
  • C. Malaviya, P. Ferreira, and A. F. T. Martins (2018) Sparse and constrained attention for neural machine translation. In ACL, Cited by: §6.
  • A. F. T. Martins and R. F. Astudillo (2016) From softmax to sparsemax: a sparse model of attention and multi-label classification. ArXiv abs/1602.02068. Cited by: §6.
  • S. Maruf, A. F. T. Martins, and G. Haffari (2019) Selective attention for context-aware neural machine translation. In NAACL-HLT, Cited by: §6.
  • P. Michel, O. Levy, and G. Neubig (2019) Are sixteen heads really better than one?. ArXiv abs/1905.10650. Cited by: §6.
  • T. Mikolov, E. Grave, P. Bojanowski, C. Puhrsch, and A. Joulin (2018) Advances in pre-training distributed word representations. In Proceedings of the International Conference on Language Resources and Evaluation (LREC 2018), Cited by: §5.1.
  • P. Nema, M. M. Khapra, A. Laha, and B. Ravindran (2017) Diversity driven attention model for query-based abstractive summarization. In ACL, Cited by: §6.
  • V. Niculae and M. Blondel (2017) A regularized framework for sparse and structured neural attention. In NIPS, Cited by: §6.
  • A. Nikfarjam, A. Sarker, K. O’Connor, R. E. Ginn, and G. Gonzalez-Hernandez (2015) Pharmacovigilance from social media: mining adverse drug reaction mentions using sequence labeling with word embedding cluster features. In JAMIA, Cited by: §2.
  • J. Pennington, R. Socher, and C. D. Manning (2014) GloVe: global vectors for word representation. In Empirical Methods in Natural Language Processing (EMNLP), pp. 1532–1543. Cited by: §5.1.
  • B. Peters, V. Niculae, and A. F. T. Martins (2018) Interpretable structure induction via sparse attention. In BlackboxNLP@EMNLP, Cited by: §6.
  • S. Petrov, D. Das, and R. T. McDonald (2011) A universal part-of-speech tagset. In LREC, Cited by: §3.4.
  • A. Sai, M. D. Gupta, M. M. Khapra, and M. Srinivasan (2019) Re-evaluating adem: a deeper look at scoring dialogue responses. In AAAI, Cited by: §3.1.
  • S. Serrano and N. A. Smith (2019) Is attention interpretable?. In ACL, Cited by: §1, §5.3, §6.
  • W. Shao, T. Meng, J. Li, R. Zhang, Y. Li, X. Wang, and P. Luo (2019) SSN: learning sparse switchable normalization via sparsestmax. In CVPR, Cited by: §6.
  • R. Socher, A. Perelygin, J. Wu, J. Chuang, C. D. Manning, A. Y. Ng, and C. Potts (2013) Recursive deep models for semantic compositionality over a sentiment treebank. In EMNLP, Cited by: §2.
  • M. Sundararajan, A. Taly, and Q. Yan (2017) Axiomatic attribution for deep networks. In ICML, Cited by: §5.5.
  • R. S. Sutton, D. A. McAllester, S. P. Singh, and Y. Mansour (1999)

    Policy gradient methods for reinforcement learning with function approximation

    .
    In NIPS, Cited by: §5.4.
  • I. Tenney, D. Das, and E. Pavlick (2019) BERT rediscovers the classical nlp pipeline. In ACL, Cited by: §6.
  • Y. Tsai, S. Bai, M. Yamada, L. Morency, and R. Salakhutdinov (2019) Empirical study of transformer’s attention mechanism via the lens of kernel. In IJCNLP 2019, Cited by: §6.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin (2017) Attention is all you need. In NIPS, Cited by: §1.
  • J. Vig and Y. Belinkov (2019) Analyzing the structure of attention in a transformer language model. ArXiv abs/1906.04284. Cited by: §6.
  • A. Wang, A. Singh, J. Michael, F. Hill, O. Levy, and S. R. Bowman (2018) GLUE: a multi-task benchmark and analysis platform for natural language understanding. In BlackboxNLP@EMNLP, Cited by: §2.
  • J. Weston, A. Bordes, S. Chopra, and T. Mikolov (2015) Towards ai-complete question answering: a set of prerequisite toy tasks. CoRR abs/1502.05698. Cited by: §2.
  • S. Wiegreffe and Y. Pinter (2019) Attention is not not explanation. ArXiv abs/1908.04626. Cited by: §1, §6.