Federated Learning for Keyword Spotting

10/09/2018 ∙ by David Leroy, et al. ∙ 0

We propose a practical approach based on federated learning to solve out-of-domain issues with continuously running embedded speech-based models such as wake word detectors. We conduct an extensive empirical study of the federated averaging algorithm for the "Hey Snips" wake word based on crowdsourced data on two distinct tasks: learning from scratch and language adaptation. We also reformulate the global averaging step of the federated averaging algorithm as a gradient update step, applying per-coordinate adaptive learning rate strategies such as Adam in place for standard weighted model averaging. We then empirically demonstrate that using adaptive averaging strategies highly reduces the number of communication rounds required to reach a target performance.



There are no comments yet.


Code Repositories



view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Wake word detection is used to start an interaction with a voice assistant. A specific case of keyword spotting (KWS), it continuously listens to an audio stream to detect a predefined keyword or set of keywords. Well-known examples of wake words include Apple’s “Hey Siri” or Google’s “OK Google”. Once the wake word is detected, voice input is activated and processed by a spoken language understanding engine, powering the perception abilities of the voice assistant [1].

Wake word detectors usually run on device in an always-on fashion, which brings two major difficulties. First, it should run with minimal memory footprint and computational cost. The resource constraints for our wake word detector are 200k parameters (based on the medium-sized model proposed in [2]), and 20 MFLOPS.

Secondly, the wake word detector should behave consistently in any usage setting, and show robustness to background noise. The audio signal is highly sensitive to recording proximity (close or far field), recording hardware, but also to the room configuration. Robustness also implies a strong speaker variability coverage (genders, accents, etc.). While the use of digital signal processing front-ends can help mitigate issues related to bad recording conditions, speaker variability remains a major challenge. High accuracy is all the more important since the model can be triggered at any time: it is therefore expected to capture most of the commands (high recall, or low false rejection rate) while not triggering unintentionally (low false alarm rate).

Today, wake word detectors are typically trained on datasets collected in real usage setting e.g. users homes in the case of voice assistants. Speech data being by nature very sensitive, centralized collection raises major privacy concerns. In this work, we investigate the use of federated learning (FL) [3] in the context of an embedded wake word detector. FL is a decentralized optimization procedure that enables to train a central model on the local data of many users without the need to ever upload this data to a central server. The training workload is moved towards the user’s devices which perform training steps on the local data. Local updates from users are then averaged by a parameter server in order to create a global model.

2 Related work

Most research around decentralized learning has historically been done in the context of a highly controlled cluster/data center setting, e.g. with a dataset evenly partitioned in an i.i.d fashion. The multi-core and multi-gpu distributed training setting has been specifically studied in the context of speech recognition in [4]. Efforts on decentralized training with highly distributed, unbalanced and non-i.i.d data is relatively recent, as the foundations were laid down in [3] with the introduction of the federated averaging (FedAvg

) algorithm and its application to a set of computer vision (MNIST, CIFAR-10) and language modeling tasks (applied to the Shakespeare and Google Plus posts datasets). There are for now very few real life experiments that we know of - except for Google’s keyboard

Gboard for Android [5] and more recently Mozilla’s URL suggestion bar [6]. To our knowledge, the present work is the first experiment of its kind on user-specific speech data.

The federated optimization problem in the context of convex objective functions has been studied in [7]

. The authors proposed a stochastic variance-reduced gradient descent optimization procedure (


) with both local and global per-coordinate gradient scaling to improve convergence. Their global per-coordinate gradient averaging strategy relies on a sparsity measure of the given coordinate in users local datasets and is only applicable in the context of sparse linear-in-the-features models. The latter assumption does not hold in the context of neural networks for speech-based applications.

Several improvements to the initial FedAvg algorithm have been suggested with a focus on client selection [8], budget-constrained optimization [9] and upload cost reduction for clients [10]. A dynamic model averaging strategy robust to concept drift based on a local model divergence criterion was recently introduced in [11]. While these contributions present efficient strategies to reduce the communication costs inherent to federated optimization, the present work is as far as we know the first one introducing a dynamic per-coordinate gradient update in place of the global averaging step.

The next section describes the federated optimization procedure, and how its global averaging can be substituted by an adaptative averaging rule inspired from Adam. It is followed by the experiments section, where both the open-sourced crowdsourced data and model used to train our wake word detector are introduced. Results come next, and a communication cost analysis is provided. Finally, the next steps towards training a wake word detector on user data that is really decentralized are described.

3 Federated optimization

We consider the standard supervised learning objective function

that is the loss function for the prediction on example

when using a model described by a real-valued parameter vector

of dimension . In a federated setting, we assume that the datapoints are partitioned across users, each user being assigned their own partition , . The optimization objective is therefore the following:


The FedAvg algorithm introduced in [3] aims at minimizing the objective function 1 assuming a synchronous update scheme and a generic non-convex neural network loss function. The model is initialized with a given architecture on a central parameter server with weights . Once initialized, the parameter server and the user’s devices interact synchronously with each other during communication rounds. A communication round at time is described below:

  1. The central model is shared with a subset of users that are randomly selected from the pool of K users given a participation ratio C.

  2. Each user performs one or several training steps on their local data based on the minimization of their local objective

    using mini-batch stochastic gradient descent (SGD) with a local learning rate

    . The number of steps performed locally is
    , being the number of datapoints available locally, E

    the number of local epochs and

    B the local batch size.

  3. Users from send back their model updates to the parameter server once local training is finished.

  4. The server computes an average model based on the user’s individual updates , each user’s update being weighted by , ‘”where”’ .

When (i.e the batch size is equal to the local dataset size) and , then a single gradient update is performed on each user’s data. It is strictly equivalent to doing a single gradient computation on a batch including all of selected user data points. This specific case is called FedSGD, e.g. stochastic gradient descent with each batch being the data of the federation of selected users at a given round. FedAvg (Federated averaging) is the generic case when more than one update is performed locally for each user.

The global averaging step can be written as follows, using a global update rate .


Setting the global update rate to 1 is equivalent to a weighted averaging case without moving average. Equation 2 highlights the parallel between global averaging and a gradient update . This parallel motivates the use of adaptive per-coordinate updates for that have proven successful for centralized deep neural networks optimization such as Adam [12]

. Moment-based averaging allows to smooth the averaged model by taking into account the previous rounds updates that were computed on different user subsets. We conjecture that the exponentially-decayed first and second order moments perform the same kind of regularization that occurs in the mini-batch gradient descent setting, where Adam has proven to be successful on a wide range of tasks with various neural-network based architectures. In this work, we set the exponential decay rates for the moment estimates to

and and as initially suggested by the authors of [12].

4 Experiments

4.1 Dataset

Unlike generic speech recognition tasks, there is no reference dataset for wake word detection. The reference dataset for multi-class keyword spotting is the speech command dataset [13], but the speech command task is generally preceded by a wake word detector and is focused on minimizing the confusion across classes, not robustness to false alarms. We constituted a crowdsourced dataset for the Hey Snips wake word. We are releasing publicly 111http://research.snips.ai/datasets/keyword-spotting this dataset [14] in the hope it can be useful to the keyword spotting community.

Train set Dev set Test set Total
1,374 users 200 users 200 users 1774 users
53,991 utt. 8,337 utt. 7,854 utt. 69,582 utt.
Table 1: Dataset statistics for the Hey Snips wake word - 18% of utterances are positive, with strong per user imbalance in the number of utterances (mean: 39, standard dev: 32)

The data used here was collected from 1.8k contributors that recorded themselves on their device with their own microphone while saying several occurrences of the Hey Snips wake word along with randomly chosen negative sentences. Each recorded audio sample has gone through a validation process, where at least two of three distinct contributors have validated that the pronounced utterance matches the transcript.

This crowdsourcing-induced data distribution mimicks a real-world non-i.i.d, unbalanced and highly distributed setting, and a parallel is therefore drawn in the following work between a crowdsourcing contributor and a voice assistant user. The statistics about the dataset comforting this analogy are summarized in Table 1. The train, dev and test splits are built purposely using distinct users, 77% of users being used solely for training while the 23% remaining are used for parameter tuning and final evaluation, measuring the generalization power of the model to new users.

4.2 Model

Acoustic features are generated based on 40-dimensional mel-frequency cepstrum coefficients (MFCC) computed every 10ms over a window of 25ms. The input window consists in 32 stacked frames, symmetrically distributed in left and right contexts. The architecture is a CNN with 5 stacked dilated convolutional layers of increasing dilation rate, followed by two fully-connected layers and a softmax inspired from [15]. The total number of parameters is 190,852. The model is trained using cross-entropy loss on frames prediction. The neural network has 4 output labels, assigned via a custom aligner specialized on the target utterance “Hey Snips”: “Hey”, “sni”, “ps”, and “filler” that accounts for all other cases (silence, noise and other words). A posterior handling [16] generates a confidence score for every frame by combining the smoothed label posteriors. The model triggers if the confidence score reaches a certain threshold , defining the operating point that maximizes recall for a certain amount of False Alarms per Hour (FAH). We set the number of false alarms per hour to 5 as a stopping criterion on the dev set. The dev set is a “hard” dataset when it comes to false alarms since it belongs to the same domain as data used for training. The model recall is finally evaluated on the test set positive data, while false alarms are computed on both the test set negative data and various background negative audio sets. See section 4.3 for further details about evaluation.

4.3 Results

We conduct an extensive empirical study of the federated averaging algorithm for the Hey Snips wake word based on crowdsourced data from Table 1. Federated optimization results are compared with a standard setting e.g. centralized mini-batch SGD with data from train set users being randomly shuffled. Our aim is to evaluate the number of communication rounds that are required in order to reach our stopping criterion on the dev set. For the purpose of this experiment early stopping is evaluated in a centralized fashion, and we assume that the dev set users agreed to share their data with the parameter server. In an actual product setting, early stopping estimation would be run locally on the devices of the dev users, they would download the latest version of the central model at the end of each round and evaluate the early stopping criterion based on prediction scores for their own utterances. These individual metrics would then be averaged by the parameter server to obtain the global model criterion estimation. Final evaluation on test users would be done in the same distributed fashion once training is finished.

Standard baseline: Our baseline e.g. a standard centralized data setting with a single training server and the Adam optimizer reaches the early stopping target in 400 steps (

epochs), which is a strong convergence speedup in comparison with standard SGD that remains under 87% after 28 epochs despite learning rate and gradient clipping tuning on the dev set.

User parallelism: The higher the ratio of users selected at each round C, the more data is used for distributed local training, and the faster the convergence is expected, assuming that local training does not diverge too much. Figure 1 shows the impact of C on convergence: the gain of using half of users is limited with comparison with using 10%, specifically in the later stages of convergence. A fraction of 10% of users per round is also more realistic in a practical setup as selected users have to be online. With lower participation ratios (), the gradients are much more sensitive and might require the use of learning rate smoothing strategy. C is therefore set to 10%.

Figure 1: Effect of the share of users involved in each round C on the dev set recall / 5 FAH, FedSGD, Adam global averaging,

Global averaging: Global adaptive learning rates based on Adam accelerates convergence when compared with standard averaging strategies with or without moving averages. Table 2 summarizes experimental results in the FedSGD setting with optimized local learning rates. Applying standard global averaging yields poor performances even after 400 communication rounds when compared with adaptive per-parameter averaging.

Avg. Strategy 100 rounds 400 rounds
29.9% 67.3%
93.50% 98.29%
Table 2: Dev set recall / 5 FAH for various averaging strategies - FedSGD,

Local training: Our results show consistency across local training configurations, with limited improvements coming from increasing the load of local training. The number of communication rounds required to reach the stopping criterion on the dev set ranges between 63 and 112 communication rounds for and , using , Adam global averaging with , and a local learning rate of 0.01. In our experiments, the best performances are obtained for and for an average of 2.4 local updates per worker taking part in a round, yielding a 80% speedup with comparison to FedSGD. Nevertheless we observed variability across experiments with regard to weight initialization and early stage behaviour. Unlike some experiments presented in [3], the speedup coming from increasing the amount of local training steps does not lead to order of magnitude improvements on convergence speed, while local learning rate and global averaging tuning proved to be crucial in our work. We conjecture that this difference is related to the input semantic variability across users. In the MNIST and CIFAR experiments from [3] the input semantics are the same across emulated users. For instance, images of the 9 digit that are attributed to various users are all very similar. In the wake word setting, each user has their own vocalization of the same wake word utterance with significant differences in pitch and accent than can lead to diverging lower stage representations that might perform poorly when averaged.

Evaluation: We evaluate the false alarm rates of the best model ( and ) for a fixed recall of 95% on the test set. We observe 3.2 FAH on the negative test data, 3.9 FAH on Librispeech [17], and respectively 0.2 and 0.6 FAH on our internal news and collected TV datasets. Unsurprisingly, false alarms are more common on close-field continuous datasets than they are on background negative audio sets.

4.4 Communication cost analysis

Communication cost is a strong constraint when learning from decentralized data, especially when user’s devices have limited connectivity and bandwidth. Considering the asymmetrical nature of broadband speeds, the communication bottleneck for federated learning is the updated weights transfer from clients to the parameter server once local training is completed [10]. We assume that the upstream communication cost associated with users involved in model evaluation at each communication round is marginal, as they would only be uploading a few floating point metrics per round that is much smaller than the model size. The total client upload bandwidth requirement is provided in the equation below:


Based on our results, this would yield a cost of 8MB per client when the stopping criterion is reached within 100 communication rounds. On its end, the server receives 137 updates per round when , amounting for 110GB over the course of the whole optimization process with 1.4k users involved during training. This cost is of course directly related to the early stopping criterion. Further experiments with latter convergence stages (400 rounds) yielded 98% recall / 0.5 FAH on the test set for an upload budget of 32 MB per user.

5 Conclusion and future Work

In this work, we investigate the use of federated learning on crowdsourced speech data to learn a resource-constrained wake word detector. We show that a revisited Federated Averaging algorithm with per-coordinate averaging based on Adam in place of standard global averaging allows the training to reach a target stopping criterion of 95% recall per 5 FAH within 100 communication rounds on our crowdsourced dataset for an associated upstream communication costs per client of 8MB. We also open source the Hey Snips wake word dataset.

The next step towards a real-life implementation is to design a system for local data collection and labeling as the wake word task requires data supervision. The frame labeling strategy used in this work relies on an aligner, which cannot be easily embedded. The use of of memory-efficient end-to-end models [14] in place of the presented class-based model could ease labelling as it would only rely on voice activity detection.

6 Acknowledgments

We thank Oleksandr Olgashko for his contribution in developing the training framework.