Deep learning (DL) is well known for requiring a large amount of data for robust training of generalizable models. For DL in medical research [wang2016deep, oh2020deep, lee2017deep, rajpurkar2017chexnet], large datasets can be difficult to obtain since the data collected by medical centers and hospitals are often privacy-sensitive. Therefore, sharing of the raw data between institutions is usually constrained by the restrictions such as Health Insurance Portability and Accountability Act (HIPAA) in the United States, and General Data Protection Regulation (GDPR) in Europe.
The recent emergence of federated learning (FL) [mcmahan2017communication, kairouz2019advances, li2020federated] has provided this issue with a feasible solution. FL is a distributed machine learning scenario where only the model weights are shared among the participating clients in the federation, while keeping the data decentralized. In medical research, by bringing different hospitals and medical centers into the federation, researchers can collaboratively train a model utilizing different datasets from siloed institutions besides their own [andreux2020siloed, roth2020federated, wang2020automated, sarhan2020fairness, qayyum2021collaborative].
However, the federated settings generate a new major challenge, namely the statistical data heterogeneity across different participating clients [li2020federated, sahu2018convergence, li2018federated, karimireddy2020scaffold, ghosh2020efficient]. The data heterogeneity reflects that the data collected by different clients is not identically distributed (non-IID), which often appears in medical datasets from different sites, because of various reasons including different data acquisition protocols and different local demographics [rieke2020future]. Data heterogeneity may lead to significant increase in communication rounds of the federated training, and inferior performance of the distributed optimization of federated models in certain clients (e.g. medical institutes) [sahu2018convergence], which can further cost their incentives to participate in the federation.
In this work, we propose a federated learning algorithm for classification tasks, Federated Learning with Shared Label Distribution (FedSLD
), which aims to utilize information regarding the clients’ label distribution, to estimate a general prior label distribution for the entire federation. We claim thatFedSLD can mitigate instability of training caused by the statistical heterogeneity of cross-silo FL, such as for medical research. While the algorithm does not access the clients’ data, we assume legitimate for the clients to share the number of samples in each class, which are often the case for cross-silo FL such as in medical applications. More specifically, our contribution in this work is two-fold:
We propose a new FL algorithm for medical image classification: Federated Learning with Shared Label Distribution (FedSLD), for robust training with non-IID data.
We demonstrate that the proposed FedSLD achieves better performance than the leading FL algorithms by conducting extensive experiments on four publicly available datasets (including two benchmark datasets) under pathological non-IID and practical non-IID data partitions.
Laws and restrictions in terms of the data privacy constrain the direct access to the raw data. Yet, there are other information regarding the dataset that can be shared in terms of the federated learning. For instance, FedAvg assumes knowledge of the number of samples in each client: after the aggregation step in FedAvg, the algorithm conducts a weighted average of the updated copies for the next round, and the weights used for the averaging, by default, are the normalized number of samples in each client.
In this work, we focus on the classification tasks and assume legitimate to gain knowledge of the label distribution of each client, namely the number of samples from every class. We compute an estimate of the prior label distribution for the entire federation using the gain knowledge on the label distributions. For FL in medical applications, the label distributions from different medical silos can often be drastically different due to the regional demographics. Knowledge of the clients’ label distributions will help us better understand the non-IID data in the federation.
To formulate the process, let us consider a federation with non-IID data. For a given data sample , where stand for the data and
represents the label, the probability that it appears in the dataset of client’s, , does not necessarily equal to the probability of it to appear in the dataset of client ’s,
. By Bayes’ theorem, we have. More often than not, especially in medical imaging domain, non-IID data implicitly implies that both the label-conditioned probabilities, , and the marginal label distributions, , are different for different clients. In this work, we focus on acquiring the information reflecting the marginal label distribution for each client (), to compute the estimate of the prior label distribution for the entire federation.
We define the estimate for the prior of class for the federation, as the sum of the numbers of samples for class in each client divided by the sum of the total number of samples in each client. This is shown in equation (1), where is the estimate prior of class , is the number of samples from class on client , is the total number of samples on client , and is the number of clients.
During local update of the current model on a client, given a batch of data , where is the batch size, we first compute the label distribution in this batch as in equation (2), where the represents the label distribution, means the indicator faction, with its value equal to 1 if the inner part is true, and 0 otherwise. In essence, Equation (2) computes the proportion of class samples in the batch by normalizing the number of class samples in this batch.
Then, we define the batch loss as a weighted cross-entropy loss, shown in Equation (3), where means the batch loss, and represents the copy of the model on client . By doing this, we can enforce proportional contribution (to the local objective) of each class of the data, with respect to its share of the true underlying distribution across the federation.
We follow the aggregation step in a typical FL algorithm such as FedAvg, where we compute the weighted average of the updated models from all clients, with the weights being the number of samples in each client. A detailed algorithm is shown in Algorithm 1.
3 Experiments and Results
In this section, we evaluate the performance of the proposed FedSLD through experiments on four publicly available datasets (including two benchmark imaging datasets), and compare it with two leading FL algorithms, FedAvg [mcmahan2017communication], an algorithm that average the local updates of the global model, and FedProx [li2018federated], an algorithm that adds a proximal term on the local objective to enhance performance robustness on non-IID data. To evaluate the general performance of the algorithms, we compute the test accuracies and demonstrate the empirical convergence performance by plotting the training loss and test accuracy curves. In addition, we examine the fairness of the method following recent work [li2019fair]. More details on the metrics are in Section 2.
3.1 Experiments setup
We conduct experiments on two benchmark image datasets: MNIST[lecun-mnisthandwrittendigit-2010], a 10-class handwritten digit classification dataset and CIFAR10 [krizhevsky2009learning], a 10-class dataset with animals and transportations images. We further evaluate the methods on two medical image datasets from the MedMNIST dataset collection [yang2021medmnist], namely the OrganMNIST(axial) dataset: an 11-class dataset of liver tumor images [bilic2019liver], and the PathMNIST dataset: a 9-class dataset of colorectal cancer images [kather2019predicting]. We partition each dataset into a training set and a test set and ensure that they share the same label distribution.
Two non-IID settings. We partition each dataset according to two different non-IID settings: 1) a pathological non-IID setting. In this setting, we follow [mcmahan2017communication] by assigning each client with two random classes. A random number of images from these two classes are assigned to this client. We set the number of clients to be 12 to mimic a cross-silo FL setting. Figure 1 shows the pathological non-IID setting in more details; 2) a practical non-IID setting. In this setting, we randomly partition each class in the training set into 12 shards (corresponding to a total of 12 clients in the federation): shards of , one shard of and one shard of . For each client, we randomly assign a shard from each class to this client, so that the client will possess images from all classes, with more images from some classes while less images from others. This non-IID setting is more similar to the real-world medical applications, since datasets held at medical centers often contain a variety of classes, but medical centers in different regions, due to local demographics, may present different occurrence of different classes. Consequently, the datasets at the medical centers are often imbalanced with different majority classes. Figure 2 shows the data distribution of all four datasets in more details.
We use the classic four-layer CNN model with two 5x5 convolutional layers and two fully connected layers (hidden layer has 500 units). We use a batch size of 256, 5 local epochs, 0.01 as the learning rate. For the practical non-IID partition, we train the model for 80 rounds, and for the pathological non-IID setting, we train the model for 160 rounds. All experiments are run on an NVIDIA Tesla V100 GPU and implemented in PyTorch[NEURIPS2019_9015] and PySyft [ryffel2018generic].
Metrics. We compute two types of test accuracies for each setting. 1) The Best Mean Client Test Accuracy (BMCTA), referred as BMTA in [huang2021personalized]. BMCTA is computed as the highest mean client test accuracy of each round. 2) The Best Test Accuracy (BTA). We treat the test sets from different clients as a combined test set, and compute the highest test accuracy over all round. We also investigate the methods’ convergence performance by plotting the training loss and test accuracy curves. In addition, we follow [li2019fair]
and examine the fairness of the methods by using the Gaussian kernel density estimation on the client test accuracies. Higher density at higher accuracy reflects a better result.
Under the pathological non-IID setting, Table 1 and Fig. 3 show that for MNIST, and the two medical datasets, the proposed FedSLD has a better performance with the improvement on the test accuracy of up to 1.57%, and the kernel density estimations show that FedSLD has slightly higher density which is more concentrated at a higher test accuracy. On CIFAR10, FedSLD reaches competitive performance with FedAvg and FedProx.
Under the practical non-IID setting, we can see from Table 2 and Figure 4 that the proposed FedSLD outperforms the compared FedAvg and FedProx on every dataset, with the improvement of BMCTA ranging from 1.10% to 5.50%, and the improvement of BTA ranging from 0.18% to 2.41%. In addition, FedSLD achieves better convergence behavior on MNIST and OrganMNIST (axial) datasets. The fairness plots reveal that FedSLD
not only increases the overall performance with respect to the entire federation, but the variances of the client test accuracies are also reduced on MNIST and PathMNIST datasets, which implies a more fair training. On CIFAR10 and OrganMNIST (axial) datasets, we can see a clear decrease of the density at low accuracy and an increase on the density at high accuracy, which explains the improvement of the BMCTA.
In this work, we proposed a new FL algorithm for medical image classification: Federated Learning with Shared Label Distribution (FedSLD). FedSLD aims to mitigate the effect caused by non-IID data by leveraging the clients’ label distribution. We conducted extensive experiments on four publicly available datasets with two types of non-IID setting, and demonstrated that FedSLD outperforms the compared leading FL algorithms, and encourages a more fair performance across all the participating clients.
5 Compliance with ethical standards
Ethical approval was not required, as this study used previously collected and deidentified data (including medical imaging data) available in public repositories.
This work was supported in part by National Institutes of Health (NIH) (1R01CA193603 and 1R01CA218405), National Science Foundation (NSF) (CICI: SIVD: 2115082), and the grant 1R01EB032896 as part of the NSF/NIH Smart Health and Biomedical Research in the Era of Artificial Intelligence and Advanced Data Science Program. This work used the Extreme Science and Engineering Discovery Environment (XSEDE), which is supported by NSF grant number ACI-1548562. Specifically, it used the Bridges-2 system, which is supported by NSF award number ACI-1928147, at the Pittsburgh Supercomputing Center.