Causal language modeling possesses great flexibility among most natural language tasks due to its unsupervised and generative nature. Large-scale pre-training on Transformer architecture like GPT2 has resulted in powerful models capable of capturing general knowledge of natural language. However, unlike bidirectional language models such as BERT, RoBERTa, causal language modeling can only look at the word history to predict the next word. While this is mathematically sound, such left-hand contextual information may potentially hinder the language model from capturing semantic knowledge at its fullest. On the other hand, while BERT provides satisfactory performance for sequence encoding thanks to its bi-directional nature, it is designed for masked language modeling which predicts the word identity at a masked position in a sentence. BERT is non-causal and thus is not suitable for sequence generation.
Recent studies have shown that retrieving prefix contextual information from an external data store can further improve performance of a causal language model without increasing the number of model parameters . However, the retrieved information are still uni-directional. In this paper, we propose a novel language model, SUffix REtrieval-Augmented LM
(SUREALM), that employs an embedding retriever for suffix retrieval from a data store. During sequence generation, the current word history, or referred as prefix in the rest of the paper, is submitted to an embedding retriever to search for training sentences that share similar prefixes. Then the corresponding suffixes of these training sentences are viewed as “future” context to guide sequence generation. The intuition is that sentences sharing a similar given prefix may probably have strong correlation on their suffixes. For example, “how may i” and “how can i” are similar prefixes. If the model also knows the complete reference sentence “how can i help you”, then the model would tend to predict “help you” given a novel prefix “how may i”. To exploit this assumption, we perform all possible splitting of each training sentence into a triple: a prefix, a word, and a suffix. We employ pre-trained sentence transformers to encode the prefix and suffix of each training sentence to create an embedding data store. Then an embedding retriever such as FAISS  is employed for prefix-suffix embedding retrieval given an encoded prefix. The retrieved prefix-suffix embeddings are augmented into the word embedding inputs during sequence generation, achieving the causal language modeling with a simulated bi-directional effect. SUREALM is causal because it only uses word history for predicting the next word. SUREALM is simulated bi-directional because it exploits “future” context from other similar sentences.
Our contributions are two-folded: First, we propose SUREALM, a new causal language model enhanced by prefix-suffix embedding retrieval to simulate a bi-directional effect for sequence generation. Second, we perform extensive experiments and show effectiveness of our model on the DSTC9 dialogue corpus.
2 Related Work
Improving language model using retrieval technique is not new. 2] employs information retrieval to perform language model adaptation for statistical machine translation. Once the language models are adapted, they are kept fixed during sequence generation.
Most recent development in language modeling is based on transformers . Masked language modeling that are BERT-based [1, 8] exploits bi-directional information of a sentence to predict the word identity of the masked tokens. While BERT is effective in encoding sequences, it is not suitable for sequence generation due to its non-causal nature. Causal language modeling such as GPT2  is uni-directional. Our proposed model attempts to retain the best of the two worlds as autoregressive and simulated bi-directional via augmentation of suffix embeddings during sequence generation.
One noticeable work for language modeling using embedding retrieval is nearest neighbor language model (KNN-LM)
. Their approach store dynamically changed information in an external knowledge base. During sequence generation, KNN-LM uses the current prefix to retrieve similar prefixes in the data store using embedding retrieval. Then the output word probability distribution is estimated by looking at the corresponding next words in the retrieved prefixes. Such word probability distribution is linearly interpolated with the output word distribution from the causal transformer LM. While their approach has shown effectiveness in reducing word perplexity, their approach is uni-directional in terms of utilization of information for word prediction. Our proposed model enjoys the simulated bi-directional effect of utilizing “future” contextual information to guide sequence generation.
Another work is retrieval-augmented generation for question and answering . Their approach employs an embedding retrieval over the encoded document embeddings. Then the top-K retrieved document embeddings are viewed as latent variables for answer generation. These latent variables are marginalized in the generator within a sequence-to-sequence generation framework. Related work of using retrieval technique for language modeling pre-training and question answering also includes . Our proposed model differs from their approach that we do not employ marginalization on the top-K retrieved results. In contrast, our model counts on the attention mechanism to attend to all previously retrieved suffix embeddings such that the cross-entropy loss is minimized.
3 Proposed approach
Our proposed approach extends causal language models with suffix retrieval. Denote a sentence . Then our model defines the negative log likelihood of as follows:
where denotes the word history (or prefix) of the word token . denotes a retrieval function parameterized by , to search for sentences that have similar prefixes in a data store. Then the suffixes of the retrieved sentences are augmented into the language model via suffix embedding. Although the true future context is unseen in causal language models, we hypothesize that such future context may be estimated by leveraging sentences that share similar prefixes. Thus, our model, SUffix REtrieval-Augmented LM (SUREALM), achieves a bi-directional modeling as in BERT and still be able to generate sentences in an autoregressive manner as in GPT. In summary, our proposed approach has three steps: (1) Data preprocessing and indexing; (2) SUREALM training; (3) SUREALM decoding. We describe the steps in Section 3.1– 3.3.
3.1 Data pre-processing and indexing
Given a training corpus containing a set of unique sentences , each sentence generates all possible partitions of into 3 parts: prefix, current word, and suffix, denoted as where , with valid word position . Motivated from masked language modeling, we exclude the current word so that each data entry for indexing into a data store is a prefix-suffix pair . This formulation enforces our model to use information from the prefix (left context) and the retrieved suffixes (“right context” from other training sentences). Thanks to the recent development in sentence embedding retrieval, we employ pre-trained sentence transformer 
to encode prefixes and suffixes such that retrieval is based on similarity of prefix embeddings. Then the data store returns the prefix and suffix embeddings. The rationale of encoding variable-length prefix and suffix into a fixed dimensional vector is to make the prefix and suffix representation smoother to mitigate word-level noises that are irrelevant to the current prefix under consideration. Intuitively,can be viewed as a key-value pair where is the embedding dimension. To preserve positional information in representation, we use absolute positions in the original sentence when computing the prefix and suffix embeddings. We employ FAISS  for embedding search.
The final number of prefix-suffix pairs to index is where is the number of unique training sentences and
is the maximum sequence length of a sentence. Essentially, our model requires to perform embedding retrieval at every word position. Therefore, we introduce a hyperparameterto control the frequency of embedding retrieval. For example, say if embedding retrieval occurs at time , then the next time to retrieve will be at time . This implies that during the time interval , all the previously retrieved suffix embeddings are reused to save computation. This allows us to explore the tradeoff between computation and accuracy.
Regarding the suffix representation, we consider applying suffix truncation assuming that word tokens in a suffix that are closer to prediction time may be more helpful. We introduce a hyper-parameter for suffix truncation so that a truncated suffix with is fed into a sentence transformer for encoding. When the number of tokens is large, we conjecture that a right may avoid an overly smoothed suffix embedding representation due to the pooling mechanism in a sentence transformer. Table 1 shows sample retrieval results using some prefixes as input queries.
|Input Query||Retrieved Word||Retrieved Suffix|
|’i also want’||’free’||’wifi’|
|’i’d like to’||’book’||’this hotel’|
|’is the hotel|
|equipped with an’||’elevator’||’for convenience?’|
|’i have several’||’options’||’for you. would you|
|like a particular area…’|
3.2 SUREALM training
3.2.1 Offline retrieval
One complication in SUREALM is the embedding retrieval required for each prefix at each word position . First, it would be computationally expensive to perform on-the-fly embedding retrieval during training. Since we freeze the sentence transformer for encoding, top-K similar prefix and suffix embeddings can be precomputed offline using FAISS to speed up training. is a hyper-parameter to determine the number of retrieved suffix embeddings to be included for SUREALM training. To avoid cheating , we exclude all the embedding results belonging to the same training sentence ID. Empirically, we found this step to be crucial so that SUREALM learns from additional suffix information provided by other similar training sentences.
3.2.2 Mini-batch training
Another challenge is to make SUREALM fit to mini-batch training where a batch of padded word sequences are fed into the model and different suffix embeddings should be applied at different time. To enable mini-batch training, we first append all offline-retrieved suffix embeddings with the input word embeddings. Then we construct a suitable attention mask so that each word position should only attend to the allowed suffix embeddings, and the previous word positions as in causal LM. Denote as a concatenation of all previously retrieved top-K prefix-suffix embedding pairs. Probability of generating a word sequence becomes:
where . SUREALM employs a transformer architecture which follows the Query-Key-Value inputs defined as follows:
Here is an attention mask which masks future positions in keys and values regarding the input word embeddings and the retrieved suffix embeddings. Finally, we obtain the output attention weights in the masked attention block as follows:
where denotes the embedding dimension of the keys and values. Figure 1 illustrates SUREALM architecture. Notice that SUREALM can be initialized with any pre-trained transformer model weights from the BERT or GPT model families. Then, we finetune the model using the training text to minimize the cross-entropy loss.
3.3 SUREALM Decoding
SUREALM decoding is similar to any decoding algorithm in sequence generation except that suffix embedding retrieval is performed when the prefix is updated during generation. We start from the start symbol as the initial prefix. Then suffix embedding retrieval takes place using the current prefix as a query. The top-K suffix embeddings are added into the set as extra inputs to the transformer in a progressive manner. The next word is generated from the output word distribution. The generated word is then appended to the prefix, giving a new prefix. The generation process is repeated until the end-of-sentence symbol is encountered. We follow the decoding algorithm implementation in the Huggingface transformers library  and augment an embedding retriever in our implementation.
In this section, we compare SUREALM with different configurations with baseline LMs for sequence generation. We report word perplexity for all experiments. We also present some details on the choice of hyper-parameters.
We used the dataset from the Dialogue System Technology Challenge 9 (DSTC9)
. The original dataset were designed for evaluating spoken dialogues that involves accessing an external knowledge base containing a list of question-answer pairs about each named entity, such as “Can I bring my pet to A and B Guest House?” and “No, pets are not allowed at this property.”. For language modeling purpose, we treated each dialogue turn as independent sentence and we only kept unique sentences in our training, validation and test sets. Then each sentence was assigned with an unique sentence ID so that they can be uniquely identified in embedding retriever. Our resulting training dataset contained 126,877 unique dialogue turns mentioning 145 named entities covering four domains: hotel, restaurant, taxi and train. Our validation dataset contained 18,037 unique dialogue turns. The test dataset had 18,021 unique dialogue turns covering 523 unseen named entities including a new domain on attraction. Due to the introduction of a new domain, we further split the test dataset into in-domain and out-of-domain portions and only evaluated on the in-domain portion. Since the test turns did not have the manual domain label, we applied named entity recognition on all dialogue turns and we used the detected named entity and applied its corresponding domain label to dialogue turns. The question-answering knowledge base was not added into our data store. The data store only contained the prefix-suffix embeddings from the training sentences.
We followed the data preprocessing procedure in Section 3.1 to generate the prefix-suffix pairs of the training dataset, pre-computing the prefix and suffix embeddings that were indexed and stored using FAISS . The prefix-suffix embeddings were computed using pre-trained sentence transformers . We also precomputed the prefix embeddings for validation and test set evaluation to speed up the retrieval process. However, their embeddings were not indexed in FAISS to avoid cheating.
4.2 Training details
In SUREALM, there are two modeling components to consider: (1) encoding model; (2) language model. For the encoding model, we used pre-trained sentence transformers  to encode prefixes and suffixes. We tried small-scale and standard-scale models with 6-layer 384-dimension and 12-layer 768-dimension respectively in our experiments. For the language model, we employed transformer-based language model with various weight initialization strategies. Inspired by , we explored different sentence transformer checkpoints to initialize the language model weights. On small-scale model training, we used a batch size of 128, AdamW optimizer with learning rate of 2e-5, and linear learning rate scheduler with 500 warmup steps. On standard-scale model training, we used a batch size of 64 and learning rate of 1e-5
and kept the same settings as in the small-scale model training. Since our dataset was relatively small, we trained SUREALM for a maximum of 200 epochs and chose the model with the minimum validation perplexity.
In our preliminary experiments, we chose the best configuration of the hyper-parameters based on the validation perplexity. Results showed that it was crucial to retrieve at each prediction time step, i.e. . Moreover, we chose for suffix truncation. Retrieving the top-K () prefix-suffix embeddings yielded the best perplexity results. We fixed these hyper-parameters for further experiments below.
4.3.1 Small-scale models
For baselines, we fine-tuned the 6-layer transformer-based masked LM (MiniLM) initialized with random and pre-trained weights. We used the pre-trained sentence transformers ( multi-qa-MiniLM- L6-cos-v1 and all-MiniLM-L6-v2 ) as our encoding models. We initialized our LM with random weights, pre-trained sentence transformer weights, and the masked LM weights. Results in Table 2 show that:
SUREALM achieved lower perplexity compared to baselines in all experiments. Our best model achieved relative test perplexity reduction of 7.2% compared to the baseline.
LM weights can be initialized differently from the ST weights for model training without any performance degradation. This implies flexibility for different weight combinations.
4.3.2 Standard-scale models
We then trained SUREALM using standard-scale models and compare with popular state-of-the-art LM baselines such as BERT, RoBERTa and GPT2. However, since they used different word tokenizers resulting in different output vocabulary sizes, we can only compare models with with similar vocabulary sizes. Table 3 shows that SUREALM achieved relative test perplexity reduction by 7.1%. Table 4 shows perplexity results with increased vocabulary size of 50k. SUREALM achieved relative test perplexity reduction by 3.2% and 2% compared to GPT2 and RoBERTa baseline respectively.
|ST weight||LM weight||Val. ppl||Test ppl|
|ST weight||LM weight||Val. ppl||Test ppl|
|ST weight||LM weight||Val. ppl||Test ppl|
During embedding retrieval, we investigated the inclusion of the current word into the suffix in a training sentence, meaning that we only split a training sentence into prefix and suffix instead of prefix, current word and suffix mentioned in Section 3.1. Then we followed the same procedure to encode the prefixes and suffixes and reran SUREALM training and evaluation. However, we observed no test perplexity reduction compared to the baseline. Excluding the current word from suffix may be analogous to applying a mask token in the mask LM. After excluding the current word, SUREALM focuses on information from the word history and the retrieved suffix context for word prediction. It is possible that the embedding retrieval results may contain sentences that share similar prefixes but having an identical suffix as in the current input sentence. From this perspective, excluding the current word from suffix is reasonable to avoid SUREALM from overly relying on the suffix embeddings and forgetting the word history in word prediction.
We have proposed a suffix retrieval-augmented language model to simulate bi-directional contextual effect while remains autoregressive so that our model can be used for sequence generation. Our proposed model shows promising perplexity performance compared to state-of-the-art LM baselines. In the future, we plan to evaluate our model on large corpora. In addition, we plan to extend our model on conditional generation such as dialogue response generation. Lastly, we will investigate domain LM adaptation using our proposed model.
-  (2019) BERT: pre-training of deep bidirectional transformers for language understanding. In Proceedings of NAACL-HLT, Vol. 1, pp. 4171–4186. Cited by: §2.
-  (2004-05) Language model adaptation for statistical machine translation based on information retrieval. In Proceedings of the Fourth International Conference on Language Resources and Evaluation (LREC’04), Lisbon, Portugal. External Links: Cited by: §2.
REALM: retrieval-augmented language model pre-training.
Proceedings of the 37th International Conference on Machine Learning, ICML 2020. Cited by: §1, §2.
-  (2019) Billion-scale similarity search with GPUs. IEEE Transactions on Big Data 7 (3), pp. 535–547. Cited by: §1, §3.1, §4.1.
-  (2020) Generalization through Memorization: Nearest Neighbor Language Models. In International Conference on Learning Representations (ICLR), Cited by: §2.
-  (2021-02) Beyond domain apis: task-oriented conversational modeling with unstructured knowledge access track in DSTC9. In Proceedings of DSTC9 Workshop @ AAAI, Cited by: §4.1.
-  (2020) Retrieval-augmented generation for knowledge-intensive nlp tasks. In 34th Conference on Neural Information Processing Systems, NeurIPS 2020. Cited by: §2.
-  (2019) RoBERTa: A robustly optimized BERT pretraining approach. CoRR abs/1907.11692. External Links: Cited by: §2.
-  (1999) Improved topic-dependent language modeling using information retrieval techniques. In IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), ICASSP 1999. Cited by: §2.
-  (2018) Improving language understanding by generative pre-training. In Technical report, OpenAI, Cited by: §2.
Sentence-bert: sentence embeddings using siamese bert-networks.
Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing, External Links: Cited by: §1, §3.1, §4.1, §4.2.
-  (2020) Leveraging pre-trained checkpoints for sequence generation tasks. Transactions of the Association for Computational Linguistics 8, pp. 264–280. External Links: Cited by: §4.2.
-  (2017) Attention is all you need. In Advances in Neural Information Processing Systems, I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett (Eds.), Vol. 30, pp. . External Links: Cited by: §2.
-  (2020-10) Transformers: state-of-the-art natural language processing. In Proceedings of EMNLP: System Demonstrations, Online, pp. 38–45. External Links: Cited by: §3.3.