FedMD: Heterogenous Federated Learning via Model Distillation

10/08/2019 ∙ by Daliang Li, et al. ∙ Harvard University 0

Federated learning enables the creation of a powerful centralized model without compromising data privacy of multiple participants. While successful, it does not incorporate the case where each participant independently designs its own model. Due to intellectual property concerns and heterogeneous nature of tasks and data, this is a widespread requirement in applications of federated learning to areas such as health care and AI as a service. In this work, we use transfer learning and knowledge distillation to develop a universal framework that enables federated learning when each agent owns not only their private data, but also uniquely designed models. We test our framework on the MNIST/FEMNIST dataset and the CIFAR10/CIFAR100 dataset and observe fast improvement across all participating models. With 10 distinct participants, the final test accuracy of each model on average receives a 20 gain on top of what's possible without collaboration and is only a few percent lower than the performance each model would have obtained if all private datasets were pooled and made directly available for all participants.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Deep learning has provided a potentially powerful framework to automate perception and inference. However, large datasets are required to fully realize this potential. In areas like health care, it is often difficult and costly to curate large datasets. For instance, typical hospitals in the US may have only dozens of MRI images of a particular disease that needs to be annotated by human experts and must be protected from potential privacy breaches. Federated learning and similar ideas mcmahan2016FedAvg ; Shokri:2015 rise to this challenge and effectively train a centralized model while keeping users’ sensitive data on device. In particular, federated learning mcmahan2016FedAvg ; bonawitz2019towards ; fedai is optimized for faster communication and is uniquely capable of handling a large number of users.

Federated learning faces many challenges 1908.07873 , among which, of particular importance is the heterogeneity that appear in all aspects of the learning process. There is system heterogeneity when each participant has a different amount of bandwidth and computational power; this was partly resolved by the native asynchronous scheme of federated learning, which was further refined e.g. to enable active sampling nishio2018client ; kang2019incentive and improve fault tolerance sahu2018convergence . There is also statistical heterogneity (the non i.i.d. problem) where clients have a varying amount of data coming from distinct distributions chen2018federated ; fed_multitask_smith_2017 ; varia_mtl ; khodak2019adaptive ; eichner2019semi ; zhao2018federated .

In this work, we focus on a different type of heterogeneity: the differences of local models. In the original federated framework, all users have to agree on the particular architecture of a centralized model. This is a reasonable assumption when the participants are millions of low capacity devices such as cell phones. In this work, we instead explore extensions to the federated framework that is realistic in a business facing setting, where each participant has capacity and desire to design their own unique model. This arise in areas like health care, finance, supply chain and AI services. For example, when several medical institutions collaborate without sharing private data, they may need to craft their own model to meet distinct specifications. They may not be willing to share details of their models due to privacy and intellectual property concerns. Another example is AI as a service. A typical AI vendor of, e.g. customer service chat bots, may have dozens of client companies. Each client’s model is distinct and solves different tasks. The standard practice is to train a client’s model with only its own data. It would be immensely beneficial if data from other clients can be utilized without compromising privacy or independency. How can one perform federated learning when each participant has a different model that is a blackbox to others? This is the central question that we will answer in this work.

This question is intimately related to the non-i.i.d. challenge of federated learning because a natural way to tackle statistical heterogeneity is to have individualized models for each user. Indeed, existing frameworks result in sightly different models. For example, fed_multitask_smith_2017 provides a framework for multi-task learning if the problem is convex. Approaches based on frameworks such as Bayesian varia_mtl , meta-learning khodak2019adaptive and transfer learning zhao2018federated also achieve good performance on non-i.i.d. data while allowing a certain amount of model customization. However, to our knowledge, all existing frameworks require a centralized control over the design of local models. Full model independency, while related to the non-i.i.d. problem, is an important new research direction in its own right.

The key to full model heterogeneity is communication. In particular, there must be a translation protocol enabling a deep network to understand the knowledge of others without sharing data or model architecture. This question touches on fundamental issues in deep learning, such as interpretability and emergent communication protocols. In principle, machines should be able to learn the best communication protocol that is adaptive to any specific use case. As a first step in this direction, we employ a more transparent framework based on knowledge distillation that solves the problem.

Transfer learning is another major framework addressing the scarcity of private data. In this work, our private datasets can be as small as a few samples per class. Therefore using transfer learning from a large public dataset is imperative in addition to federated learning. We leverage the power of transfer learning in two ways. First, before entering the collaboration, each model is fully trained first on the public data and then on its own private data. Second, and more importantly, the blackbox models communicate based on their output class scores on samples from the public dataset. This is realized through knowledge distillation Hinton44873 , which has been capable of transmitting learned information in a model agnostic way.


The primary contribution of this work is FedMD, a new federated learning framework that enables participants to independently design their models. Our centralized server does not control the architecture of these models and only requires limited black box access. We identify the key element of this framework to be the communication module that translates knowledge between participants. We implement such a communication protocol by leveraging the power of transfer learning and knowledge distillation. We test this framework using a subset of the FEMNIST dataset caldas2018leaf and the CIFAR10/CIFAR100 datasets CIFAR . We find significant gains in performance of local models using this framework compared to what’s possible without collaboration.

2 Methods

We propose the following challenge:

2.1 Problem definition

There are participants in the federated learning process. Each owns a very small labeled dataset that may or may not be drawn from the same distribution. There is also a large public dataset that everyone can access. Each participant independently designs its own model to perform a classification task. The models can have different architectures. Furthermore, hyper-parameters need not to be shared among participants. The goal is to establish a framework of collaboration that improves the performance of beyond individual effort with locally accessible data and .

Figure 1: A general framework for heterogeneous federated learning. Each agent owns a private dataset and a uniquely designed model. To communicate and collaborate without data leakage, the agents need to translate their learned knowledge to a standard format. A central server collects these knowledges, compute a consensus distributed across the network. In this work, the translator is implemented using knowledge distillation.

2.2 The framework for heterogeneous federated learning

We propose FedMD, Algorithm 1, that solves the problem stated in sec 2.1. We comment on key components of this framework.

Input: Public dataset , private datasets , independently designed model , ,
Output: Trained model
Transfer learning: Each party trains to convergence on the public and then on its private .
for j=1,2…P do
       Communicate: Each party computes the class scores on the public dataset, and transmits the result to a central server.
       Aggregate: The server computes an updated consensus, which is an average .
       Distribute: Each party downloads the updated consensus .
       Digest: Each party trains its model to approach the consensus on the public dataset .
       Revisit: Each party trains its model

on its own private data for a few epochs.

end for
Algorithm 1 The FedMD framework enabling federated learning for heterogeneous models.

Transfer learning:

Before a participant starts the collaboration phase, its model must first undergo the entire transfer learning process. It will be trained fully on the public dataset and then on its own private data. Therefore any future improvements are compared to this baseline.


We re-purpose the public dataset as the basis of communication between models, which is realized using knowledge distillation. Each learner expresses its knowledge by sharing the class scores, , computed on the public dataset . The central server collects these class scores and computes an average . Each party then trains to approach the consensus . In this way, the knowledge of one participant can be understood by others without explicitly sharing its private data or model architecture. Using the entire large public dataset can cause a large communication burden. In practice, the server may randomly select a much smaller subset at each round as the basis of communication. In this way, the cost is under control and does not scale with the complexity of participating models.

3 Results

We test this framework in two different environments. In the first environment, the public data is the MNIST and the private data is a subset of the FEMNIST. We consider the i.i.d. case where each private dataset is drawn randomly from FEMNIST, as well as the non-.i.i.d. case where each participant, while only given letters written by a single writer during training, is asked to classify letters by all writers at test time.

In the second environment, the public dataset is the CIFAR10 and the private dataset is a subset of the CIFAR100, which has 100 subclasses that falls under 20 superclasses, e.g. bear, leopard, lion, tiger and wolf belongs to large carnivores. In the i.i.d. case, the task is for each participant to classify test images into correct subclasses. The non-i.i.d. case is more challenging: during training, each participant has data from one subclass per superclass; at test time, participants need to classify generic test data into the correct superclasses. For example, a participant who has only seen wolfs during training is expected to classify lions correctly as large carnivores. Therefore it has to rely on information communicated by other participants.

In each environment, 10 participants design unique convolution networks that can differ by number of channels and number of layers, see Table 1,2 for details. First they are trained on the public dataset until convergence, — these models typically have test accuracy around on MNIST and on CIFAR10. Secondly each participant trains its model on its own small private dataset. After these steps, they go through the collaborative training phase, during which the models acquire strong and fast improvements across the board, and quickly outperform the baseline of transfer learning. We use Adam optimizer KinBa17 with an initial learning rate of ; in each round of collaborative training we randomly select a subset of size 5000 as the basis for communication. More details are given in the supplementary material. The code will be made publicly available after the workshop.

Figure 2: FedMD improves the test accuracy of participating models beyond their baselines. A dashed line (on the left) represents the test accuracy of a model after full transfer learning with the public dataset and its own small private dataset. This baseline is our starting point and overlaps with the beginning of the corresponding learning curve. A dash-dot line (on the right) represents the would-be performance of a model if private datasets from all participants were declassified and made available to every participant of the group.

4 Discussion and conclusion

In this work we proposed FedMD, a framework that enables federated learning for independently designed models. Our framework is based on knowledge distillation and is tested to work on various tasks and datasets. In future we will explore more sophisticated communication module, such as feature transformations and emergent communication protocols that will further improve the performance of our framework. Our framework can also be applied to tasks involving NLP and reinforcement learning. We will extend our framework to extreme cases of heterogeneity involving large discrepancies in the amounts of data, in model capacities and very different local tasks. We believe that heterogeneous federated learning will be an essential tool in future in a broad spectrum of business facing applications of deep learning.


We would like to thank Ethan Dyer, Jared Kaplan, Jaehoon Lee, Patrick (Langechuan) Liu, Sam McCandlish, Wenbo Shi, Gennady Voronov, Yunlong Wang, Sho Yaida, Xi Yin and Yao Zhao for discussions and comments on the manuscript. DL was supported by the Simons Collaboration Grant on the Non-Perturbative Bootstrap.


Supplementary Material

We provide more details about the models, the datasets, the algorithm and the results in this supplementary material.


We list the architectures of the models used by each participant in the MNIST/FEMNIST environment in table 1 and those in the CIFAR enrionment in table 2.

Model 1st conv layer filters () 2nd conv layer filters ( 3 conv layer filters () dropout rate pre-trained test accuracies on MNIST
0 128 256 None 0.2 98.6%
1 128 384 None 0.2 98.8%
2 128 512 None 0.2 98.4%
3 256 256 None 0.3 98.3%
4 256 512 None 0.4 98.2 %
5 64 128 256 0.2 98.9%
6 64 128 192 0.2 99.0%
7 128 192 256 0.2 99.1%
8 128 128 128 0.3 99.2%
9 128 128 192 0.3 98.9%
Table 1: Models for MNIST/FEMNIST
Model 1st conv layer filters () 2nd conv layer filters () 3rd conv layer filters () 4th conv layer filters () dropout rate pre-trained test accuracies on CIFAR10
0 128 256 None None 0.2 71.5%
1 128 128 192 None 0.2 78.8%
2 64 64 64 None 0.2 75.5%
3 128 64 64 None 0.3 74.5%
4 64 64 128 None 0.4 74.9%
5 64 128 256 None 0.2 74.8%
6 64 128 192 None 0.2 77.5%
7 128 192 256 None 0.2 77.7%
8 128 128 128 None 0.3 78.8%
9 64 64 64 64 0.2 75.4%
Table 2: Models for CIFAR10/CIFAR100


We provide a summary of our public and private datasets in table 3.

Collaborative Task Public Dataset Private Classes Number of Private Data Samples per Class per Party
FEMNIST/MNIST I.I.D. MNIST letters [a-f] classes 3
FEMNIST/MNIST Non I.I.D. MNIST letters from one writers around 20 (varies)
CIFAR I.I.D. CIFAR10 CIFAR100 subclasses [0,2,20,63,71,82] 3
CIFAR Non I.I.D. CIFAR10 CIFAR100 superclasses [0-5] 20
Table 3: Summary of datasets


We clarify important details about our implementation of Algorithm 1:

  1. In the communication phase, the models communicate and align their logits computed from public data without applying the softmax activation layer . We could also use the softmax score with a particular temperature

    Hinton44873 , and we do not expect large effects from this distinction.

  2. In the communication phase, instead of using the entire public dataset, we use a subset of size 5000 that is randomly selected at each round. This speeds up the process without sacrificing the performance.

  3. The number of rounds and the batch size in the Digest and the Revisit phase control the stability of the learning process. A model may undergo transient retrogression in test performance that is quickly recovered in the next couple of rounds. This issue can be resolved by choosing smaller number of epochs in the revisit phase and larger batch size in the digest phase.

  4. In principle the consensus can be computed using a weighted average . In this work we almost always choose the weights to be equal to . One exception is in the CIFAR case where we slightly suppress the contribution from two weaker models (0 and 9). These weights may become more important when we have extremely different models or data.


We discuss several interesting aspects of our results.

  1. We measure our results against the test accuracy that a model could have achieved if the private data of all participants were pooled and made directly available to the whole group. See Table 4. Usually our framework boosts the performance of all participants to a level only a few percent lower than this pooled data performance.

  2. There are isolated cases where a model trained in our framework consistently outperforms the same model trained with pooled private data. In particular model-0 in the CIFAR non-i.i.d. case. Besides, its performance is mostly on the top of the herd. This model has the simplest architecture and is usually lagging behind its more sophisticated peers. It is interesting to understand the mechanism behind this success and utilize it to improve our framework.

  3. Our framework can incorporate extreme cases of model heterogeneity. We have experimented with several models having much lower performance, such as two layer fully connected networks. If they contribute to the consensus with the same weight as the advanced models, they tend to hinder the accuracy of the herd. Our framework works better if we suppress their contribution with a lower weight.

Collaborative Task Each model’s performance trained with pooled private data
Table 4: Performance of models trained with pooled private data.