TLDR: Token Loss Dynamic Reweighting for Reducing Repetitive Utterance Generation

03/26/2020 ∙ by Shaojie Jiang, et al. ∙ Hugging Face, Inc. University of Amsterdam 0

Natural Language Generation (NLG) models are prone to generating repetitive utterances. In this work, we study the repetition problem for encoder-decoder models, using both recurrent neural network (RNN) and transformer architectures. To this end, we consider the chit-chat task, where the problem is more prominent than in other tasks that need encoder-decoder architectures. We first study the influence of model architectures. By using pre-attention and highway connections for RNNs, we manage to achieve lower repetition rates. However, this method does not generalize to other models such as transformers. We hypothesize that the deeper reason is that in the training corpora, there are hard tokens that are more difficult for a generative model to learn than others and, once learning has finished, hard tokens are still under-learned, so that repetitive generations are more likely to happen. Based on this hypothesis, we propose token loss dynamic reweighting (TLDR) that applies differentiable weights to individual token losses. By using higher weights for hard tokens and lower weights for easy tokens, NLG models are able to learn individual tokens at different paces. Experiments on chit-chat benchmark datasets show that TLDR is more effective in repetition reduction for both RNN and transformer architectures than baselines using different weighting functions.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

It has been widely reported that Natural Language Generation (NLG) models are prone to generating repetitive utterances, which is a cross-task and cross-architecture problem. E.g., the repetition problem has been observed in question answering (QA(fan2019eli5), language modeling (LM(holtzman2019curious; keskar2019ctrl; welleck2019neural), abstractive summarization (suzuki2017cutting; nallapati2016abstractive), machine translation (MT(tu2016modeling; mi2016coverage; tu2017context)

, and image captioning 

(cornia2018paying), etc. The problem arises with most popular neural architectures, including sequence-to-sequence (Seq2Seq) modeling using recurrent neural networks (RNNs) (sutskever2014sequence) and transformers (vaswani2017attention). Previous approaches to the repetition problem have mostly been developed for decoder-only models (holtzman2019curious; keskar2019ctrl; welleck2019neural). Whenever it is possible to use them, encoder-decoder models usually outperform decoder-only models (raffel2019exploring); because of this, we study the repetition problem for encoder-decoder models in this work. We base our experiments on the chit-chat task,111Throughout this paper, we use ‘chit-chat,’ ‘open-domain dialogue,’ and ‘response generation’ to refer to the same task interchangeably, and use ‘chatbot’ to refer to the model trained for this task. where the repetition problem is more prominent than other tasks that need encoder-decoder architectures. Figure 1 is an illustration of a chit-chat scenario.

User Message: Hi, Jim. How are you? I haven’t seen you for a while.
System Response: i’ve been out of town. i’ve been out of town. i’ve been out of town and got some old ideas on i’ve been out of town.
Figure 1: An illustration of the repetitive response problem.

The repetition problem on which we focus in this paper is not to be confused with the problem where responses of different turns share a great deal of similarity, which is sometimes also referred to as repetitive responses (li2016deep).

We first surmise that the repetition problem is caused by model architectures. By studying existing approaches to repetition reduction involving changing model architectures, we identify that a certain model design for RNNs can help reduce repetition. More specifically, by using pre-attention (§3) together with RNNs instead of the conventional usage of post-attention, the model generates less repetitive utterances. The usage of highway connection (srivastava2015training) after pre-attention can further reduce repetition. Though it is effective for RNNs, this architectural design unfortunately does not generalize to transformers. Therefore, we believe that there are deeper reasons to the repetition problem.

Inspired by the well-studied class imbalance problem (lin2017focal), we hypothesize that in training corpora there are hard tokens and easy tokens; if the training stops when hard tokens are still under-learned, repetitive generations are more likely to happen. We empirically validate this hypothesis by using the Focal Loss (FL) method proposed by lin2017focal. Based on this finding, we propose token loss dynamic reweighting (TLDR), which is inspired by FL4) and which assigns differentiable weights to training losses according to the training difficulty of tokens. By dynamically using higher weights for hard tokens and lower weights for easy tokens, the model is able to learn tokens at different paces. Experiments on benchmark chit-chat datasets show that TLDR is more effective in reducing repetitions for RNN and transformer architectures than the baselines.

The main contributions of this work include:

  • [leftmargin=*,nosep]

  • We discover that RNNs with pre-attention and highway connections can help reduce repetition.

  • We hypothesize that the deeper reason for repetitions is due to hard tokens, and empirically validate this by using FL.

  • We propose a more effective token loss dynamic reweighting (TLDR) method for reducing repetitive generations that is architecture insensitive.

  • We share our source code at https://github.com/ ShaojieJiang/tldr.

2 Related Work

2.1 Repetition reduction

There are generally three types of approach to the repetition problem, namely architecture-based methods, auxiliary loss, and sampling-based methods.

Architecture-based methods

try to address the repetition problem by exerting inductive biases on model architectures, such as introducing coverage vectors to attention mechanisms

(tu2016modeling; mi2016coverage). However, these solutions are based on the desirable property of source-target correspondence of neural machine translation (NMT) tasks, which rarely holds for other tasks like open-domain dialogue generation and abstractive summarization. Instead, we observe that by using pre-attention (§3), RNN models can generate less repetitive responses. Adding a highway connection to attention layers can further reduce repetition. However, it is not yet clear which modifications to transformers can help reduce repetitive generations, which we surmise to be mainly due to their lack of temporal inner states.

Auxiliary loss adds an unlikelihood item of repeated tokens to the conventional cross-entropy (CE) loss (welleck2019neural; li2019don) to discourage repetitive generations. However, besides the need for extra data types (li2019don) or extra processing on model generations (welleck2019neural), it is also worth noting that the repetition problem seldom exists in the training data (§6), and therefore, the models are not trained to be repetitive in the first place. Our experiments show that the repetition problem is mainly due to the fact that trained models have different levels of proficiency toward easy and hard tokens. Based on this, we propose to weight the token losses w.r.t. their training difficulties to address the repetition problem. Though, normally, hard tokens already result in larger gradients, apparently this is not enough to learn them well as their effects may be dissolved by large amount of easy tokens.

Sampling based approaches

discourage repetition at the sampling stage, e.g., by creating an n-gram blacklist

(paulus2017deep). Some more sophisticated methods include discounting the scores of previously generated tokens (keskar2019ctrl), top- random sampling (fan2018hierarchical), and top- sampling with a dynamic nucleus (holtzman2019curious). These methods can only be used at inference time to remedy the repetition problem. In this work, we tackle the problem at a deeper level and show that by dynamically reweighting token losses at training time, the problem can be largely alleviated.

2.2 Class imbalance

Class imbalance is a common cause of poor performance of trained models. To mitigate the low efficiency of heuristic sampling approaches to balancing classes,

lin2017focal try to balance the CE loss by weighting easy and hard examples differently for visual object detection. Similarly, FACE (jiang2019improving) tries to balance the CE loss for frequent and rare tokens through heuristics such as token frequency for improving generation diversity. In our case, however, the frequency of a token does not necessarily reflect its training difficulty, as the difficulty is also influenced by the context of this token.

We hypothesize that the repetition problem is also caused by different training difficulties of tokens. This hypothesis is in line with (fadaee2018back), where the authors improve the learning of hard tokens by using back-translation and by sampling examples with difficult words in similar contexts. Considering that at the token level in the chit-chat task, balancing easy and hard tokens through sampling is infeasible, loss reweighting such as Focal Loss (lin2017focal) is a preferable solution. Focal Loss down-weights222Models trained without weighting can be thought of as using a uniform weight of 1. both easy and hard examples, albeit at different ratios, which is likely to slow down training. Instead, we propose a different weighting function that simultaneously up-weights hard tokens and down-weights easy tokens.

3 RNN with Pre-Attention and Highway Connection

In this section, we introduce some architectural modifications that alleviate the repetition problem.

3.1 Notation and response generation

We write , , , …, for message and response utterances, respectively. Given , the objective of a chatbot model parameterized by

is to assign a higher conditional probability

to the ground truth response than to other responses, which usually follows a sequential decomposition:

(1)

Here, , and represents the model-predicted probability for token .

For the sake of clarity, we further denote . Training the above chatbot model is usually achieved by minimizing the CE loss:

(2)

For estimating

using an RNN model, we largely follow the notation in (bahdanau2015neural) by denoting the decoder output and attention output at decoding step as and , respectively. Readers are referred to (bahdanau2015neural) for more details.

3.2 Pre-attention

Inspired by Input-feeding (IF) attention (luong2015effective), we find that pre-attention performs better w.r.t. repetition rates. Figure 2 is an illustration of the difference between pre-attention and IF attention.

Figure 2: Pre-attention vs IF attention. denotes vector concatenation.

Formally, the input of RNN using IF attention at each decoding step , i.e., , is modified:

(3)

with denoting vector concatenation. Besides, at step the decoder output is also transformed with attention output :

(4)

The function used in IF attention is a vector concatenation followed by a linear projection, to resize to the dimension of .

We find that only transforming the input is more helpful for reducing repetition, which leads us to pre-attention:

(5)

with being a transformation function, and has the same dimensionality as . The function can take the form of , but in this paper, we observe that by using a highway connection for , the repetition rate can be further reduced. The highway connection is formulated as follows:

(6)

where represents element-wise product; is a learnable gating vector, through which the model learns how to effectively combine attentional information and input information . Readers are referred to (srivastava2015training) for details of the highway connection.

3.3 Discussion

The attention mechanism does not have internal states to keep track of attentions previously paid, therefore when using post-attention, similar attentions may be paid repeatedly, increasing the chance of repetition. On the other hand, when using pre-attention as part of the input to RNNs, it enables RNN states to keep track of previous attention status. Furthermore, by using highway connections, the model learns to more effectively combine context and step-wise input than simple concatenation and linear projection.

However, noticing that the transformer model (vaswani2017attention) already has an analogous architecture as our pre-attention, we tried to apply highway connections to the transformer attention outputs, but have not yet been successful with reducing repetitive generations in this manner. We suspect that this is due to the lack of temporal inner states in transformers, and more work needs to be done to mitigate this drawback in order to reduce repetition. Meanwhile, we believe there are more fundamental reasons for the repetition problem.

4 Token Loss Dynamic Reweighting

Inspired by lin2017focal, we hypothesize that the difficulty of some training examples is the main reason for repetition. Based on this hypothesis, example reweighting methods such as FL (lin2017focal) can be used to alleviate the repetition problem. In this section, we first adapt FL to our chit-chat task, and later we present a more effective weighting method than FL.

4.1 Example-level loss reweighting

Using the notation introduced in §3.1 and writing for , a direct application of FL to the response generation task can be formulated as follows:

(7)

Here, is the weighting coefficient, which is scaled by an exponential focusing parameter , with reflecting the training difficulty of an example, that is, a pair . denotes the loss of predicting that is usually calculated as an average over the CE losses of all tokens that make up :

(8)

Since , in practice is usually extremely small according to the sequential decomposition Eq. (1), turning almost all training examples into hard examples. Therefore, we propose to approximate the training difficulty of w.r.t. the average of as follows:

(9)

which is more suitable for our task.

However, there is another unwanted feature of FL. Although in Eq. (7) is higher for hard examples than for easy ones, to make the model focus on examples that are hard to train, which is essentially what the method is named after, this method is likely to slow down the training as is always smaller than 1. To address this, we propose to use a different weighting function:

(10)

which projects all probabilities to the weight domain of . With indicating a hard example, which results in , we up-weight the loss of hard examples, and otherwise we down-weight easy examples with . To distinguish our method (using (10)) from the adapted FL, we refer to it method as loss dynamic reweighting (LDR):

(11)

4.2 Token-level loss reweighting

Noticing from Eq. (2) that in our task, CE losses are calculated for each token, we can alternatively infer the difficulty at the token-level and apply FL or loss dynamic reweighting (LDR) methods thereafter. To achieve this, we use instead of for Eq. (7), resulting in token-level Focal Loss (TFL):

(12)

where . Using instead of and for Eq. (11) results in token loss dynamic reweighting (TLDR):

(13)

Although treating each token as an example is somewhat counterintuitive, it has been shown in (jiang2019improving) that such a token-level formulation can have a more direct impact on the generation of each token. We confirm in §6 that TLDR results in less repetition than LDR.

4.3 Gradient analysis

We now compare the gradients of our loss function TLDR to those of CE and TFL w.r.t.

. The gradient function of CE is:

(14)

The gradient function of TFL is:

(15)

Finally, the gradient function of TLDR is:

(16)
Figure 3: Gradient curves of each loss function. For TFL, we use as recommended by (lin2017focal).

To see the differences between the three loss functions more clearly, we visualize their curves in Figure 3. Compared to CE and TFL, the gradients for easy tokens derived by TLDR are suppressed and gradients for hard tokens are amplified.

5 Experimental Setup

For our experiments, we use 3 multi-turn chit-chat datasets: Dailydialog (li2017dailydialog), Cornell Movie Dialogs (danescu2011chameleons) and Ubuntu Dialogue (lowe2015ubuntu), in increasing order of sizes. As a quantitative measure of system generated responses, we use BLEU-4 scores (papineni2002bleu). For repetition measurement, unlike existing work utilizing repeated n-gram counts (mi2016coverage) or n-gram diversity (li2016diversity; welleck2019neural) that reflect only one order of n-gram repetition (usually 4-gram), we propose to use a DIversity Metric based on N-grams (DIMEN), which is inspired by BLEU by weighted-averaging the n-gram diversities at multiple granularities, thus being more comprehensive. To introduce DIMEN, we first introduce the definition of n-gram diversity proposed by li2016diversity:

(17)

Here, means getting the n-gram list of for ; maps a list to a set, and denotes the cardinality of a set/list, respectively. For with less than unigrams, we clip the denominator to be 1 such that . With n-gram diversity defined, the DIMEN score of a given is calculated as

(18)

where is the highest order of n-grams and is the weight coefficient corresponding to with .

The argument can be either a response utterance, in which case we refer to the score as u-DIMEN, or a list of response utterances for a whole validation set, where we refer to as l-DIMEN. u-DIMEN and l-DIMEN scores are compensating. With , a high u-DIMEN score represents a non-repetitive utterance, and a high l-DIMEN score represents a diverse list of utterances.

To further measure the repetition performance on a validation set, we can consider averaging the u-DIMEN scores of each utterance. However, since the majority of system generated utterances are non-repetitive, the averaged u-DIMEN score can be very close to 1 and thus indistinguishable from model to model. Instead, we first bin the u-DIMEN scores into buckets to create a histogram vector , and then calculate the weighted L2-norm of the count of each bucket to emphasize the overall repetition performance of a model. Formally,

(19)

where is a weighting vector. When we assign higher weights to bins with lower u-DIMEN scores, the WL2 score emphasizes highly repetitive utterances.

With all the datasets and metrics introduced, we seek to answer the following research questions through our experiments:

  1. [label=(RQ0),leftmargin=*,nosep]

  2. Are pre-attention and highway connection helpful for reducing repetition of RNN Seq2Seq?

  3. Is our hypothesis on hard tokens correct?

  4. How effectively can LDR and TLDR reduce repetition for both RNN and transformer models?

  5. Is the improvement of LDR and TLDR merely due to up-weighting?

To answer RQ1, we use an RNN Seq2Seq with post-attention as the baseline. To answer RQ2, we obtain results of using TFL on both RNN and transformer models. For answering RQ3, we compare the results of LDR and TLDR to those of baseline models trained without weighting and with TFL weighting. For RQ4, we also add a baseline using uniform weights, with for all training examples.

Below we introduce the model parameters. Unless stated otherwise, we use the same settings for different models on each dataset whenever possible. For RNN Seq2Seq, we use 2-layer LSTMs (hochreiter1997long) with a hidden size of 512 and separate embedding matrices with an embedding size of 200 for both the encoder and the decoder. We use the ‘general’ attention variant from (luong2015effective). For transformer Seq2Seq, we use 6-layer transformer blocks with 8 attention heads (vaswani2017attention)

and an embedding size of 256 with separate matrices for both the encoder and the decoder. We use 800 as the feed-forward layer size with ‘relu’ activation. We train both RNN and transformer Seq2Seq models using Adam optimizer

(kingma2014adam) with fixed learning rate of 0.001 and . The gradients of RNNs are clipped with L2-norm (pascanu2013difficulty), no larger than 5 during training, while for transformers we clip to 1. For all the intermediate outputs of both models, we use dropout rate of 0.1 (srivastava2014dropout).

We train the models on the Dailydialog and Movie Dialogs datasets for 100 epochs and check the performances on the validation sets every 0.5 epochs. On the Ubuntu Dialogue dataset, we train the models for 30 epochs with a validation interval of 0.1 epochs. We save checkpoints after every validation and select the best checkpoint according to the lowest repetition rate. Since the repetition rate can sometimes be very low, this checkpoint-saving strategy may fail occasionally, in which case we use the last checkpoint during training. For all three datasets, we tokenize the utterances using the NLTK

(loper2002nltk) tokenizer and keep the most frequent 30,000 tokens according to their training sets, respectively. We use up to 3 turns of history as the input message, and truncate the message to 128 tokens and response to 32 tokens. A batch size of 256 is used.

For FL, we use a focusing factor of as recommended in (lin2017focal). For the hyper-parameters introduced in this work, we use for calculating the DIMEN scores, with . We group the u-DIMEN scores into bins, with the weighting vector .

6 Results

Dailydialog Movie Dialogs Ubuntu Dialogue
WL2 BLEU l-DIMEN WL2 BLEU l-DIMEN WL2 BLEU l-DIMEN
(a) RNN 8.04 19.63 0.35 36.71 0.35 0.25 208.7 0.08 0.10
(b) RNN w/ pre-attn 5.70 18.17 0.36 28.15 0.41 0.27 173.0 0.09 0.08
(c) (b) w/ highway 5.50 19.60 0.36 29.00 0.38 0.25 157.1 0.07 0.08
(d) (b) w/ LDR 5.20 11.73 0.36 26.90 0.10 0.22 230.4 0.06 0.06
(e) (b) w/ TFL 4.67 11.48 0.33 19.11 0.38 0.26 153.3 0.06 0.08
(f) (b) w/ uniform 5.32 18.47 0.37 31.96 0.43 0.27 132.0 0.08 0.08
(g) (b) w/ TLDR 2.71 17.61 0.36 22.45 0.39 0.26 101.4 0.05 0.08
(h) transformer 7.72 20.38 0.35 42.50 0.57 0.30 79.84 0.05 0.11
(i) (h) w/ LDR 7.92 19.19 0.36 42.40 0.28 0.29 255.0 0.05 0.12
(j) (h) w/ TFL 7.33 20.69 0.35 38.52 0.51 0.29 103.2 0.05 0.15
(k) (h) w/ uniform 6.42 20.74 0.36 39.13 0.56 0.29 150.6 0.07 0.15
(l) (h) w/ TLDR 6.67 20.13 0.35 28.61 0.53 0.29 75.11 0.05 0.13
Human 4.78 0.43 40.24 0.42 32.01 0.42
Table 1: Results for all our methods and the baselines. ‘w/ pre-attn’ and ‘w/ uniform’ are short for ‘with pre-attention’ and ‘with weight of 2 uniformly’. Best scores in each column within each row-wise block are highlighted in bold face. / indicate higher/lower is better. All results are on the test set.

The results of all models on the three datasets are shown in Table 1. Since the validation repetition (WL2 scores) of models can be very low in the early stage of training when the l-DIMEN score is still very low, we mainly use l-DIMEN to make sure that the results are shown for checkpoints at similar stages.

Table 1 is divided into three main blocks row-wise. The first block ranging from rows (a)–(c) is for answering RQ1. The second (rows (d)–(g)) and third (rows (h)–(l)) blocks are for answering RQ2 and RQ3, with RNNs and transformer models, respectively. For reference, we also include the Human performance calculated using the ground truth of each dataset in the bottom row of Table 1.

Answer to RQ1: Row (a) of Table 1 are the results for RNN models with post-attention. From row (b), we can see that pre-attention is effective in reducing repetition on all three datasets. On top of pre-attention, row (c) shows that highway connections are helpful with repetition reduction on two datasets, namely Dailydialog and Movie Dialogs. BLEU scores show that the quality of the utterances generated by our models (b) and (c) are on a par with those generated by the baseline (a). These results indicate that proper architectural modifications can indeed reduce repetition, and more work needs to be done to find such modifications for transformers.

Answer to RQ2: From rows (e) and (j), we can see that in most cases TFL helps to reduce repetition, compared to rows (b) and (h), respectively. This empirically shows that our hypothesis on hard tokens is correct. However, due to monotonically down-weighting all training examples, the improvement brought by TFL is limited.

Answer to RQ3: Next, we proceed to our proposed weighting methods. Giving that the LDR and TLDR methods are about model learning, thus should be architecture insensitive, we use the RNN variant (b) for the following experiments, as (c) needs more computation due to the utilization of a gating (3.2). From the second and third blocks of Table 1, we can see that both RNN (row (g)) and transformer (row (l)) models trained using TLDR are consistently better than their corresponding baselines without weighting or with TFL weighting. Though TLDR occasionally performs worse than TFL, e.g., row (g) vs row (e) on the Movie Dialogs dataset, we note that this is probably due to sample bias of the test sets. We plot the result curves for the Movie Dialogs validation set at each validation point in Figure 4, where we can see that TLDR is no worse, if not consistently better, than the baselines. These comparisons indicate that dynamically weighting tokens according to their difficulty is an effective solution to reducing repetition.

The results in rows (d) and (i) show that, on the smaller datasets Dailydialog and Movie Dialogs, LDR has a limited effect on reducing repetition. However, on the large Ubuntu Dialogue dataset, LDR hurts the repetition performance. Besides, LDR consistently results in lower BLEU scores, which implies that more work needs to be done to estimate the training difficulty at the example level.

Figure 4: Repetition curves for Movie Dialogs validation set during training. TLDR can reduce the repetition rate earlier than baseline. Better viewed in color.

Answer to RQ4: To understand whether the improvement of TLDR is brought by up-weighting, let us compare the performance of models trained using uniform weight at rows (f) and (k) to rows (g) and (l), respectively. We can see that w.r.t. repetition, TLDR wins the majority of the times, which supports our claim that TLDR has more effect than simply up-weighting, as can also be seen from the gradient visualization in Figure 3. Multiplying a constant factor ( in our experiments) to the loss function also multiplies the gradients with the same factor, while TLDR changes the gradient function in a more complex way so that the gradients resulting from hard examples are amplified while those resulting from easy examples are suppressed, which is important for reducing repetition. It is also worth noting that using a uniform weight of 2 is effectively doubling the base learning rate, which seems to be helpful with improving BLEU and l-DIMEN.

7 Conclusion and Discussion

We have studied the repetition problem of encoder-decoder architectures, using both RNN and transformer models. We have discovered that by using pre-attention instead of post-attention for RNNs, the repetition problem can be alleviated. Together with highway connections, repetitive generations can be further reduced.

We have hypothesized that the repetition problem is caused by hard tokens and find empirical support for this claim using FL. We then propose a more effective weighting function than FL, namely TLDR. With a differentiable cosine weighting function, TLDR amplifies the gradients resulting from hard tokens while it suppresses those from easy tokens. Through experiments we show that TLDR outperforms strong baselines like TFL and uniform weighting.

Our hard token hypothesis and the TLDR weighting function both operate mainly on target side tokens, while hard source side tokens can also have a detrimental effect on the decoder generations. Future work can be done by applying TLDR to source side language representation learning, possibly by training the encoder independently on a language representation task. It might also be worth exploring the effect on decoder-only models on tasks like LM.

We also observe that in our experiments, before the repetition rate converges during training, the repetition rate on the Ubuntu Dialogues validation set fluctuates with a regular pattern, which suggests that there can be certain training examples that can harm the repetition rate. If this is true, then the work by sharchilev2018finding can be used to target such harmful examples and hence reduce repetition. We leave this for future work.

Acknowledgments

This research was supported by the China Scholarship Council, Ahold Delhaize, the Association of Universities in the Netherlands (VSNU), and the Innovation Center for Artificial Intelligence (ICAI). All content represents the opinion of the authors, which is not necessarily shared or endorsed by their respective employers and/or sponsors.

References