FedSLD: Federated Learning with Shared Label Distribution for Medical Image Classification

10/15/2021
by   Jun Luo, et al.
University of Pittsburgh
UPMC
0

Machine learning in medical research, by nature, needs careful attention on obeying the regulations of data privacy, making it difficult to train a machine learning model over gathered data from different medical centers. Failure of leveraging data of the same kind may result in poor generalizability for the trained model. Federated learning (FL) enables collaboratively training a joint model while keeping the data decentralized for multiple medical centers. However, federated optimizations often suffer from the heterogeneity of the data distribution across medical centers. In this work, we propose Federated Learning with Shared Label Distribution (FedSLD) for classification tasks, a method that assumes knowledge of the label distributions for all the participating clients in the federation. FedSLD adjusts the contribution of each data sample to the local objective during optimization given knowledge of the distribution, mitigating the instability brought by data heterogeneity across all clients. We conduct extensive experiments on four publicly available image datasets with different types of non-IID data distributions. Our results show that FedSLD achieves better convergence performance than the compared leading FL optimization algorithms, increasing the test accuracy by up to 5.50 percentage points.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

11/25/2021

FedDropoutAvg: Generalizable federated learning for histopathology image classification

Federated learning (FL) enables collaborative learning of a deep learnin...
01/25/2020

TiFL: A Tier-based Federated Learning System

Federated Learning (FL) enables learning a shared model across many clie...
03/05/2021

FedDis: Disentangled Federated Learning for Unsupervised Brain Pathology Segmentation

In recent years, data-driven machine learning (ML) methods have revoluti...
04/08/2022

CD^2-pFed: Cyclic Distillation-guided Channel Decoupling for Model Personalization in Federated Learning

Federated learning (FL) is a distributed learning paradigm that enables ...
04/09/2022

Local Learning Matters: Rethinking Data Heterogeneity in Federated Learning

Federated learning (FL) is a promising strategy for performing privacy-p...
05/16/2019

BrainTorrent: A Peer-to-Peer Environment for Decentralized Federated Learning

Access to sufficient annotated data is a common challenge in training de...
07/16/2021

AutoFL: Enabling Heterogeneity-Aware Energy Efficient Federated Learning

Federated learning enables a cluster of decentralized mobile devices at ...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Deep learning (DL) is well known for requiring a large amount of data for robust training of generalizable models. For DL in medical research [wang2016deep, oh2020deep, lee2017deep, rajpurkar2017chexnet], large datasets can be difficult to obtain since the data collected by medical centers and hospitals are often privacy-sensitive. Therefore, sharing of the raw data between institutions is usually constrained by the restrictions such as Health Insurance Portability and Accountability Act (HIPAA) in the United States, and General Data Protection Regulation (GDPR) in Europe.

The recent emergence of federated learning (FL) [mcmahan2017communication, kairouz2019advances, li2020federated] has provided this issue with a feasible solution. FL is a distributed machine learning scenario where only the model weights are shared among the participating clients in the federation, while keeping the data decentralized. In medical research, by bringing different hospitals and medical centers into the federation, researchers can collaboratively train a model utilizing different datasets from siloed institutions besides their own [andreux2020siloed, roth2020federated, wang2020automated, sarhan2020fairness, qayyum2021collaborative].

However, the federated settings generate a new major challenge, namely the statistical data heterogeneity across different participating clients  [li2020federated, sahu2018convergence, li2018federated, karimireddy2020scaffold, ghosh2020efficient]. The data heterogeneity reflects that the data collected by different clients is not identically distributed (non-IID), which often appears in medical datasets from different sites, because of various reasons including different data acquisition protocols and different local demographics [rieke2020future]. Data heterogeneity may lead to significant increase in communication rounds of the federated training, and inferior performance of the distributed optimization of federated models in certain clients (e.g. medical institutes) [sahu2018convergence], which can further cost their incentives to participate in the federation.

In this work, we propose a federated learning algorithm for classification tasks, Federated Learning with Shared Label Distribution (FedSLD

), which aims to utilize information regarding the clients’ label distribution, to estimate a general prior label distribution for the entire federation. We claim that

FedSLD can mitigate instability of training caused by the statistical heterogeneity of cross-silo FL, such as for medical research. While the algorithm does not access the clients’ data, we assume legitimate for the clients to share the number of samples in each class, which are often the case for cross-silo FL such as in medical applications. More specifically, our contribution in this work is two-fold:

  1. [label=)]

  2. We propose a new FL algorithm for medical image classification: Federated Learning with Shared Label Distribution (FedSLD), for robust training with non-IID data.

  3. We demonstrate that the proposed FedSLD achieves better performance than the leading FL algorithms by conducting extensive experiments on four publicly available datasets (including two benchmark datasets) under pathological non-IID and practical non-IID data partitions.

The rest of the paper is organized as follows: Section 2 provides details of the proposed FedSLD; experiments and results are shown in Section 3; Section 4 discusses our findings and concludes the study.

2 Method

Laws and restrictions in terms of the data privacy constrain the direct access to the raw data. Yet, there are other information regarding the dataset that can be shared in terms of the federated learning. For instance, FedAvg assumes knowledge of the number of samples in each client: after the aggregation step in FedAvg, the algorithm conducts a weighted average of the updated copies for the next round, and the weights used for the averaging, by default, are the normalized number of samples in each client.

In this work, we focus on the classification tasks and assume legitimate to gain knowledge of the label distribution of each client, namely the number of samples from every class. We compute an estimate of the prior label distribution for the entire federation using the gain knowledge on the label distributions. For FL in medical applications, the label distributions from different medical silos can often be drastically different due to the regional demographics. Knowledge of the clients’ label distributions will help us better understand the non-IID data in the federation.

To formulate the process, let us consider a federation with non-IID data. For a given data sample , where stand for the data and

represents the label, the probability that it appears in the dataset of client

’s, , does not necessarily equal to the probability of it to appear in the dataset of client ’s,

. By Bayes’ theorem, we have

. More often than not, especially in medical imaging domain, non-IID data implicitly implies that both the label-conditioned probabilities, , and the marginal label distributions, , are different for different clients. In this work, we focus on acquiring the information reflecting the marginal label distribution for each client (), to compute the estimate of the prior label distribution for the entire federation.

We define the estimate for the prior of class for the federation, as the sum of the numbers of samples for class in each client divided by the sum of the total number of samples in each client. This is shown in equation (1), where is the estimate prior of class , is the number of samples from class on client , is the total number of samples on client , and is the number of clients.

(1)

During local update of the current model on a client, given a batch of data , where is the batch size, we first compute the label distribution in this batch as in equation (2), where the represents the label distribution, means the indicator faction, with its value equal to 1 if the inner part is true, and 0 otherwise. In essence, Equation (2) computes the proportion of class samples in the batch by normalizing the number of class samples in this batch.

(2)
(3)

Then, we define the batch loss as a weighted cross-entropy loss, shown in Equation (3), where means the batch loss, and represents the copy of the model on client . By doing this, we can enforce proportional contribution (to the local objective) of each class of the data, with respect to its share of the true underlying distribution across the federation.

Input: Initialized model parameter weights , number of clients

, number of local epochs

, batch size , is the batch size, learning rate , number of rounds .

1:, acquire , client ’s numbers of samples of each class .
2:, // compute estimated prior label distribution.
3:for  do
4:      // broadcast the model parameters.
5:     for  in parallel do
6:         for  in all minibatches do
7:              
8:              
9:              
10:         end for
11:     end for
12:      // aggregate the model updates
13:end for
14:return
Algorithm 1 FedSLD.

We follow the aggregation step in a typical FL algorithm such as FedAvg, where we compute the weighted average of the updated models from all clients, with the weights being the number of samples in each client. A detailed algorithm is shown in Algorithm 1.

3 Experiments and Results

In this section, we evaluate the performance of the proposed FedSLD through experiments on four publicly available datasets (including two benchmark imaging datasets), and compare it with two leading FL algorithms, FedAvg [mcmahan2017communication], an algorithm that average the local updates of the global model, and FedProx [li2018federated], an algorithm that adds a proximal term on the local objective to enhance performance robustness on non-IID data. To evaluate the general performance of the algorithms, we compute the test accuracies and demonstrate the empirical convergence performance by plotting the training loss and test accuracy curves. In addition, we examine the fairness of the method following recent work [li2019fair]. More details on the metrics are in Section 2.

3.1 Experiments setup

Datasets.

We conduct experiments on two benchmark image datasets: MNIST  

[lecun-mnisthandwrittendigit-2010], a 10-class handwritten digit classification dataset and CIFAR10 [krizhevsky2009learning], a 10-class dataset with animals and transportations images. We further evaluate the methods on two medical image datasets from the MedMNIST dataset collection [yang2021medmnist], namely the OrganMNIST(axial) dataset: an 11-class dataset of liver tumor images [bilic2019liver], and the PathMNIST dataset: a 9-class dataset of colorectal cancer images [kather2019predicting]. We partition each dataset into a training set and a test set and ensure that they share the same label distribution.

Figure 1: Pathological non-IID distribution. The first row depicts the portions locations from each class. The second row depicts the dataset composition of each client.
Figure 2: Practical non-IID distribution. The first row depicts the portions’ locations from each class. The second row depicts the dataset composition of each client.

Two non-IID settings. We partition each dataset according to two different non-IID settings: 1) a pathological non-IID setting. In this setting, we follow [mcmahan2017communication] by assigning each client with two random classes. A random number of images from these two classes are assigned to this client. We set the number of clients to be 12 to mimic a cross-silo FL setting. Figure 1 shows the pathological non-IID setting in more details; 2) a practical non-IID setting. In this setting, we randomly partition each class in the training set into 12 shards (corresponding to a total of 12 clients in the federation): shards of , one shard of and one shard of . For each client, we randomly assign a shard from each class to this client, so that the client will possess images from all classes, with more images from some classes while less images from others. This non-IID setting is more similar to the real-world medical applications, since datasets held at medical centers often contain a variety of classes, but medical centers in different regions, due to local demographics, may present different occurrence of different classes. Consequently, the datasets at the medical centers are often imbalanced with different majority classes. Figure 2 shows the data distribution of all four datasets in more details.

Implementation details.

We use the classic four-layer CNN model with two 5x5 convolutional layers and two fully connected layers (hidden layer has 500 units). We use a batch size of 256, 5 local epochs, 0.01 as the learning rate. For the practical non-IID partition, we train the model for 80 rounds, and for the pathological non-IID setting, we train the model for 160 rounds. All experiments are run on an NVIDIA Tesla V100 GPU and implemented in PyTorch 

[NEURIPS2019_9015] and PySyft [ryffel2018generic].

Metrics. We compute two types of test accuracies for each setting. 1) The Best Mean Client Test Accuracy (BMCTA), referred as BMTA in [huang2021personalized]. BMCTA is computed as the highest mean client test accuracy of each round. 2) The Best Test Accuracy (BTA). We treat the test sets from different clients as a combined test set, and compute the highest test accuracy over all round. We also investigate the methods’ convergence performance by plotting the training loss and test accuracy curves. In addition, we follow [li2019fair]

and examine the fairness of the methods by using the Gaussian kernel density estimation on the client test accuracies. Higher density at higher accuracy reflects a better result.

BMCTA BTA
Dataset FedAvg FedProx FedSLD FedAvg FedProx FedSLD
 [mcmahan2017communication]  [li2018federated] (Ours)  [mcmahan2017communication]  [li2018federated] (Ours)
MNIST 95.60 95.71 95.74 95.92 95.98 96.03
CIFAR10 51.50 51.39 50.81 51.39 51.24 50.71
OrganMNIST(axial) 59.52 59.44 59.70 64.99 65.10 66.13
PathMNIST 56.44 56.62 57.94 57.54 57.56 59.11
Table 1: The Best Mean Client Test Accuracy (BMCTA) and Best Test Accuracy (BTA) for the pathological non-IID setting
Figure 3: The convergence and fairness performance under the pathological non-IID setting. We measure the fairness using Gaussian kernel density estimation. Higher density concentrated at a higher accuracy reflects a better result.

3.2 Results

Under the pathological non-IID setting, Table 1 and Fig. 3 show that for MNIST, and the two medical datasets, the proposed FedSLD has a better performance with the improvement on the test accuracy of up to 1.57%, and the kernel density estimations show that FedSLD has slightly higher density which is more concentrated at a higher test accuracy. On CIFAR10, FedSLD reaches competitive performance with FedAvg and FedProx.

BMCTA BTA
Dataset FedAvg FedProx FedSLD FedAvg FedProx FedSLD
 [mcmahan2017communication]  [li2018federated] (Ours)  [mcmahan2017communication]  [li2018federated] (Ours)
MNIST 93.41 93.45 95.56 94.15 94.20 95.85
CIFAR10 32.07 31.98 37.48 35.46 35.38 37.79
OrganMNIST(axial) 82.32 81.53 84.75 85.69 85.54 87.37
PathMNIST 52.70 52.77 53.87 57.38 57.72 57.90
Table 2: BMCTA and BTA for the practical non-IID setting

Under the practical non-IID setting, we can see from Table 2 and Figure 4 that the proposed FedSLD outperforms the compared FedAvg and FedProx on every dataset, with the improvement of BMCTA ranging from 1.10% to 5.50%, and the improvement of BTA ranging from 0.18% to 2.41%. In addition, FedSLD achieves better convergence behavior on MNIST and OrganMNIST (axial) datasets. The fairness plots reveal that FedSLD

not only increases the overall performance with respect to the entire federation, but the variances of the client test accuracies are also reduced on MNIST and PathMNIST datasets, which implies a more fair training. On CIFAR10 and OrganMNIST (axial) datasets, we can see a clear decrease of the density at low accuracy and an increase on the density at high accuracy, which explains the improvement of the BMCTA.

Figure 4: The convergence and fairness performance under the practical non-IID setting.

4 Conclusion

In this work, we proposed a new FL algorithm for medical image classification: Federated Learning with Shared Label Distribution (FedSLD). FedSLD aims to mitigate the effect caused by non-IID data by leveraging the clients’ label distribution. We conducted extensive experiments on four publicly available datasets with two types of non-IID setting, and demonstrated that FedSLD outperforms the compared leading FL algorithms, and encourages a more fair performance across all the participating clients.

5 Compliance with ethical standards

Ethical approval was not required, as this study used previously collected and deidentified data (including medical imaging data) available in public repositories.

Acknowledgements.

This work was supported in part by National Institutes of Health (NIH) (1R01CA193603 and 1R01CA218405), National Science Foundation (NSF) (CICI: SIVD: 2115082), and the grant 1R01EB032896 as part of the NSF/NIH Smart Health and Biomedical Research in the Era of Artificial Intelligence and Advanced Data Science Program. This work used the Extreme Science and Engineering Discovery Environment (XSEDE), which is supported by NSF grant number ACI-1548562. Specifically, it used the Bridges-2 system, which is supported by NSF award number ACI-1928147, at the Pittsburgh Supercomputing Center.

References