An Embedded Deep Learning based Word Prediction

by   Seunghak Yu, et al.

Recent developments in deep learning with application to language modeling have led to success in tasks of text processing, summarizing and machine translation. However, deploying huge language models for mobile device such as on-device keyboards poses computation as a bottle-neck due to their puny computation capacities. In this work we propose an embedded deep learning based word prediction method that optimizes run-time memory and also provides a real time prediction environment. Our model size is 7.40MB and has average prediction time of 6.47 ms. We improve over the existing methods for word prediction in terms of key stroke savings and word prediction rate.



There are no comments yet.


page 4


Is Word Segmentation Necessary for Deep Learning of Chinese Representations?

Segmenting a chunk of text into words is usually the first step of proce...

Multitask Finetuning for Improving Neural Machine Translation in Indian Languages

Transformer based language models have led to impressive results across ...

The Final Frontier: Deep Learning in Space

Machine learning, particularly deep learning, is being increasing utilis...

Next word prediction based on the N-gram model for Kurdish Sorani and Kurmanji

Next word prediction is an input technology that simplifies the process ...

LightRNN: Memory and Computation-Efficient Recurrent Neural Networks

Recurrent neural networks (RNNs) have achieved state-of-the-art performa...

Graph Algorithms for Multiparallel Word Alignment

With the advent of end-to-end deep learning approaches in machine transl...

genCNN: A Convolutional Architecture for Word Sequence Prediction

We propose a novel convolutional architecture, named genCNN, for word se...

Code Repositories


Natural Language Processing Tasks and References

view repo


An embedded deep learning based word prediction method that optimizes run-time memory and also provides a real time prediction environment.

view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Recurrent neural networks (RNNs) have delivered state of the art performance on language modeling (RNN-LM) (Mikolov et al., 2010; Kim et al., 2015; Miyamoto and Cho, 2016). A major advantage of RNN-LMs is that these models inherit the property of storing and accessing information over arbitrary context lengths from RNNs (Karpathy et al., 2015)

. The model takes as input a textual context and generates a probability distribution over the words in the vocabulary for the next word in the text.

However, the state of the art RNN-LM requires over 50MB of memory ((Zoph and Le, 2016) contains 25M parameters; quantized to 2 bytes). This has prevented deploying of RNN-LM on mobile devices for word prediction, word completion, and error correction tasks. Even on high-end devices, keyboards have constraints on memory (10MB) and response time (10ms), hence we cannot apply RNN-LM directly without compression.

Various deep model compression methods have been developed. Compression through matrix factorization (Sainath et al., 2013; Xue et al., 2013; Nakkiran et al., 2015; Prabhavalkar et al., 2016; Lu et al., 2016) has shown promising results in model compression but has been applied to the tasks of automatic speech recognition. Network pruning (LeCun et al., 1989; Han et al., 2015a, b) keeps the most the relevant parameters while removing the rest. Weight sharing (Gong et al., 2014; Chen et al., 2015; Ullrich et al., 2017) attempts to quantize the parameters into clusters. Network pruning and weight sharing methods only consider memory constraints while compressing the models. They achieve high compression rate but do not optimize test time computation and hence, none of them are suitable for our application.

Figure 1: Overview of the proposed method. : the logits of model, : the softened output of ensemble. and substitute and in the proposed model respectively.

To address the constraints of both memory size and computation, we propose a word prediction method that optimizes for run-time, and memory to render a smooth performance on embedded devices. We propose shared matrix factorization to compress the model along with using knowledge distillation to compensate the loss in accuracy while compressing. The resulting model is approximately 8

compressed with negligible loss in accuracy and has a response time of 6.47ms per prediction. To the best of our knowledge, this is the first approach to use RNN-LMs for word prediction on mobile devices whereas previous approaches used n-gram based statistical language models

(Klarlund and Riley, 2003; Tanaka-Ishii, 2007) or unpublished. We achieve better performance than existing approaches in terms of Key Stroke Savings (KSS) (Fowler et al., 2015) and Word Prediction Rate (WPR). The proposed method has been successfully commercialized.

2 Proposed Method

2.1 Overview

Figure 1 shows an overview of our approach. We propose a pipeline to compress RNN-LM for on-device word prediction with negligible loss of accuracy. Following sections describe each steps of our method. In Section 2.2, we describe the basic architecture of language model which is used as an elementary model in our pipeline. In Section 2.3, we describe method to make a distilled model by knowledge distillation and compensate for loss in accuracy due to compression. Following Section 2.4 describes model compression strategies to reduce memory usage and run-time.

2.2 Baseline Language Model

All language models in our pipeline mimic the conventional RNN-LM architecture as in Figure 2

. Each model consists of three parts: word embedding, recurrent hidden layers, and softmax layer. Word embedding

(Mikolov et al., 2013) takes input word at time

as one hot vector and maps it to

in continuous vector space . This process is parametrized by embedding matrix as , where , is the vocabulary, and is the dimension of embedding space.

Figure 2: Conventional RNN-LM.

The embedded word is input to LSTM based hidden layers. We use the architecture similar to the non-regularized LSTM model by (Zaremba et al., 2014). The hidden state of the LSTM unit is affine-transformed by the softmax function, which is a probability distribution over all the words in the as in Eq. 1.


We train the model with cross-entropy loss function using Adam

(Kingma and Ba, 2014)

optimizer. Initial learning rate is set to 0.001 and decays with roll-back after every epoch with no decrement in perplexity on the validation dataset.

2.3 Distilling Language Model

Knowledge Distillation (KD) (Hinton et al., 2015) uses an ensemble of pre-trained teacher models (typically deep and large) to train a distilled model (typically shallower). KD helps provide global information to distilled model, and hence regularizes and gives faster updates for the parameters.

We refer to ‘hard targets’ as true labels from the data. Contrary to the baseline model which only uses ‘hard targets’, we adapt KD to learn a combined cost function from ‘hard targets’ and ‘soft targets’. ‘Soft targets’ are generated by adding a temperature (Eq.2) to averaged logits of teachers’ to train distilled model.


Experiments in Table 2 shows improvement in perplexity compared to the models trained only with ‘hard targets’. We also use the combined cost function to retrain the model after compression. Retraining with combined cost function compensates for the loss in performance due to compression proposed in Section 2.4.

Language Source Words Sentences
EN Reddit 1.1B 71.2M
EN Twitter 0.9B 55.2M
Table 1: Collected data for language modeling.

2.4 Shared Matrix Factorization

We present a compression method using shared matrix factorization for embedding and softmax layers in a RNN-LM. In the language model word embedding is trained to map words with similar context into a solution space closely, while softmax layer maps context to similar words. Therefore, we assume we can find sharable parameters that have characteristics similar to both embedding and softmax. Recently, there have been preprints (Press and Wolf, 2016; Inan et al., 2016) suggesting an overlap of characteristics between embedding and softmax weights.

We facilitate sharing by across softmax and embedding layers, allowing for more efficient parameterization of weight matrices. This reduces the total parameters in embedding and softmax layers by half. We introduce two trainable matrices and , called the projection matrices, that adapt the for the individual tasks of embedding and softmax as in Eq. 3.


Furthermore, in the layers parametrized by only a few outputs are active for a given input, we suspect that they are probably correlated and the underlying weight matrix has low rank . For such a weight matrix, , there exists a factorization of where and are full rank (Strang et al., 1993). In our low-rank compression strategy, we expect rank of as which leads to factorization in Eq. 4.


Moreover, we compress by applying Singular Value Decomposition (SVD) to initialize the decomposed matrices. SVD has been proposed as a promising method to perform factorization for low rank matrices

(Nakkiran et al., 2015; Prabhavalkar et al., 2016). We apply SVD on to decompose it as . are used to initialize and for the retraining process. We use the top singular values from and corresponding rows from . Therefore, and

, we replace all the linear transformations using

with . Approximation in Eq. 4 during factorization leads to degradation in model performance but when followed by fine-tuning through retraining it results in restoration of accuracy. This compression scheme, without loss of generality is applied to .

3 Experiment Results

3.1 Evaluation of proposed approach

Table 1 describes the source of dataset111The dataset is available at, number of words and sentences. This data is extracted from resources on the Internet, in a raw form with 8 billion words. We uniformly sample 10% (196 million) from the dataset. It consists of 60% for training, 10% for validation and 30% for test.

We preprocess raw data to remove noise and filter phrases. We also replace numbers in the dataset with a special symbol, and out-of-vocabulary (OOV) words with . We append start of sentence token and end of sentence token to every sentence. We convert our dataset to lower-case to increase vocabulary coverage and use top 15K words as the vocabulary.

Model PP Size CR
Baseline 56.55 56.76 -
+ KD 55.76 56.76 -
+ Shared Matrix 55.07 33.87 1.68
+ Low-Rank, Retrain 59.78 14.80 3.84
+ Quantization 59.78 7.40 7.68
Table 2: Evaluation of each model in our pipeline. Baseline uses ‘hard targets’ and Knowledge Distillation (KD) uses ‘soft targets’. Size is in MB and 16 bit quantization is applied to the final model. PP: Word Perplexity, CR: Compression Rate.

Table 2 shows evaluation result of each step in our pipeline. We empirically select 600 embedding dimension, single hidden layer with 600 LSTM hidden units for baseline model. Word Perplexity is used to evaluate and compare our models. Perplexity over the test set is computed as , where N is the number of words in the test set. Our final model is roughly smaller than the baseline with 5% (3.16) loss in perplexity.

3.2 Performance Comparison

We compare our performance with existing word prediction methods using manually curated dataset222The dataset consists of 102 sentences (926 words, 3,746 characters) which are collection of formal and informal utterances from various sources. It is available at, which covers general keyboard scenarios. Due to lack of access to language modeling engine used in other solutions, we are unable to compare word perplexity. To the best of our efforts, we try to minimize all the personalization these solutions offer in their prediction engines. We performed human evaluation on the manually curated dataset. We employed three evaluators from the inspection group to cross-validate all the tests in Table 3 to eliminate human errors.

Developer KSS(%) WPR(%)
Proposed 65.11 34.38
Apple 64.35 33.73
Swiftkey 62.39 31.14
Samsung 59.81 28.84
Google 58.89 28.02
Table 3: Performance comparison of proposed method and other commercialized keyboard solutions by various developers.

We achieve the best performance compared to other solutions in terms of Key Stroke Savings (KSS) and Word Prediction Rate (WPR) as shown in Table 3. KSS is a percentage of key strokes not pressed compared to a keyboard without any prediction or completion capabilities. Every character the user types using the predictions of the language model counts as key stroke saving. WPR is percentage of correct word predictions in the dataset.

While evaluating KSS and WPR, the number of predictions for the next word is same for all the solutions. The proposed method shows 65.11% in terms of KSS and 34.38% in WPR which is the best score among the compared solutions. For example, if the user intents to type “published” and types only 34.89% characters (“pub”), one of the top two predictions is “published”. Furthermore, 34.38% words the user intents to type are among the top three predictions. Figure 3 shows an example of word prediction across different solutions. In this example, we can spot some grammatical errors in the predictions from other solutions.

Figure 3: Example of comparision with other commercialized solutions. Predicted words for the contexts “Last year I” and “Next year I”.

4 Conclusions and Future Work

We have proposed a practical method for training and deploying RNN-LM for mobile device which can satisfy memory and runtime constrains. Our method utilizes averaged output of teachers to train a distilled model and compresses its weight matrices by applying shared matrix factorization. We achieve 7.40MB in memory size and satisfy the run time constraint of 10ms in average prediction time (6.47ms). Also, we have compared proposed method to existing commercialized keyboards in terms of key stroke savings and word prediction rate. In our benchmark tests, our method out-performed the others.

RNN-LM does not support personalization independently. However, our model which is currently commercialized uses RNN-LM along with n-gram statistics to learn user’s input pattern and uni-gram to cover OOV words. Future work is required on directly personalizing the RNN-LM model to user’s preferences rather than interpolating it with n-gram statistics to take full advantage.