The most successful machine learning methods generalize well to different data sources by training on large amounts of data. However, in many important applications such as healthcare, data is subject to strict privacy constraints that prevent direct access to local data. In addition, devices often have limited communication bandwidth and on-device memory.
Federated learning is an increasingly popular method that addresses these constraints. In particular, it differs from other machine learning approaches by allowing multiple edge devices to learn a shared global model without the need to reveal their data to the central server. Under the standard federated learning approach (FedAvg), each device trains a copy of the global model locally on its own data and sends a weight update to the central server, which averages all updated model weights and re-deploys them as the new model to the individual devices (McMahan et al., 2017). This allows a single global model to train on multiple devices’ sensitive data without compromising privacy, e.g. by moving the data off-device.
Unfortunately, FedAvg and other recently developed privacy- and bandwidth-conscious approaches perform poorly when data is not independent and identically distributed (IID) across devices (Kairouz et al., 2019). Non-IID data may cause different devices’ updates to conflict with each other, which could lead to significant oscillations between training rounds and slower convergence.
Devices often belong to one of many archetypes, where an archetype describes a subset of non-IID data that is itself IID. Previously proposed learning schemes such as FedAvg attempt to learn a single global model that performs well for all archetypes, yet this is often difficult or even infeasible when data is non-IID. In contrast, we propose Federated Cloning-and-Deletion (FedCD), a learning scheme that results in a specialized model for each archetype through iterative cloning of global models at specified milestones, adaptive updating of a high-scoring subset of global models, and deletion of poor-performing models. By maintaining multiple global models, devices can preferentially update models that perform well on their local data, thus self-selecting into groups with similar data. This allows for both faster convergence and higher accuracy.
1.1. Related Work
Most federated learning approaches use stochastic gradient descent, which optimally requires IID sampling of the data. In practice, federated learning rarely sees IID data across edge devices and learning on non-IID data is an open problem(Kairouz et al., 2019). Recent work has proposed various solutions to addressing this challenge:
1.1.1. Globally Shared Subsets
Zhao et al. found that sharing just 5% of global data improved accuracy by 30% on non-iid subsets of the CIFAR-10 dataset (Zhao et al., 2018). However, a globally shared subset of data that is representative of all devices’ individual data can be difficult to obtain or synthesize and is generally infeasible in many contexts.
1.1.2. Peer-to-peer Federated Learning
Peer-to-peer learning schemes increase the number of global models and the communication cost per round as every device participates in each round with a unique model (Shoham et al., 2019; Bellet et al., 2017). Although this approach increases accuracy, in many scenarios such as deploying edge devices in the field where security is important, individual learners would not be connected to each other in favor of maintaining a single, stable connection to a centralized server. Furthermore, not every device will be online for every round of training realistically.
1.1.3. Personalized Federated Learning
FedAvg generally fails as its objective is to find a single shared global model rather than specialized models for different groups of edge devices. Personalized FL methods, heavily based on Model Agnostic Meta Learning (MAML), run FedAvg followed by specialization (Jiang et al., 2019; Fallah et al., 2020). Our approach eliminates many rounds of general model training by developing specialized models early on.
Algorithm 1 describes FedCD, which addresses non-IID federated learning with minimal communication and on-device memory overheads after convergence. FedCD clones high-performing models at milestone rounds and deletes low-performing models while updating model scores for each device.
The FedCD algorithm, like FedAvg, begins with a global model on a centralized server that all devices update to. At every milestone round, every model on the centralized server is cloned and compressed. In each training round, every participating device trains its local models for epochs, compresses the models, and sends its weight update and score (with some randomization) for each model to the global server, where each model’s score on a given device reflects how well that model performs on the device’s validation data. Then the server updates each global model by taking the weighted average of all devices’ weight updates for that model and weighing them by that model’s score. These global models are then re-deployed to the appropriate edge devices, and low-scoring models are deleted.
Note that in Algorithm 1, denotes the total number of previously created global models (including deleted models), which doubles at every milestone. Let denote the score that device assigns model , where a higher score denotes a better performing model. We modify the weight update function as follows. Let be the number of devices. Let
denote the weight vector for modelby device . Then we have
We experimentally investigated multiple ways of generating model score based on the accuracy that model has on device ’s validation data in round . We found that using a normalized average of the most recent rounds’ validation accuracy results in the highest performance while being robust to oscillation. Thus we define the score of model by device at round as
When models are cloned, they receive the score of , where denotes the parent model, to encourage differentiation between the parent models and the newly cloned models.
To avoid exploding storage requirements, we delete all models for which the following holds
where denotes the score that device assigns to its highest performing model, and
denotes the standard deviation over the model scores by device. Note that using a standard deviation based deletion criterion ensures that any device will maintain at least two models if there are at least two global models. After 20 rounds of training, if a device has two active models it will delete the lower-performing model if .
For our experiments, we define the performance of a device as the accuracy of its highest-scoring model on its local testing data.
By creating copies of the global model with different model scores to encourage exploration we can learn the archetypes of the edge devices and update weights based on the device’s archetype. Then edge devices with the same archetype will preferentially update the same global model.
Each model fits its devices’ distribution without access to the devices’ data, thereby effectively addressing the problems that non-IID data pose to federated learning. Compression via quantization allows for multiple smaller models on-device, and faster convergence leads to reduced communication cost.
3. Experimental Results
Our FedCD system consists of learners that have a non-IID subset of the global data. To evaluate our approach, we compared the performance of FedCD to the performance of FedAvg on CIFAR-10, a dataset typically used for FL benchmarking, in two different setups. We also measured and compared the communication costs between the central server and devices and the time to convergence under FedCD and FedAvg.
We used data from CIFAR-10 (7)
, a comparison dataset standard for federated learning, consisting of 40k training images, 10k validation images, and 10k test images. Each device has a non-iid sample from the larger dataset that is consistent with its archetype to comprise its training/validation/test set. Each device received and sent weights to a 10-layer convolutional neural network. We exclusively used a device’s validation set to determine its scores for a given model. We evaluated the best performing model for each device against its test set.111See https://github.com/jessijzhao/fedcd/ for code.
Our experimental setup specified two required characteristics for each edge device: Archetypes (to describe the data distribution) and scores for each model (a normalized weighting of models that the device is maintaining). 15 devices participated in each training round and the global model was set to the weighted average of their updates.
3.2. Hierarchical Archetypes
In the real world, individual archetypes are seldom perfectly independent but rather can be grouped into ”meta-archetypes” that each include several different archetypes. An example of this structure are next-word predictions on phones of users living in a predominantly English-speaking country versus in a predominantly Spanish-speaking country (where the countries are meta-archetypes) of all ages (where the age groups are archetypes). Different age groups in the same country will likely share some common vernacular but common words across countries might be very limited due to the language barrier.
To test the applicability of FedCD in this scenario, we constructed two sets of data (meta-archetypes that have data labeled 0,1,2,3,4 and 5,6,7,8,9 respectively) with 10 archetypes represented by the labels, i.e. an edge device of meta-archetype 1 only has access to training examples with labels 0, 1, 2, 3, and 4. The experiment was run with 3 devices per archetype with bias , where the bias denotes the fraction of a device’s local dataset that consists of examples whose labels equal the archetype, i.e. a device with archetype 3 has training images, of which k images have label and k images have labels 0, 1, 2, and 4 each. We set the cloning milestones at rounds 5, 15, 25, and 30.
They also show that the two meta-archetypes converge to slightly different accuracies (meta-archetype 0, consisting of archetypes 0, 1, 2, 3, and 4, performs worse than meta-archetype 1, consisting of archetypes 5, 6, 7, 8, and 9.) The accuracy oscillations (where archetypes from the same meta-archetype oscillate together) in FedCD stop by round 10, whereas accuracy under FedAvg continues to oscillate past round 40 (see Figure 1). Furthermore, Figure 2 shows that while FedCD converged after approximately 35 rounds, FedAvg failed to converge within 150 training rounds.
3.3. Hypergeometric Archetypes
Assuming a strict hierarchy excludes more complicated scenarios where the true distribution of data may be further or closer to two extremes. A real-world example of this setup are patient histories of citizens who visited hospitals across the US. In all parts of the country, an individual could have any disease, but hospitals in different locations may see different distributions of patients with respect to e.g. the severity of the disease, insurance quality, or socioeconomic status.
To test the applicability of FedCD, each device sampled labeled training examples from a hypergeometric distribution over labels withbased on its archetype, and (see Figure 3).
We chose the hypergeometric distribution as it becomes a discrete approximation of the standard normal distribution whenare large. Figure 3 shows the data distribution for each archetype. The experiment was run with 5 devices per archetype.
We see in Figure 3(a)
that the FedCD algorithm converges quickly (by round 45) and that archetypes with more skewed probability distributions (archetypes whose distributions differ most from the global distribution, e.g. archetypes 0, 5) achieve higher accuracy than the central archetypes (archetypes whose distributions are most similar to the global distribution, e.g. archetypes 2, 3), since their distribution has a smaller standard deviation (see Figure3).
In particular, while FedCD performs better on more skewed archetypes relative to other archetypes, FedAvg performs better and converges faster on the central archetypes. The increased success of FedCD on archetypes with more skewed data shows that FedCD indeed improves performance by learning specialized models that fit a given archetype’s data distribution, as desired.
3.4. Effects of Quantization
Training multiple models on each device allows devices to self-sort into groups with similar archetypes by assigning similar scores to the same models. However, as on-device memory is limited, each model must be compressed to a smaller size, ideally without losing accuracy.
Figure 6 shows that in the hierarchical archetypes experiment, different levels of quantization had no significant effect on model performance and only slightly impacted the time to convergence of the resulting models. Note that FedCD results in a single model per device, which is similarly insensitive to quantization as the FedAvg global model. Furthermore, while the central server may need to store significantly more models, relatively few models are maintained in practice.
3.5. Model Selection Behavior
Note that after rounds of cloning, there will exist at most global models. However, devices delete any models that already specialized for other archetypes as they will perform poorly on the device’s data, such that these models are not cloned in future cloning rounds. Note that after 4 rounds of cloning, 10 out of 16 models were deleted from all devices.
Figure 7 depicts the consensus highest-scoring model that was not deleted by all devices for each archetype in the hierarchical archetypes experiment (consisting of 3 devices each). We can see that after the first cloning milestone at round 5, the devices segregate by meta-archetype. Subsequent cloning rounds have a limited effect, as the preferred model of individual archetypes oscillates between models 0 and 1 and models 4 and 5 respectively, indicating that these models perform similarly.
3.6. Communication Costs
Although the worst-case (each model is cloned at each milestone, i.e. models) would have an exponential communication cost overhead, devices tended to favor a single model and delete other models that didn’t fit their data as well in practice. Note that this supposes the existence of archetypes (as in our experiments).
Figure 8 shows that the number of active models initially increases during the cloning rounds (5, 15, 25, 30) and drops during the subsequent rounds as devices delete models they no longer update to. In the end, each of the 30 devices update at most two active models and only a total of 6 models were preferred by any given device.
As the bias and therefore the difference between archetypes increases in the hierarchical archetypes experiment, devices of similar archetypes converge to similar models faster by scoring them higher than other models. In contrast, as the bias decreases and therefore the data of different archetypes becomes more similar (note that a bias of 0.2 represents the IID case within a meta-archetype), models become more similar as well such that devices tend to maintain multiple models for a larger number of rounds.
The goal of FedCD is for each device to have one high-performing model and delete all other models. In some scenarios, such as the low-bias situations depicted in 8, the algorithm terminates with each device having two equally-ranked high-performing models. This is fine as well, since each device can arbitrarily choose a model for deployment without loss of performance. Both cases would exhibit a low standard deviation of the scores they assign to active models (0 if all the scores were equal and 0 if there is a single model). Figure 9 shows that the average standard deviation over model scores approaches 0 at the end of the training rounds for all levels of bias for the hierarchical archetype setup.
|FedCD:FedAvg Wall-Clock Time|
Table 1 shows the wall-clock time for a run of FedCD versus a run of FedAvg till convergence. The run-time for FedAvg was capped at 300 rounds of training, since it had not converged by then for both the Hierarchical and Hypergeometric experiments. The wall-clock time of the experiments provide another insight into the advantage of FedCD as compared to the baseline, which takes a significant number of rounds to train.
FedCD improves model performance on non-IID data by learning specialized models that best fit the data distribution of a group of similar devices (devices belonging to the same archetype). Previous approaches have taken a decentralized approach by accepting complete peer-to-peer communication costs with full device participation in each round. However, this framework is sensitive to fluctuations in a real-world environment and incurs significant communication overhead.
Our centralized framework addresses these concerns by requiring only partial device participation in each round, though it incurs the costs of storing multiple quantized models on each device and the global server and sending multiple model updates per device during training. In this work, our main contributions are:
We propose a new framework for personalized FL.
We empirically demonstrate that FedCD exhibits faster convergence and higher accuracy than the baseline FedAvg algorithm in several common non-iid scenarios.
We empirically show the number of active models (the total number of models stored on-device) does not explode by aggressively deleting poor-performing models from a local devices.
By amending the standard federated learning framework to train multiple global models simultaneously, we can improve model performance on non-IID data while incurring some limited communication and storage overhead during training.
4.1. Future Work
While we experimentally showed that FedCD converges faster and achieves higher accuracy at a reasonably low cost, future work could further analyze the dynamic nature of FedCD and attempt to find theoretical guarantees for convergence as well as bounds for communication and (server-side and on-device) storage costs.
Future work could also explore different types of bias other than label bias to determine the device archetypes, including archetypes defined by modifications to the input image. In addition, there are promising extensions of FedCD to other open problems in FL, such as using the cloning technique to address concerns regarding device bias and attack mitigation.
We express our sincere appreciation to Professor H. T. Kung and Dr. Marcus Comiter for their valuable and constructive suggestions during the planning and development of this research.
- Fast and differentially private algorithms for decentralized collaborative machine learning. CoRR abs/1705.08435. External Links: Cited by: §1.1.2.
- Personalized federated learning: a meta-learning approach. ArXiv abs/2002.07948. Cited by: §1.1.3.
- Improving federated learning personalization via model agnostic meta learning. ArXiv abs/1909.12488. Cited by: §1.1.3.
- Advances and open problems in federated learning. arXiv.org (eng). External Links: Cited by: §1.1, §1.
Communication-efficient learning of deep networks from decentralized data.
20th International Conference on Artificial Intelligence and Statistics (AISTATS) 2017,, External Links: Cited by: §1.
- Overcoming forgetting in federated learning on non-iid data. External Links: Cited by: §1.1.2.
-  (Website) External Links: Cited by: §3.1.
- Federated learning with non-iid data. CoRR abs/1806.00582. External Links: Cited by: §1.1.1.