Communication-Efficient Agnostic Federated Averaging

by   Jae Ro, et al.

In distributed learning settings such as federated learning, the training algorithm can be potentially biased towards different clients. Mohri et al. (2019) proposed a domain-agnostic learning algorithm, where the model is optimized for any target distribution formed by a mixture of the client distributions in order to overcome this bias. They further proposed an algorithm for the cross-silo federated learning setting, where the number of clients is small. We consider this problem in the cross-device setting, where the number of clients is much larger. We propose a communication-efficient distributed algorithm called Agnostic Federated Averaging (or AgnosticFedAvg) to minimize the domain-agnostic objective proposed in Mohri et al. (2019), which is amenable to other private mechanisms such as secure aggregation. We highlight two types of naturally occurring domains in federated learning and argue that AgnosticFedAvg performs well on both. To demonstrate the practical effectiveness of AgnosticFedAvg, we report positive results for large-scale language modeling tasks in both simulation and live experiments, where the latter involves training language models for Spanish virtual keyboard for millions of user devices.


page 1

page 2

page 3

page 4


Agnostic Federated Learning

A key learning scenario in large-scale applications is that of federated...

Gradient Masked Federated Optimization

Federated Averaging (FedAVG) has become the most popular federated learn...

FedPAGE: A Fast Local Stochastic Gradient Method for Communication-Efficient Federated Learning

Federated Averaging (FedAvg, also known as Local-SGD) (McMahan et al., 2...

Federated Reconstruction: Partially Local Federated Learning

Personalization methods in federated learning aim to balance the benefit...

SCAFFOLD: Stochastic Controlled Averaging for On-Device Federated Learning

Federated learning is a key scenario in modern large-scale machine learn...

Scaling Federated Learning for Fine-tuning of Large Language Models

Federated learning (FL) is a promising approach to distributed compute, ...

Kernel-based Federated Learning with Personalization

We consider federated learning with personalization, where in addition t...

1 Introduction

In federated learning (FL), a global model is trained on decentralized data from a large number of clients, which may be mobile phones, other edge devices, or sensors [konevcny2016federated, konecny2016federated2]. The training data remains distributed over the clients, thus providing a layer of privacy during model training. However, FL also raises several types of issues, both practical and algorithmic, that have been the topic of multiple research efforts. The design of more efficient communication strategies [konevcny2016federated, konecny2016federated2, suresh2017distributed], devising efficient distributed optimization methods with differential privacy guarantees [AgarwalSureshYuKumarMcMahan2018], and recent lower bound guarantees for parallel stochastic optimization with a dependency graph [WoodworthWangSmithMcMahanSrebro2018] are just a few of these efforts. We refer readers to [li2019federated] and [kairouz2021advances] for a detailed literature survey on FL. FL is typically studied in two scenarios: cross-silo and cross-device. In cross-silo federated learning, the number of clients is small, where as in cross-device federated learning, the number of clients is very large and can be in the order of millions.

Fairness is a key objective in general machine learning

[Bickel398, NIPS2016_9d268236] and especially federated learning [abay2020mitigating, Li2020Fair]

, where the network of clients can be massive and heterogeneous. Standard learning objectives in FL minimize the loss with respect to the uniform distribution over all samples.

[mohri2019agnostic] argued that, in many common instances, the uniform distribution is not the natural objective distribution as the data observed during training and inference in FL can differ. This is, in part, because models are typically trained on client devices under certain conditions (e.g. device is charging, is connected to an un-metered network, is idle, etc.), whereas during inference, these conditions need not be met.  Hence it’s risky to seek to minimize the expected loss with respect to a specific distribution. To overcome this, they proposed a new framework, agnostic federated learning, where the centralized model is optimized for any possible target distribution formed by a mixture of the client distributions. Instead of optimizing for a specific distribution, which has the high risk of a mismatch with the target, they defined an agnostic and more risk-averse objective. They further showed generalization guarantees for this new objective and proposed a stochastic mirror descent type algorithm to minimize this objective.

However, their approach and algorithm did not address some key scenarios in FL. Firstly, their algorithm is feasible in the cross-silo setting, where the number of clients is small and the samples per client is large. However, in the cross-device setting, where the number of clients is very large, we argue that their model yields very loose generalization bounds. Secondly, their algorithm did not fully address the important communication bottleneck and decentralized data issues [McMahanMooreRamageHampsonAguera2017] inherent in the cross-device FL setting. A straightforward implementation of their approach requires running a federated algorithm for a few hundred thousand rounds, which is not feasible in the cross-device setting.

In this paper, we overcome these bottlenecks and propose a communication-efficient federated algorithm called Agnostic Federated Averaging (or AgnosticFedAvg) to minimize the agnostic learning objective in the cross-device setting. AgnosticFedAvg is not only communication-efficient, but also amenable to privacy preserving techniques such as secure aggregation [bonawitz2017practical]. The rest of the paper is organized as follows. In Section 2, we state the notation and overview existing results, in Section 3, we define the framework, and in Section 4, we propose our algorithm. Finally, in Section 5, we evaluate the proposed algorithm on different synthetic and live user datasets.

1:procedure Server 2:     , , 3:     for round to  do 4:          5:          (random set of clients) 6:         for client  do 7:               8:         end for 9:          10:          11:          12:          13:     end for 14:end procedure 1:procedure Client() Run on client 2:     for domain to  do 3:          4:          5:          6:     end for 7:      (split into batches of size ) 8:     for


to  do
9:         for batch  do 10:               11:         end for 12:     end for 13:     return 14:end procedure
Algorithm 1 AgnosticFedAvg

2 Preliminaries and Previous Work

We start with some general notation and definitions. Let denote the input space and the output space. A distribution is a distribution over .

We will primarily discuss a multi-class classification problem where is a finite set of classes, but much of our results can be extended straightforwardly to regression and other problems. The hypotheses we consider are of the form , where stands for the simplex over . Thus,

is a probability distribution over the classes or categories that can be assigned to

. We will denote by a family of such hypotheses . We also denote by

a loss function defined over

taking non-negative values. The loss of for a labeled sample is given by . One key example in applications is the cross-entropy loss, which is defined as We will denote by the expected loss of a hypothesis with respect to a distribution over , and by its minimizer: . In standard learning scenarios, the distribution is the test or target distribution, which typically coincides with the distribution of the training samples. However, in FL, this is often not the case.

In FL, the data is distributed across many heterogeneous clients and the data distribution is different for each client [kairouz2019advances]. Let be the total number of clients. Let denote the data distribution for client . The client does not have access to the true distribution and instead has access to where . Let denote the empirical distribution associated to sample of size . A natural goal is to minimize the empirical risk on the average risk given by

where is the uniform distribution over all clients data. However, as argued by [mohri2019agnostic], due to differences between the train and test distributions, minimizing this objective is risky. Hence, they proposed to minimize the loss on the worst case distribution. More concretely, for distributions , , let for some . For simplicity, we allow any . Where

is the probability simplex over the

clients. Thus, the learner minimizes the empirical agnostic loss (or agnostic risk) associated to a predictor as


where . However, their generalization bounds [mohri2019agnostic, Theorem 1] depends on , the minimum number of samples of any client. In the cross-device setting, this yields loose bounds as each client typically only has a few hundred samples. Hence, instead of treating each client as a domain, we treat collections of clients or data pooled from clients as domains.

3 Proposed Formulation

As stated before, treating each client as a separate domain yields loose generalization bounds. Hence, we treat collections of clients as domains, which naturally leads to two types of partitions. Let there be domains .

  1. data partition: Each client has data from one or more domains and domains represent different types of data. For example, for virtual keyboard applications [hard2018federated], the domains could be the application source of client inputs, such as messaging, emails, or documents. In this case, the data distribution for client is given by

    where and for all .

  2. client partition: Each client has data from exactly one domain and domains represent clusters of clients. For example, clustering clients based on their geographic location yields this domain type. In this case, the data distribution of client is given by

In both of the above formulations, even though there are different clients, the number of underlying distinct domains is , which we argue is considerably smaller. Hence, we have a large number of samples from each of the domains and get strong generalization bounds. Since each client distribution can be written as a linear combination of domain distributions,


However, we do not have access to the true domain distributions and instead have samples from , where is the empirical distribution obtained by pooling all the data of domain . Let be the number of samples in domain . By (2), the true agnostic loss over clients is smaller than the true agnostic loss over domains. Hence we propose to minimize the empirical agnostic loss over domains,

where . The previous known generalization bounds from [mohri2019agnostic, Lemma 3, Corollary 4] yields the following generalization bound. Let . With probability at least , for any client and any hypothesis

for some constant which depends on the maximum value of the loss and is the VC dimension of the hypothesis class . The above generalization bound scales inversely with , which is the minimum number of samples in the domain. Since the number of domains is small, as long as the domains are well-distributed, we would have a good number of samples from all domains and hence a strong generalization bound in the cross-device setting.

We now propose a communication-efficient algorithm to minimize agnostic loss (1) in the cross-device setting.

4 AgnosticFedAvg

[mohri2019agnostic] showed that agnostic learning can be treated as a two-player game, where a learner tries to find the best hypothesis and the adversary tries to find the domain weights that maximize the loss. They proposed a stochastic mirror descent algorithm and showed that the objective reaches the optimum value at a rate of after steps in training. In practice, a direct implementation of their work is not communication-efficient in the cross-device setting, as the number of steps can be in the order of millions. Secondly, it requires the clients to reveal their domain to the server, which can be privacy-invasive. To overcome this, we propose an algorithm called AgnosticFedAvg that is communication-efficient and can be used with privacy preserving techniques such as secure aggregation. We design AgnosticFedAvg with the following properties.

  • Each round of FL uses only a single round of transmission, which includes model download from the server to the clients and upload from the clients to the server.

  • The clients train with multiple local SGD steps similar to FederatedAveraging (or FedAvg) of [McMahanMooreRamageHampsonAguera2017].

  • The server does not have access to individual clients data, but only aggregated statistics, making it compatible with other cryptographic techniques such as secure aggregation [bonawitz2017practical], which prevents anyone from retrieving or rebuilding privacy-sensitive information from individual client parameter updates.

Let be the set of parameters of the hypothesis class. The algorithm first initializes the weights to , domain weights to , and the number of examples per domain to , where denotes the number of samples for domain at round and denotes the number of samples for client , split by domain. We keep a sliding window of the number of examples per domain over the last training rounds. The algorithm uses learning rate for learning domain weights. In the following, let denote the loss function as a function of hypothesis parameter .

At each round of training

, the algorithm computes a scaling vector

by taking the ratio of domain weights and the average number of samples per domain for the last rounds . The algorithm then selects clients randomly and sends the parameters and scaling vector to each of them. First, each selected client computes the number of samples per domain , initial loss per domain , and scaled client weight for their local dataset. Then, each client updates the parameters based on and by running epochs of SGD with batch size and learning rate . Finally, the client transmits the updated parameters , weight per client , initial loss per domain , and number of samples per domain back to the server. Since this is done using secure aggregation, the server only observes the total number of samples and loss per domain across clients. The server then computes the new parameters by averaging the client updates weighted by and does an exponentiated gradient (EG) step for the domain weights ,

If a round does not have any samples from a particular domain, we set to zero for that round. This process is repeated for rounds.

To see why the above algorithm aims to minimize the agnostic loss, consider the weighted average of all the client losses

where is the average loss for domain and is the number of samples in domain from the selected clients at round . The approximation assumes that the moving average is close to the number of samples in domain from the selected clients at round . Thus, AgnosticFedAvg aims to minimize the domain agnostic objective defined in (1). We further note that by using secure aggregation [bonawitz2017practical], the server only observes aggregated statistics rather than learning domains or gradients of individual clients and provides an additional layer of privacy.

5 Experiments

algorithm answer question difference
perp. acc. perp. acc. perp. acc.
FedAvg (uniform)
FedAvg (answer)
Table 1:

Perplexity and in-vocab-accuracy and for Stack Overflow test dataset with the standard deviation for three trials in parentheses.

AgnosticFedAvg attains lowest perplexity for the harder domain answer.
algorithm es-AR es-419 es-US
perp. acc. perp. acc. perp. acc.
FedAvg (uniform)
FedAvg (es-AR)
Table 2: Perplexity and in-vocab-accuracy for Spanish virtual keyboard. AgnosticFedAvg attains lowest perplexity for the harder domain es-AR.
train held-out test
clients 342K 38.8K 204K
sentences 136M 16.5M 16.6M
answers 78.0M 9.33M 9.07M
questions 57.8M 7.17M 7.52M
Table 3: Statistics per domain in the Stack Overflow dataset.

We implemented all algorithms and simulation experiments using the open-source FedJAX

[fedjax2020github] library and report the results for the Stack Overflow language model task and for a live experiment training a Spanish language model for millions of actual virtual keyboard user devices. For all experiments, we compare three algorithms:

  • FedAvg (uniform): Trained uniformly on all available data.

  • FedAvg (target-only): Trained only on data from the target.

  • AgnosticFedAvg: Trained on all available data.

We demonstrate that AgnosticFedAvg attains a lower perplexity compared to FedAvg (uniform) and FedAvg (target-only) for both the experiments on the harder domain: answer domain for Stack Overflow and es-AR for the Spanish language model.

To verify that AgnosticFedAvg correctly minimizes the domain agnostic objective and to showcase its effectiveness on non-language tasks, we also include experiments on a synthetic toy regression example and the EMNIST-62 image recognition task, respectively (presented in the attached supplementary material due to lack of space).

5.1 Stack Overflow Language Model

Motivated by FL uses in virtual keyboard applications [hard2018federated], we use AgnosticFedAvg to train a language model on a large text corpus. We consider the language model task for the Stack Overflow dataset provided by TFF [tff2019]. This dataset contains two domains, questions and answers, from the Stack Overflow forum grouped by client ids. This corresponds to the data partition domain type since an individual client can post both questions and answers. Table 3 summarizes the statistics per domain.

We match the model and training setup from [reddi2020adaptive] and train a single layer LSTM language model over the top K words with an Adam server optimizer and clients participating per training round for rounds. For AgnosticFedAvg, we use the same set up with domain weight learning rate .

For the Stack Overflow experiments, we report perplexity and in-vocab-accuracy, where in-vocab-accuracy is the number of correct predictions, without UNK (out-of-vocabulary) or EOS (end-of-sentence) tokens, divided by the number of words without the EOS token. Defining in-vocab-accuracy this way allows valid comparisons for different vocabulary sizes. The results are in Table 1. For the baseline FedAvg (uniform), of the two domains, the answer domain is harder and has higher perplexity and lower accuracy. Given this, we also train an additional baseline FedAvg (answer) on answer examples only. While FedAvg (answer) does improve answer performance over FedAvg (uniform), it results in significantly worse question performance. However, AgnosticFedAvg outperforms both FedAvg (uniform) and FedAvg (answer) on the answers domain, while also significantly decreasing the performance disparity between answers and questions. This suggests that there could be important features in the questions that can augment performance on answers that are leveraged by AgnosticFedAvg but aren’t optimally weighted in FedAvg (uniform) or are completely ignored in FedAvg (answer).

5.2 Spanish Virtual Keyboard Language Model

We further use AgnosticFedAvg to train a Coupled Input and Forget Gate (CIFG) [greff2017lstm] language model for Spanish on virtual keyboard client devices. Following the settings in [hard2018federated], text data for FL is stored in local caches on device. Clients for this study must meet the following FL requirements, including (1) having 2 gigabytes of memory; (2) being connected to an un-metered network; (3) being charged; (4) being idle. We consider three domains based on the Spanish locales: es-US, es-AR and es-419111defined by UN M.49 region code, for US, Argentina, and Latin America / the Caribbean region, respectively. Since each user device falls in a single region, this task corresponds to the client partition.

Similar to Section 5.1, we report perplexity and in-vocab-accuracy. For all algorithms, we use the Momentum server optimizer, using Nesterov accelerated gradient [nesterov], and clients participating per training round for rounds. Over the course of training, approximately 141 million sentences are processed by 1.5 million clients. The results are in Table 2. For the baseline FedAvg (uniform), of the three languages, es-AR has the worst perplexity. Similar to Section 5.1, training FedAvg (es-AR) on es-AR clients only improves es-AR performance over FedAvg (uniform) but also results in much worse performance for es-US and es-419. Again, AgnosticFedAvg improves the perplexity and accuracy on es-AR over FedAvg (uniform) while also decreasing the regression on es-US and es-419 when compared to FedAvg (es-AR).

6 Conclusion

We presented an algorithmic study of domain agnostic learning in the cross-device FL setting. We also examined the two types of naturally occurring domains in FL: data partition and client partition and provided example learning tasks for both in large-scale language modeling. Finally, we defined AgnosticFedAvg, a communication-efficient federated algorithm that aims to minimize the domain agnostic objective proposed in [mohri2019agnostic] and can provide additional security using secure aggregation and demonstrated its practical effectiveness in simulations and real live experiments. We hope that our efforts will spur further studies into improving the practical efficiency of FL algorithms.


Appendix A Toy Regression

We first evaluate AgnosticFedAvg on a toy regression task to ensure its correctness. We consider a simple regression example, where each domain is a set of random points in . Let each domain , be a set of points in . Further, let be the center of these points. We distribute these points on clients randomly. The goal is to find the point that minimizes the maximum distance to all the domain centers i.e.,

It is easy to see that

thus we maximize the latter objective by AgnosticFedAvg. We choose points such that the true answer is and plot the performance of AgnosticFedAvg for 5 domains in Figure 1. As expected, AgnosticFedAvg converges to the true solution within rounds.

Appendix B EMNIST-62 Image Recognition

We consider the image recognition task for the EMNIST-62 dataset provided by TensorFlow Federated (TFF) 

[tff2019]. This dataset consists of writers and their writing samples which are one of classes (alphanumeric). According to the original NIST source documentation 222, the writers come from two distinct sources: high school and census field. This corresponds to the client partition domain type since a given client can only belong to a single domain. Table 5 summarizes the statistics on the number of clients and examples per domain.

We match the model and training setup from [reddi2020adaptive] and train a convolution neural net with an Adam server optimizer and clients participating per training round for rounds. For AgnosticFedAvg, we use the same set up with domain weight learning rate . [reddi2020adaptive] provides a comprehensive overview over different server optimizer varieties and their respective performances. For our experiments, we use the Adam server optimizer as it was shown to produce the highest accuracy.

The results are in Table 5. For the baseline FedAvg (uniform), of the two domains, the high school domain is harder and has lower accuracy, most likely because it has fewer clients and training examples. In light of this, we also train FedAvg(high school) only on clients from the high school domain. While FedAvg (high school) does improve high school performance over FedAvg (uniform), it results in drastically worse accuracy on the census domain. This is somewhat expected as the number of census clients far outsizes the number of high school clients. AgnosticFedAvg not only outperforms FedAvg (uniform) on the high school domain, but also significantly decreases the gap in accuracy between high school and census from and , for FedAvg (uniform) and FedAvg (high school), respectively, to just .

Figure 1: Left: Learned value over federated training rounds. Right: Domain weights over federated training rounds.
train test
high school clients 500 500
census clients 2900 2900
high school examples 68.8K 8.7K
census examples 597K 74.4K
Table 5: Accuracy for EMNIST-62 test dataset with the standard deviation for three trials in parentheses.
algorithm high school census difference
FedAvg (uniform)
FedAvg (high school)
Table 4: Statistics per domain in the EMNIST-62 dataset.