1.1 Virtual keyboard applications
Virtual keyboards for mobile devices provide a host of functionalities from decoding noisy spatial signals from tap and glide typing inputs to providing auto-corrections, word completions, and next-word predictions. These features must fit within tight RAM and CPU budgets, and operate under strict latency constraints. A key press should result in visible feedback within about 20 milliseconds (Ouyang et al., 2017; Alsharif et al., 2015). Weighted finite-state transducers have been used successfully to decode keyboard spatial signals using a combination of spatial and language models (Ouyang et al., 2017; Hellsten et al., 2017). Figure 1 shows the glide trails of two spatially-similar words. Because of the similarity of the two trails, the decoder must rely on the language model to discriminate between viable candidates.
For memory and latency reasons, especially on low-end devices, the language models are typically based on n-grams and do not exceed ten megabytes. A language model (LM) is a probabilistic model on words. Given previous words
, an LM assigns a probability to the new words, i.e.. An n-gram LM is a Markovian distribution of order , defined by
where is the order of the n-gram. For computation and memory efficiency, keyboard LMs typically have higher-order n-grams over a subset of the vocabulary, e.g. the most frequent K words, and the rest of the vocabulary only has unigrams. We consider n-gram LMs that do not exceed M n-grams and include fewer than K unigrams.
N-gram models are traditionally trained by applying a smoothing method to n-gram counts from a training corpus Chen and Goodman (1999). The highest quality n-gram models are trained over data that are well-matched to the desired output Moore and Lewis (2010). For virtual keyboards, training over users’ typed text would lead to the best results. Of course, such data are very personal and need to be handled with care.
1.2 Federated learning
, a technique where machine learning models are trained in a decentralized manner on end-users’ devices, so that raw data never leaves these devices. Only targeted and ephemeral parameter updates are aggregated on a centralized server. Figure2 provides an illustration of the process.
Federated learning for keyboard input was previously explored in Hard et al. (2018), in which a federated recurrent neural network (RNN) was trained for next-word prediction. However, latency constraints prevent the direct use of an RNN for decoding. To overcome this problem, we propose to derive an n-gram LM from a federated RNN LM model and use that n-gram LM for decoding. Specifically, the approximation algorithm is based on SampleApprox , which was recently proposed in Suresh et al. (2019a, b). The proposed approach has several advantages:
Improved model quality: Since the RNN LM is trained directly on domain-matched user data, its predictions are more likely to match actual user behavior. In addition, as shown in Suresh et al. (2019a), an n-gram LM approximated from such an RNN LM is of higher quality than an n-gram LM trained on user data directly.
Minimum information transmission: In FL, only the minimal information necessary for model training (the model parameter deltas) is transmitted to centralized servers. The model updates contain much less information than the complete training data.
Additional privacy-preserving techniques: FL can be further combined with privacy-preserving techniques such as secure multi-party computation Bonawitz et al. (2017) and differential privacy McMahan et al. (2018); Agarwal et al. (2018); Abadi et al. (2016). By the post-processing theorem, if we train a single differentially private recurrent model and use it to approximate n-gram models, all the distilled models will also be differentially private with the same parameters Dwork et al. (2014).
For the above reasons, we have not proposed to learn n-gram models directly using FederatedAveraging of n-gram counts for all orders.
The paper is organized along the lines of challenges associated with converting RNN LMs to n-gram LMs for virtual keyboards: the feasibility of training neural models with a large vocabulary, inconsistent capitalization in the training data, and data sparsity in morphologically rich languages. We elaborate on each of these challenges below.
Keyboard n-gram models are typically based on a carefully hand-curated vocabulary to eliminate misspellings, erroneous capitalizations, and other artifacts. The vocabulary size often numbers in the hundreds of thousands. However, training a neural model directly over the vocabulary is memory intensive as the embedding and softmax layers require space, where is the vocabulary size and is the embedding dimension. We propose a way to handle large vocabularies for federated models in Section 3.
Incorrect capitalization: In virtual keyboards, users often type with incorrect casing (e.g. “She lives in new york” instead of “She lives in New York”). It would be desirable to decode with the correct capitalization even though the user-typed data may be incorrect. Before the discussion of capitalization, the SampleApprox algorithm is reviewed in Section 4. We then modify SampleApprox to infer capitalization in Section 5.
Language morphology: Many words are composed of root words and various morpheme components, e.g. “crazy”, “crazily”, and “craziness”. These linguistic features are prominent in morphologically rich languages such as Russian. The presence of a large number of morphological variants increases the vocabulary size and data sparsity ultimately making it more difficult to train neural models. Algorithms to convert between word and word-piece models are discussed in Section 6.
Finally, we compare the performance of word and word-piece models and present the results of A/B experiments on real users of a virtual keyboard in Section 7.
3 Unigram distributions
Among the K words in the vocabulary, our virtual keyboard models only use the top
K words in the higher-order n-grams. We train the neural models only on these most frequent words and train a separate unigram model over the entire vocabulary. We interpolate the two resulting models to obtain the final model for decoding.
Unigrams are collected via a modified version of the FederatedAveraging algorithm. No models are sent to client devices. Instead of returning gradients to the server, counting statistics are compiled on each device and returned. In our experiments, we aggregate over groups of approximately 500 devices per training round. We count a unigram distribution from a whitelist vocabulary by , where is the index over devices, are the raw unigram counts collected from a single device , and is a weight applied to device .
To prevent users with large amounts of data from dominating the unigram distribution, we apply a form of L1-clipping:
where is a threshold that caps each device’s contribution. When , L1-clipping is equivalent to equal weighting. The limit is equivalent to collecting the true counts, since .
Convergence of the unigram distribution is measured using the unbiased chi-squared statistic (for simplicity, referred to as the -statistic) defined in Bhattacharya and Valiant (2015), the number of unique unigrams seen, and a moving average of the number of rounds needed to observe new unigrams.
Figure 3 shows the overall distributional convergence based on the -statistic. At round , unigram counts after and rounds are compared.
Figure 3 plots the number of whitelist vocabulary words seen and a moving average of the number of rounds containing new unigrams. New unigrams are determined by comparing a round with all rounds through and noting if any new words are seen. The shaded bands range from the LM’s unigram capacity to the size of the whitelist vocabulary.
Since the whitelist vocabulary is uncased, capitalization normalization is applied based on an approach similar to Section 5. We then replace the unigram part of an n-gram model with this distribution to produce the final LM.
In A/B experiments, unigram models with different L1-clipping thresholds are compared against a baseline unigram model gathered from centralized log data. Results are presented in Table 1. Accuracy is unchanged and OOV rate is improved at and .
|Model||acc@1 [%]||OOV rate [%]|
Relative change with L1-clipped unigrams on live traffic of en_US users on the virtual keyboard. Quoted 95% confidence intervals are derived using the jackknife method with user buckets.
Before we discuss methods to address inconsistent capitalization and data sparsity in morphologically rich languages, we review SampleApprox .
4 Review of SampleApprox
SampleApprox , proposed in Suresh et al. (2019a, b), can be used to approximate a RNN as a weighted finite automaton such as an n-gram model. A weighted finite automaton (WFA) over (probabilities) is given by a finite alphabet (vocabulary words), a finite set of states (n-gram contexts), an initial state (sentence start state), a set of final states (sentence end states), and a set of labeled transitions and associated weights that represent the conditional probability of labels (from ) given the state (list of n-grams and their probabilities). WFA models allow a special backoff label for succinct representation as follows. Let be the set of labels on transitions from state . For , let , be the weight of the transition of at state and be the destination state. For a label and a state ,
In other words, is followed if . The definition above is consistent with that of backoff n-gram models Chen and Goodman (1999). Let denote the set of states from which can be reached by a path of backoff labels and let be the first state at which label can be read by following a backoff path from .
Given an unweighted finite automaton and a neural model, SampleApprox finds the probability model on that minimizes the Kullback-Leibler (KL) divergence between the neural model and the WFA. The algorithm has two steps: a counting step and a KL minimization step. For the counting step, let be independent samples from the neural model. For a sequence , let denote the label and denote the first labels. For every and , the algorithm computes given by
We illustrate this counting with an example. Suppose we are interested in the count of the bi-gram New York. Given a bi-gram LM, SampleApprox generates sentences and computes
In other words, it finds all sentences that have the word New, observes how frequently York appears subsequently, and computes the conditional probability. After counting, it uses a difference of convex (DC) programming based algorithm to find the KL minimum solution. If is the average number of words per sentence, the computational complexity of counting is 111, means . and the computational complexity of the KL minimization is per iteration of DC programming.
As mentioned in Section 2, users often type with incorrect capitalization. One way of handling incorrect capitalization is to store an on-device capitalization normalizer Beaufays and Strope (2013) to correctly capitalize sentences before using them to train the neural model. However, capitalization normalizers have large memory footprints and are not suitable for on-device applications. To overcome this, the neural model is first trained on uncased user data. SampleApprox is then modified to approximate cased n-gram models from uncased neural models.
As before, let be independent (uncased) samples from the neural model. We capitalize them correctly at the server using Beaufays and Strope (2013). Let represent the corresponding correctly capitalized samples. Let be another probability model on non-user data that approximates the ratio of uncased to cased probabilities given a context. Given a label , let be the uncased symbol. For example, if is York, then is york. With the above definitions, we modify the counting step of SampleApprox as follows:
where is given by
We refer to this modified algorithm as CapSampleApprox . We note that word-piece to word approximation incurs an additional computation cost of , where is the number of words, and are the set of arcs and set of states in the word n-gram model, and is the maximum number of word-pieces per word.
6 Morphologically rich languages
To train neural models on morphologically rich languages, subword segments such as byte-pair encodings or word-pieces Shibata et al. (1999); Schuster and Nakajima (2012); Kudo (2018) are typically used. This approach assigns conditional probabilities to subword segments, conditioned on prior subword segments. It has proved successful in the context of speech recognition Chiu et al. (2018) and machine translation Wu et al. (2016). Following these successes, we propose to train RNN LMs with word-pieces for morphologically rich languages.
We apply the word-piece approach of Kudo (2018), which computes a word-piece unigram LM using a word-piece inventory . Each word-piece is associated with a unigram probability . For a given word and its possible segmentation candidates, the word is encoded with the segmentation that assigns the highest probability.
Throughout this paper we apply K, K, and K as the word-piece inventory sizes. These values lie within a range that provides good trade-off between the LSTM embedding size and the richness of the language morphology. We apply character coverage to include all the symbols that appeared in the unigram distribution (Section 3), including the common English letters, accented letters e.g. é, ô, and digits. Accented letters are important for languages like Portuguese. For fast decoding, the n-gram models still need to be at the word-level, since word-piece n-gram models increase the depth of the beam-search during decoding. We convert the word n-gram topology to an equivalent word-piece WFA topology and use SampleApprox to approximate the neural word-piece model on the word-piece WFA topology. We then convert the resulting word-piece WFA LM to the equivalent n-gram LM. The remainder of this section outlines efficient algorithms for converting between word and word-piece WFA models.
A natural way to represent the transduction from word-piece sequences to word sequences is with a finite-state transducer. Given the properties of our word-piece representation, that transducer can be made sequential (i.e., input deterministic).
A sequential weighted finite-state transducer (WFST) is a deterministic WFA where each transition has an output label in addition to its (input) label and weight. We will denote by the output label of the transition at state with input label , , where denotes the output alphabet of the transducer and the empty string/sequence.
be the minimal sequential (unweighted) finite-state transducer (FST) lexicon from word-piece sequences into word sequences in , where denotes our word-piece inventory, denotes our vocabulary, and is Kleene closure.
A word-piece topology equivalent to the word topology can be obtained by composing the word-piece-to-word transducer with :
Since has backoff transitions, the generic composition algorithm of Allauzen et al. (2011) is used with a custom composition filter that ensures the result, , is deterministic with a well-formed backoff structure, and hence is suitable for the counting step of SampleApprox . We give an explicit description of the construction of , from which readers familiar with Allauzen et al. (2011) can infer the form of the custom composition filter.
The states in are pairs , with and in , initial state , and final state . Given a state , the outgoing transitions and their destination states are defined as follows. If , then an -labeled transition is created if one of two conditions holds:
if , then
if and , then
where denotes the set of output non- labels that can be emitted after following an output- path from . Finally if , a backoff transition is created:
The counting step of SampleApprox is applied to , and transfers the computed counts from to by relying on the following key property of . For every word in , there exists a unique state and unique word-piece in such that . This allows us to transfer the counts from to as follows:
The KL minimization step of SampleApprox to is applied subsequently.
As an alternative, the unweighted word automaton could be used to perform the counting step directly. Each sample could be mapped to a corresponding word sequence , mapping out-of-vocabulary word-piece sequences to an unknown token. However, the counting steps would have become much more computationally expensive, since would have to be evaluated for all , and for all words in the vocabulary, where is now a word-piece RNN.
7.1 Neural language model
LSTM models Hochreiter and Schmidhuber (1997) have been successfully used in a variety of sequence processing tasks. LSTM models usually have a large number of parameters and are not suitable for on-device learning. In this work, we use various techniques to reduce the memory footprint and to improve model performance.
We use a variant of LSTM with a Coupled Input and Forget Gate (CIFG) Greff et al. (2017) for the federated neural language model. CIFG couples the forget and input decisions together, which reduces the number of LSTM parameters by 25%. We also use group-LSTM (GLSTM) Kuchaiev and Ginsburg (2017) to reduce the number of trainable variables of an LSTM matrix by the number of feature groups, . We set in experiments.
Table 2 lists the parameter settings of the word (W) and word-piece (P) models used in this study. Due to the memory limitations of on-device training, all models use fewer than parameters. For each vocabulary size, we first start with a base architecture consisting of one LSTM layer, a -dimensional embedding, and hidden state units. We then attempt to increase the representational power of the LSTM cell by increasing the number of hidden units and using multi-layer LSTM cells Sutskever et al. (2014). Residual LSTM Kim et al. (2017) and layer normalization Lei Ba et al. (2016) are used throughout experiments, as these techniques were observed to improve convergence. To avoid the restriction that in the output, we apply a projection step at the output gate of the LSTM Sak et al. (2014). This step reduces the dimension of the LSTM hidden state from to . We also share the embedding matrix between the input embedding and output softmax layer, which reduces the memory requirement by . We note that other recurrent neural models such as gated recurrent units Chung et al. (2014) can also be used instead of CIFG LSTMs.
The federated RNN LMs are trained on two language settings of the virtual keyboard: American English (en_US) and Brazilian Portuguese (pt_BR). Following McMahan et al. (2017)
, 500 reporting clients are used to compute the gradient updates for each round. A server-side learning rate of 1.0, a client-side learning rate of 0.5, and Nesterov momentum of 0.9 are used. Both the word and word-piece models are trained over the same time range and with the same hyperparameters. Prior to federated training of the RNN LM, the word-piece inventory is constructed from the unigram distribution collected via the federated approach introduced in Section3.
A common evaluation metric for both word and word-piece models is desirable during federated training. Such a metric can be used to monitor the training status and select models to be used for theCapSampleApprox algorithm. Neither cross-entropy nor accuracy serves this need due to the mismatch in vocabularies used. Word-level accuracy is hard to compute for the word-piece model, since it requires hundreds of inference calls to traverse all combinations of a word from the word-piece vocabulary. In this study, we apply sentence log likelihood (SLL) in the evaluation. Given a sentence composed of units (either words or word-pieces), SLL is evaluated as . One issue that arises is the handling of out-of-vocabulary (OOV) words. The OOV probability of the word model is about . The comparable probability of an OOV word (according to ) for word-piece models is the product of the corresponding word-piece conditional probabilities, which is much smaller than . To mitigate this issue, we define SLL excluding OOV as:
where the OOV in the equation includes word-pieces that are components of OOV words. In the following, is used as model selection metric.
7.2 Approximated n-gram model
Algorithm 1 illustrates the workflow we use to generate different n-gram models for evaluation. Recall that CapSampleApprox takes a RNN LM, an n-gram topology, and a reweighting FST for capitalization normalization. The n-gram topology is empty under self-inference mode. Suresh et al. (2019a) showed that inferring topology from the RNN LM does not perform as well as using the true n-gram topology obtained from the training corpus. Hence, we supplement the neural-inferred topology with the topology obtained by a large external large corpus denoted by . We use CapSampleApprox on four topologies and compare the resulting models: an n-gram model obtained from an external corpus’s topology , an n-gram model obtained from a neural inferred topology , an n-gram model obtained by interpolating (merging) the two models above , and an n-gram model obtained by approximating on the interpolated topology . We repeat this experiment for both word and word-piece RNN LMs and use subscripts and , respectively. We evaluate all eight produced n-gram models directly on the traffic of a production virtual keyboard, where prediction accuracy is evaluated over user-typed words.
Figure 5 shows the metric for all the experiments listed in Table 2. In general, larger models generate better results than smaller baseline models. For the baseline architectures with same RNN size, having a larger vocabulary leads to some gains. For the larger architectures that have similar total numbers of parameters, 4K word-piece models are shown to be superior to 16K and 30K. For 4K word-piece models, GLSTM is in general on-par with its counterpart. The word model is better than all the word-piece models in both languages in . We were surprised by this result, and hypothesize that it is due to the metric discounting word-piece models’ ability to model the semantics of OOV words. The solid lines are the best models we pick for A/B experiment evaluation for the virtual keyboard ( and ).
Table 3 shows the A/B evaluation result on both en_US and pt_BR populations. The baseline model is an n-gram model trained directly from centralized logs. All of the federated trained models perform better than the baseline model. We repeated the A/B evaluation with word-piece models on en_US and the results are in Table 4. The performance of word-piece models is similar to that of word models. Among the federated models for en_US, has the best result. This meets our expectation that the supplemental corpus helps improve the performance of the topology inferred from the RNN LM.
We have proposed methods to train production-quality n-gram language models using federated learning, which allows training models without user-typed text ever leaving devices. The proposed methods are shown to perform better than traditional server-based algorithms in A/B experiments on real users of a virtual keyboard.
The authors would like to thank colleagues in Google Research for providing the federated learning framework and for many helpful discussions.
- Deep learning with differential privacy. In Proceedings of the 2016 ACM SIGSAC Conference on Computer and Communications Security, pp. 308–318. Cited by: §1.2.
- CpSGD: communication-efficient and differentially-private distributed sgd. In Neural Information Processing Systems, External Links: Cited by: §1.2.
- A filter-based algorithm for efficient composition of finite-state transducers. International Journal of Foundations of Computer Science 22 (8), pp. 1781–1795. Cited by: §6.
- Long short term memory neural network for keyboard gesture decoding. 2015 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 2076–2080. Cited by: §1.1.
- Language model capitalization. In 2013 IEEE International Conference on Acoustics, Speech and Signal Processing, pp. 6749–6752. Cited by: §5, §5.
- Testing closeness with unequal sized samples. In Advances in Neural Information Processing Systems 28, Cited by: §3.2.
- Practical secure aggregation for privacy-preserving machine learning. In Proceedings of the 2017 ACM SIGSAC Conference on Computer and Communications Security, CCS ’17, New York, NY, USA, pp. 1175–1191. External Links: Cited by: §1.2.
- An empirical study of smoothing techniques for language modeling. Computer Speech & Language 13 (4), pp. 359–394. Cited by: §1.1, §4.
- State-of-the-art speech recognition with sequence-to-sequence models. In 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 4774–4778. Cited by: §6.
- Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555. Cited by: §7.1.
- The algorithmic foundations of differential privacy. Foundations and Trends® in Theoretical Computer Science 9 (3–4), pp. 211–407. Cited by: §1.2.
- LSTM: a search space odyssey. IEEE transactions on neural networks and learning systems 28 (10), pp. 2222–2232. Cited by: §7.1.
- Federated learning for mobile keyboard prediction. CoRR abs/1811.03604. External Links: Cited by: §1.2.
Transliterated mobile keyboard input via weighted finite-state transducers.
Proceedings of the 13th International Conference on Finite State Methods and Natural Language Processing (FSMNLP), External Links: Cited by: §1.1.
- Long short-term memory. Neural computation 9 (8), pp. 1735–1780. Cited by: §7.1.
- Residual LSTM: design of a deep recurrent architecture for distant speech recognition. See DBLP:conf/interspeech/2017, pp. 1591–1595. External Links: Cited by: §7.1.
- Federated optimization: distributed machine learning for on-device intelligence. arXiv preprint arXiv:1610.02527. Cited by: §1.2.
- Federated learning: strategies for improving communication efficiency. In NIPS Workshop on Private Multi-Party Machine Learning, External Links: Cited by: §1.2.
- Factorization tricks for LSTM networks. See DBLP:conf/iclr/2017w, External Links: Cited by: §7.1.
- Subword regularization: improving neural network translation models with multiple subword candidates. See DBLP:conf/acl/2018-1, pp. 66–75. External Links: Cited by: §6, §6.
- Layer normalization. arXiv preprint arXiv:1607.06450. Cited by: §7.1.
Communication-efficient learning of deep networks from decentralized data.
Proceedings of the 20th International Conference on Artificial Intelligence and Statistics, AISTATS 2017, 20-22 April 2017, Fort Lauderdale, FL, USA, pp. 1273–1282. External Links: Cited by: §7.1.
- Learning differentially private recurrent language models. In International Conference on Learning Representations (ICLR), External Links: Cited by: §1.2.
- Federated learning: collaborative machine learning without centralized training data. Note: https://ai.googleblog.com/2017/04/federated-learning-collaborative.html Cited by: Figure 2.
- Intelligent selection of language model training data. In Proceedings of the ACL 2010 Conference Short Papers, ACLShort ’10, Stroudsburg, PA, USA, pp. 220–224. External Links: Cited by: §1.1.
- Mobile keyboard input decoding with finite-state transducers. CoRR abs/1704.03987. External Links: Cited by: §1.1.
- Long short-term memory recurrent neural network architectures for large scale acoustic modeling. In Fifteenth annual conference of the international speech communication association, Cited by: §7.1.
- Japanese and korean voice search. In 2012 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 5149–5152. Cited by: §6.
Byte pair encoding: a text compression scheme that accelerates pattern matching. Technical report Technical Report DOI-TR-161, Department of Informatics, Kyushu University. Cited by: §6.
- Approximating probabilistic models as weighted finite automata. CoRR abs/1905.08701. Cited by: §1.2, §1.2, §4, §7.2.
- Distilling weighted finite automata from arbitrary probabilistic models. In Proceedings of the 14th International Conference on Finite State Methods and Natural Language Processing (FSMNLP 2019), Cited by: §1.2, §4.
- Sequence to sequence learning with neural networks. In Advances in neural information processing systems, pp. 3104–3112. Cited by: §7.1.
Google’s neural machine translation system: bridging the gap between human and machine translation. arXiv preprint arXiv:1609.08144. Cited by: §6.