In federated learning (FL), a global model is trained on decentralized data from a large number of clients, which may be mobile phones, other edge devices, or sensors [konevcny2016federated, konecny2016federated2]. The training data remains distributed over the clients, thus providing a layer of privacy during model training. However, FL also raises several types of issues, both practical and algorithmic, that have been the topic of multiple research efforts. The design of more efficient communication strategies [konevcny2016federated, konecny2016federated2, suresh2017distributed], devising efficient distributed optimization methods with differential privacy guarantees [AgarwalSureshYuKumarMcMahan2018], and recent lower bound guarantees for parallel stochastic optimization with a dependency graph [WoodworthWangSmithMcMahanSrebro2018] are just a few of these efforts. We refer readers to [li2019federated] and [kairouz2021advances] for a detailed literature survey on FL. FL is typically studied in two scenarios: cross-silo and cross-device. In cross-silo federated learning, the number of clients is small, where as in cross-device federated learning, the number of clients is very large and can be in the order of millions.
Fairness is a key objective in general machine learning[Bickel398, NIPS2016_9d268236] and especially federated learning [abay2020mitigating, Li2020Fair]
, where the network of clients can be massive and heterogeneous. Standard learning objectives in FL minimize the loss with respect to the uniform distribution over all samples.[mohri2019agnostic] argued that, in many common instances, the uniform distribution is not the natural objective distribution as the data observed during training and inference in FL can differ. This is, in part, because models are typically trained on client devices under certain conditions (e.g. device is charging, is connected to an un-metered network, is idle, etc.), whereas during inference, these conditions need not be met. Hence it’s risky to seek to minimize the expected loss with respect to a specific distribution. To overcome this, they proposed a new framework, agnostic federated learning, where the centralized model is optimized for any possible target distribution formed by a mixture of the client distributions. Instead of optimizing for a specific distribution, which has the high risk of a mismatch with the target, they defined an agnostic and more risk-averse objective. They further showed generalization guarantees for this new objective and proposed a stochastic mirror descent type algorithm to minimize this objective.
However, their approach and algorithm did not address some key scenarios in FL. Firstly, their algorithm is feasible in the cross-silo setting, where the number of clients is small and the samples per client is large. However, in the cross-device setting, where the number of clients is very large, we argue that their model yields very loose generalization bounds. Secondly, their algorithm did not fully address the important communication bottleneck and decentralized data issues [McMahanMooreRamageHampsonAguera2017] inherent in the cross-device FL setting. A straightforward implementation of their approach requires running a federated algorithm for a few hundred thousand rounds, which is not feasible in the cross-device setting.
In this paper, we overcome these bottlenecks and propose a communication-efficient federated algorithm called Agnostic Federated Averaging (or AgnosticFedAvg) to minimize the agnostic learning objective in the cross-device setting. AgnosticFedAvg is not only communication-efficient, but also amenable to privacy preserving techniques such as secure aggregation [bonawitz2017practical]. The rest of the paper is organized as follows. In Section 2, we state the notation and overview existing results, in Section 3, we define the framework, and in Section 4, we propose our algorithm. Finally, in Section 5, we evaluate the proposed algorithm on different synthetic and live user datasets.
2 Preliminaries and Previous Work
We start with some general notation and definitions. Let denote the input space and the output space. A distribution is a distribution over .
We will primarily discuss a multi-class classification problem where is a finite set of classes, but much of our results can be extended straightforwardly to regression and other problems. The hypotheses we consider are of the form , where stands for the simplex over . Thus,
is a probability distribution over the classes or categories that can be assigned to. We will denote by a family of such hypotheses . We also denote by
a loss function defined overtaking non-negative values. The loss of for a labeled sample is given by . One key example in applications is the cross-entropy loss, which is defined as We will denote by the expected loss of a hypothesis with respect to a distribution over , and by its minimizer: . In standard learning scenarios, the distribution is the test or target distribution, which typically coincides with the distribution of the training samples. However, in FL, this is often not the case.
In FL, the data is distributed across many heterogeneous clients and the data distribution is different for each client [kairouz2019advances]. Let be the total number of clients. Let denote the data distribution for client . The client does not have access to the true distribution and instead has access to where . Let denote the empirical distribution associated to sample of size . A natural goal is to minimize the empirical risk on the average risk given by
where is the uniform distribution over all clients data. However, as argued by [mohri2019agnostic], due to differences between the train and test distributions, minimizing this objective is risky. Hence, they proposed to minimize the loss on the worst case distribution. More concretely, for distributions , , let for some . For simplicity, we allow any . Where
is the probability simplex over theclients. Thus, the learner minimizes the empirical agnostic loss (or agnostic risk) associated to a predictor as
where . However, their generalization bounds [mohri2019agnostic, Theorem 1] depends on , the minimum number of samples of any client. In the cross-device setting, this yields loose bounds as each client typically only has a few hundred samples. Hence, instead of treating each client as a domain, we treat collections of clients or data pooled from clients as domains.
3 Proposed Formulation
As stated before, treating each client as a separate domain yields loose generalization bounds. Hence, we treat collections of clients as domains, which naturally leads to two types of partitions. Let there be domains .
data partition: Each client has data from one or more domains and domains represent different types of data. For example, for virtual keyboard applications [hard2018federated], the domains could be the application source of client inputs, such as messaging, emails, or documents. In this case, the data distribution for client is given by
where and for all .
client partition: Each client has data from exactly one domain and domains represent clusters of clients. For example, clustering clients based on their geographic location yields this domain type. In this case, the data distribution of client is given by
In both of the above formulations, even though there are different clients, the number of underlying distinct domains is , which we argue is considerably smaller. Hence, we have a large number of samples from each of the domains and get strong generalization bounds. Since each client distribution can be written as a linear combination of domain distributions,
However, we do not have access to the true domain distributions and instead have samples from , where is the empirical distribution obtained by pooling all the data of domain . Let be the number of samples in domain . By (2), the true agnostic loss over clients is smaller than the true agnostic loss over domains. Hence we propose to minimize the empirical agnostic loss over domains,
where . The previous known generalization bounds from [mohri2019agnostic, Lemma 3, Corollary 4] yields the following generalization bound. Let . With probability at least , for any client and any hypothesis
for some constant which depends on the maximum value of the loss and is the VC dimension of the hypothesis class . The above generalization bound scales inversely with , which is the minimum number of samples in the domain. Since the number of domains is small, as long as the domains are well-distributed, we would have a good number of samples from all domains and hence a strong generalization bound in the cross-device setting.
We now propose a communication-efficient algorithm to minimize agnostic loss (1) in the cross-device setting.
[mohri2019agnostic] showed that agnostic learning can be treated as a two-player game, where a learner tries to find the best hypothesis and the adversary tries to find the domain weights that maximize the loss. They proposed a stochastic mirror descent algorithm and showed that the objective reaches the optimum value at a rate of after steps in training. In practice, a direct implementation of their work is not communication-efficient in the cross-device setting, as the number of steps can be in the order of millions. Secondly, it requires the clients to reveal their domain to the server, which can be privacy-invasive. To overcome this, we propose an algorithm called AgnosticFedAvg that is communication-efficient and can be used with privacy preserving techniques such as secure aggregation. We design AgnosticFedAvg with the following properties.
Each round of FL uses only a single round of transmission, which includes model download from the server to the clients and upload from the clients to the server.
The clients train with multiple local SGD steps similar to FederatedAveraging (or FedAvg) of [McMahanMooreRamageHampsonAguera2017].
The server does not have access to individual clients data, but only aggregated statistics, making it compatible with other cryptographic techniques such as secure aggregation [bonawitz2017practical], which prevents anyone from retrieving or rebuilding privacy-sensitive information from individual client parameter updates.
Let be the set of parameters of the hypothesis class. The algorithm first initializes the weights to , domain weights to , and the number of examples per domain to , where denotes the number of samples for domain at round and denotes the number of samples for client , split by domain. We keep a sliding window of the number of examples per domain over the last training rounds. The algorithm uses learning rate for learning domain weights. In the following, let denote the loss function as a function of hypothesis parameter .
At each round of training
, the algorithm computes a scaling vectorby taking the ratio of domain weights and the average number of samples per domain for the last rounds . The algorithm then selects clients randomly and sends the parameters and scaling vector to each of them. First, each selected client computes the number of samples per domain , initial loss per domain , and scaled client weight for their local dataset. Then, each client updates the parameters based on and by running epochs of SGD with batch size and learning rate . Finally, the client transmits the updated parameters , weight per client , initial loss per domain , and number of samples per domain back to the server. Since this is done using secure aggregation, the server only observes the total number of samples and loss per domain across clients. The server then computes the new parameters by averaging the client updates weighted by and does an exponentiated gradient (EG) step for the domain weights ,
If a round does not have any samples from a particular domain, we set to zero for that round. This process is repeated for rounds.
To see why the above algorithm aims to minimize the agnostic loss, consider the weighted average of all the client losses
where is the average loss for domain and is the number of samples in domain from the selected clients at round . The approximation assumes that the moving average is close to the number of samples in domain from the selected clients at round . Thus, AgnosticFedAvg aims to minimize the domain agnostic objective defined in (1). We further note that by using secure aggregation [bonawitz2017practical], the server only observes aggregated statistics rather than learning domains or gradients of individual clients and provides an additional layer of privacy.
Perplexity and in-vocab-accuracy and for Stack Overflow test dataset with the standard deviation for three trials in parentheses.AgnosticFedAvg attains lowest perplexity for the harder domain answer.
We implemented all algorithms and simulation experiments using the open-source FedJAX[fedjax2020github] library and report the results for the Stack Overflow language model task and for a live experiment training a Spanish language model for millions of actual virtual keyboard user devices. For all experiments, we compare three algorithms:
FedAvg (uniform): Trained uniformly on all available data.
FedAvg (target-only): Trained only on data from the target.
AgnosticFedAvg: Trained on all available data.
We demonstrate that AgnosticFedAvg attains a lower perplexity compared to FedAvg (uniform) and FedAvg (target-only) for both the experiments on the harder domain: answer domain for Stack Overflow and es-AR for the Spanish language model.
To verify that AgnosticFedAvg correctly minimizes the domain agnostic objective and to showcase its effectiveness on non-language tasks, we also include experiments on a synthetic toy regression example and the EMNIST-62 image recognition task, respectively (presented in the attached supplementary material due to lack of space).
5.1 Stack Overflow Language Model
Motivated by FL uses in virtual keyboard applications [hard2018federated], we use AgnosticFedAvg to train a language model on a large text corpus. We consider the language model task for the Stack Overflow dataset provided by TFF [tff2019]. This dataset contains two domains, questions and answers, from the Stack Overflow forum grouped by client ids. This corresponds to the data partition domain type since an individual client can post both questions and answers. Table 3 summarizes the statistics per domain.
We match the model and training setup from [reddi2020adaptive] and train a single layer LSTM language model over the top K words with an Adam server optimizer and clients participating per training round for rounds. For AgnosticFedAvg, we use the same set up with domain weight learning rate .
For the Stack Overflow experiments, we report perplexity and in-vocab-accuracy, where in-vocab-accuracy is the number of correct predictions, without UNK (out-of-vocabulary) or EOS (end-of-sentence) tokens, divided by the number of words without the EOS token. Defining in-vocab-accuracy this way allows valid comparisons for different vocabulary sizes. The results are in Table 1. For the baseline FedAvg (uniform), of the two domains, the answer domain is harder and has higher perplexity and lower accuracy. Given this, we also train an additional baseline FedAvg (answer) on answer examples only. While FedAvg (answer) does improve answer performance over FedAvg (uniform), it results in significantly worse question performance. However, AgnosticFedAvg outperforms both FedAvg (uniform) and FedAvg (answer) on the answers domain, while also significantly decreasing the performance disparity between answers and questions. This suggests that there could be important features in the questions that can augment performance on answers that are leveraged by AgnosticFedAvg but aren’t optimally weighted in FedAvg (uniform) or are completely ignored in FedAvg (answer).
5.2 Spanish Virtual Keyboard Language Model
We further use AgnosticFedAvg to train a Coupled Input and Forget Gate (CIFG) [greff2017lstm] language model for Spanish on virtual keyboard client devices. Following the settings in [hard2018federated], text data for FL is stored in local caches on device. Clients for this study must meet the following FL requirements, including (1) having 2 gigabytes of memory; (2) being connected to an un-metered network; (3) being charged; (4) being idle. We consider three domains based on the Spanish locales: es-US, es-AR and es-419111defined by UN M.49 region code, for US, Argentina, and Latin America / the Caribbean region, respectively. Since each user device falls in a single region, this task corresponds to the client partition.
Similar to Section 5.1, we report perplexity and in-vocab-accuracy. For all algorithms, we use the Momentum server optimizer, using Nesterov accelerated gradient [nesterov], and clients participating per training round for rounds. Over the course of training, approximately 141 million sentences are processed by 1.5 million clients. The results are in Table 2. For the baseline FedAvg (uniform), of the three languages, es-AR has the worst perplexity. Similar to Section 5.1, training FedAvg (es-AR) on es-AR clients only improves es-AR performance over FedAvg (uniform) but also results in much worse performance for es-US and es-419. Again, AgnosticFedAvg improves the perplexity and accuracy on es-AR over FedAvg (uniform) while also decreasing the regression on es-US and es-419 when compared to FedAvg (es-AR).
We presented an algorithmic study of domain agnostic learning in the cross-device FL setting. We also examined the two types of naturally occurring domains in FL: data partition and client partition and provided example learning tasks for both in large-scale language modeling. Finally, we defined AgnosticFedAvg, a communication-efficient federated algorithm that aims to minimize the domain agnostic objective proposed in [mohri2019agnostic] and can provide additional security using secure aggregation and demonstrated its practical effectiveness in simulations and real live experiments. We hope that our efforts will spur further studies into improving the practical efficiency of FL algorithms.
Appendix A Toy Regression
We first evaluate AgnosticFedAvg on a toy regression task to ensure its correctness. We consider a simple regression example, where each domain is a set of random points in . Let each domain , be a set of points in . Further, let be the center of these points. We distribute these points on clients randomly. The goal is to find the point that minimizes the maximum distance to all the domain centers i.e.,
It is easy to see that
thus we maximize the latter objective by AgnosticFedAvg. We choose points such that the true answer is and plot the performance of AgnosticFedAvg for 5 domains in Figure 1. As expected, AgnosticFedAvg converges to the true solution within rounds.
Appendix B EMNIST-62 Image Recognition
We consider the image recognition task for the EMNIST-62 dataset provided by TensorFlow Federated (TFF)[tff2019]. This dataset consists of writers and their writing samples which are one of classes (alphanumeric). According to the original NIST source documentation 222https://s3.amazonaws.com/nist-srd/SD19/sd19_users_guide_edition_2.pdf, the writers come from two distinct sources: high school and census field. This corresponds to the client partition domain type since a given client can only belong to a single domain. Table 5 summarizes the statistics on the number of clients and examples per domain.
We match the model and training setup from [reddi2020adaptive] and train a convolution neural net with an Adam server optimizer and clients participating per training round for rounds. For AgnosticFedAvg, we use the same set up with domain weight learning rate . [reddi2020adaptive] provides a comprehensive overview over different server optimizer varieties and their respective performances. For our experiments, we use the Adam server optimizer as it was shown to produce the highest accuracy.
The results are in Table 5. For the baseline FedAvg (uniform), of the two domains, the high school domain is harder and has lower accuracy, most likely because it has fewer clients and training examples. In light of this, we also train FedAvg(high school) only on clients from the high school domain. While FedAvg (high school) does improve high school performance over FedAvg (uniform), it results in drastically worse accuracy on the census domain. This is somewhat expected as the number of census clients far outsizes the number of high school clients. AgnosticFedAvg not only outperforms FedAvg (uniform) on the high school domain, but also significantly decreases the gap in accuracy between high school and census from and , for FedAvg (uniform) and FedAvg (high school), respectively, to just .
|high school clients||500||500|
|high school examples||68.8K||8.7K|
|FedAvg (high school)|