FedCostWAvg: A new averaging for better Federated Learning

11/16/2021
by   Leon Mächler, et al.
Cole Normale Suprieure
4

We propose a simple new aggregation strategy for federated learning that won the MICCAI Federated Tumor Segmentation Challenge 2021 (FETS), the first ever challenge on Federated Learning in the Machine Learning community. Our method addresses the problem of how to aggregate multiple models that were trained on different data sets. Conceptually, we propose a new way to choose the weights when averaging the different models, thereby extending the current state of the art (FedAvg). Empirical validation demonstrates that our approach reaches a notable improvement in segmentation performance compared to FedAvg.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

08/04/2021

FedJAX: Federated learning simulation with JAX

Federated learning is a machine learning technique that enables training...
12/13/2019

Federated learning with multichannel ALOHA

In this paper, we study federated learning in a cellular system with a b...
10/09/2018

Federated Learning for Keyword Spotting

We propose a practical approach based on federated learning to solve out...
11/15/2019

Information-Theoretic Perspective of Federated Learning

An approach to distributed machine learning is to train models on local ...
05/12/2021

The Federated Tumor Segmentation (FeTS) Challenge

This manuscript describes the first challenge on Federated Learning, nam...
08/25/2020

Accelerating Federated Learning in Heterogeneous Data and Computational Environments

There are situations where data relevant to a machine learning problem a...
12/22/2020

Turn Signal Prediction: A Federated Learning Case Study

Driving etiquette takes a different flavor for each locality as drivers ...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

1.1 Motivation

Preserving data privacy is of paramount importance for confidentiality-critical fields such as the medical domain. Today it is not uncommon that large volumes of private medical records are illegally released to the dark web[4]. To prevent such incidents, often large amounts of resources are allocated but cannot guarantee full security. Among many precautions, reducing human (including IT specialists) exposure to the data is highly desirable to reduce the chance of compromising data protection by human failure.

Figure 1: Schematic illustration of the federated learning concept. Within multiple data centers, a model is trained for our task. Next, parameters are sent to the central server, where aggregation of the parameters takes place. An aggregated global configuration of the parameters is broadcasted back to the centers. The procedure repeats until convergence or some other limit is reached.

1.2 The typical training scenario

In machine learning, a common scenario today looks like this: One or more institutions (companies, research institutes, governments, etc.) gather data, share it with data scientists who, in turn, train some sort of a model using the data. For example, a group of hospitals share MRI scans of tumors with the medical community to help with the development of an automatic tumor segmentation model. One problem with this approach is that the data, once it is shared, might get leaked, misused, or stolen from the developers. Other hurdles include legal reasons that might make it impossible for the hospitals to share and pool the data in the first place.

1.3 Federated Learning

Conventional machine learning requires exposing training data to a learning algorithm and its developers. When several data sources are involved, the pooling together of the data to create a single data set is also required. New approaches like Federated learning (FL) [15] allow to separate model training from developer access while also not requiring any pooling of data. FL was introduced in a series of seminal works starting from 2015 [7, 6, 10]. FL is a protocol consisting of two alternating steps: a) independent training of models on local entities with their respective unique corpus of data, and b) broadcasting back of only the weights of the trained models to a central entity where the weights are aggregated and a new model is redistributed. The choice of which type of model or network to perform step (a) is dictated by the task (e.g., classification, segmentation, etc.) and can be made based on the state-of-the-art in the respective task. The new FL scenario looks like this: A developer sends his or her model to all the institutions that own training data, the institutions locally train the model for the developer and send the newly trained models back. In this way, the developer can train their model while never getting any access to the data. In this setting however a new problem arises.

1.4 The aggregation problem

How to aggregate the different models that come back? A naive approach to solve the problem would be:

  1. Send an initial model to the first data center

  2. Get back a newly trained model and send it to the second data center

  3. Repeat until all data centers have trained the model once

Approaches like this are called sequential learning and fail due to a phenomenon called ”catastrophic forgetting” [9]. Effectively what would happen is that the final model would only be trained on the data of the last center and would not have generalized to the entire corpus of data. It would simply forget what was learned in the previous center as soon as it gets trained by the next. The state-of-the-art approach tries to avoid this phenomenon by including feedback from every center in each update.

1.5 State of the art

The seminal work of learning deep networks from decentralized data [10] proposed as a solution a plain coordinate-wise mean averaging (FedAvg) of the model weights coming separately from multiple centers. Recently [20] proposed a valuable extension to FedAvg, which takes invariance of network weights to permutations into account. In [16] (FedProx), the authors adjust the training loss of a local model to enforce closeness of local and global model updates. Despite methodological advances, there is neither theoretical nor practical evidence for the right recipe when choosing an aggregation strategy. In this paper, we propose a new idea on how to do aggregation. Similar to other initiatives [17, 13, 11, 3], the FETS challenge111https://fets-ai.github.io/Challenge/ [12] is organized to benchmark different weight aggregation strategies on the clinically important glioma segmentation problem [1, 12, 14, 18, 5]. We contribute to the initiative by proposing an effective extension to the FedAvg strategy. When compared with the other submissions, our model significantly outperformed all of them and won the challenge. On top of that we tested the model locally on a smaller corpus of data to compare it to FedAvg. It notably improves performance compared with FedAvg at no additional compute time.

2 Methodology

2.1 Segmentation network

The segmentation network is a 3D-Unet. It was provided by the challenge organizers and remained unchanged during all experiments. The architecture is composed of an encoder with residual branches followed by a decoder. We use the LeakyReLuactivation function [8] along with instance normalization [19]

- for mitigating the covariate shift. Dice serves as loss function. Fig.

2 illustrates the schematic of the network.

Figure 2: 3D U-net architecture as provided by the FETS challenge.

2.2 Federated Cost Weighted Averaging (FedCostWAvg)

The gold standard federated averaging (FedAvg) approach updates the global model as an average of all local models weighted by the respective sizes of the training data set. The new model is calculated as follows:

(1)

where is the number of samples that model was trained on in round and . We propose a new weighting strategy that includes the amount by which the cost function decreased during the last step. Using FedCostWAvg, the new model is calculated as following:

(2)

with:

(3)

where is the cost of the model at timestep that is simply calculated from the cost function that is being used to train the models locally. is a parameter ranging between and that can be chosen to determine the balance between data size and cost improvement. In our experiments, a value of performed best. Intuitively, this weighting strategy adjusts not only for the training data set size but also for the size of the local improvements that were made during the last training round. Local updates which only marginally improved the local cost will influence the global update to a lesser extent than those which had a bigger impact.

3 Results

The method won the challenge and significantly outperformed all other submitted methods; tables 1 and 2 summarize the performance upon convergence.

In addition we used the provided data (which is a smaller subset of the challenge data) to test the performance of FedCostWAvg against FedAvg in order to visualize the convergence behaviour. We trained and validated the model on 369 samples which were unevenly distributed over 17 data centers. The training-validation split was , the learning rate was and we did epochs per federated round. Please note that computational resources were limited so no exhaustive grid search to find optimal hyperparameters was feasible, also training could not run long enough to achieve maximal performance. Figures 3, 4 and 5 depict the performances over communication rounds. Also note that of course the most informative comparison between methods was done in the challenge itself with more data and many different initialisations. This comparison serves only as a visualization of how different convergence behaviours look like for one initialisation. We observe an improvement for almost all classes and metrics, when using our proposed method. The exemption is the DICE Enhanced Tumor Metric. Note though that the difference is not significant and the methods have not yet converged.

3.1 Discussion

While these results already show a clear improvement over FedAvg, it is unclear whether other hyperparameters would have achieved an even better result. Due to limitations in training resources a proper grid search was not feasible.

The simple and straightforward interpretation of the mechanism of FedCostWAvg is amplification of more informative updates against less informing ones. It could be seen as a diminishing returns acknowledging method. A deeper insight might be the interpretation as resembling a PID controller222The credit for this observation goes to David Naccache. [2]. When one reframes the federated learning problem as a control problem, then the central server that does the averaging is equivalent to a control unit that is included in a feedback loop. When one would then extend this logic to the averaging approach, it might be intelligent to view FedCostWAvg as an approximation of a PID controller, where the newly added term corresponding to the drop in cost is effectively functioning as the derivative part and the data size term as the proportional one. Future research could try to include the integral term as well.

0in0in Label DICE WT DICE ET DICE TC Sens. WT Sens. ET Sens. TC Mean 0,8248 0,7476 0,7932 0,8957 0,8246 0,8269 StdDev 0,1849 0,2444 0,2643 0,1738 0,2598 0,2721 Median 0,8936 0,8259 0,9014 0,948 0,9258 0,9422

25th quantile

0,8116 0,7086 0,8046 0,9027 0,7975 0,8258 75th quantile 0,9222 0,8909 0,942 0,9787 0,9772 0,9785

Table 1: Final performance of FedCostWAvg in the FETS Challenge, DICE and Sensitivity

0in0in Label Spec WT Spec ET Spec TC H95 WT H95 ET H95 TC Comm. Cost Mean 0,9981 0,9994 0,9994 11,618 27,2745 28,4825 0,723 StdDev 0,0024 0,0011 0,0014 31,758 88,566 88,2921 0,723 Median 0,9986 0,9996 0,9998 5 2,2361 3,0811 0,723 25th quantile 0,9977 0,9993 0,9995 2,8284 1,4142 1,7856 0,723 75th quantile 0,9994 0,9999 0,9999 8,6023 3,5628 7,0533 0,723

Table 2: Final performance of FedCostWAvg in the FETS Challenge, Specificity, Hausdorff95 Distance and Communication Cost
Figure 3: Comparison of the DICE Whole Tumor metric per federated round for FedCostWAvg vs. FedAvg. Note of course that the bigger the DICE score, the better and the smaller the Hausdorff95 distance, the better.
Figure 4: Comparison of the DICE Enhanced Tumor metric per federated round for FedCostWAvg vs. FedAvg. Note of course that the bigger the DICE score, the better and the smaller the Hausdorff95 distance, the better.
Figure 5: Comparison of the DICE Tumor Core metric per federated round for FedCostWAvg vs. FedAvg. Note of course that the bigger the DICE score, the better and the smaller the Hausdorff95 distance, the better.

4 Conclusion

In this paper, we describe a method for model aggregation developed for the MICCAI Federated Tumor Segmentation Challenge (FETS). The novelty of the method lays in including local cost improvements when calculating the weights for averaging models which are trained at different centers. The approach is validated on a brain tumor segmentation task and achieves the best performance among all participating teams.

Acknowledgements

Bjoern Menze, Benedikt Wiestler, and Florian Kofler are supported through the SFB 824, subproject B12. Supported by Deutsche Forschungsgemeinschaft (DFG) through TUM International Graduate School of Science and Engineering (IGSSE), GSC 81. Suprosanna Shit and Ivan Ezhov are supported by the Translational Brain Imaging Training Network (TRABIT) under the European Union’s ‘Horizon 2020’ research & innovation program (Grant agreement ID: 765148). With the support of the Technical University of Munich – Institute for Advanced Study, funded by the German Excellence Initiative. Ivan Ezhov is also supported by the International Graduate School of Science and Engineering (IGSSE). Johannes C. Paetzold and Suprosanna Shit are supported by the Graduate School of Bioengineering, Technical University of Munich.

References

  • [1] S. Bakas, H. Akbari, A. Sotiras, M. Bilello, M. Rozycki, J. S. Kirby, J. B. Freymann, K. Farahani, and C. Davatzikos (2017) Advancing the cancer genome atlas glioma mri collections with expert segmentation labels and radiomic features. Scientific data 4 (1), pp. 1–13. Cited by: §1.5.
  • [2] R. E. Bellman (2015) Adaptive control processes. Princeton university press. Cited by: §3.1.
  • [3] P. Bilic, P. Christ, E. Vorontsov, et al. (2019) The liver tumor segmentation benchmark (lits). arXiv preprint arXiv:1901.04056. Cited by: §1.5.
  • [4] Healthcareitnews.com Tens of thousands of patient records posted to dark web. Note: https://www.healthcareitnews.com/news/tens-thousands-patient-records-posted-dark-web, accessed: 2021-07-16 Cited by: §1.1.
  • [5] F. Kofler, C. Berger, D. Waldmannstetter, J. Lipkova, I. Ezhov, G. Tetteh, J. Kirschke, C. Zimmer, B. Wiestler, and B. H. Menze (2020) BraTS toolkit: translating brats brain tumor segmentation algorithms into clinical and scientific practice. Frontiers in neuroscience 14, pp. 125. Cited by: §1.5.
  • [6] J. Konečný, B. McMahan, and D. Ramage (2015) Federated optimization: distributed optimization beyond the datacenter. CoRR abs/1511.03575. External Links: Link, 1511.03575 Cited by: §1.3.
  • [7] J. Konečný, H. B. McMahan, F. X. Yu, P. Richtárik, A. T. Suresh, and D. Bacon (2016) Federated learning: strategies for improving communication efficiency. CoRR abs/1610.05492. External Links: Link, 1610.05492 Cited by: §1.3.
  • [8] A. L. Maas, A. Y. Hannun, A. Y. Ng, et al. (2013)

    Rectifier nonlinearities improve neural network acoustic models

    .
    In Proc. icml, Vol. 30, pp. 3. Cited by: §2.1.
  • [9] M. McCloskey and N. J. Cohen (1989) Catastrophic interference in connectionist networks: the sequential learning problem. In Psychology of Learning and Motivation, G. H. Bower (Ed.), Vol. 24, pp. 109–165. External Links: ISSN 0079-7421, Document, Link Cited by: §1.4.
  • [10] B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas (2017) Communication-efficient learning of deep networks from decentralized data. In Artificial intelligence and statistics, pp. 1273–1282. Cited by: §1.3, §1.5.
  • [11] J. C. Paetzold, J. McGinnis, S. Shit, I. Ezhov, P. Büschl, C. Prabhakar, M. I. Todorov, A. Sekuboyina, G. Kaissis, A. Ertürk, et al. (2021) Whole brain vessel graphs: a dataset and benchmark for graph learning and neuroscience (vesselgraph). arXiv preprint arXiv:2108.13233. Cited by: §1.5.
  • [12] S. Pati, U. Baid, M. Zenk, B. Edwards, M. Sheller, G. A. Reina, P. Foley, A. Gruzdev, J. Martin, S. Albarqouni, et al. (2021) The federated tumor segmentation (fets) challenge. arXiv preprint arXiv:2105.05874. Cited by: §1.5.
  • [13] K. Payette, P. de Dumast, H. Kebiri, I. Ezhov, J. C. Paetzold, S. Shit, A. Iqbal, R. Khan, R. Kottke, P. Grehten, et al. (2020) A comparison of automatic multi-tissue segmentation methods of the human fetal brain using the feta dataset. arXiv e-prints, pp. arXiv–2010. Cited by: §1.5.
  • [14] G. A. Reina, A. Gruzdev, P. Foley, O. Perepelkina, M. Sharma, I. Davidyuk, I. Trushkin, M. Radionov, A. Mokrov, D. Agapov, et al. (2021)

    OpenFL: an open-source framework for federated learning

    .
    arXiv preprint arXiv:2105.06413. Cited by: §1.5.
  • [15] N. Rieke, J. Hancox, W. Li, F. Milletari, H. R. Roth, S. Albarqouni, S. Bakas, M. N. Galtier, B. A. Landman, K. Maier-Hein, et al. (2020) The future of digital health with federated learning. NPJ digital medicine 3 (1), pp. 1–7. Cited by: §1.3.
  • [16] A. K. Sahu, T. Li, M. Sanjabi, M. Zaheer, A. Talwalkar, and V. Smith (2018) On the convergence of federated optimization in heterogeneous networks. arXiv preprint arXiv:1812.06127 3, pp. 3. Cited by: §1.5.
  • [17] A. Sekuboyina et al. (2020) Verse: a vertebrae labelling and segmentation benchmark. arXiv preprint arXiv:2001.09193. Cited by: §1.5.
  • [18] M. J. Sheller, B. Edwards, G. A. Reina, J. Martin, S. Pati, A. Kotrotsou, M. Milchenko, W. Xu, D. Marcus, R. R. Colen, et al. (2020) Federated learning in medicine: facilitating multi-institutional collaborations without sharing patient data. Scientific reports 10 (1), pp. 1–12. Cited by: §1.5.
  • [19] D. Ulyanov, A. Vedaldi, and V. Lempitsky (2016) Instance normalization: the missing ingredient for fast stylization. arXiv preprint arXiv:1607.08022. Cited by: §2.1.
  • [20] M. Yurochkin, M. Agarwal, S. Ghosh, K. Greenewald, N. Hoang, and Y. Khazaeni (2019) Bayesian nonparametric federated learning of neural networks. In International Conference on Machine Learning, pp. 7252–7261. Cited by: §1.5.