Federated Learning (FL) fed-learning:google is a distributed, and decentralized protocol to train machine learning models. A set of participating agents can jointly train a model without sharing their local data with each other, or any other third-party. In that regard, FL differs from the traditional distributed learning setting in which data is first centralized, and then distributed to the agents dean2012large; li2014scaling. Due to this, FL is expected to have prominent applications in settings where data privacy is of concern.
However, some recent works have observed that the accuracy of models drops significantly in FL as the local data distributions of agents differ zhao2018federated:noniid; hsieh2019non:noniid. In this work, we study this phenomena, and identify a few techniques that improves performance of FL when data is distributed in a non-i.i.d. fashion among the participating agents. In contrast to some of the recent works, the techniques we identify incur no extra communication overhead to the FL, e.g., server does not have to transmit any data to agents as in zhao2018federated:noniid, or the trained model does not have to be shuffled across participating agents as in hsieh2019non:noniid. In brief, techniques we identify improve the performance of FL only by incurring some computation overhead either on the client-side, or the server-side. Further, they can easily be incorporated to FL without any structural changes.
We organize the rest of the paper as follows. In Section 2, we provide the necessary background on FL. In Section 3, we discuss, and explain our techniques to improve FL in non-i.i.d. setting. In Section 4, we provide the experimental results, and demonstrate the performance improvement in FL due to our techniques. Finally in Section 5, we provide a few concluding remarks.
2.1 Federated Learning (FL)
At a high level, FL is multi-round protocol between an aggregation serve,r and a set of agents where agents jointly train a machine learning model. Formally, participating agents try to minimize the average of their loss functions,
where is the loss function of kth
agent. For example, for neural networks,is typically empirical risk minimization under a loss function , i.e.,
with being the total number of samples in agent’s dataset and being the ith sample.
A round in FL starts by the serving picking a subset of agents , and sending them the current weights of the model . After receiving the weights, an agent initializes his model with
, and trains for some number of iterations on his local dataset, for example via stochastic gradient descent (SGD). At the end of his local training, ith agent computes his update for the current round as , and sends the update back to the server. Once server receives updates from all the agents in , it computes the weights for the next round via weighted averaging111Technically, the server can use an arbitrary aggregation function, but we stick to weighted averaging in this work as it is the most common one. as,
where is the number of samples in the dataset of agent . Aggregating the updates via this way has been referred as FedAvg fed-learning:google.
It has been shown that models trained via FL can perform better than locally trained models at agent’s side in various settings fed-learning:google; keyboard. In contrast, as noted before, it has also been observed that the performance of FL drops drastically when local data distributions of agents differ significantly, i.e., when data is distiributed in a non-i.i.d. fashion among agents zhao2018federated:noniid; hsieh2019non:noniid.
In this section, we discuss a few possible causes for performance degradation of FL in non-i.i.d. setting, and propose techniques to alleviate these causes.
3.1 Hypothesis Conflicts and Server-Side Training
We hypothesize that one cause behind the performance degradation in non-i.i.d. setting is hypothesis conflicts between agents’ local models. For example, for a classification task, we say two models have conflicting hypotheses if there exists an input in which the models assign different classes. Concretely, let be the local models of agent at the end of round , respectively. Then, we say there is a hypothesis conflict between these models at round if there exists an input such that . We illustrate how hypothesis conflicts can be problematic on a toy example in Figure 1.
An illustration of how hypothesis conflicts can lead to performance degradation. In (a), we have a simple dataset of four points with two features. As can be seen, a decision tree can easily learn the rules to classify this dataset perfectly. In (b), the dataset is partitioned where first agent gets points with(left), and the second agent gets points with (right). As can be seen, agents’ local models learn conflicting hypotheses. First agent predicts (-) if and second agent predicts (+) if . If we were to do ensemble prediction with these models, their votes would just cancel each other.
Given the discussion above, we argue that having a small, balanced training data on server-side can improve the performance in non-i.i.d. setting significantly. That is, the central server can aggregate the updates as usual, and then fine-tune the resulting model on such a data by training the model for a short amount. If the training data on the server-side has a similar distribution with respect to the union of the datasets of agents, then the parameters learned on this data will likely perform well on the local data of agents. Given that, agents should not overwrite the mappings learned during server-side training. We believe this can reduce the hypothesis conflicts, and consequently improve the performance of the trained models.
3.2 Projected Gradient Descent at Agents’ Side
Another reason for the performance degradation might stem from the fact that local loss surfaces of agents are different. Each local model might reach to some local minima on their own surface, however, aggregated model might not be close to any local minima on the loss surface defined by the union of local datasets. Indeed, in a recent work zhao2018federated:noniid, authors hypothesize FL performs worse in non-i.i.d. setting as local models’ parameters differ significantly from each other (see Figure 2).
We believe that this divergence can be alleviated to some extent by having each agent run projected gradient descent with some common parameters, e.g., norm of local models can be bounded by a common value . Since each local model is initialized to the same parameters, and each agent uses the same projection, we believe this might prevent local models from diverging too much from each other. Consequently, this can improve the performance of the trained models.
3.3 Server-side Momentum
It is well known that momentum techniques can significantly improve convergence by making gradient descent robust to certain elements in loss surfaces, e.g., it can dampen oscillations in valleys. Given the differences in loss surfaces, we believe that it might have a similar effect in non-i.i.d. case. To adapt momentum in FL, we keep running averages of aggregated models. That is, with momentum, the update rule for FedAvg (equation 1) can be written as,
where is server’s momentum constant and .
3.4 Preventing Parameter Cancellation
Following the argument in Figure 2, we predict that as local models’ diverge, they could cancel each others’ parameters when they are aggregated. To prevent this, we can simply look at the sign information of updates, and adjust the learning rate at the server. That is, we can set a threshold , and for every dimension where the sum of signs of updates is less than this threshold, we can set to 0. With this, the server’s learning rate for ith dimension is given by,
With such a learning rate, FedAvg can be written as,
is simply the learning rate vector over all dimensions, andis the element-wise product operation.
We note that, the same technique has been proposed as a defense against backdoor attacks safa2020defending in FL setting. Although authors note their technique can potentially improve performance of models in non-i.i.d. setting too, they provide no empirical evaluation.
In this section, we illustrate the performance of our techniques from Section 3 via experiments. We first look at each technique in isolation, and then in combination with each other.
In our experiments, we simulate FL for rounds among agents. At each round, agents locally train for epoch with a batch size of , using a learning rate of , and a weight decay of before sending their updates. Upon receiving and aggregating updates, we measure validation accuracy on a validation data. We use the FedAvg aggregation (equation 1) with . For dataset, we use CIFAR10 cifar10 and for model, we use ResNet20 with fixup initialization zhang2019fixup
. Each experiment is repeated for 3 times, and we report on the mean and standard deviation of results.
We first provide the results for baseline in Table 1.
|Setting||Val. Acc. Mean||Val. Acc. STD|
|FL - IID||90.0%||0.08%|
|FL - NIID(5)||73.0%||2.6%|
|Setting||Server-side data||Val. Acc. Mean||Val. Acc. STD|
|FL - NIID(5)||0%||73.0%||2.6%|
|FL - NIID(5)||5%||83.7%||0.1%|
We now show the results for projected gradient descent. In this setting, we use a norm bound on local models. That is, if at any point norm of his model exceeds , agent projects the model back to the ball with radius . We have also combined this with noise addition as some works argue that noise addition improves regularization noh2017regularizing, i.e., at each iteration, agent samples a Gaussian noise with mean and some standard deviation, and adds it to his gradients.
|Setting||Norm Threshold||Noise STD||Val. Acc. Mean||Val. Acc. STD|
|FL - NIID(5)||0||0||73.0%||2.6%|
|FL - NIID(5)||3||0||77.5%||2.3%|
|FL - NIID(5)||3||79.6%||1.3%|
We now illustrate the effect of server-side momentum in Table 4.
|Setting||Server’s Momentum Constant||Val. Acc. Mean||Val. Acc. STD|
|FL - NIID(5)||0||73.0%||2.6%|
|FL - NIID(5)||0.5||80.9%||2.4%|
|FL - NIID(5)||0.9||75.9%||1.1%|
|Setting||Using Adjustable LR||Val. Acc. Mean||Val. Acc. STD|
|FL - NIID(5)||No||73.0%||2.6%|
|FL - NIID(5)||Yes||78.5%||1.72%|
After looking at each method individually, we do a grid search to see if a combination of them can yield a better performance. The result of grid search yielded a mean accuracy of with a standard deviation of under the following setting: server-side training with 5% of data, server-side momentum with a constant of , and with adjustable learning rate at the server-side. This is more than a 12% improvement over our baseline setting.
In this paper, we studied FL with a particular focus on improving the accuracy of trained models in non-i.i.d. settings. Through our study, we have identified a few simple techniques that improves over our baseline accuracy by more than 12%. The techniques we identify are rather simple, and can easily be incorporated to FL without any structural changes. Further, they incur no extra communication overhead, but only some light computation overhead either on the client-side, or the server side. In a future work, we hope to expand our experimental setting by testing different models, datasets, and hyperparameters.