Federated Learning Of Out-Of-Vocabulary Words

03/26/2019 ∙ by Mingqing Chen, et al. ∙ Google 0

We demonstrate that a character-level recurrent neural network is able to learn out-of-vocabulary (OOV) words under federated learning settings, for the purpose of expanding the vocabulary of a virtual keyboard for smartphones without exporting sensitive text to servers. High-frequency words can be sampled from the trained generative model by drawing from the joint posterior directly. We study the feasibility of the approach in two settings: (1) using simulated federated learning on a publicly available non-IID per-user dataset from a popular social networking website, (2) using federated learning on data hosted on user mobile devices. The model achieves good recall and precision compared to ground-truth OOV words in setting (1). With (2) we demonstrate the practicality of this approach by showing that we can learn meaningful OOV words with good character-level prediction accuracy and cross entropy loss.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Gboard — the Google keyboard — is a virtual keyboard for touch screen mobile devices with support for more than 600 language varieties and over 1 billion installs as of 2018. Gboard provides a variety of input features, including tap and word-gesture typing, auto-correction, word completion and next word prediction.

Learning frequently typed words from user-generated data is an important component to the development of a mobile keyboard. Example usage includes incorporating new trending words (celebrity names, pop culture words, etc.) as they arise, or simply compensating for omissions in the initial keyboard implementation, especially for low-resource languages.

The list of words is often referred to as the “vocabulary”, and may or may not be hand-curated. Words missing from the vocabulary cannot be predicted on the keyboard suggestion strip, cannot be gesture typed, and more annoyingly, may be autocorrected even when typed correctly Ouyang et al. (2017)

. Moreover, for latency and reliability reasons, mobile keyboards models run on-device. This means that the vocabulary supporting the models are intrinsically limited in size, e.g. to a couple hundred thousand words per language. It is therefore crucial to discover and include the most useful words in this rather short vocabulary list. Words not in the vocabulary are often called “out-of-vocabulary” (OOV) words. Note that the concept of vocabulary is not limited to mobile keyboards. Other natural language applications, such as for example neural machine translation (NMT), rely on a vocabulary to encode words during end-to-end training. Learning OOV words and their rankings is thus a fairly generic technology need.

The focus of our work is learning OOV words in a environment without transmitting and storing sensitive user content on centralized servers. Privacy is easier to ensure when words are learned on device. Our work builds upon recent advances in Federated Learning (FL) Konečnỳ et al. (2016); Yang et al. (2018); Hard et al. (2018), a decentralized learning technique to train models such as neural networks on users’ devices, uploading only ephemeral model updates to the server for aggregation, and leaving the users’ raw data on their device. Approaches based on hash maps, count sketches Charikar et al. (2002), or tries Gunasinghe and Alahakoon (2012) require significant adaptation to be able to run on FL settings Zhu et al. (2019). Our work builds upon a very general FL framework for neural networks, where we train a federated character-based recurrent neural network (RNN) on device. OOV words are Monte Carlo sampled on servers during interference (details in 2.1).

While Federated Learning removes the need to upload raw user material — here OOV words — to the server, the privacy risk of unintended memorization still exists (as demonstrated in Carlini et al. (2018)). Such risk can be mitigated, usually with some accuracy cost, using techniques including differential privacy McMahan et al. (2018). Exploring these trade-offs is beyond the scope of this paper.

The proposed approach relies on a learned probabilistic model, and may therefore not generate the words it is trained from faithfully. It may “daydream” OOV words, that is, come up with character sequences never seen in the training data. And it may not be able to regenerate some words that would be interesting to learn. The key to demonstrate the practicality of the approach is to answer: (1) How frequently do daydreamed words occur? (2) How well does the sampled distribution represent the true word frequencies in the dataset? In response to these questions, our contributions include the following:

1. We train the proposed LSTM model on a public Reddit comments dataset with a simulated FL environment. The Reddit dataset contains user ID information for each entry (user’s comments or posts), which can be used to mimic the process of FL to learn from each client’s local data. The simulated FL model is able to achieve precision and recall for top unique words, based on a total number of parallel independent samplings.

2. We show the feasibility of training LSTM models from daily Gboard user data with real, on-device, FL settings. The FL model is able to reach character-level top-3 prediction accuracy and 2.35 cross entropy on users’ on-device data. We further show that the top sampled words are very meaningful and are able to capture words we know to be trending in the news at the time of the experiments.

2 Method

2.1 LSTM Modeling

LSTM models Hochreiter and Schmidhuber (1997) have been successfully used in a variety of sequence processing tasks. In this work we use a variant of LSTM with a Coupled Input and Forget Gate (CIFG) Greff et al. (2017), peephole connections Gers and Schmidhuber (2000) and a projection layer Sak et al. (2014). CIFG couples the forget and input decisions together, thereby reducing the number of parameters by . The projection layer that is located at the output gate generates the hidden state to reduce the dimension and speed-up training. Peephole connections let the gate layers look at the cell state. We use multi-layer LSTMs Sutskever et al. (2014)

to increase the representation power of the model. The loss function of the LSTM model is computed as the cross entropy (CE) between the predicted character distribution

at each step

and a one-hot encoding of the current true character

, summed over the words in the dataset.

During the inference stage, the sampling is based on the chain rule of probability,

(1)

where is a sequence with arbitrary length

. RNNs estimate

as

(2)

where can be perceived as a function mapping input sequence to the hidden state of the RNN. The sampling process starts with multiple threads, each beginning with start of the word token. At step each thread generates a random index based on . This is done iteratively until it hits end of the word tokens. This parallel sampling approach avoids the dependency between each sampling thread, which might occur in beam search or shortest path search sampling Carlini et al. (2018),

Figure 1: Monte Carlo sampling of OOV words from the LSTM model.

2.2 Federated Learning

Federated learning (FL) is a new paradigm of machine learning that keeps training data localized on mobile devices and never collects them centrally. FL is shown to be robust to unbalanced or non-IID (independent and identically distributed) data distributions 

McMahan et al. (2016); Konečnỳ et al. (2016). It learns a shared model by aggregating locally-computed gradients. FL is especially useful for OOV word learning, since OOV words typically include sensitive user content. FL avoids the need for transmitting and storing such content on centralized servers. FL can also be combined with privacy-preserving techniques like secure aggregation Bonawitz et al. (2016) and differential privacy McMahan et al. (2018) to offer stronger privacy guarantees to users.

We use the FederatedAveraging algorithm presented in McMahan et al. (2016) to combine client updates after each round of local training to produce a new global round with weight , where is the global model weight at round and is the model weight for each participating device , and and are number of data points on client and the total sum of all devices. Adaptive -norm clipping is performed on each client’s gradient, as it is found to improve the robustness of model convergence.

3 Experiments

In this section, we show the details of our experiments in two different settings: (1) simulated FL with the public Reddit dataset Al-Rfou et al. (2016), (2) on-device FL where user text never leaves their device. In all of our experiments, we use a filter to exclude “invalid” OOV patterns that represent things we don’t want the model to learn. The filtering excludes words that: (1) start with non-ASCII alphabetical characters (mostly emoji), (2) contain numbers (street or telephone numbers), (3) contain repetitive patterns like “hahaha” or “yesss”, (4) are no longer than 2 in length. In both simulated and on-device FL setting, training and evaluation are defined by separate computation tasks that sample users’ data independently.

3.1 Evaluation metrics

In OOV word learning, we are interested in how many words are either missing or daydreamed from the model sampling. In simulated FL, we have access to the datasets and know the ground truth OOV words and their frequencies. Thus, the quality of the model can be evaluated based on precision and recall (PR). For the on-device FL setting, it is not possible to compute PR since users’ raw input data is inaccessible by design. We show that the model is able to converge to good CE loss and top-K character-level prediction accuracy. Unlike PR, CE and accuracy does not need computationally-intensive sampling and can be computed on the fly during training.

3.2 Model parameters

Table 1 shows three different model hyper-parameters used in our federated experiments. , , , and refers to number of RNN layers, server side learning rate, momentum, and batch size, respectively. and applies standard SGD without momentum or clipping. They vary in the LSTM model architectures, where model contains K parameters and contains K parameters. has the same model architecture as

and further applies adaptive gradient clipping, combined with Nesterov accelerated gradient 

Nesterov (1983) and a momentum hyper-parameter of . Unlike server-based training, FL uses a client-side learning rate with local min-batch update, in addition to the server-side learning rate . converges with , while and diverges with such a high value.

0.0 0.0 0.9
64 64 64
0.0 0.0
1.0 1.0 1.0
0.1 0.1 0.5
2 3 3
256 256 256
16 128 128
64 128 128
Table 1: Hyper-parameters for three different FL settings. *Here adaptive clipping with a 0.99 percentile ratio is used together with server-side L2 clipping.

3.3 Federated learning settings

For both simulated and on-device FL,

client updates are required to close each round. For each training round we set number of local epochs as

.

3.3.1 Simulated FL on Reddit data

The Reddit conversation corpus is a publicly-accessible and fairly large dataset. It includes diverse topics from 300,000 sub-forums Al-Rfou et al. (2016). The data are organized as “Reddit posts”, where users can comment on each other’s comments indefinitely and user IDs are kept. The data contain 133 million posts from 326 thousand different sub-forums, consisting of 2.1 billion comments. Unlike Twitter, Reddit message sizes are unlimited. Similar to work in Al-Rfou et al. (2016), Reddit posts that have more than 1000 comments are excluded. The final filtered data used for FL simulation contain million comments coming from thousand unique users. There are million filtered OOV words, among which million are unique. As user-tagged data is needed to do FL experiments, Reddit posts are sharded in FL simulations based on user ID to mimic the local client cache scenario.

Figure 2: Precision vs. top- uniquely sampled words in simulated FL experiments.
Figure 3: Recall vs. top- unique words from ground truth in simulated FL experiments.
Figure 2: Precision vs. top- uniquely sampled words in simulated FL experiments.

3.3.2 FL on client device data

In the on-device FL setting, the original raw data content is not accessible for human inspection since it remains stored in local caches on client devices. In order to participate in a round of FL, client devices must: (1) have at least 2G of memory, (2) be charging, (3) be connected to an un-metered network, and (4) be idle. In this study we experiment FL on three languages: (1) American English (en_US), (2) Brazilian Portuguese (pt_BR), and (3) Indonesian (in_ID). A separate FL model is trained specifically on devices located in each region for each language. Although ground truth OOV words are not accessible in the on-device setting, the model can still be evaluated on a character-level metric like top-3 prediction accuracy or CE (i.e. given current context “extra” in “extraordinary”, predict “o” for next character).

4 Results

4.1 FL simulation on Reddit data

During training, converges faster and better than and in both CE loss and prediction accuracy, with and for top-3 accuracy and CE loss, respectively. The larger model does not lead to significant gains. Momentum and adaptive clipping lead to faster convergence and more stable performance.

Table 2 shows the top 10 OOV words with their occurring probability in the Reddit dataset (left) and the generative model (right). The model generally learns the probability of word occurrences, where the absolute value and relative rank for top words are very close to the ground truth.

yea 0.0050 yea 0.0057
upvote 0.0033 upvote 0.0040
downvoted 0.0030 downvoted 0.0033
alot 0.0026 alot 0.0029
downvote 0.0023 downvote 0.0026
downvotes 0.0018 downvotes 0.0022
upvotes 0.0016 upvotes 0.0021
wp-content 0.0016 op’s 0.0019
op’s 0.0015 wp-content 0.0017
restrict_sr 0.0014 redditors 0.0016
Table 2: Top 10 OOV words and their probabilities from ground truth (left) vs. samples (right) from simulated federated model trained on Reddit data

In figures 3 and 3, PR are computed using the model checkpoint of with after 3000 rounds, giving and , respectively. PR rate is plotted against the top unique words (x-axis). Curves shown in red, green, and blue represent , and number of independent samplings, respectively. Both the PR rate improve significantly when the amount of sampling is increased from to . We also observe increased PR over the training rounds.

4.2 FL on local client data

Figure 4: Cross entropy loss on live client evaluation data for three different FL settings for en_US.
Figure 5: Top-3 Character-level prediction accuracy on live client evaluation data for three different FL settings for en_US.

For on-device settings, all the three models converge in about rounds over the course of about 4 days. Figures 4 and 5 compare the CE loss and top-3 prediction accuracy on evaluation data for three FL settings. Similar to Reddit data, converges faster and better than and , achieving 55.8% and 2.35 for prediction accuracy and CE loss, respectively (compared to and in training). Experiments in pt_BR and in_ID shows a very similar pattern among the three settings.

Table 3 shows sampled OOV words in the aforementioned three languages. Here “abbr.” is short for abbreviations. “slang/typo” refers to commonly spoken slang words and purposeful misspellings. “repetitive” refers to interjections or words that people commonly misspell in a repetitive way intentionally. “foreign” refers to words typed in a language foreign to the current language/region setting. “names” refers to trending celebrities’ names. We also observed a lot of profanity learned by the model that is not shown here. Our future work will focus on better filtering out those OOV words, especially unintended typos. This can be accomplished by using a manually curated desirable and undesirable OOV blacklists, which can be updated with newer rounds of FL in an iterative process. As we sample words, those desirable or undesirable words can be continually incorporated and made available for all users in future. Thus the model can save more capacity to focus entirely on new words.

en_US in_ID pt_BR
abbr. rlly noh pqp
srry yws pfv
lmaoo gtw rlx
adip tlpn sdds
slang/typo yea gimana nois
tommorow duwe perai
gunna clan fuder
sumthin beb ein
repetitive ewwww siim tadii
hahah rsrs lahh
youu oww lohh
yeahh diaa kuyy
foreign muertos block buenas
quiero contract fake
bangaram cream the
names kavanaugh
khabib N.A. N.A.
cardi
Table 3: Sampled OOV words from FL models in three languages (en_US, pt_BR, in_ID).

5 Conclusion

In this paper, we present a method to discover OOV words through federated learning. The model relies on training a character-based model from which words can be generated via sampling. Compared with traditional server-side methods, our method learns OOV words on each device and transmits the learned knowledge by aggregating gradient updates from local SGD. We demonstrate the feasibility of this approach with simulated FL on a publicly-available corpus where we achieve precision and recall for top unique words. We also perform live experiments with on-device data from 3 populations of Gboard users and demonstrate that this method can learn OOV words effectively in a real-world setting.

Acknowledgments

The authors would like to thank colleagues in the Google AI team for providing the federated learning framework and for many helpful discussions.

References