Stochastic Answer Networks for Natural Language Inference

04/21/2018 ∙ by Xiaodong Liu, et al. ∙ Microsoft Johns Hopkins University 0

We propose a stochastic answer network (SAN) to explore multi-step inference strategies in Natural Language Inference. Rather than directly predicting the results given the inputs, the model maintains a state and iteratively refines its predictions. Our experiments show that SAN achieves the state-of-the-art results on three benchmarks: Stanford Natural Language Inference (SNLI) dataset, MultiGenre Natural Language Inference (MultiNLI) dataset and Quora Question Pairs dataset.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Motivation

The natural language inference task, also known as recognizing textual entailment (RTE), is to infer the relation between a pair of sentences (e.g., premise and hypothesis). This task is challenging, since it requires a model to fully understand the sentence meaning, (i.e., lexical and compositional semantics). For instance, the following example from MultiNLI dataset Williams et al. (2017) illustrates the need for a form of multi-step synthesis of information between premise

: “If you need this book, it is probably too late unless you are about to take an SAT or GRE.”

, and hypothesis: “It’s never too late, unless you’re about to take a test.” To predict the correct relation between these two sentences, the model needs to first infer that “SAT or GRE” is a “test”, and then pick the correct relation, e.g., contradiction.

This kind of iterative process can be viewed as a form of multi-step inference. To best of our knowledge, all of works on NLI use a single step inference. Inspired by the recent success of multi-step inference on Machine Reading Comprehension (MRC) Hill et al. (2016); Dhingra et al. (2016); Sordoni et al. (2016); Kumar et al. (2015); Liu et al. (2018); Shen et al. (2017); Xu et al. (2018), we explore the multi-step inference strategies on NLI. Rather than directly predicting the results given the inputs, the model maintains a state and iteratively refines its predictions. We show that our model outperforms single-step inference and further achieves the state-of-the-art on SNLI, MultiNLI, SciTail, and Quora Question Pairs datasets.

2 Multi-step inference with SAN

The natural language inference task as defined here involves a premise of words and a hypothesis of words, and aims to find a logic relationship between and , which is one of labels in a close set: entailment, neutral and contradiction. The goal is to learn a model .

In a single-step inference architecture, the model directly predicts given and as input. In our multi-step inference architecture, we additionally incorporate a recurrent state ; the model processes multiple passes through and , iteratively refining the state , before finally generating the output at step , where is an a priori chosen limit on the number of inference steps.

trim=.0810pt .010pt .050pt .010pt,clip

Figure 1: Architecture of the Stochastic Answer Network (SAN) for Natural Language Inference.

Figure 1 describes in detail the architecture of the stochastic answer network (SAN) used in this study; this model is adapted from the MRC multi-step inference literature Liu et al. (2018)

. Compared to the original SAN for MRC, in the SAN for NLI we simplify the bottom layers and Self-attention layers since the length of the premise and hypothesis is short). We also modify the answer module from prediction a text span to an NLI classification label. Overall, it contains four different layers: 1) the lexicon encoding layer computes word representations; 2) the contextual encoding layer modifies these representations in context; 3) the memory generation layer gathers all information from the premise and hypothesis and forms a “working memory” for the final answer module; 4) the final answer module, a type of multi-step network, predicts the relation between the premise and hypothesis.

Lexicon Encoding Layer. First we concatenate word embeddings and character embeddings to handle the out-of-vocabulary words111We omit POS Tagging and Name Entity Features for simplicity. Following Liu et al. (2018), we use two separate two-layer position-wise feedforward network Vaswani et al. (2017) to obtain the final lexicon embedings, and , for the tokens in and , respectively. Here, is the hidden size.

Contextual Encoding Layer. Two stacked BiLSTM layers are used on the lexicon encoding layer to encode the context information for each word in both and . Due to the bidirectional layer, it doubles the hidden size. We use a maxout layer Goodfellow et al. (2013) on the BiLSTM to shrink its output into its original hidden size. By a concatenation of the outputs of two BiLSTM layers, we obtain and as representation of and , respectively.

Memory Layer. We construct our working memory via an attention mechanism. First, a dot-product attention is adopted like in Vaswani et al. (2017) to measure the similarity between the tokens in and . Instead of using a scalar to normalize the scores as in Vaswani et al. (2017), we use a layer projection to transform the contextual information of both and :


where is an attention matrix, and dropout is applied for smoothing. Note that and is transformed from and

by one layer neural network

, respectively. Next, we gather all the information on premise and hypothesis by: and . The semicolon

indicates vector/matrix concatenation;

is the transpose of . Last, the working memory of the premise and hypothesis is generated by using a BiLSTM based on all the information gathered: and .

Answer module. Formally, our answer module will compute over memory steps and output the relation label. At the beginning, the initial state is the summary of the : , where . At time step in the range of , the state is defined by . Here, is computed from the previous state and memory : and . Following Mou et al. (2015)

, one layer classifier is used to determine the relation at each step



At last, we utilize all of the outputs by averaging the scores:



is a probability distribution over all the relations,

. During training, we apply stochastic prediction dropout before the above averaging operation. During decoding, we average all outputs to improve robustness.

This stochastic prediction dropout is similar in motivation to the dropout introduced by Srivastava et al. (2014). The difference is that theirs is dropout at the intermediate node-level, whereas ours is dropout at the final layer-level. Dropout at the node-level prevents correlation between features. Dropout at the final layer level, where randomness is introduced to the averaging of predictions, prevents our model from relying exclusively on a particular step to generate correct output.

3 Experiments

3.1 Dataset

Here, we evaluate our model in terms of accuracy on four benchmark datasets. SNLI Bowman et al. (2015) contains 570k human annotated sentence pairs, in which the premises are drawn from the captions of the Flickr30 corpus, and hypothesis are manually annotated. MultiNLI Williams et al. (2017) contains 433k sentence pairs, which are collected similarly as SNLI. However, the premises are collected from a broad range of genre of American English. The test and development sets are further divided into in-domain (matched) and cross-domain (mismatched) sets. The Quora Question Pairs dataset Wang et al. (2017) is proposed for paraphrase identification. It contains 400k question pairs, and each question pair is annotated with a binary value indicating whether the two questions are paraphrase of each other. SciTail dataset is created from a science question answering (SciQ) dataset. It contains 1,834 questions with 10,101 entailments examples and 16,925 neutral examples. Note that it only contains two types of labels, so is a binary task.

3.2 Implementation details

The spaCy tool222

is used to tokenize all the dataset and PyTorch is used to implement our models. We fix word embedding with 300-dimensional GloVe word vectors

Pennington et al. (2014). For the character encoding, we use a concatenation of the multi-filter Convolutional Neural Nets with windows and the hidden size .333We limit the maximum length of a word by 20 characters. The character embedding size is set to 20. So lexicon embeddings are =600-dimensions. The embedding for the out-of-vocabulary is zeroed. The hidden size of LSTM in the contextual encoding layer, memory generation layer is set to 128, thus the input size of output layer is 1024 (128 * 2 * 4) as Eq 2. The projection size in the attention layer is set to 256. To speed up training, we use weight normalization Salimans and Kingma (2016). The dropout rate is 0.2, and the dropout mask is fixed through time steps Gal and Ghahramani (2016) in LSTM. The mini-batch size is set to 32. Our optimizer is Adamax Kingma and Ba (2014)

and its learning rate is initialized as 0.002 and decreased by 0.5 after each 10 epochs.

Single-step SAN
 MultiNLI matched 78.69 79.88
 MultiNLI mismatched 78.83 79.91
 SNLI 88.32 88.73
 Quora 89.67 90.70
 SciTail 85.46 89.35
Table 1: Comparison of single and multi-step inference strategies on MultiNLI, SNLI, Quora Question and SciTail dev sets.

3.3 Results

One main question which we would like to address is whether the multi-step inference help on NLI. We fixed the lower layer and only compare different architectures for the output layer:

  1. Single-step: Predict the relation using Eq 2 based on and . Here, , where 444For direct comparison, this has the same three lower layers as Fig. 1 and only changes the answer module..

  2. SAN: The multi-step inference model. We use 5-steps with the prediction dropout rate on the all experiments.

Table 1 shows that our multi-step model consistently outperforms the single-step model on the dev set of all four datasets in terms of accuracy. For example, on SciTail dataset, SAN outperforms the single-step model by +3.89 (85.46 vs 89.35).

We compare our results with the state-of-the-art in Table 2. Our model achieves the best performance on SciTai and Quora Question tasks. For instance, SAN obtains 89.4 (vs 89.1) and 88.4 (88.3) on the Quora Question and SciTail test set, respectively and set the new state-of-the-art. On SNLI and MultiNLI dataset, ESIM+ELMo Peters et al. (2018), GPT Radford et al. (2018) and BERT Devlin et al. (2018) use a large amount of external knowledge or a large scale pretrained contextual embeddings. However, SAN is still competitive these models. On SciTail dataset, SAN even outperforms GPT. Due to the space limitation, we only list two top models.555See leaderboard for more information:, https//,,

We further utilize BERT as a feature extractor666We run BERT (the base model) to extract embeddings of both premise and hypothesis and then feed it to answer models for a fair comparison. and use the SAN answer module on top of it. Comparing with Single-step baseline, the proposed model obtains +2.8 improvement on the SciTail test set (94.0 vs 91.2) and +2.1 improvement on the SciTail dev set (96.1 vs 93.9). This shows the generalization of the proposed model which can be easily adapted on other models 777Due to highly time consumption and space limitation, we omit the results using BERT on SNLI/MNLI/Quora Question dataset..

 Model MultiNLI Test
 Matched Mismatched
 DIINGong et al. (2017) 78.8 77.8
 BERTDevlin et al. (2018) 86.7 85.9
 SAN 79.3 78.7
SNLI Dataset (Accuracy%)
 ESIM+ELMo 88.7
 GPTRadford et al. (2018) 89.9
 SAN 88.7
Quora Question Dataset (Accuracy%)
Tomar et al. (2017) 88.4
Gong et al. (2017) 89.1
 SAN 89.4
SciTail Dataset (Accuracy%)
Khot et al. (2018) 77.3
 GPTRadford et al. (2018) 88.3
 SAN 88.4
Table 2: Comparison with the state-of-the-art on MultiNLI, SNLI and Quora Question test sets.
 Tag Matched Mismatched
Chen SAN Chen SAN
 Conditional 100% 65% 100% 81%
 Word overlap 63 % 86% 76% 92%
 Negation 75% 80% 72% 79%
 Antonym 50% 77% 58% 85%
 Long Sentence 67% 84% 67% 79%
 Tense Difference 86% 75% 89% 83%
 Active/Passive 88% 100% 91% 100%
 Paraphrase 78% 92% 89% 92%
 Quantity/Time 33% 53% 46% 51%
 Coreference 83% 73% 80% 84%
 Quantifier 74% 81% 77% 80%
 Modal 75% 79% 76% 82%
 Belief 73% 77% 74% 78%
Table 3: Error analysis on MultiNLI. See Nangia et al. (2017) for reference.

Analysis: How many steps it needs? We search the number of steps from 1 to 10. We observe that when increases, our model obtains a better improvement (e.g., 86.7 ()); however when or , it achieves best results (89.4) on SciTail dev set and then begins to downgrade the performance. Thus, we set in all our experiments.

We also looked internals of our answer module by dumping predictions of each step (the max step is set to 5). Here is an example888Its ID is id 144185n with premise (And he said, What’s going on?) and hypothesis (I told him to mind his own business.) from MutiNLI dev set. Our model produces total 5 labels (contradiction, neutral, neutral, neutral, and neutral) at each step and makes the final decision by voting neutral. Surprising, we found that human annotators also gave different 5 labels: contradiction, neutral, neutral, neutral, neutral. It shows robustness of our model which uses collective wise.

Finally, we analyze our model on the annotated subset999
of development set of MultiNLI. It contains 1,000 examples, each tagged by categories shown in Table 3. Our model outperforms the best system in RepEval 2017 Chen et al. (2017) in most cases, except on “Conditional” and “Tense Difference” categories. We also find that SAN works extremely well on “Active/Passive” and “Paraphrase” categories. Comparing with Chen’s model, the biggest improvement of SAN (50% vs 77% and 58% vs 85% on Matched and Mismatched settings respectively) is on the “Antonym” category. In particular, on the most challenging “Long Sentence” and “Quantity/Time” categories, SAN’s result is substantially better than previous systems. This demonstrates the robustness of multi-step inference.

4 Conclusion

We explored the use of multi-step inference in natural language inference by proposing a stochastic answer network (SAN). Rather than directly predicting the results (e.g. relation such as entailment or not) given the input premise and hypothesis , SAN maintains a state , which it iteratively refines over multiple passes on and in order to make a prediction. Our state-of-the-art results on four benchmarks (SNLI, MultiNLI, SciTail, Quora Question Pairs) show the effectiveness of this multi-step inference architecture. In future, we would like to incorporate the pertrained contextual embedding, e.g., ELMo Peters et al. (2018) and GPT Radford et al. (2018) into our model and multi-task learning Liu et al. (2019).