Gboard — the Google keyboard — is a virtual keyboard for touch screen mobile devices with support for more than 600 language varieties and over 1 billion installs as of 2018. Gboard provides a variety of input features, including tap and word-gesture typing, auto-correction, word completion and next word prediction.
Learning frequently typed words from user-generated data is an important component to the development of a mobile keyboard. Example usage includes incorporating new trending words (celebrity names, pop culture words, etc.) as they arise, or simply compensating for omissions in the initial keyboard implementation, especially for low-resource languages.
The list of words is often referred to as the “vocabulary”, and may or may not be hand-curated. Words missing from the vocabulary cannot be predicted on the keyboard suggestion strip, cannot be gesture typed, and more annoyingly, may be autocorrected even when typed correctly Ouyang et al. (2017)
. Moreover, for latency and reliability reasons, mobile keyboards models run on-device. This means that the vocabulary supporting the models are intrinsically limited in size, e.g. to a couple hundred thousand words per language. It is therefore crucial to discover and include the most useful words in this rather short vocabulary list. Words not in the vocabulary are often called “out-of-vocabulary” (OOV) words. Note that the concept of vocabulary is not limited to mobile keyboards. Other natural language applications, such as for example neural machine translation (NMT), rely on a vocabulary to encode words during end-to-end training. Learning OOV words and their rankings is thus a fairly generic technology need.
The focus of our work is learning OOV words in a environment without transmitting and storing sensitive user content on centralized servers. Privacy is easier to ensure when words are learned on device. Our work builds upon recent advances in Federated Learning (FL) Konečnỳ et al. (2016); Yang et al. (2018); Hard et al. (2018), a decentralized learning technique to train models such as neural networks on users’ devices, uploading only ephemeral model updates to the server for aggregation, and leaving the users’ raw data on their device. Approaches based on hash maps, count sketches Charikar et al. (2002), or tries Gunasinghe and Alahakoon (2012) require significant adaptation to be able to run on FL settings Zhu et al. (2019). Our work builds upon a very general FL framework for neural networks, where we train a federated character-based recurrent neural network (RNN) on device. OOV words are Monte Carlo sampled on servers during interference (details in 2.1).
While Federated Learning removes the need to upload raw user material — here OOV words — to the server, the privacy risk of unintended memorization still exists (as demonstrated in Carlini et al. (2018)). Such risk can be mitigated, usually with some accuracy cost, using techniques including differential privacy McMahan et al. (2018). Exploring these trade-offs is beyond the scope of this paper.
The proposed approach relies on a learned probabilistic model, and may therefore not generate the words it is trained from faithfully. It may “daydream” OOV words, that is, come up with character sequences never seen in the training data. And it may not be able to regenerate some words that would be interesting to learn. The key to demonstrate the practicality of the approach is to answer: (1) How frequently do daydreamed words occur? (2) How well does the sampled distribution represent the true word frequencies in the dataset? In response to these questions, our contributions include the following:
1. We train the proposed LSTM model on a public Reddit comments dataset with a simulated FL environment. The Reddit dataset contains user ID information for each entry (user’s comments or posts), which can be used to mimic the process of FL to learn from each client’s local data. The simulated FL model is able to achieve precision and recall for top unique words, based on a total number of parallel independent samplings.
2. We show the feasibility of training LSTM models from daily Gboard user data with real, on-device, FL settings. The FL model is able to reach character-level top-3 prediction accuracy and 2.35 cross entropy on users’ on-device data. We further show that the top sampled words are very meaningful and are able to capture words we know to be trending in the news at the time of the experiments.
2.1 LSTM Modeling
LSTM models Hochreiter and Schmidhuber (1997) have been successfully used in a variety of sequence processing tasks. In this work we use a variant of LSTM with a Coupled Input and Forget Gate (CIFG) Greff et al. (2017), peephole connections Gers and Schmidhuber (2000) and a projection layer Sak et al. (2014). CIFG couples the forget and input decisions together, thereby reducing the number of parameters by . The projection layer that is located at the output gate generates the hidden state to reduce the dimension and speed-up training. Peephole connections let the gate layers look at the cell state. We use multi-layer LSTMs Sutskever et al. (2014)
to increase the representation power of the model. The loss function of the LSTM model is computed as the cross entropy (CE) between the predicted character distributionat each step
and a one-hot encoding of the current true character, summed over the words in the dataset.
where is a sequence with arbitrary length
. RNNs estimateas
where can be perceived as a function mapping input sequence to the hidden state of the RNN. The sampling process starts with multiple threads, each beginning with start of the word token. At step each thread generates a random index based on . This is done iteratively until it hits end of the word tokens. This parallel sampling approach avoids the dependency between each sampling thread, which might occur in beam search or shortest path search sampling Carlini et al. (2018),
2.2 Federated Learning
Federated learning (FL) is a new paradigm of machine learning that keeps training data localized on mobile devices and never collects them centrally. FL is shown to be robust to unbalanced or non-IID (independent and identically distributed) data distributionsMcMahan et al. (2016); Konečnỳ et al. (2016). It learns a shared model by aggregating locally-computed gradients. FL is especially useful for OOV word learning, since OOV words typically include sensitive user content. FL avoids the need for transmitting and storing such content on centralized servers. FL can also be combined with privacy-preserving techniques like secure aggregation Bonawitz et al. (2016) and differential privacy McMahan et al. (2018) to offer stronger privacy guarantees to users.
We use the FederatedAveraging algorithm presented in McMahan et al. (2016) to combine client updates after each round of local training to produce a new global round with weight , where is the global model weight at round and is the model weight for each participating device , and and are number of data points on client and the total sum of all devices. Adaptive -norm clipping is performed on each client’s gradient, as it is found to improve the robustness of model convergence.
In this section, we show the details of our experiments in two different settings: (1) simulated FL with the public Reddit dataset Al-Rfou et al. (2016), (2) on-device FL where user text never leaves their device. In all of our experiments, we use a filter to exclude “invalid” OOV patterns that represent things we don’t want the model to learn. The filtering excludes words that: (1) start with non-ASCII alphabetical characters (mostly emoji), (2) contain numbers (street or telephone numbers), (3) contain repetitive patterns like “hahaha” or “yesss”, (4) are no longer than 2 in length. In both simulated and on-device FL setting, training and evaluation are defined by separate computation tasks that sample users’ data independently.
3.1 Evaluation metrics
In OOV word learning, we are interested in how many words are either missing or daydreamed from the model sampling. In simulated FL, we have access to the datasets and know the ground truth OOV words and their frequencies. Thus, the quality of the model can be evaluated based on precision and recall (PR). For the on-device FL setting, it is not possible to compute PR since users’ raw input data is inaccessible by design. We show that the model is able to converge to good CE loss and top-K character-level prediction accuracy. Unlike PR, CE and accuracy does not need computationally-intensive sampling and can be computed on the fly during training.
3.2 Model parameters
Table 1 shows three different model hyper-parameters used in our federated experiments. , , , and refers to number of RNN layers, server side learning rate, momentum, and batch size, respectively. and applies standard SGD without momentum or clipping. They vary in the LSTM model architectures, where model contains K parameters and contains K parameters. has the same model architecture as
and further applies adaptive gradient clipping, combined with Nesterov accelerated gradientNesterov (1983) and a momentum hyper-parameter of . Unlike server-based training, FL uses a client-side learning rate with local min-batch update, in addition to the server-side learning rate . converges with , while and diverges with such a high value.
3.3 Federated learning settings
For both simulated and on-device FL,
client updates are required to close each round. For each training round we set number of local epochs as.
3.3.1 Simulated FL on Reddit data
The Reddit conversation corpus is a publicly-accessible and fairly large dataset. It includes diverse topics from 300,000 sub-forums Al-Rfou et al. (2016). The data are organized as “Reddit posts”, where users can comment on each other’s comments indefinitely and user IDs are kept. The data contain 133 million posts from 326 thousand different sub-forums, consisting of 2.1 billion comments. Unlike Twitter, Reddit message sizes are unlimited. Similar to work in Al-Rfou et al. (2016), Reddit posts that have more than 1000 comments are excluded. The final filtered data used for FL simulation contain million comments coming from thousand unique users. There are million filtered OOV words, among which million are unique. As user-tagged data is needed to do FL experiments, Reddit posts are sharded in FL simulations based on user ID to mimic the local client cache scenario.
3.3.2 FL on client device data
In the on-device FL setting, the original raw data content is not accessible for human inspection since it remains stored in local caches on client devices. In order to participate in a round of FL, client devices must: (1) have at least 2G of memory, (2) be charging, (3) be connected to an un-metered network, and (4) be idle. In this study we experiment FL on three languages: (1) American English (en_US), (2) Brazilian Portuguese (pt_BR), and (3) Indonesian (in_ID). A separate FL model is trained specifically on devices located in each region for each language. Although ground truth OOV words are not accessible in the on-device setting, the model can still be evaluated on a character-level metric like top-3 prediction accuracy or CE (i.e. given current context “extra” in “extraordinary”, predict “o” for next character).
4.1 FL simulation on Reddit data
During training, converges faster and better than and in both CE loss and prediction accuracy, with and for top-3 accuracy and CE loss, respectively. The larger model does not lead to significant gains. Momentum and adaptive clipping lead to faster convergence and more stable performance.
Table 2 shows the top 10 OOV words with their occurring probability in the Reddit dataset (left) and the generative model (right). The model generally learns the probability of word occurrences, where the absolute value and relative rank for top words are very close to the ground truth.
In figures 3 and 3, PR are computed using the model checkpoint of with after 3000 rounds, giving and , respectively. PR rate is plotted against the top unique words (x-axis). Curves shown in red, green, and blue represent , and number of independent samplings, respectively. Both the PR rate improve significantly when the amount of sampling is increased from to . We also observe increased PR over the training rounds.
4.2 FL on local client data
For on-device settings, all the three models converge in about rounds over the course of about 4 days. Figures 4 and 5 compare the CE loss and top-3 prediction accuracy on evaluation data for three FL settings. Similar to Reddit data, converges faster and better than and , achieving 55.8% and 2.35 for prediction accuracy and CE loss, respectively (compared to and in training). Experiments in pt_BR and in_ID shows a very similar pattern among the three settings.
Table 3 shows sampled OOV words in the aforementioned three languages. Here “abbr.” is short for abbreviations. “slang/typo” refers to commonly spoken slang words and purposeful misspellings. “repetitive” refers to interjections or words that people commonly misspell in a repetitive way intentionally. “foreign” refers to words typed in a language foreign to the current language/region setting. “names” refers to trending celebrities’ names. We also observed a lot of profanity learned by the model that is not shown here. Our future work will focus on better filtering out those OOV words, especially unintended typos. This can be accomplished by using a manually curated desirable and undesirable OOV blacklists, which can be updated with newer rounds of FL in an iterative process. As we sample words, those desirable or undesirable words can be continually incorporated and made available for all users in future. Thus the model can save more capacity to focus entirely on new words.
In this paper, we present a method to discover OOV words through federated learning. The model relies on training a character-based model from which words can be generated via sampling. Compared with traditional server-side methods, our method learns OOV words on each device and transmits the learned knowledge by aggregating gradient updates from local SGD. We demonstrate the feasibility of this approach with simulated FL on a publicly-available corpus where we achieve precision and recall for top unique words. We also perform live experiments with on-device data from 3 populations of Gboard users and demonstrate that this method can learn OOV words effectively in a real-world setting.
The authors would like to thank colleagues in the Google AI team for providing the federated learning framework and for many helpful discussions.
- Al-Rfou et al. (2016) Rami Al-Rfou, Marc Pickett, Javier Snaider, Yun-hsuan Sung, Brian Strope, and Ray Kurzweil. 2016. Conversational contextual cues: The case of personalization and history for response ranking. arXiv preprint arXiv:1606.00372.
- Bonawitz et al. (2016) Keith Bonawitz, Vladimir Ivanov, Ben Kreuter, Antonio Marcedone, H. Brendan McMahan, Sarvar Patel, Daniel Ramage, Aaron Segal, and Karn Seth. 2016. Practical secure aggregation for federated learning on user-held data. CoRR, abs/1611.04482.
- Carlini et al. (2018) Nicholas Carlini, Chang Liu, Jernej Kos, Úlfar Erlingsson, and Dawn Song. 2018. The secret sharer: Measuring unintended neural network memorization & extracting secrets. arXiv preprint arXiv:1802.08232.
- Charikar et al. (2002) Moses Charikar, Kevin Chen, and Martin Farach-Colton. 2002. Finding frequent items in data streams. In International Colloquium on Automata, Languages, and Programming, pages 693–703. Springer.
- Gers and Schmidhuber (2000) Felix A Gers and Jürgen Schmidhuber. 2000. Recurrent nets that time and count. In Proceedings of the IEEE-INNS-ENNS International Joint Conference on Neural Networks. IJCNN 2000. Neural Computing: New Challenges and Perspectives for the New Millennium, volume 3, pages 189–194. IEEE.
- Greff et al. (2017) Klaus Greff, Rupesh K Srivastava, Jan Koutník, Bas R Steunebrink, and Jürgen Schmidhuber. 2017. Lstm: A search space odyssey. IEEE transactions on neural networks and learning systems, 28(10):2222–2232.
- Gunasinghe and Alahakoon (2012) Upuli Gunasinghe and Damminda Alahakoon. 2012. Sequence learning using the adaptive suffix trie algorithm. In Neural Networks (IJCNN), the 2012 International Joint Conference on, pages 1–8. IEEE.
- Hard et al. (2018) Andrew Hard, Kanishka Rao, Rajiv Mathews, Françoise Beaufays, Sean Augenstein, Hubert Eichner, Chloé Kiddon, and Daniel Ramage. 2018. Federated learning for mobile keyboard prediction. arXiv preprint arXiv:1811.03604.
- Hochreiter and Schmidhuber (1997) Sepp Hochreiter and Jürgen Schmidhuber. 1997. Long short-term memory. Neural computation, 9(8):1735–1780.
- Konečnỳ et al. (2016) Jakub Konečnỳ, H Brendan McMahan, Felix X Yu, Peter Richtárik, Ananda Theertha Suresh, and Dave Bacon. 2016. Federated learning: Strategies for improving communication efficiency. arXiv preprint arXiv:1610.05492.
- McMahan et al. (2016) H Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, et al. 2016. Communication-efficient learning of deep networks from decentralized data. arXiv preprint arXiv:1602.05629.
- McMahan et al. (2018) H. Brendan McMahan, Daniel Ramage, Kunal Talwar, and Li Zhang. 2018. Learning differentially private language models without losing accuracy. International Conference on Learning Representations.
- Nesterov (1983) Yurii E Nesterov. 1983. A method for solving the convex programming problem with convergence rate o (1/k^ 2). In Dokl. Akad. Nauk SSSR, volume 269, pages 543–547.
- Ouyang et al. (2017) Tom Ouyang, David Rybach, Françoise Beaufays, and Michael Riley. 2017. Mobile keyboard input decoding with finite-state transducers. arXiv preprint arXiv:1704.03987.
- Sak et al. (2014) Haşim Sak, Andrew Senior, and Françoise Beaufays. 2014. Long short-term memory recurrent neural network architectures for large scale acoustic modeling. In Fifteenth annual conference of the international speech communication association.
- Sutskever et al. (2014) Ilya Sutskever, Oriol Vinyals, and Quoc V Le. 2014. Sequence to sequence learning with neural networks. In Advances in neural information processing systems, pages 3104–3112.
- Yang et al. (2018) Timothy Yang, Galen Andrew, Hubert Eichner, Haicheng Sun, Wei Li, Nicholas Kong, Daniel Ramage, and Françoise Beaufays. 2018. Applied federated learning: Improving google keyboard query suggestions. arXiv preprint arXiv:1812.02903.
- Zhu et al. (2019) Wennan Zhu, Peter Kairouz, Haicheng Sun, Brendan McMahan, and Wei Li. 2019. Federated heavy hitters discovery with differential privacy.