Reinforced Self-Attention Network: a Hybrid of Hard and Soft Attention for Sequence Modeling

01/31/2018 ∙ by Tao Shen, et al. ∙ University of Technology Sydney University of Washington Griffith University 0

Many natural language processing tasks solely rely on sparse dependencies between a few tokens in a sentence. Soft attention mechanisms show promising performance in modeling local/global dependencies by soft probabilities between every two tokens, but they are not effective and efficient when applied to long sentences. By contrast, hard attention mechanisms directly select a subset of tokens but are difficult and inefficient to train due to their combinatorial nature. In this paper, we integrate both soft and hard attention into one context fusion model, "reinforced self-attention (ReSA)", for the mutual benefit of each other. In ReSA, a hard attention trims a sequence for a soft self-attention to process, while the soft attention feeds reward signals back to facilitate the training of the hard one. For this purpose, we develop a novel hard attention called "reinforced sequence sampling (RSS)", selecting tokens in parallel and trained via policy gradient. Using two RSS modules, ReSA efficiently extracts the sparse dependencies between each pair of selected tokens. We finally propose an RNN/CNN-free sentence-encoding model, "reinforced self-attention network (ReSAN)", solely based on ReSA. It achieves state-of-the-art performance on both Stanford Natural Language Inference (SNLI) and Sentences Involving Compositional Knowledge (SICK) datasets.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Equipping deep neural networks (DNN) with attention mechanisms provides an effective and parallelizable approach for context fusion and sequence compression. It achieves compelling time efficiency and state-of-the-art performance in a broad range of natural language processing (NLP) tasks, such as neural machine translation

[Bahdanau et al.2015, Luong et al.2015], dialogue generation [Shang et al.2015], machine reading/comprehension [Seo et al.2017], natural language inference [Liu et al.2016], sentiment classification [Li et al.2017b], etc. Recently, some neural nets based solely on attention, especially self-attention, outperform traditional recurrent [Bowman et al.2015] or convolutional [Dong et al.2017] neural networks on NLP tasks, such as machine translation [Vaswani et al.2017] and sentence embedding [Shen et al.2018], which further demonstrates the power of attention mechanisms in capturing contextual dependencies.

Soft and hard attention are the two main types of attention mechanisms. In soft attention [Bahdanau et al.2015], a categorical distribution is calculated over a sequence of elements. The resulting probabilities reflect the importance of each element and are used as weights to produce a context-aware encoding that is the weighted sum of all elements. Hence, soft attention only requires a small number of parameters and less computation time. Moreover, soft attention mechanism is fully differentiable and thus can be easily trained by end-to-end back-propagation when attached to any existing neural net. However, the function usually assigns small but non-zero probabilities to trivial elements, which will weaken the attention given to the few truly significant elements.

Unlike the widely-studied soft attention, in hard attention [Xu et al.2015], a subset of elements is selected from an input sequence. Hard attention mechanism forces a model to concentrate solely on the important elements, entirely discarding the others. In fact, various NLP tasks solely rely on very sparse tokens from a long text input. Hard attention is well suited to these tasks, because it overcomes the weaknesses associated with soft attention in long sequences. However, hard attention mechanism is time-inefficient with sequential sampling and non-differentiable by virtue of their combinatorial nature. Thus, it cannot be optimized through back-propagation and more typically rely on policy gradient, e.g., REINFORCE [Williams1992]

. As a result, training a hard attention model is usually an inefficient process – some even find convergence difficult – and combining them with other neural nets in an end-to-end manner is problematic.

However, soft and hard attention mechanisms might be integrated into a single model to benefit each other in overcoming their inherent disadvantages, and this notion motivates our study. Specifically, a hard attention mechanism is used to encode rich structural information about the contextual dependencies and trims a long sequence into a much shorter one for a soft attention mechanism to process. Conversely, the soft one is used to provide a stable environment and strong reward signals to help in training the hard one. Such method would improve both the prediction quality of the soft attention mechanism and the trainability of the hard attention mechanism, while boosting the ability to model contextual dependencies. To the best of our knowledge, the idea of combining hard and soft attention within a model has not yet been studied. Existing works focus on only one of the two types.

In this paper, we first propose a novel hard attention mechanism called “reinforced sequence sampling (RSS)”, which selects tokens from an input sequence in parallel, and differs from existing ones in that it is highly parallelizable without any recurrent structure. We then develop a model,“reinforced self-attention (ReSA)”, which naturally combines the RSS with a soft self-attention. In ReSA, two parameter-untied RSS are respectively applied to two copies of the input sequence, where the tokens from one and another are called dependent and head tokens, respectively. ReSA only models the sparse dependencies between the head and dependent tokens selected by the two RSS modules. Finally, we build an sentence-encoding model, “reinforced self-attention network (ReSAN)”, based on ReSA without any CNN/RNN structure.

We test ReSAN on natural language inference and semantic relatedness tasks. The results show that ReSAN achieves the best test accuracy among all sentence-encoding models on the official leaderboard of the Stanford Natural Language Inference (SNLI) dataset, and state-of-the-art performance on the Sentences Involving Compositional Knowledge (SICK) dataset. Compared to the commonly-used models, ReSAN is more efficient and has better prediction quality than existing recurrent/convolutional neural networks, self-attention networks, and even well-designed models (e.g., semantic tree or external memory based models). All the experiments codes are released at


: 1) lowercase denotes a vector; 2) bold lowercase denotes a sequence of vectors (stored as a matrix); and 3) uppercase denotes a matrix or a tensor.

2 Background

2.1 Attention

Given an input sequence ( denotes the embedded vector of -th element), and the vector representation of a query , an vanilla attention mechanism uses a parameterized compatibility function to computes an alignment score between and each token as the attention of to [Bahdanau et al.2015]. A function is then applied to the alignment scores over all tokens to generate a categorical distribution , where implies that token is selected according to its relevance to query . This can be formally written as


The output of attention, , is the expectation of sampling a token according to the categorical distribution , i.e.,


Multi-dimensional (multi-dim) attention mechanism [Shen et al.2018] extends the vanilla one [Bahdanau et al.2015] to a feature-wise level, i.e., each feature of every token has an alignment score. Hence, rather than a scalar, the output of is a vector with the same dimensions as the input, and the resulting alignment scores compose a matrix

. Such feature-level attention has been verified in terms of its ability to capture the subtle variances of different contexts.

2.2 Self-Attention

Self-attention is a special case of attention where the query stems from the input sequence itself. Hence, self-attention mechanism can model the dependencies between tokens from the same sequence. Recently, a variety of self-attention mechanisms have been developed, each serving a distinct purpose, but most can be roughly categorized into two types, token2token self-attention and source2token self-attention.

Token2token self-attention mechanisms aim to produce a context-aware representation for each token in light of its dependencies on other tokens in the same sequence. The query is replaced with the token , and the dependency of on another token is computed by . There are two proposed self-attentions in this type, i.e., scaled dot-product attention which composes the multi-head attention [Vaswani et al.2017] and masked self-attention which leads to directional self-attention [Shen et al.2018]. Because the latter experimentally outperforms the former, we select the masked self-attention as our fundamental soft self-attention module.

Masked Self-Attention

is more sophisticated than scaled dot-product attention in that, it uses multi-dim and multi-layer perceptron with an additional position mask, rather than a scaled dot-product, as the compatibility function, i.e.,


where is a scalar and is the mask with each entry . When , applying the function to results in a zero probability, , which switches off the attention of to . An asymmetric mask where enforces directional attention between and , which can encode temporal order information. Two positional masks have been designed to encode the forward and backward temporal order, respectively, i.e.,

In forward and backward masks, . Thus, the attention of a token to itself is blocked, so the output of masked self-attention mechanism comprises the features of the context around each token rather than context-aware features.

Directional self-attention uses a fusion gate to combine the embedding of each token with its context. Specifically, a fusion gate combines the input and output of a masked self-attention to produce context-aware representations. This idea is similar to the highway network [Srivastava et al.2015].

Source2token self-attention mechanisms [Shen et al.2018] remove the query from the compatibility function in Eq.(1) and directly compresses a sequence into a vector representation calculated from the dependency between each token and the entire input sequence . Hence, this form of self-attention is highly data- and task- driven.

3 Proposed Models

This section begins by introducing a hard attention mechanism called RSS in Section 3.1, followed by integrating the RSS with a soft self-attention mechanism into a context fusion model called ReSA in Section 3.2. Finally, a model named ReSAN, based on ReSA, is designed for sentence encoding tasks in Section 3.3

3.1 Reinforced Sequence Sampling (RSS)

The goal of hard attention mechanism is to select a subset of critical tokens that provides sufficient information to complete downstream tasks, so any further computations on the trivial tokens can be saved. In the following, we introduce a hard attention mechanism called RSS. Given an input sequence

, RSS generates an equal-length sequence of binary random variables

where implies that is selected whereas indicates that is discarded. In RSS, the elements of are sampled in parallel according to probabilities computed by a learned attention mechanism. This is more efficient than using MCMC with iterative sampling. The particular aim of RSS is to learn the following product distribution.


The function denotes a context fusion layer, e.g., Bi-LSTM, Bi-GRU, etc., producing context-aware representation for each . Then, maps to the probability of selecting the token. Note we can sample all for different in parallel because the probability of (i.e., whether is selected) does not depends on . This is because the context features given by already take the sequential information into account, so the conditionally independent sampling does not discard any useful information.

To fully explore the high parallelizability of attention, we avoid using recurrent models in this paper. Instead we apply a more efficient inspired by source2token self-attention and intra-attention [Liu et al.2016], i.e.,


where denotes the element-wise product, and the represents the mean-pooling operation along the sequential axis. RSS selects a subset of tokens by sampling according to the probability given by for all in parallel.

For the training of RSS, there are no ground truth labels to indicate whether or not a token should be selected, and the discrete random variables in

lead to a non-differentiable objective function. Therefore, we formulate learning the RSS parameter

as a reinforcement learning problem, and apply the policy gradient method. Further details on the model training are presented in Section


3.2 Reinforced Self-Attention (ReSA)

The fundamental idea behind this paper is that the hard and soft attention mechanisms can mutually benefit each other to overcome their inherent disadvantages via interaction within an integrated model. Based on this idea, we develop a novel self-attention termed ReSA. On the one hand, the proposed RSS provides a sparse mask to a self-attention module that only needs to model the dependencies for the selected token pairs. Hence, heavy memory loads and computations associated with soft self-attention can be effectively relieved. On the other hand, ReSA uses the output of the soft self-attention module for prediction, whose correctness (as compared to the ground truth) is used as reward signal to train the RSS. This alleviates the difficulty of training hard attention module.

Figure 1: Reinforced self-attention (ReSA) model. denotes the alignment score obtained from .

Figure 1 shows the detailed architecture of ReSA. Given the token embedding in an input sequence, , ReSA aims to produce token-wise context-aware representations, . Unlike previous self-attention mechanisms, ReSA only selects a subset of head tokens, and generates their context-aware representations by only relating each head token to a small subset of dependent tokens. This notion is based on the observation that for many NLP tasks, the final prediction only relies on a small set of key words and their contexts, and each key word only depends on a small set of other words. Namely, the dependencies between tokens from the same sequence are sparse.

In ReSA, we use two RSS modules, as outlined in Section 3.1, to generate two sequences of labels for the selections of head and dependent tokens, respectively, i.e.,


We use and sampled from the two independent (parameter untied) RSS to generate an mask , i.e.,


The resulting mask is then applied as an extra mask to the masked self-attention mechanism introduced in Section 2.2. Specifically, we add to Eq.(2.2) and use


to generate the alignment scores. For each head token , a function is applied to , which produces a categorical distribution over all dependent tokens, i.e.,


The context features of is computed by


where denotes a broadcast product in the vanilla attention or an element-wise product in the multi-dim attention.

For a selected head token, as formulated in Eq.(10), the attention from a token to itself is disabled in , so the for the selected head token encodes only the context features but not the desired context-ware embedding. For an unselected head token with , its alignment scores over all dependent tokens are equal to , which leads to the equal probabilities in produced by the function. Hence, for each unselected token can be regarded as the result of mean-pooling over all dependent tokens.

To merge the word embedding with its context feature for the selected heads, and distinguish the representations from others for the unselected heads, a fusion gate is used to combine with the input embedding in parallel and generate the final context-aware representations for all tokens, i.e.,


where , are the learnable parameters. The context-aware representations, , are final output. One primary advantage of ReSA is that it generates better predictions using less time and memory than existing self-attention mechanisms. In particular, major computations of ReSA are 1) the inference of self-attention over a shorter subsequence, and 2) the mean-pooling over the remaining elements. This is much more time- and memory- efficient than computing the self-attention over the entire input sequence.

3.3 Applications of the Proposed Models

To adapt ReSA for sentence encoding tasks, we build an RNN/CNN-free network, called reinforced self-attention network (ReSAN), which is solely based on ReSA and source2token self-attention (Section 2.2). In particular, we pass the output sequence of ReSA into a source2token self-attention module to generate a compressed vector representation, , which encodes the semantic and syntactic knowledge of the input sentence and can be used for various downstream NLP tasks.

Further, we propose two simplified variants of ReSAN with a simpler structure or fewer parameters, i.e., 1) ReSAN w/o unselected heads which only applies the soft self-attention to the selected head and dependent tokens, and 2) ReSAN w/o dependency restricted which use only one RSS to select tokens for both heads and dependents. Both variants entirely discard the information of the unselected tokens and hence are more time-efficient. However, neither can be used for context fusion, because the input and output sequences are not equal in length.

4 Model Training

The parameters in ReSAN can be divided into two parts, for the RSS modules and for the rest parts which includes word embeddings, soft self-attention module, and classification/regression layers. Learning is straightforward and can be completed by back-propagation in an end-to-end manner. However, Optimizing is more challenging because the RSS modules contain discrete variables and, thus, the objective function is non-differentiable w.r.t. .

In supervised classification settings, we use the cross-entropy loss plus L2 regularization penalty as the loss, i.e.,


where denotes a sample from dataset . The loss above is used for learning by back-propagation algorithm.



Inference T(s) Train Accuracy Test Accuracy
300D LSTM encoders [Bowman et al.2016] 3.0m 83.9 80.6
300D SPINN-PI encoders [Bowman et al.2016] 3.7m 89.2 83.2
600D Bi-LSTM encoders [Liu et al.2016] 2.0m 86.4 83.3
600D Bi-LSTM +intra-attention [Liu et al.2016] 2.8m 84.5 84.2
300D NSE encoders [Munkhdalai and Yu2017] 3.0m 86.2 84.6
600D Deep Gated Attn. [Chen et al.2017] 11.6m 90.5 85.5
600D Gumbel TreeLSTM encoders [Choi et al.2017b] 10m 93.1 86.0
600D Residual stacked encoders [Nie and Bansal2017] 29m 91.0 86.0
Bi-LSTM [Graves et al.2013] 2.9m 2080 9.2 90.4 85.0
Bi-GRU [Chung et al.2014] 2.5m 1728 9.3 91.9 84.9
Multi-window CNN [Kim2014] 1.4m 284 2.4 89.3 83.2
Hierarchical CNN [Gehring et al.2017] 3.4m 343 2.9 91.3 83.9
Multi-head [Vaswani et al.2017] 2.0m 345 3.0 89.6 84.2
DiSAN [Shen et al.2018] 2.4m 587 7.0 91.1 85.6
300D ReSAN 3.1m 622 5.5 92.6 86.3
Table 1: Experimental results for different methods on SNLI. : the number of parameters (excluding word embedding part). T(s)/epoch: average training time (second) per epoch. Inference T(s): average inference time (second) for all dev data on SNLI with a batch size of .

Optimizing is formulated as a reinforcement learning problem solved by the policy gradient method (i.e., REINFORCE algorithm). In particular, RSS plays as an agent and takes action of whether to select a token or not. After going through the entire sequence, it receives a loss value from the classification problem, which can be regarded as the negative delay reward to train the agent. Since the overall goal of RSS is to select a small subset of tokens for better efficiency and meanwhile retain useful information, a penalty limiting the number of selected tokens is included in the reward , i.e.,


where is the penalty weight and is fine-tuned with values from in all experiments. Then, the objective of learning is to maximize the expected reward, i.e.,


where the and is sample number in the dataset. Based on REINFORCE, the policy gradient of w.r.t is


Although theoretically feasible, it is not practical to optimize and simultaneously, since the neural nets cannot provide accurate reward feedback to the hard attention at the beginning of the training phrase. Therefore, in early stage, the RSS modules are not updated, but rather forced to select all tokens (i.e., ). And, is optimized for several beginning epochs until the loss over development set does not decrease significantly. The resulting ReSAN now can provide a solid environment for training RSS modules through reinforcement learning. and are then optimized simultaneously to pursue a better performance by selecting critical token pairs and exploring their dependencies.

Training Setup:

All experiments are conducted in Python with Tensorflow and run on a Nvidia GTX 1080Ti. We use Adadelta as optimizer, which performs more stable than Adam on ReSAN. All weight matrices are initialized by Glorot Initialization

[Glorot and Bengio2010] and the biases are initialized as zeros. We use 300D GloVe 6B pre-trained vectors [Pennington et al.2014] to initialize the word embeddings [Liu et al.2018]

. The words which do not appear in GloVe from the training set are initialized by sampling from uniform distribution between

. We choose Dropout [Srivastava et al.2014] keep probability from for all models and report the best result. The weight decay factor for L2 regularization is set to . The number of hidden units is .

5 Experiments

We implement ReSAN, its variants and baselines on two NLP tasks, language inference in Section 5.1 and semantic relatedness in Section 5.2. A case study is then given to provide the insights into model.

The baselines are listed as follows: 1) Bi-LSTM: 600D bi-directional LSTM (300D forward LSTM + 300D backward LSTM) [Graves et al.2013]; 2) Bi-GRU: 600D bi-directional GRU [Chung et al.2014]; 3) Multi-window CNN: 600D CNN sentence embedding model (200D for each of 3, 4, 5-gram) [Kim2014]; 4) Hierarchical CNN: 3-layer 300D CNN [Gehring et al.2017] with kernel length 5. GLU [Dauphin et al.2016]

and residual connection

[He et al.2016b] are applied; 5) Multi-head: 600D multi-head attention (8 heads, each has 75 hidden units), where the positional encoding method is applied to the input [Vaswani et al.2017]; 6) DiSAN: 600D directional self-attention network (forward+backward masked self-attn) [Shen et al.2018].

5.1 Natural Language Inference

The goal of natural language inference is to infer the semantic relationship between a pair of sentences, i.e., a premise and the corresponding hypothesis. The possible relationships are entailment, neutral or contradiction. This experiment is conducted on the Stanford Natural Language Inference [Bowman et al.2015] (SNLI) dataset which consists of 549,367/9,842/9,824 samples for training/dev/test.

In order to apply sentence encoding model to SNLI, we follow bowman2016fast bowman2016fast and use two parameter-tied sentence encoding models to respectively produce the premise and the hypothesis encodings, i.e., , . Their semantic relationship is represented by the concatenation of , , and , which is passed to a classification module to generate a categorical distribution over the three classes.

The experimental results for different methods from leaderboard and our baselines are shown in Table 1. Compared to the methods from official leaderboard, ReSAN outperforms all the sentence encoding based methods and achieves the best test accuracy. Specifically, compared to the last best models, i.e., 600D Gumbel TreeLSTM encoders and 600D Residual stacked encoders, ReSAN uses far fewer parameters with better performance. Moreover, in contrast to the RNN/CNN based models with attention or memory module, ReSAN uses attention-only modules with equal or fewer parameters but outperforms them by a large margin, e.g., 600D Bi-LSTM + intra-attention (+3.0%), 300D NSE encoders (+1.7%) and 600D Deep Gated Attn (+0.8%). Furthermore, ReSAN even outperforms the 300D SPINN-PI encoders by 3.1%., which is a recursive model and uses the result of an external semantic parsing tree as an extra input.

In addition, we compare ReSAN with recurrent, convolutional, and attention-only baseline models in terms of the number of parameters, training/inference time and test accuracy. Compared to the recurrent models (e.g., Bi-LSTM and Bi-GRU), ReSAN shows better prediction quality and more compelling efficiency due to parallelizable computations. Compared to the convolutional models (i.e., Multi-window CNN and Hierarchical CNN), ReSAN significantly outperforms them by 3.1% and 2.4% respectively due to the weakness of CNNs in modeling long-range dependencies. Compared to the attention-based models, multi-head attention and DiSAN, ReSAN uses a similar number of parameters with better test performance and less time cost.

Model Inference T(s) Test Accu.
ReSAN 3.1m 5.5 86.3
ReSAN w/o unselected heads 3.1m 5.3 86.1
ReSAN w/o dependency restricted 2.8m 4.6 85.6
ReSAN w/o hard attention 2.5m 7.0 86.0
ReSAN w/o soft self-attention 1.0m 1.6 83.4
ReSAN w/o all attentions 0.5m 1.8 83.1
Table 2: An ablation study of ReSAN.

Further, we conduct an ablation study of ReSAN, as shown in Table 2, to evaluate the contribution of each component. One by one, each component is removed and the changes in test accuracy are recorded. In addition to the two variants of ReSAN introduced in Section 3.3, we also remove 1) the hard attention module, 2) soft self-attention module and 3) both hard attention and soft self-attention modules. In terms of prediction quality, the results show that 1) the unselected head tokens do contribute to the prediction, bringing 0.2% improvement; 2) using separate RSS modules to select the head and dependent tokens improves accuracy by 0.5%; and 3) hard attention and soft self-attention modules improve the accuracy by 0.3% and 2.9% respectively. In terms of inference time, it shows that 1) the two variants are more time-efficient but have poorer performance; and 2) applying the RSS modules to self-attention or attention improves not only performance but also time efficiency.

5.2 Semantic Relatedness

Semantic relatedness aims to predict the similarity degree of a given pair of sentences, which is formulated as a regression problem. We use and to denote the encodings of the two sentences, and assume the similarity degree is a scalar between . Following tai2015improved tai2015improved, the relationship between the two sentences is represented as a concatenation of and . The representation is fed into a classification module with -way categorical distribution output. We implement ReSAN and baselines on the Sentences Involving Compositional Knowledge [Marelli et al.2014] (SICK) dataset, which provides the ground truth as similarity degree between . SICK come with a standard training/dev/test split of 4,500/500/4,927 samples.

Model Pearson’s Spearman’s MSE
Meaning Factory .8268 .7721 .3224
ECNU .8414 / /
DT-RNN .7923 (.0070) .7319 (.0071) .3822 (.0137)
SDT-RNN .7900 (.0042) .7304 (.0042) .3848 (.0042)
Cons. Tree-LSTM .8582 (.0038) .7966 (.0053) .2734 (.0108)
Dep. Tree-LSTM .8676 (.0030) .8083 (.0042) .2532 (.0052)
Bi-LSTM .8473 (.0013) .7913 (.0019) .3276 (.0087)
Bi-GRU .8572 (.0022) .8026 (.0014) .3079 (.0069)
Multi-window CNN .8374 (.0021) .7793 (.0028) .3395 (.0086)
Hierarchical CNN .8436 (.0014) .7874 (.0022) .3162 (.0058)
Multi-head .8521 (.0013) .7942 (.0050) .3258 (.0149)
DiSAN .8695 (.0012) .8139 (.0012) .2879 (.0036)
ReSAN .8720 (.0014) .8163 (.0018) .2623 (.0053)
Table 3:

Experimental results for different methods on SICK semantic relatedness dataset. The reported accuracies are the mean of five runs (standard deviations in parentheses). Cons. and Dep. represent Constituency and Dependency, respectively.

[Bjerva et al.2014], [Zhao et al.2014], [Socher et al.2014], [Tai et al.2015]

The results in Table 3 show that the ReSAN achieves state-of-the-art or competitive performance for all three metrics. Particularly, ReSAN outperforms the feature engineering method by a large margin, e.g., Meaning Factory and ECNU. ReSAN also significantly outperforms the recursive models, which is widely used in semantic relatedness task, especially ones that demand external parsing results, e.g., DT/SDT-RNN and Tree-LSTMs. Further, ReSAN achieves the best results among all the recurrent, convolutional and self-attention models listed as baselines. This thoroughly demonstrates the capability of ReSAN in context fusion and sentence encoding.

5.3 Case Study

To gain an insights into how the hard/soft attention and fusion gate work within ReSA, we visualize their resulting values in this section. Note that only the values at token level are illustrated. If the attention probabilities and the gate values are feature-level, we average the probabilities over all features.

Two sentences from the SNLI test set serve as examples for this case study: 1) “The three men sit and talk about their lives.” and 2) “A group of adults are waiting for an event.”.

(a) Sentence 1
(b) Sentence 2
Figure 2: Attention probabilities of soft self-attention in ReSA. The tokens aligned in horizontal axis are heads, and the tokens aligned in vertical axis are dependents.

The head and dependent tokens selected by RSS modules are show in Figure 2 (a small square with color white denotes unselection and vice versa). It shows that more dependent tokens are selected than the head tokens, because all non-trivial dependents should be retained to adequately modify the corresponding heads, e.g., three, their in sentence 1 and group in sentence 2, whereas only the key heads should be kept to compose the trunk of a sentence. It also shows that most stop words (i.e., articles, conjunctions, prepositions, etc.) are selected as neither head tokens nor dependent tokens.

We also visualize the probability distributions of the soft self-attention module in Figure

2 (the depth of color blue). From the figure, we observe that 1) the semantically important words (e.g., noun and verb) usually receive great attention from all the other tokens, e.g., sit, talk, lives in sentence 1 and adults, waiting, event in sentence 2; and 2) the attention score increases if the token pair can be constituted to a sense-group, e.g., (sit, talk) in sentence 1 and (adults, waiting), (waiting, event) in sentence 2.

6 Related Work

Applying reinforcement learning (RL) to natural language processing (NLP) tasks recently attracts enormous interests for two main purposes, i.e., optimizing the model according to non-differentiable objectives and accelerating the model speed. lei2016rationalizing lei2016rationalizing propose a method to select a subset of a review passage for sentiment analysis from a specific aspect. he2016dual he2016dual use RL method to fine-tune a bilingual machine translation model by well-trained monolingual language models. yogatama2016learning yogatama2016learning use built-in transition-based parsing module to generate semantic constituency parsing tree for downstream NLP tasks by using RL. yu2017learning yu2017learning propose a RL-based skim reading method, which is implemented on recurrent models, to skim the insignificant time slots to achieve higher time efficiency. choi2017coarse choi2017coarse separately implement a hard attention or a soft attention on a question answering task to generate the document summary. shen2017reasonet shen2017reasonet use dynamic episode number determined by RL rather than fixed one to attend memory for efficient machine comprehension. hu2017reinforced hu2017reinforced employ policy gradient method to optimize the model for non-differentiable objectives of machine comprehension, i.e., F1 score of matching the prediction with the ground truth. li2017end li2017end propose a service dialog system to sell movie tickets, where the agent in RL is used to select which user’s information should be obtained in next round for minimum number of dialog rounds to sell the ticket. zhang2017sentence zhang2017sentence simplify a sentence with objectives of maximum simplicity, relevance and fluency, where all three objectives are all non-differentiable w.r.t the parameters of model.

7 Conclusions

This study presents a context fusion model, reinforced self-attention (ReSA), which naturally integrates a novel form of highly-parallelizable hard attention based on reinforced sequence sampling (RSS) and soft self-attention mechanism for the mutual benefit of overcoming the intrinsic weaknesses associated with hard and soft attention mechanisms. The hard attention modules could be used to trim a long sequence into a much shorter one and encode rich dependencies information for a soft self-attention mechanism to process. Conversely, the soft self-attention mechanism could be used to provide a stable environment and strong reward signals, which improves the feasibility of training the hard attention modules. Based solely on ReSA and a source2token self-attention mechanism, we then propose an RNN/CNN-free attention model, reinforced self-attention network (ReSAN), for sentence encoding. Experiments on two NLP tasks – natural language inference and semantic relatedness – demonstrate that ReSAN deliver a new best test accuracy for the SNLI dataset among all sentence-encoding models and state-of-the-art performance on the SICK dataset. Further, these results are achieved with equal or fewer parameters and in less time.


This research was funded by the Australian Government through the Australian Research Council (ARC) under grants 1) LP160100630 partnership with Australia Government Department of Health and 2) LP150100671 partnership with Australia Research Alliance for Children and Youth (ARACY) and Global Business College Australia (GBCA). We also acknowledge the support of NVIDIA Corporation and MakeMagic Australia with the donation of GPU.


Appendix A Comparison to a Iterative Sampling

To verify the RSS that uses parallel discrete sampling is sufficient to trim the long sentence and model the dependencies, we implement the iteration-based sequence sampling method following lei2016rationalizing lei2016rationalizing and integrate it with the soft self-attention in the same way as ReSA.

Given a input sequence, , iterative sampling aims to learn the following product distribution.


A RNN is used to parameterize the conditional probability function above and the basic RNN rather than LSTM or GRU is employed to reduce the number of parameters. The latent state of the RNN can be referred to as the embedding of both contextual information and history selection results. The recurrence can be formally written as


where denotes the discrete sampling operation and is the learnable parameters of RNN. Consequently, after this recurrence over the input sequence, a sequence of sampling result, , is obtained, which shares the same notion with RSS.

We then apply two iterative sampling modules which make selections over the dependent and head tokens, respectively. The output of these two sampling modules is formated as a mask which is then applied to the compatibility function of soft self-attention mechanism. The details of the integration are described in the main paper.

For the comparison of RSS and iterative sampling, we also implement the ReSAN with iterative sampling on SNLI dataset that is one of the largest NLP dataset designed to test the sentence-encoding model. A thorough comparison of them in terms of parameters number, training/inference time, training/test accuracy are show in Table 4

ReSAN w/ RSS ReSAN w/ Iteration
Parameter Num (300D) 3.1m 4.0m
Time/Epoch 622s 2996s
Inference Time 5.5s 17.1s
Train Accuracy 92.6% 92.3%
Test Accuracy 86.3% 86.2%
Table 4: A thorough comparison of a ReSAN with RSS and Iterative Sampling on SNLI dataset. The accuracies of these two models should be experimentally equal, but, due to the randomness of neural networks (e.g., initialization, batch SGD), there are some experimental error on the accuracies.

As shown in the table, compared with ReSAN with iterative sampling, the one with RSS requires much fewer parameters, less training time and less inference time to achieve the competitive test accuracy. This is consistent with the motivation and target for which we develop the RSS.