Health data can be used by medical practitioners to provide health care and by researchers to build machine learning models to improve clinical services and make health predictions. But such data is mostly stored distributively on mobile devices or in different hospitals because of its large volume and high privacy, implying that traditional learning approaches on centralized data may not be viable. Therefore, federated learning that avoids data collection and central storage becomes necessary and up to now significant progress has been made.
In 2005, Rehak et al. [rehak2005model] established CORDRA, a framework that provided standards for an interoperable repository infrastructure where data repositories were clustered into community federations and their data were retrieved by a global federation using the metadata of each community federation. In 2011, Barcelos et al. [barcelos2011agent] created an agent-based federated catalog of learning objects (AgCAT system) to facilitate assess of distributed educational resources. Although little machine learning was involved in these two models, their practice of distributed data management and retrieval served as a reference for the development of federated learning algorithms.
In 2012, Balcan et al. [balcan2012distributed]
implemented probably approximately correct (PAC) learning in a federated manner and reported the upper and lower bounds on the amount of communication required to obtain desirable learning outcomes. In 2013, Richtáriket al. [richtarik2013distributed] proposed a distributed coordinate descent method named HYbriD for solving loss minimization problems with big data. Their work provided the bounds of communication rounds needed for convergence and presented experimental results with the LASSO algorithm on 3TB data. In 2014, Fercoq et al. [fercoq2014fast]
designed an efficient distributed randomized coordinate descent method for minimizing regularized non-strongly convex loss functions and demonstrated that their method was extendable to a LASSO optimization problem with 50 billion variables. In 2015, Konecnyet al. [konevcny2015federated] introduced a federated optimization algorithm suitable for training massively distributed, non-identically independently distributed (non-IID) and unbalanced datasets.
In 2016, McMahan et al. [mcmahan2016communication] developed the () algorithm that fitted a global model with the training data left locally on distributed devices (known as clients). Experimental results showed that performed satisfactorily on both IID and non-IID data and was robust to various datasets. Later, Konevcny et al. [konevcny2016federated] modified the global model update of in two ways, namely structured updates and sketched updates, to reduce the uplink communication costs, and experiments indicated that the reduction can be two orders of magnitude. In addition, Bonawitz et al. [bonawitz2016practical] designed the Secure Aggregation protocol to protect the privacy of each client’s model gradient in federated learning, without sacrificing the communication efficiency. More recently, Smith et al. [smith2017federated] devised a systems-aware optimization method named MOCHA that considered simultaneously the issues of high communication cost, stragglers, and fault tolerance in multi-task learning. Zhao et al. [zhao2018federated] addressed the non-IID data challenges in federated learning and presented a data-sharing strategy whereby the test accuracy could be enhanced significantly with only a small portion of globally shared data among clients. Bagdasaryan et al. [bagdasaryan2018backdoor] designed a novel model-poisoning technique that used model replacement
to backdoor federated learning. Liu et al. used a federated transfer learning strategy to balance global and local learning[liu2018fadl].
Most of the previously published federated learning methods focused on optimization of a single issue such as test accuracy, privacy, security or communication efficiency; yet none of them considered the computation load on the clients. This study took into account three issues in federation learning, namely, the local client-side computation complexity, the communication cost, and the test accuracy. We developed an algorithm named Loss-based Adaptive Boosting (LoAdaBoost FedAvg), where the local models with a high cross-entropy loss were further optimized before model averaging on the server. To evaluate the predictive performance of our method, we extracted the data of critical care patients’ drug usage and mortality from the Medical Information Mart for Intensive Care (MIMIC-III) database [johnson2016mimic] and partitioned this data into IID and non-IID distributions. In the IID scenario LoAdaBoost FedAvg was compared with , while in the non-IID scenario the two methods were both combined with the data sharing strategy and then compared. Our primary contributions include the application of federated learning to health data and the development of the straightforward LoAdaBoost FedAvg algorithm that had better performance than the state-of-the-art approach.
Materials and Methods
FedAvg and the data-sharing strategy
Developed by McMahan et al. [mcmahan2016communication], the
algorithm trained neural network models via local stochastic gradient descent (SGD) on each client and then averaged the weights of each client model on a server to produce a global model. This local-training-and-global-average process was carried out iteratively as follows. At theth iteration, a random fraction of the clients were selected for computation: the server first sent the average weights at the previous iteration (denoted ) to the selected clients (except for the st iteration where the clients started its model from the same random weight initialization); each client independently learnt a neural network model initialized with on its local data divided into minibatches for epochs, and then reported the learned weights (denoted where was the client index) to the server for averaging (see Figure 1). The global model was updated by the average weights of each iteration.
As demonstrated in the literature[mcmahan2016communication],
exhibited satisfactory performance with IID data. But when trained on non-IID data, its accuracy could drop substantially in that, with non-IID sampling, stochastic gradient could no longer be regarded as an unbiased estimate of the full gradient according to Zhaoet al. [zhao2018federated]. To solve this statistical problem, they proposed a data-sharing strategy that complements via globally sharing a small subset of training data between all the clients (see Figure 2). Stored on the server, the shared data was a dataset distinct from the clients’ data and assigned to clients when was initialized. Thus, this strategy improved with no harm to privacy and little addition to the communication cost. The strategy had two parameters that were , the random fraction of the globally-shared data distributed to each client, and , the ratio of the globally-shared data size to the total client data size. Raising these two parameters could lead to a better predictive accuracy but meanwhile make federated learning less decentralized, reflecting a trade-off between non-IID accuracy and centralization. In addition, it is worth mentioning that Zhao et al. also introduced an alternative initialization for their data-sharing strategy: the server could train a warm-up model on the globally shared data and then distribute the model’s weights to the clients, rather than assigning them with the same random initial weights. In this work, we kept the original initialization method to leave all computation on the clients.
We devised a variant of named LoAdaBoost FedAvg
that was based on the median cross-entropy loss to adaptively boost the training process on those clients appearing to be weak learners. The reason for using the median loss rather than average lied in that the latter was less robust to outliers that were significantly underfitted or overfitted client models. Under our approach, not only the model weights but also the cross-entropy losses were communicated between the clients and the server. As demonstrated in Figure3, at the th iteration, the server delivered the average weights and the median loss obtained at the th iteration to each client; then, each client learnt a neural network model in a loss-based adaptive boosting () manner, and reported the learnt weights and the cross-entropy loss to the server. The global model was parametrized by .
Our method involved retraining of client models and worked as follows. Same as FedAvg, for Client , its model was initialized with and trained on the local data divided into minibatches; different from FedAvg, the learning process was performed for instead of
epochs. For odd, would be rounded up to the nearest integer. Here, we used and to respectively denote the weights and the cross-entropy loss of the trained model and currently equaled 0. If was not greater than (that is, ), computation on Client would be finished, with and sent to the server; otherwise, its model would be retrained for another epochs. Now, the new loss was denoted where the superscript indicated the first retraining round. If , the model would be retrained for more epochs. This process was repeated for epochs (where ) and stopped until or the total number of epochs performed on Client reached . and the final were sent to the server.
was adaptive in the sense that the performance of a poorly-fitted client model after the first epochs was boosted via continuous retraining for a decaying number of epochs. The quality of training was determined by comparing the model’s loss with the median loss . In this way, our method was able to ensure that the losses of most (if not all) client models would be lower than the median loss at the prior iteration, thereby making the learning process more effective. In addition, because at one iteration only a few of the client models were expected to be trained for the full epochs, the average number of epochs run on each client would be less than E, meaning a smaller local computational load under our method than that of . Furthermore, since both and were a single value transferred at the same time with between the server and Client , little additional communication cost would be incurred by our method.
Similar to other stochastic optimization-based machine learning methods [zhao2018federated, bottou2010large, rakhlin2012making, ghadimi2013stochastic], an important assumption for our approach to work satisfactorily was that the stochastic gradient on the clients’ local data was an unbiased estimate of the full gradient on the population data. This held true for IID data but broke for non-IID. In the latter case, an optimized client model with low losses did not necessarily generalize well to the population, implying that reducing the losses through adding more epochs to the clients was less likely to enhance the global model’s performance. This non-IID problem could be alleviated by combining LoAdaBoost FedAvg with the data-sharing strategy, because the local data became less non-IID when integrated with even a small portion of IID data.
In this work, was used as the baseline model for IID data evaluation to be compared with LoAdaBoost FedAvg, while with the data-sharing strategy was used as the benchmark for non-IID data evaluation to be compared with LoAdaBoost FedAvg with the data-sharing strategy.
The MIMIC-III database
The performance evaluation concerned with the MIMIC-III database [johnson2016mimic], which contains health information for critical care patients at a large tertiary care hospital in the US. Included in MIMIC-III are 26 tables of data ranging from patients’ admissions, to laboratory measurements, diagnostic codes, imaging reports, hospital length of stay and more. We processed three of these tables, namely ADMISSIONS, PATIENTS and PRESCRIPTIONS, to obtain two new tables as follows:
ADMISSIONS and PATIENTS were inner-joined on to form the PERSONAL_INFORMATION table which recorded , and the survival status () of all patients.
Each patient’s usage of DRUGS during the first 48 hours of stay (that is, = two days) at the hospital was extracted from PRESCRIPTIONS to give the SUBJECT_DRUG_TABLE table.
Further joining these two tables on gave a dataset of 30,760 examples, from which we randomly selected 30,000 examples to form the evaluation dataset where were the predictors and
was the response variable. The summary of this dataset was provided in Table1.
|integer: IDs ranging from 2 to 99,999||30,000|
|binary: 0 for female and 1 for male||17,284/12,716|
|binary: 0 for ages less than or equal to 65 and 1 for greater||13,947/16,053|
|binary: 0 for survive and 1 for death||20,841/9,159|
|binary: 0 for not prescribed to patients and 1 for prescribed||2814 dimensions|
The evaluation dataset was shuffled and split into a training set of 20,000 examples, a test set of 8,000 examples, and a holdout set of 2,000 examples for implementing data-sharing strategy. As with the literature [mcmahan2016communication], the training set was partitioned over 100 clients in two ways: IID in which the data was randomly divided into 100 clients, each consisting of 200 examples; and non-IID in which the data was firstly sorted according to and
, and then split into equal-sized 100 clients. Using the skewed non-IID data, we would be able to assess the robustness of our model to scenarios when IID data assumption cannot be made, which is more realistic in the healthcare industry. As for the holdout set, we chose the size of the globally shared datasetand the random distributed fraction . This means that 1000 examples of the holdout set were shared between the clients, each receiving 40 examples, which were only of the total non-IID data. The rationale behind this combination of and was that, having experimented with different values, we found that under this setting complemented by the data-sharing strategy with non-IID data would perform similarly to with IID data. Both and could be increased to further enhance the performance, at the expense of decentralization.
Parameter sets and evaluation metrics
The neural network trained on each client consisted of three hidden layers with , and
parameters in total. Instead of SGD, the stochastic optimizer chosen in this study was Adaptive Moment Estimation (Adam), which requires less memory and is more computationally efficient according to empirical results[kingma2014adam]
. We used the default parameter set for Adam in the Keras framework: the learning rateand the exponential decay rates for the moment estimates and . In addition, while setting the minibatch size to , we experimented with the number of epochs , and and the fraction of clients ,, and .
The evaluation metrics were threefold. First, the area under the ROC curve (AUC) was used to assess the predictive performance of a federated learning model. Here, ROC stands for the receiver operating characteristic curve, a plot of the true positive rate (TPR) against the false positive rate (FPR) at various thresholds. In our study, for a given threshold, TPR was the ratio of the number of mortalities predicted by the global model to the total number of mortalities in the test dataset, while FPR was calculated aswhere was the ratio of the number of predicted survivals to the total number of survivals. Second, the communication cost was measured by the number of iterations or, equivalently, communication rounds. One complete round started from the server conveying under ( and under LoAdaBoost FedAvg) to the clients, and ended with the clients sending under ( and under LoAdaBoost FedAvg) to the server. Third, we used the average number of epochs on the clients per communication round (denoted ) to measure the local computation load. Under , would be constant since each client learns its model for a fixed value of . But under our adaptive approach, the number of epochs for each client would be varying and collectively contribute to a different value of from that of . In the experiments, our approach was evaluated against the baseline methods in terms of whether it was able to reach the same AUC target for a non-greater number of communication rounds and with a smaller value .
The experiments began with an illustration of ’s performance gap between IID and non-IID local data with ,, and and (see Figure 4). Same as the work by McMahan et al.[mcmahan2016communication], each curve in the figure was made monotonically increasing via taking the highest test-set AUC achieved over all previous communication rounds. It is apparent that the global model trained on IID data outperformed that fitted by non-IID data for all different s. The former exhibited a higher test-set AUC than the latter ( versus ), though they converged after a similar number of communication rounds (for instance, when ). Since every curve for IID data progressed toward an AUC of 0.79 (depicted as the gray dashed line in Figure 4), we decided to use this value as the target AUC, which each method in evaluation was required to reach before comparing the number of communication rounds and the average epochs per client per round.
Also noticeable from Figure 4 is that, on the same local data distribution, choosing a smaller fraction of the clients for computation in each round tended to yield better performance. For example, on IID data with C=0.1 reached a marginally higher maximum AUC for less communication rounds than with , and , which was consistent with the experimental results reported in the literature[mcmahan2016communication]. Therefore, throughout the performance evaluation we set to . Lastly, to enable a fair comparison, and our approach used the same random seeds to select the clients.
Evaluation on IID data
As the baseline method, was compared with LoAdaBoost FedAvg under the experimental setup and , and . The results are shown in Figure 5 where the horizontal dashed line represents the target test-set AUC of .
Except the curve for with , all other curves rose across the dashed target line and the number of communication rounds taken were inversely proportional to . For instance, with converged faster than with , and LoAdaBoost FedAvg with faster than with . This was reasonable because adding more epochs on the clients sped up the optimization; however, overfitting could be caused to reduce the performance, as in the case of with .
Given the same , it can be noted that, although slightly lagging behind at the initial rounds, our method always met the target for less communication rounds than . A summary of the experiment results proving the superiority of our method is given in Table 2. With , our method reached a test-set AUC of for communication rounds and average epochs, whereas took rounds and average epochs. With , the two methods had the same number of rounds but the clients under our method ran less epochs on average. With , failed to fulfill the target but our method continued to improve: only rounds were needed to achieve the AUC of 0.79 for average epochs. This indicates that the loss-based adaptive boosting mechanism of our method seemed to have an effective regularization effect that prevented overfitting.
|epochs||communication rounds||average epochs||communication rounds||average epochs|
Evaluation on non-IID data
When evaluating on non-IID data, we regarded complemented by the data-sharing strategy as the benchmark. As aforementioned and shown in Figure 4, the original algorithm performed undesirably. But when each client was shared with only of the total non-IID data, the performance could be increased to an AUC of , which was comparable to that of with IID data. This can be viewed from Figure 6 where the curves for with shared data for all s converged to a level above . In addition, same as observed with IID data, a smaller led to a moderately better performance. Therefore, we set for evaluation as before.
The experiment results for and our approach with shared data are displayed in Figure 7. The trends of the curves were broadly similar to those discussed for IID data evaluation: higher s would result in faster convergence and LoAdaBoost FedAvg had superior performance to that of for all different s. With , the curve for our method was always above that for . With and , the curves fell slightly behind during the first several communication rounds, and then caught up to outperform the curves for .
|with shared data||LoAdaBoost FedAvg with shared data|
|epochs||communication rounds||average epochs||communication rounds||average epochs|
Moreover, all curves for our method achieved a test-set AUC of , but only when did the curve for fulfill this. Table 3 summarizes the information from Figure 7. LoAdaBoost FedAvg with took the same number of communication rounds (11) but less average epochs than . With and where underperformed, our method took and rounds, respectively, and and average epochs, respectively, to reach the target AUC. These experimental results resonated with our conjecture that LoAdaBoost imposed a beneficial regularization effect on the optimization process. This has particular practical value in learning tasks where the communication cost is the primary concern and more epochs have to be added to the clients to accelerate the model-training.
Distributed health data in large quantity and of high privacy can be harnessed by federated learning where both data and computation are kept on the clients. In this study, we proposed LoAdaBoost FedAvg, an optimized algorithm that adaptively boosted the performance of those clients that were weak learners according to their cross-entropy losses. Experiments with IID data and non-IID data showed that learning under LoAdaBoost FedAvg converged to slighly higher AUCs, required less client-side computation, and consumed less communication rounds between the clients and the server than learning. Our approach can also be extended to learning tasks in other fields, such as image classification and speech recognition, wherever the data is distributed.
As a final point, federated learning with IID data does not always outperform that with non-IID data. For instance, for the language modeling task on the Shakespeare dataset[mcmahan2016communication], learning on the non-IID distribution reached the target test-set AUC nearly six times faster than on IID. In cases like this, the data-sharing strategy becomes unnecessary and, since the data is severely skewed, LoAdaBoost FedAvg may lose its competitive advantage. In the continuation of our study, we will investigate what kind of medical datasets may result in superior modeling performance with non-IID distribution and why this occurs. Furthermore, we will try to improve the LoAdaBoost FedAvg algorithm to make learning on such datasets easier.
Author contributions statement
L.H initiated the idea, designed the algorithm, processed the data and conducted the experiments. Y.Y conducted the experiments and processed the data. Z.F conducted the experiments and help design the algorithm. S.Z instructed computational optimization and realization. H.D is a clinical expert experienced critical care and provided clinical instructions in this project. D.L initiated the idea, designed the algorithms, supervised and coordinated the project. All authors reviewed the manuscript.