Overcoming Forgetting in Federated Learning on Non-IID Data

by   Neta Shoham, et al.

We tackle the problem of Federated Learning in the non i.i.d. case, in which local models drift apart, inhibiting learning. Building on an analogy with Lifelong Learning, we adapt a solution for catastrophic forgetting to Federated Learning. We add a penalty term to the loss function, compelling all local models to converge to a shared optimum. We show that this can be done efficiently for communication (adding no further privacy risks), scaling with the number of nodes in the distributed setting. Our experiments show that this method is superior to competing ones for image recognition on the MNIST dataset.



There are no comments yet.


page 1

page 2

page 3

page 4


Decentralized Federated Learning: A Segmented Gossip Approach

The emerging concern about data privacy and security has motivated the p...

FLIX: A Simple and Communication-Efficient Alternative to Local Methods in Federated Learning

Federated Learning (FL) is an increasingly popular machine learning para...

Towards Causal Federated Learning For Enhanced Robustness and Privacy

Federated Learning is an emerging privacy-preserving distributed machine...

FedFMC: Sequential Efficient Federated Learning on Non-iid Data

As a mechanism for devices to update a global model without sharing data...

Communication-Efficient ADMM-based Federated Learning

Federated learning has shown its advances over the last few years but is...

Robust Federated Learning with Noisy Communication

Federated learning is a communication-efficient training process that al...

Think Locally, Act Globally: Federated Learning with Local and Global Representations

Federated learning is an emerging research paradigm to train models on p...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Recent years have seen the advent of smart devices and sensors gathering data at the edge and being able to act on that data. The desire to keep data private and other considerations have led the machine learning community to study algorithms for distributed training that do not require sending the data out of the edge devices. Edge devices most often have low networking availability and capacity, which could prohibit training through standard SGD. The

Federated Averaging (FedAvg) algorithm of McMahan et al. McMahan et al. (2016)

lets the devices train on their local data for several epochs (using local SGD) before sending the trained model to a central server. The server then aggregates the models and sends the aggregated model back to the devices. This is done iteratively until convergence is achieved.

Federated Learning poses three challenges that make it different from traditional distributed learning. The first one is the number of computing stations, which can be in the hundreds of millions.111In order to cope with this, it is common practice to select only a subset of devices at every training iteration McMahan et al. (2016). For simplicity of presentation, we will ignore this method, for which our suggested algorithm can be easily adapted. The second is much slower communication compared to the inter cluster communication found in data centers. The third difference, on which we focus in this work, is the highly non i.i.d. manner in which the data may be distributed among the devices.

In some real-life cases, Federated Learning has shown robustness to non i.i.d. distribution Ramaswamy et al. (2019). There are also recent theoretical results proving the convergence of Federated Learning algorithms Li et al. (2019) on non i.i.d. data. It is evident, however, that even in very simple scenarios, Federated Learning on non i.i.d. distributions has trouble achieving good results (in terms of accuracy and the number of communication rounds, as compared to the i.i.d. case) McMahan et al. (2016); Wang et al. (2019).

1.1 Overcoming Forgetting in Sequential Lifelong Learning and in Federated Learning

There is a deep parallel between the Federated Learning problem and another fundamental machine learning problem called Lifelong Learning (and the related Multi-Task Learning). In Lifelong Learning, the challenge is to learn task , and continue on to learn task using the same model, but without "forgetting", without severely hurting the performance on, task ; or in general, learning tasks in sequence without forgetting previously-learnt tasks for which samples are not presented anymore. Besides learning tasks serially rather than in parallel, in Lifelong Learning each task is thus seen only once, whereas in Federated Learning there is no such limitation. But these differences aside, the paradigms share a common main challenge - how to learn a task without disturbing different ones learnt on the same model.

It is not surprising, then, that similar approaches are being applied to solve the Federated Learning and the Lifelong Learning problems. One such example is data distillation, in which representative data samples are shared between tasks Hou et al. (2018); Zhao et al. (2018). However, Federated Learning is frequently used in order to achieve data privacy, which would be broken by sending a piece of data from one device to another, or from one device to a central point. We therefore seek for some other type of information to be shared between the tasks.

The answer to what kind of information to use may be found in Kirkpatrick et al. Kirkpatrick et al. (2017). In this work, the authors present a new algorithm for Lifelong Learning - Elastic Weight Consolidation (EWC). EWC aims to prevent catastrophic forgetting when moving from learning task to learning task . The idea is to identify the coordinates in the network parameters that are the most informative for task , and then, while task

is being learned, penalize the learner for changing these parameters. The basic assumption is that deep neural networks are over-parameterized enough, so that there are good chances of finding an optimal solution

to task in the neighborhood of previously learned .

In order to control the stiffness of per coordinate while learning task , the authors suggest to use the diagonal of the Fisher information matrix

to selectively penalize parts of the parameters vector

that are getting too far from . This is done using the following objective


The formal justification they provide for (1) is Bayesian: Let and be independent datasets used for tasks and . We have that

is just the standard likelihood maximized in the optimization of , and the posterior

is approximated with Laplace’s method as a Gaussian distribution with expectation

and covariance .

It is also well known that under some regularity conditions, the information matrix approximates the Hessian of , at Pronzato and Pázman (2013). By this we get a non Bayesian interpretation of (1),


where is exactly the loss we want to minimize. In general, one can learn a sequence of tasks . In section 3 we rely on the above interpretation as a second order approximation in order to construct an algorithm for Federated Learning. We will further show how to implement such an algorithm in a way that preserves the privacy benefits of the standard FedAvg algorithm.

2 Related Work

There are only a handful of works that directly try to cope with the challenge of Federated Learning with non i.i.d. distribution. One approach is to just give up the periodical averaging, and reduce the communication by sparsification and quantization of the updates sent to the central point after each local mini batch Sattler et al. (2019). In Zhao et al. Zhao et al. (2018) it was shown that by sharing only a small portion of the data between different nodes, one can achieve a great improvement in model accuracy. However, sharing data is not acceptable in many Federated Learning scenarios.

Somewhat similar to our approach, MOCHA Smith et al. (2017) links each task with a different parameter and the relation between the tasks is modeled by adding a loss term , where and . The optimization is done on both and . MOCHA uses a primal-dual formulation in order to solve the optimization problem and thus, unlike our algorithm, is not suitable for deep networks.

Perhaps the closest work to ours is Sahu et al. Sahu et al. (2018), where the authors present their FedProx algorithm, which, like our algorithm, also uses parameter stiffness. However, unlike our algorithm, in FedProx the penalty term is isotropic, . DANE Shamir et al. (2014) augments FedProx by adding a gradient correction term to accelerate convergence, but is not robust to non i.i.d. data Sahu et al. (2018); Reddi et al. (2016). AIDE Reddi et al. (2016) improves the ability of DANE to deal with non i.i.d. data. However, it does so by using an inexact version of DANE, through a limitation on the amount of local computations.

A recent work Li et al. (2019) proves convergence of FedAvg for the non i.i.d. case. It also provides a theoretical explanation for a phenomenon known in practice, of performance degradation when the number of local iterations is too high. This is exactly the problem that we tackle in this work.

3 Federated Curvature

In this section we present our adaptation of the EWC algorithm to the Federated Learning scenario. We call it FedCurv (for Federated Curvature, motivated by (2)). We mark by the nodes, with the tasks’ local datasets . We diverge from the FedAvg algorithm and in each round we use all the nodes in instead of randomly selecting a subset on them. (Our algorithm can easily be extended to select a subset.) At round each node optimizes the following loss:


On each round , starting from initial point , the nodes optimize their local loss by running SGD for local epochs. At the end of each round , each node sends to the rest of the nodes the SGD result and (where ). and will be used for the loss of round . We switched from to to signify that local tasks are optimized for epochs and not until they converge (as was the case for EWC). However, (2) (its generalization to tasks) supports using large values of , so and then .

3.1 Keeping Low Bandwidth and Preserving Privacy

At first glance, maintaining all the historical data required by FedCurv might look cumbersome and expensive to store and transmit. It also looks like a sensitive information is passed between nodes. However by careful implementation we can avoid these potential drawbacks. We note that (3) can also be rearranged as


The central point needs only to maintain and transmit to the edge node two additional elements, besides , of the same size as ,

The device can then construct the data needed for the evaluation of from by subtraction. The device at time needs also two transmit only two additional element at the same size of and .


It should be noted that we only need to send local gradient-related aggregated information (aggregated per local data sample) from the devices to the central point. In terms of privacy, it is not significantly different from the classical FedAvg algorithm. The central point itself, like in FedAvg, needs only to keep globally aggregated information from the devices. We see no reason why secure aggregation methods Bonawitz et al. (2016) which were successfully applied to FedAvg could not be applied to FedCurv.

Further potential bandwidth reduction

The diagonal of the Fisher information has been used successfully for parameter pruning in neural networks LeCun et al. (1990). This gives us a straightforward way to save bandwidth by using sparse versions of and even , as provides a natural evaluation for the importance measure of the parameters of . The sparse versions are achieved by keeping only a fraction of indices that are related to the largest elements of the diagonal of the Fisher information matrix. We have not explored this idea in practice.

4 Experiments

We conducted our experiments on a group of 96 simulated devices. We divided the MNIST dataset LeCun et al. (1998) into

blocks of homogeneous labels (discarding a small amount of data). We randomly assigned two blocks to each device. We used the CNN architecture from the MNIST PyTorch example


We explored two factors: (1) Learning method - we considered three algorithms, FedAvg, FedProx, and FedCurv (our algorithm); (2) , the number of epochs in each round, which is of special interest in this work, as our algorithm is designed for large values of . , the fraction of devices that participate in each iteration, and , the local batch size, were kept fixed at . For all the experiments, we have also used a constant learning rate of .

FedProx’s and FedCurv’s values were chosen in the following way: We looked for values that reached 90% test-accuracy in the smallest number of rounds. We did it by searching on a multiplicative grid using a factor of 10 and then a factor of 2 in order to ensure a minimum. Table 1 shows the number of rounds required in order to achieve 95%, 90% and 85% test-accuracy with these chosen parameters. We see that for , FedCurv achieved 90% test-accuracy three times as fast as the vanila FedAvg algorithm. FedProx also reached 90% faster than FedAvg. However, while our algorithm achieved 95% twice as fast as FedAvg, FedProx achieved it two times slower. For , the improvement of both FedCurv and FedProx is less significant, with FedCurv still outperforming FedProx and FedAvg.

In Figure 2 and Figure 2 we can see that both FedProx and FedCurv are doing well at the beginning of the training process. However, while FedCurv provides enough flexibility with that allows for reaching high accuracy at the end of the process, the stiffness of the parameters in FedProx comes at the expense of accuracy. FedCurv gives more significant improvements for higher values of (as does FedProx), as expected by the theory.

Algorithm 0.95 0.90 0.85 0.95 0.90 0.85
FedCurv, 38 9 6 99 35 27
FedProx, 140 22 16
FedProx, 115 46 33
FedAvg 76 30 22 106 51 43
Table 1: Number of rounds to achieve a certain accuracy on Non-IID MNIST
Figure 1: Learning curves, =50
Figure 2: Learning curves, =10

5 Conclusion

This work has provided a novel approach to the problem of Federated Learning on non i.i.d. data. It built on a solution from Lifelong Learning, which uses the diagonal of the Fisher information matrix in order to protect the parameters that are important to each task. The adaptation required modifying that sequential solution (from Lifelong Learning) into a parallel form (of Federated Learning), which a priori involves excessive sharing of data. We showed that this can be done efficiently, without substantially increasing bandwidth usage and compromising privacy. As our experiments have demonstrated, our FedCurv algorithm guards the parameters important to each task, improving convergence.


  • [1] K. Bonawitz, V. Ivanov, B. Kreuter, A. Marcedone, H. B. McMahan, S. Patel, D. Ramage, A. Segal, and K. Seth (2016) Practical secure aggregation for federated learning on user-held data. arXiv preprint arXiv:1611.04482. Cited by: §3.1.
  • [2] S. Hou, X. Pan, C. Change Loy, Z. Wang, and D. Lin (2018) Lifelong learning via progressive distillation and retrospection. In

    Proceedings of the European Conference on Computer Vision (ECCV)

    pp. 437–452. Cited by: §1.1.
  • [3] J. Kirkpatrick, R. Pascanu, N. Rabinowitz, J. Veness, G. Desjardins, A. A. Rusu, K. Milan, J. Quan, T. Ramalho, A. Grabska-Barwinska, et al. (2017) Overcoming catastrophic forgetting in neural networks. Proceedings of the national academy of sciences 114 (13), pp. 3521–3526. Cited by: §1.1.
  • [4] Y. LeCun, L. Bottou, Y. Bengio, P. Haffner, et al. (1998) Gradient-based learning applied to document recognition. Proceedings of the IEEE 86 (11), pp. 2278–2324. Cited by: §4.
  • [5] Y. LeCun, J. S. Denker, and S. A. Solla (1990) Optimal brain damage. In Advances in neural information processing systems, pp. 598–605. Cited by: §3.1.
  • [6] X. Li, K. Huang, W. Yang, S. Wang, and Z. Zhang (2019) On the convergence of fedavg on non-iid data. arXiv preprint arXiv:1907.02189. Cited by: §1, §2.
  • [7] H. B. McMahan, E. Moore, D. Ramage, S. Hampson, et al. (2016) Communication-efficient learning of deep networks from decentralized data. arXiv preprint arXiv:1602.05629. Cited by: §1, §1, footnote 1.
  • [8] L. Pronzato and A. Pázman (2013) Design of experiments in nonlinear models. Lecture notes in statistics 212. Cited by: §1.1.
  • [9] PyTorch mnist example. https://github.com/pytorch/examples/tree/master/mnist Cited by: §4.
  • [10] S. Ramaswamy, R. Mathews, K. Rao, and F. Beaufays (2019) Federated learning for emoji prediction in a mobile keyboard. arXiv preprint arXiv:1906.04329. Cited by: §1.
  • [11] S. J. Reddi, J. Konečnỳ, P. Richtárik, B. Póczós, and A. Smola (2016) Aide: fast and communication efficient distributed optimization. arXiv preprint arXiv:1608.06879. Cited by: §2.
  • [12] A. K. Sahu, T. Li, M. Sanjabi, M. Zaheer, A. Talwalkar, and V. Smith (2018) Federated optimization for heterogeneous network. arXiv preprint arXiv:1812.06127. Cited by: §2.
  • [13] F. Sattler, S. Wiedemann, K. Müller, and W. Samek (2019) Robust and communication-efficient federated learning from non-iid data. arXiv preprint arXiv:1903.02891. Cited by: §2.
  • [14] O. Shamir, N. Srebro, and T. Zhang (2014) Communication-efficient distributed optimization using an approximate newton-type method. In International conference on machine learning, pp. 1000–1008. Cited by: §2.
  • [15] V. Smith, C. Chiang, M. Sanjabi, and A. S. Talwalkar (2017) Federated multi-task learning. In Advances in Neural Information Processing Systems, pp. 4424–4434. Cited by: §2.
  • [16] S. Wang, T. Tuor, T. Salonidis, K. K. Leung, C. Makaya, T. He, and K. Chan (2019) Adaptive federated learning in resource constrained edge computing systems. IEEE Journal on Selected Areas in Communications 37 (6), pp. 1205–1221. Cited by: §1.
  • [17] Y. Zhao, M. Li, L. Lai, N. Suda, D. Civin, and V. Chandra (2018) Federated learning with non-iid data. arXiv preprint arXiv:1806.00582. Cited by: §1.1, §2.