Improving Accuracy of Federated Learning in Non-IID Settings

10/14/2020 ∙ by Mustafa Safa Ozdayi, et al. ∙ The University of Texas at Dallas 0

Federated Learning (FL) is a decentralized machine learning protocol that allows a set of participating agents to collaboratively train a model without sharing their data. This makes FL particularly suitable for settings where data privacy is desired. However, it has been observed that the performance of FL is closely tied with the local data distributions of agents. Particularly, in settings where local data distributions vastly differ among agents, FL performs rather poorly with respect to the centralized training. To address this problem, we hypothesize the reasons behind the performance degradation, and develop some techniques to address these reasons accordingly. In this work, we identify four simple techniques that can improve the performance of trained models without incurring any additional communication overhead to FL, but rather, some light computation overhead either on the client, or the server-side. In our experimental analysis, combination of our techniques improved the validation accuracy of a model trained via FL by more than 12 with respect to our baseline. This is about 5 model trained on centralized data.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Federated Learning (FL) fed-learning:google is a distributed, and decentralized protocol to train machine learning models. A set of participating agents can jointly train a model without sharing their local data with each other, or any other third-party. In that regard, FL differs from the traditional distributed learning setting in which data is first centralized, and then distributed to the agents dean2012large; li2014scaling. Due to this, FL is expected to have prominent applications in settings where data privacy is of concern.

However, some recent works have observed that the accuracy of models drops significantly in FL as the local data distributions of agents differ zhao2018federated:noniid; hsieh2019non:noniid. In this work, we study this phenomena, and identify a few techniques that improves performance of FL when data is distributed in a non-i.i.d. fashion among the participating agents. In contrast to some of the recent works, the techniques we identify incur no extra communication overhead to the FL, e.g., server does not have to transmit any data to agents as in zhao2018federated:noniid, or the trained model does not have to be shuffled across participating agents as in hsieh2019non:noniid. In brief, techniques we identify improve the performance of FL only by incurring some computation overhead either on the client-side, or the server-side. Further, they can easily be incorporated to FL without any structural changes.

We organize the rest of the paper as follows. In Section 2, we provide the necessary background on FL. In Section 3, we discuss, and explain our techniques to improve FL in non-i.i.d. setting. In Section 4, we provide the experimental results, and demonstrate the performance improvement in FL due to our techniques. Finally in Section 5, we provide a few concluding remarks.

2 Background

2.1 Federated Learning (FL)

At a high level, FL is multi-round protocol between an aggregation serve,r and a set of agents where agents jointly train a machine learning model. Formally, participating agents try to minimize the average of their loss functions,

where is the loss function of kth

agent. For example, for neural networks,

is typically empirical risk minimization under a loss function , i.e.,

with being the total number of samples in agent’s dataset and being the ith sample.

A round in FL starts by the serving picking a subset of agents , and sending them the current weights of the model . After receiving the weights, an agent initializes his model with

, and trains for some number of iterations on his local dataset, for example via stochastic gradient descent (SGD). At the end of his local training, i

th agent computes his update for the current round as , and sends the update back to the server. Once server receives updates from all the agents in , it computes the weights for the next round via weighted averaging111Technically, the server can use an arbitrary aggregation function, but we stick to weighted averaging in this work as it is the most common one. as,

(1)

where is the number of samples in the dataset of agent . Aggregating the updates via this way has been referred as FedAvg fed-learning:google.

It has been shown that models trained via FL can perform better than locally trained models at agent’s side in various settings  fed-learning:google; keyboard. In contrast, as noted before, it has also been observed that the performance of FL drops drastically when local data distributions of agents differ significantly, i.e., when data is distiributed in a non-i.i.d. fashion among agents zhao2018federated:noniid; hsieh2019non:noniid.

3 Methods

In this section, we discuss a few possible causes for performance degradation of FL in non-i.i.d. setting, and propose techniques to alleviate these causes.

3.1 Hypothesis Conflicts and Server-Side Training

Hypothesis conflicts

We hypothesize that one cause behind the performance degradation in non-i.i.d. setting is hypothesis conflicts between agents’ local models. For example, for a classification task, we say two models have conflicting hypotheses if there exists an input in which the models assign different classes. Concretely, let be the local models of agent at the end of round , respectively. Then, we say there is a hypothesis conflict between these models at round if there exists an input such that . We illustrate how hypothesis conflicts can be problematic on a toy example in Figure 1.

(a) Prediction in centralized setting
(b) Hypothesis conflict between two sites in distributed setting
Figure 1:

An illustration of how hypothesis conflicts can lead to performance degradation. In (a), we have a simple dataset of four points with two features. As can be seen, a decision tree can easily learn the rules to classify this dataset perfectly. In (b), the dataset is partitioned where first agent gets points with

(left), and the second agent gets points with (right). As can be seen, agents’ local models learn conflicting hypotheses. First agent predicts (-) if and second agent predicts (+) if . If we were to do ensemble prediction with these models, their votes would just cancel each other.

Server-side training

Given the discussion above, we argue that having a small, balanced training data on server-side can improve the performance in non-i.i.d. setting significantly. That is, the central server can aggregate the updates as usual, and then fine-tune the resulting model on such a data by training the model for a short amount. If the training data on the server-side has a similar distribution with respect to the union of the datasets of agents, then the parameters learned on this data will likely perform well on the local data of agents. Given that, agents should not overwrite the mappings learned during server-side training. We believe this can reduce the hypothesis conflicts, and consequently improve the performance of the trained models.

3.2 Projected Gradient Descent at Agents’ Side

Another reason for the performance degradation might stem from the fact that local loss surfaces of agents are different. Each local model might reach to some local minima on their own surface, however, aggregated model might not be close to any local minima on the loss surface defined by the union of local datasets. Indeed, in a recent work zhao2018federated:noniid, authors hypothesize FL performs worse in non-i.i.d. setting as local models’ parameters differ significantly from each other (see Figure 2).

Figure 2: Parameter divergence in FL zhao2018federated:noniid. SGD refers to the direction of the gradient in centralized setting. As local loss surfaces differ, local agents’ parameters diverge from each other. This could be one of the reasons why FL performs significantly worse in non-i.i.d. setting, and perhaps can be alleviated to some extend by using projected gradient descent,

We believe that this divergence can be alleviated to some extent by having each agent run projected gradient descent with some common parameters, e.g., norm of local models can be bounded by a common value . Since each local model is initialized to the same parameters, and each agent uses the same projection, we believe this might prevent local models from diverging too much from each other. Consequently, this can improve the performance of the trained models.

3.3 Server-side Momentum

It is well known that momentum techniques can significantly improve convergence by making gradient descent robust to certain elements in loss surfaces, e.g., it can dampen oscillations in valleys. Given the differences in loss surfaces, we believe that it might have a similar effect in non-i.i.d. case. To adapt momentum in FL, we keep running averages of aggregated models. That is, with momentum, the update rule for FedAvg (equation 1) can be written as,

(2)
(3)

where is server’s momentum constant and .

3.4 Preventing Parameter Cancellation

Following the argument in Figure 2, we predict that as local models’ diverge, they could cancel each others’ parameters when they are aggregated. To prevent this, we can simply look at the sign information of updates, and adjust the learning rate at the server. That is, we can set a threshold , and for every dimension where the sum of signs of updates is less than this threshold, we can set to 0. With this, the server’s learning rate for ith dimension is given by,

With such a learning rate, FedAvg can be written as,

(4)

where

is simply the learning rate vector over all dimensions, and

is the element-wise product operation.

We note that, the same technique has been proposed as a defense against backdoor attacks safa2020defending in FL setting. Although authors note their technique can potentially improve performance of models in non-i.i.d. setting too, they provide no empirical evaluation.

4 Experiments

In this section, we illustrate the performance of our techniques from Section 3 via experiments. We first look at each technique in isolation, and then in combination with each other.

4.1 Setting

In our experiments, we simulate FL for rounds among agents. At each round, agents locally train for epoch with a batch size of , using a learning rate of , and a weight decay of before sending their updates. Upon receiving and aggregating updates, we measure validation accuracy on a validation data. We use the FedAvg aggregation (equation 1) with . For dataset, we use CIFAR10  cifar10 and for model, we use ResNet20 with fixup initialization  zhang2019fixup

. Each experiment is repeated for 3 times, and we report on the mean and standard deviation of results.

4.2 Results

We first provide the results for baseline in Table 1.

Setting Val. Acc. Mean Val. Acc. STD
Centralized 90.4% 0.77%
FL - IID 90.0% 0.08%
FL - NIID(5) 73.0% 2.6%
Table 1: Results of baseline setting. As can be seen, when data is distributed in iid fashion among agents (FL - IID), i.e, each agent gets same number of instances from all 10 classes, accuracy of FL is more or less the same as centralized setting. However, when data is distributed in a way such that first agent gets all samples of first 5 classes, and the other agent gets the remaining samples (FL - NIID(5)), accuracy drops by more than 17% with respect to the centralized case.

We then illustrate how server-side training, as described in Section 3.2, improves upon baseline in Table 2.

Setting Server-side data Val. Acc. Mean Val. Acc. STD
FL - NIID(5) 0% 73.0% 2.6%
FL - NIID(5) 5% 83.7% 0.1%
Table 2: Effect of server-side training. We give 5% of training data to the server and distribute the remaining training data to the agents in non-i.i.d. fashion as explained in baseline setting. Server fine-tunes the aggregated model by training for a single epoch on this data at each round after aggregating the updates. We see that this alone improves the accuracy over baseline by more than 10%.

We now show the results for projected gradient descent. In this setting, we use a norm bound on local models. That is, if at any point norm of his model exceeds , agent projects the model back to the ball with radius . We have also combined this with noise addition as some works argue that noise addition improves regularization noh2017regularizing, i.e., at each iteration, agent samples a Gaussian noise with mean and some standard deviation, and adds it to his gradients.

Setting Norm Threshold Noise STD Val. Acc. Mean Val. Acc. STD
FL - NIID(5) 0 0 73.0% 2.6%
FL - NIID(5) 3 0 77.5% 2.3%
FL - NIID(5) 3 79.6% 1.3%
Table 3: Improvement due to projected gradient descent. As can be seen, using projected gradient descent alone improves over baseline by about 4%. Combining it with noise addition improves accuracy by more than 6% with respect to baseline.

We now illustrate the effect of server-side momentum in Table 4.

Setting Server’s Momentum Constant Val. Acc. Mean Val. Acc. STD
FL - NIID(5) 0 73.0% 2.6%
FL - NIID(5) 0.5 80.9% 2.4%
FL - NIID(5) 0.9 75.9% 1.1%
Table 4: Effect of using momentum on server-side. As can be seen, using momentum with a momentum constant of alone improves accuracy over baseline by more than 7%.

Finally, we illustrate the effect of adjusting learning rate of server, , as described in Section 3.4 in Table 5. Since we have only two agents, we use a threshold of .

Setting Using Adjustable LR Val. Acc. Mean Val. Acc. STD
FL - NIID(5) No 73.0% 2.6%
FL - NIID(5) Yes 78.5% 1.72%
Table 5: Effect of adjusting the learning rate at the server-side. The validation accuracy is improved by more than 5% with respect to the baseline.

After looking at each method individually, we do a grid search to see if a combination of them can yield a better performance. The result of grid search yielded a mean accuracy of with a standard deviation of under the following setting: server-side training with 5% of data, server-side momentum with a constant of , and with adjustable learning rate at the server-side. This is more than a 12% improvement over our baseline setting.

5 Conclusion

In this paper, we studied FL with a particular focus on improving the accuracy of trained models in non-i.i.d. settings. Through our study, we have identified a few simple techniques that improves over our baseline accuracy by more than 12%. The techniques we identify are rather simple, and can easily be incorporated to FL without any structural changes. Further, they incur no extra communication overhead, but only some light computation overhead either on the client-side, or the server side. In a future work, we hope to expand our experimental setting by testing different models, datasets, and hyperparameters.

References