The problem of language modelling (LM) usually consists of two main challenges. Firstly, mapping the context, i.e. the sentence prefixes, to a vector representation, and secondly using this representation to predict the subsequent word. In(Khandelwal et al., 2020), the authors claim that the first problem is much easier to solve. Hence, given a pre-trained LM, they post-hoc modify the representation using a k-nearest neighbor scheme (NN) and achieve significant improvements on challenging datasets, such as WIKI-103.
Given that NN improves the overall language modelling of a pre-trained network, we examine training strategies that can make the underlying network’s representations more amenable to the NN step. Our results show improvements over applying NN to a generic LM network.
We first explore a simple learning scheme for the language model, where during training we intentionally push representations that predict the same word closer to together in the L2 sense, using a Momentum Contrastive (MOCO) He et al. (2020) style implementation. We go on to note that this MOCO style learning can be replaced by simply adding L2 regularization to the activation of the layer used for NN, eliminating implementation complexity. We then present some initial experiments toward understanding why this L2 regularization brings improved performance.
Our work builds upon NN-LM (Khandelwal et al., 2020). In essence, NN-LM tackles the problem of how to improve a trained LM’s representations, and how to adapt LMs to capture non-frequent sentences that are usually forgotten by the model during training.
NN-LMs achieve significantly higher performance through a simple interpolation between the original LM predictions and theNN predictions.
At inference time, given a new context sentence, NN-LM works as follows:
The context sentence is passed through the pre-trained network to produce a representation
as well as the corresponding logitsto predict the next word.
is used to find the -nearest neighbors in the training data. The logits are computed by a weighted average of the neighbors’ labels, using the inverse exponential distance as the weight for each neighbor.
The logits are interpolated to give the final prediction:
where is the interpolation parameter that can be tuned on validation data.
This simple post-hoc implementation allows (Khandelwal et al., 2020) to improve upon the SOTA in LM by a significant margin. One thing to note about NN-LM is that they do not need to retrain the LM and hence the whole algorithm can be run on CPU only. Furthermore, NN-LM use FAISS, which is an efficient library that allows them to quickly find NNs.
One interesting detail to note in (Khandelwal et al., 2020) which was crucial for this work, was that the authors tried both the inner product and the L2 for their distance metric in NN. They concluded that L2 worked significantly better. This observation implies the fact that the default training recipe of LMs implicitly prefers one distance over another. We then ask the question whether one could train in such a way that we could adapt the model to the post-hoc NN with a given distance metric. In the next section, we describe how to adapt the training of the LM, with NN and the corresponding L2 distance metric in mind.
3 Proposed Method
In our initial attempt, we experimented with the idea of explicitly minimizing the L2 distance between context vectors which predict the same target word. This strategy directly mirrors the use of context vectors at the NN step, and we hope that training the representations in a way similar to testing will further improve the effectiveness of NN LM. However, a naive implementation of it is infeasible. We then resorted to a MOCO (He et al., 2020) style training scheme. Specifically, for each target word , we construct a queue of fixed length , which stores the recent context representations for . During training, we optimize a regularized objective as follows:
where is the batch size, is the context representation of the th word, is the queue corresponding to the th target word ; is the regularization parameter; is the stop gradient operator. Specifically, is updated with a momentum target encoding network which is initialized with the same parameters of the LM, similar to MOCO (He et al., 2020).
Empirically, we found that Equation 1 provides a practical solution and yields improved representations for the nn LM, as shown in Fig 1. However, the use of the queue and momentum target network still adds overhead to a large scale model training. Hence we tried to decrease and , which interestingly did not decrease the performance at all and therefore, to promote efficiency, we tested an even simpler formulation, where we replace with all zero vectors. This eliminates the need of explicitly constructing and updating the queue, while instead encourages the model to learning conservative representations w.r.t. the L2 norms of its context representations. The corresponding loss is as follows:
To our surprise, Equation 2 yields similar performance to Equation 1 in practice, while being much easier to implement and tune. This is a new finding which we will try to explain in the ablation study. We thus use Equation 2
as the default loss function in our experiments unless otherwise mentioned.
We tested our method on the WIKI-2 and WIKI-103 datasets, which are widely used benchmarks for language modelling. We are interested in demonstrating two empirical results: improved performance using our approach over that of NN-LM, and exploring a possible mechanism for this improved performance.
|Train Ppl. LM||19.99||20.05||20.11||21.37|
|Valid Ppl. LM||75.96||75.68||76.37||81,29|
|Valid Ppl. NN-LM||74.11||73.13||70.63||80.52|
4.1 Experimental setup
WIKI-2 is a benchmark with 30k word vocabulary and consists of 2M tokens. WIKI-103 is a benchmark with 250k word vocabulary and consisting of 103M tokens.
Language Model Architecture For the language model architecture, we will be using the exact setup as described in (Khandelwal et al., 2020). This setup consists of the language model by (Baevski and Auli, 2018), which consists of 16 layers, each with 16 self-attention heads, 1024 dimensional hidden states, and 4096 dimensional feedforward layers. Thus, following (Baevski and Auli, 2018), this LM has adaptive inputs and an adaptive softmax (Joulin et al., 2017) with tied weights (Press and Wolf, 2016) for all our experiments. We trained the each language model on a Tesla V100 with 400GB if RAM.
In addition, we follow the exact same training procedure as in Khandelwal et al. (2020) and refer to their paper for further details on the training parameters. The only difference in terms of implementation is the MOCO style learner as well as the L2 regularization added to the final layer. Lastly, we would like to note that while crossvalidating though the interpolation parameter we note that for all models, works the best which is in accordance to the finding in Khandelwal et al. (2020).
4.2 Experiments on Wiki-2
We first apply our proposed method on the standard WIKI-2
dataset, where we run each configuration 5 times and plot the standard deviation, as seen in Fig.1. Note that in Fig. 1 corresponds to the standard NN-LM version, i.e. without the added term in the loss. Comparing Figure 1 and Table 1 to see that the MOCO and L2 approaches produces similar results. From these results, we note the following phenomena:
A clear "U"-shape demonstrating the added benefit of our loss term on the validation perplexity of the LM for moderate values of .
Training performance does not decrease for moderate values of , showing that the extra term does not destroy training and generalization of the standard LM.
There is no difference in terms of validation perplexity between the standard LM and our version before applying NN, but there is a significant difference after applying NN. Our approach likely finds a different local minimum for the language model that is better suited for NN.
|Train Ppl. LM||11.31||11.24||11.07|
|Valid Ppl. LM||18.00||17.95||17.71|
|Valid Ppl. NN-LM||16.09||15.89||17.46|
4.3 Experiments on Wiki-103
We illustrate our findings on the more challenging WIKI-103 dataset and demonstrate that our L2 fix significantly improves the performance of the LM. In the Table 2
, we illustrate that when changing the regularization strength we again see a significant gain in performance when adding our regularization during training of the LM. Due to the computational costs when training these models, we resort to the same hyperparameters as in theWIKI-2 dataset and hence present fair comparisons of the different variants of the model.
Note that again, we see significant improvements in terms of validation perplexity when using the NN-LM scheme by simply adding a L2 regularization when training the language model.
On another note, when taking a closer look at the validation perplexity before applying NN, we note that
seems to have lowest validation perplexity. This better generalization phenomenon is interesting and has recently been noted in the machine vision community in the context of investigating the regularization effects of batch normalization in classification settings(Dauphin and Cubuk, 2021). This also relates to the findings in (Merity et al., 2017), who used L2 regularization in LSTMs. In this paper, we have found initial indications that the L2 regularization on the activations might also be useful for Transformer models.
Finally, we believe that these two standard benchmark datasets in language modelling are sufficient evidence to demonstrate the merit our of findings.
4.4 Further investigations into the representations and possible explanations
We first examine the effect that target word frequency has on the loss. Figure 2 shows a histogram of word frequency, where the color represents the respective losses each word incurred.
Note that firstly, there is little difference in loss for the less frequent words on the right side of the graph. If we shift our attention to the more frequent words however, we see a different picture. Looking at our L2 regularized model, we note that for the most frequent words, our model seems to incur lower loss (see the brighter colors bars at the peak of the histograms) compared to the standard LM with NN. This observation suggests that the main differences in terms of representations come from the frequent words rather than rare ones.
Secondly, knowing that the main differences are within the words that are most frequent, we investigated these representations in more detail. In particular, we analyzed the most frequent words and divided the data into "high scores" meaning they contributed a lot to the loss (bad predictions) and "low scores" meaning they are good predictions i.e. they contributed a little to the loss.
We employed a simple mixture of Gaussians model (GMM) (
) and used the log-likelihood as an indicator for how well the data are clustered. GMMs allows us to put probability mass on each of the representations and given that we are using a mixture of gaussians, we inherently captures clusters. GMMs can be thought of as "putting multiple gaussian distributions on the data". Intuitively, this means that if the likelihood of the GMM is high, the representations can be easily captured using a mixture of gaussians, which is indicative of being more clustered i.e. close to one of the gaussian mixture means.
In Figure 3 we compare the distributions of the loglikelihoods for the representations that have been trained using the standard LM and our modified L2 regularization. In particular, for each representations we obtain the corresponding likelihood from the GMM (-axis on Figure 3). As mentioned before, we split the words into "high scores" and "low scores" and plot their histograms in blue and orange respectively. Fig. 3 demonstrates that when using a L2 regularization (right), the representations that contribute less to the loss (shown in orange) are better fitted by the GMM than when training the model in the standard way. This can be seen by the increase in likelihood, i.e. shift to the right on the -axis. What is interesting is that the difference between the likelihoods of the "high scores" and "low scores" varies much more dramatically in the L2 regularized case. Hence, one of our hypotheses is that NN-LM improves the classification accuracy mostly for the non-frequent words (Khandelwal et al., 2020), whereas our proposed method with L2 regularization, in addition, also improves the classification accuracy of the frequent words by clustering them closer together and hence allowing NN-LM to work better.
In conclusion, we propose a useful training mechanism that is inspired by the fact that the post-hoc application of NN seems to significantly improve the performance of standard LMs. We have found that training a LM with L2 regularization at the final layer, i.e. layer which is used for the post-hoc NN search, improves validation performance. We have also found initial indications that the L2 regularization mostly improves performance for the most frequent, lower-loss words.
These findings motivate exploring L2 regularization in different Transformer architectures and LM tasks. Furthermore, using training data at inference time is an interesting direction for improving model performance, and we hope that our approach encourages future work in this area.
- Adaptive input representations for neural language modeling. arXiv preprint arXiv:1809.10853. Cited by: §4.1.
- Deconstructing the regularization of batch-norm. In International Conference on Learning Representations (ICLR), Cited by: §4.3.
- Momentum contrast for unsupervised visual representation learning. In , pp. 9729–9738. Cited by: §1, §3, §3.
Efficient softmax approximation for gpus.
International Conference on Machine Learning, pp. 1302–1310. Cited by: §4.1.
- Generalization through Memorization: Nearest Neighbor Language Models. In International Conference on Learning Representations (ICLR), Cited by: Regularized Training of Nearest Neighbor Language Models, §1, §2, §2, §2, Figure 1, §4.1, §4.1, §4.4.
- Regularizing and optimizing lstm language models. arXiv preprint arXiv:1708.02182. Cited by: §4.3.
- Using the output embedding to improve language models. arXiv preprint arXiv:1608.05859. Cited by: §4.1.