FedH2L: Federated Learning with Model and Statistical Heterogeneity

01/27/2021
by   Yiying Li, et al.
0

Federated learning (FL) enables distributed participants to collectively learn a strong global model without sacrificing their individual data privacy. Mainstream FL approaches require each participant to share a common network architecture and further assume that data are are sampled IID across participants. However, in real-world deployments participants may require heterogeneous network architectures; and the data distribution is almost certainly non-uniform across participants. To address these issues we introduce FedH2L, which is agnostic to both the model architecture and robust to different data distributions across participants. In contrast to approaches sharing parameters or gradients, FedH2L relies on mutual distillation, exchanging only posteriors on a shared seed set between participants in a decentralized manner. This makes it extremely bandwidth efficient, model agnostic, and crucially produces models capable of performing well on the whole data distribution when learning from heterogeneous silos.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

03/28/2020

Semi-Federated Learning

Federated learning (FL) enables massive distributed Information and Comm...
06/02/2020

Monitoring Data Distribution and Exploitation in a Global-Scale Microservice Artefact Observatory

Reusable microservice artefacts are often deployed as black or grey boxe...
02/12/2020

Salvaging Federated Learning by Local Adaptation

Federated learning (FL) is a heavily promoted approach for training ML m...
10/07/2021

Neural Tangent Kernel Empowered Federated Learning

Federated learning (FL) is a privacy-preserving paradigm where multiple ...
11/20/2020

Towards Building a Robust and Fair Federated Learning System

Federated Learning (FL) has emerged as a promising practical framework f...
09/16/2020

FedSmart: An Auto Updating Federated Learning Optimization Mechanism

Federated learning has made an important contribution to data privacy-pr...
08/11/2021

FedMatch: Federated Learning Over Heterogeneous Question Answering Data

Question Answering (QA), a popular and promising technique for intellige...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

Introduction

Today, artificial intelligence (AI) is showing its strengths in almost every walk of life. But under the data-driven pattern, organizations often face the dilemma of data isolated islands. For example, hospitals usually build AI models only based on its own data, having no access to the data in other hospitals because of the privacy, competition, or administrative reasons. But meanwhile, the collaboration and knowledge sharing is indeed needed for the performance improvement for all of them. One can imagine that a patient also wants to get the efficient diagnosis in one hospital by taking the medical images (e.g., cardiograms or electroencephalograms) from other hospitals. Thus how to achieve healthy competition and collaboration in such industry alliance is a valuable challenge. Federated Learning (FL)

Bonawitz et al. (2017); McMahan et al. (2017); Konečný et al. (2016), a new learning framework, provides mechanisms to enable the creation of a centralized model with distributed participants’ data and resources locally.

While FL promises regulatory and economic benefits, existing methods still have some drawbacks especially for the “B2B” industry alliance: (1) the line of work mainly focuses on the gradient-sharing FL (e.g., FedAvg McMahan et al. (2017)) with a trusted centralized server to aggregate the gradients. Thus all nodes have to agree on the identical model structures and will be finally distributed with the same parameters, lacking of model personalities. And in fact, such participant may desire to have its specific model and not want to share model details; (2) the gradient sharing still faces the risk of serious privacy leakage Zhu et al. (2019); Luca et al. (2018), although there are some techniques like the differential privacy Shokri and Shmatikov (2015) and secret sharing Bonawitz et al. (2017); (3) the parameters of model gradient are still in the huge number (10-100M for common CNNs), which is still not communication-efficient enough.

In this paper, we present a novel solution under FL framework for multi-party multi-domain learning in the industry alliance scenario, called Modal Agnostic Federated Mutual Leaning (MAFML). Our primary contributions are:

  • MAFML involves the Deep Mutual Learning Zhang et al. (2018) idea to enable nodes to “learn from and teach each other” collaboratively via only their “awareness” (soft knowledge, e.g., soft labels in the classification) on small-fraction of public data, and that will be safer and more communication-efficient than the traditional FL.

  • Each node can own its customized model, not restricted to the identical model structure and final parameters. That is, our method is model agnostic, which is more flexible and can keep the model personalities. And thus MAFML does not need a centralized global model any more, for these distributed models are trained mutually.

  • We carefully notice that the data among nodes are always collected in the non-iid manner due to their different devices or equipments used, leading to the domain shift with different statistics Peng et al. (2020); Quiñonero-Candela et al. (2009). Therefore under the cross-domain setting, we build on the ideas from continual learning Lopez-Paz and Ranzato (2017); McCloskey and Cohen (1989) to make each node avoid or alleviate forgetting on its own domain data, while allowing beneficial knowledge transfer for better generalization to other domains’ data.

    We conduct extensive experiments on several cross-domain datasets: Rotated MNIST

    Ghifary et al. (2015), PACS Li et al. (2017), and Office-Home Venkateswara et al. (2017a). Compared to the baselines, we improve the model performance across all domains, demonstrating the effectiveness of MAFML. Our approach provides a new paradigm to the “competition and collaboration” in the business facing industry alliances that all nodes can benefit from joining the federation.

Related Work

1. Federated Learning. The nascent idea of federated learning (FL) is to train statistical models directly on devices McMahan et al. (2017). The goal considers fitting a single global model to data generated by and stored in distributed nodes, and this framework is always for the global service provider. Then Yang et al. (2019) extends the concept of FL to the collaboration among organizations. But no matter in their vertical FL Hardy et al. (2017) or transfer FL Liu et al. (2018), the execution of a certain task asks all nodes to work together, for each node is only responsible for part of the task. Considering learning the complete models simultaneously for each node under the non-iid Zhao et al. (2018) setting, a related FL study to ours is MOCHA Smith et al. (2017) which is intrinsically captured in a multi-task learning way without the need of a global model. Yet each model only focuses on the performance on its own task, not considering the generalization to other tasks; in addition, these model structures are still required to be homogeneous, for a matrix modeling relationship among tasks using nodes’ weights is calculated centrally. Another related work is Peng et al. (2020) which addresses the forward domain adaptation in a federated way, but it still needs a global model with homogeneous local models because they use the feature representations for alignment. Although heterogeneous models has been paid attention in some the FL studies, the basic architecture of FedAvg McMahan et al. (2017) with a global model is still asked in Shen et al. (2020), and the extra heterogeneous model then distills knowledge from the distributed homogeneous model on its node. FedMD Li and Wang (2019) is another work for heterogeneous FL, but it focuses more on the communication module through model distillation. Mohri et al. (2019); Jiang et al. (2019) study the agnostic FL, but a centralized global model is required to be optimized and fine-tune the local models. In our work, we keep the local customized homogeneous or heterogeneous model on each node itself with no need for a centralized model or other extra models and our models can learn in a “peer-to-peer” way.

2. Cross-Domain Learning. The domain shift in data statistics is often a realistic problem. Multi-domain learning Yang and Hospedales (2015); Rebuffi et al. (2017) aims to ultimately create a single model for multiple domains. If the training and test domains have been clearly defined, domain adaptation (DA) Bousmalis et al. (2016); Long et al. (2016) and domain generalization (DG) Li et al. (2019b, a) are the main sub-research areas in the cross-domain learning. They both focus on the model’s generalization to the target domain. Under the federated learning setting, data on each node is always generated in a distinct distribution, which faces the cross-domain problems. Different from the previous studies, we consider not only the data itself but also the models (structures and parameters) including privacy information for the security. Thus a separate model for each node is needed for such cross-domain FL setting. In addition, different from the objective of DA and DG, we hope the node’s model shows nice generalization to other domains without sacrificing the ability on its own source domain. This idea is consistent with “maximizing transfer (generalization) and minimizing interference (forgetting)” in the continual learning (a.k.a. lifelong learning) McCloskey and Cohen (1989) over non-stationary distributions Riemer et al. (2019); Lopez-Paz and Ranzato (2017). We successfully extend this concept to the cross-domain FL that each node maintains knowledge about its own task and strengthen the knowledge for adapting to other domain tasks.

3. Collaborative Learning.

Different from the conventional “knowledgeable teacher”-supervised learning, collaborative learning

Lee (2001) considers exchanging information among “peers”, i.e., an ensemble of “students” learn and teach each other collaboratively. A kind of typical work is Dual Learning He et al. (2016)

where two cross-lingual translation models can be trained side-by-side. But it’s only suitable for such bilingual translation setting with the input and output duality, for agents assess the predictions by comparing the reconstructed sentences with the original. Some other researches extend the dual learning to other translation sub-areas like image-to-image translation

Yi et al. (2017)

and image captioning

Zhao et al. (2017). Cooperative learning Batra and Parikh (2017) has been lately proposed to train multiple agents jointly for the same task in different domains, where semantic mid-level visual attributes are used for the communication among agents. Codistillation is used for large scale distributed network training in Anil et al. (2018), but the same architecture and dataset are required to all nodes. We notice that the deep mutual learning (DML) Zhang et al. (2018) provides an effective way to improve the generalization of each network by training mutually with other peers even if the heterogeneous cohort nets. But most of this line of work still focuses on the single domain (i.i.d.) problem. In this paper, we extend the idea of DML to the cross-domain setting and fully exploit the mutual benefits to boost performance for all agents.

Methodology

In this section, we introduce the details of MAFML. Assuming there are nodes in the federated industry alliance, holding the data with distinct distribution among each other , and the data on each node contains a set of data-label pairs, i.e., . We also split into its private data which must only be kept locally, the public data, validation data and test data, i.e., . We aim to achieve the mutual knowledge learning among nodes without sacrificing much privacy on domain-specific data and model personalities, and activating the nodes’ “awareness” on the small amount of public data is a natural strategy for that. In our work, we focus on the homogeneous cross-domain data Li et al. (2017) that all are in the same label space with classes. One can think of the medical images in the same diseases label space but collected by different machines in different hospitals. We assume node uses a network parameterized by , and their network structures can be the same or different optionally, since our method is model-agnostic. MAFML allows each node to keep its customized model locally and no need for a centralized model. The workflow is divided into two stages: the local optimization and global mutual optimization for models.

Local Optimization

The local optimization for a node follows the conventional network update based on the gradient. We denote the network for -th node as and take the cross-entropy (CE) loss on the data batch for the node’s local optimization:

(1)
(2)

Here is a batch of the -th domain’s data that can be used for training. Moreover, there is also an alternative since other domains’ public data with labels can be reserved by the current node in advance, and we use this option of as default for it behaves slightly better in our experiment, and this is consistent with the data usage strategy in the FL studies with public data Li and Wang (2019); Zhao et al. (2018). is the gradient worked out in this local optimization.

Note that

is to give out the probability soft_labels

, which is the output of “softmax” layer of the network. CE loss is worked out between the predicted soft_labels and the ground truth one-hot labels. The local performance is obtained through such conventional supervised learning normally, but MAFML is definitely more than that. Specifically, we hope to challenge the “stability-plasticity” dilemma

Carpenter and Grossberg (1987); Riemer et al. (2019). That is, we pay attention to the “stability” (i.e., preserve the performance in the local domain), and also the “plasticity” (i.e., generalize to other domains well via mutual learning).

Preparation for mutual learning.  After the local optimization, we then randomly sample a batch from in each domain and work out the soft_labels . Note that the in the superscript denotes the data is (from the -th domain), and the in the subscript denotes the network is applied. At the same time, it is easy to get the accuracy over the batch public data in each corresponding domain. We prepare all and for the next federated global mutual optimization. Since these two kinds of data are low-dimensional, our approach is really communication-efficient.

Global Mutual Optimization

Different from the one-way knowledge transfer from an already pre-trained teacher to untrained students and from model distillation Hinton et al. (2015), the federated multiple nodes can be regarded as an ensemble of students to teach each other (i.e., everyone teaches each other and also learns from each other).

Node as a teacher.  It is to teach others the current node’s domain experience via its domain public data. As and have been obtained after the local optimization, it is reasonable to regard as the “teacher benchmark” and as the “teaching confidence” to teach others via .

Node as a student.  The node learns domains experience from all others via their respective domain public data. To improve ’s generalization to other nodes’ domains, we use the calculated and of other peer models ( is disjoint to ) as “teachers’ experience”. So there are teachers for . “Learn from teachers” can be achieved by mimicking the “awareness” of teachers on the teachers’ domain data, so the Kullback Leibler (KL) Divergence loss can be used to qualify the match of the student’s and teachers’ predictions. The KL loss of node is:

(3)

where,

(4)

In addition, besides the KL mimicry loss, we can also take advantage of the conventional supervised loss (CE loss):

(5)

Thus we obtain:

(6)

So the nodes in MAFML act as “student-teacher” cohorts. Now based on on locally accessible data, we obtain and indeed update using during the local optimization; and based on on the domains’ public data, we can obtain . So the question we need to think about at present is: will we then update by directly using in the global optimization stage?

Recall our learning goal across nodes: we pay attention not only to the “plasticity” (improve generalization to other domains), but also to the “stability” (reduce the forgetting on its own domain). Inspired by the idea in continual learning Lopez-Paz and Ranzato (2017), we introduce the backward transfer (BWT) and forward transfer (FWT) here:

Backward transfer: . BWT is the within-domain performance of on the test data for node . On one hand, there exists positive backward transfer that learning experience from other nodes can benefit the performance on its own domain. On the other hand, there also exists negative backward transfer. (Large) negative backward transfer is also known as (catastrophic) forgetting.

Forward transfer: . FWT is the cross-domain performance of on all other nodes’ test data. FWT shows the model’s generalization performance.

Average accuracy: . ACC is the all-domain performance of on all nodes’ test data.

Here is the conventional function to calculate the accuracy. The larger these metrics, the better the model. For larger BWT and FWT, we analyze in the following way: Since is for the better forward transfer (generalization), we hope to achieve the constraint satisfaction on all nodes:

(7)

If such constraint is satisfied, then is unlikely to increase the on the domain’s vanilla data, and thus we use to directly update for improvement on the generalization without decreasing the local performance. If the constraint is violated, we propose to project to the closest gradient (in norm) satisfying the constraints (7). So is unlikely to increase the or .

subject to (8)

Computation of .  We set . Here the function is the optimization of dual problem of Quadratic Program (QP). To solve (Global Mutual Optimization) efficiently, recall the primal of a QP Nocedal and Wright (2006) with inequality constraints:

subject to (9)

where is a real symmetric matrix,

is a real-valued vector ,

is a real matrix, and is a real vector, is the dimension of gradient vector.

The solution to the dual problem provides a lower bound to the primal QP problem. The Lagrangian dual of a QP is also a QP. Because original problem has constraint conditions, these can be built into the function. We write the Lagrangian function Bot et al. (2009) as:

(10)

Defining the (Lagrangian) dual function as , we find an infimum of , which occurs where the gradient is equal to zero, using and positive-definiteness of Q:

(11)

So, the dual problem of (9) is:

subject to (12)

With these notations, we write the primal QP (Global Mutual Optimization) as:

subject to

According to the conversion formula above, We can pose the dual of the MAFML QP as:

subject to (13)

So we reset the projected gradient as , and then use the to update in the global mutual optimization stage.

Input: domains , . Initialized networks , learning rate , , shared buffer.

Output: Optimized networks

begin

       while not converge or reach max steps do
             for  do
                   Sample local batch and public batch Compute Eq. (2) using Update Compute and on Share and to the shared buffer
            for  do
                   for  do
                         Compute using
                  Compute Eq. (6) using if Eq. (7) is satisfied then
                        
                  else
                        
                  Update
            
      
Algorithm 1 Model Agnostic Federated Mutual Learning

Summary

Bringing all components together, we have the full algorithm in Algo. 1. To summarize, in each domain we first calculate locally: perform a SGD update using on the local reserved data and then form the “awareness” information on its public data as the preparation for mutual learning. Then MAFML dependently completes the exchange of public “awareness” among domains. During the global mutual optimization stage, we introduce in addition to the conventional in their cohort learning to achieve the cross-domain public gradients . Then based on the relationship between and , we calculate the appropriate gradient as the final global gradient to update networks. This is not only for the positive forward transfer to other domains but also for the positive backward transfer to maintain the model’s performance on its vanilla domain. Our work is the first one to consider and analyze the “stability-plasticity” problem under the cross-domain FL setting to our knowledge.

Experiments

We evaluate our approach on the following cross-domain tasks: digit classification (Rotated MNIST) and image recognition (PACS, Office-Home). These datasets all have the domain shifts. We use Ray111https://ray.io/

framework to implement our distributed applications on different nodes. All experiments are implemented by PyTorch and run on a server with 4 GPUs. We compare MAFML to the following alternatives.

  • Independent (IND): each node only uses its own domain (pri+pub) data for conventional training (SGD based on CE loss), avoiding other nodes’ interference completely.

  • Aggregation (AGG): each node aggregates its domain (pri+pub) data and the public data from other domains together for conventional training (SGD based on CE loss). As the experiments in Li et al. (2019b), AGG is usually a strong baseline to beat under the cross-domain situation.

  • FedMD Li and Wang (2019): applicable to heterogeneous FL with communication protocol via model distillation.

  • FedAvg McMahan et al. (2017): the conventional method that is only applicable to FL with homogeneous networks. A centralized server is to aggregate nodes’ gradients and distribute the same parameters to them.

Evaluation on Rotated MNIST

Dataset and settings.  Rotated MNIST Ghifary et al. (2015) contains different domains with each one corresponding to a degree of roll rotation in the classic MNIST dataset. So different node has the different data statistics. We take this idea to build the Rotated MNIST dataset for our experiment. The basic view (M0) is formed by randomly choosing 100 images each of ten classes from the original MNIST dataset, and we then create 3 rotating domains from M0 with rotation each in clockwise direction, denoted by M20, M40, M60, and the image dimensions remain unchanged. The data on each node is split by default for .

We first experiment by easily deploying the homogeneous networks (e.g. LeNet LeCun et al. (1998)) on these nodes. We train using the AMSGrad Reddi et al. (2018) optimizer (lr=1e-3, weight decay=1e-4) for 10,000 rounds and set the batch_size as 32. We consider the performance may be related to several factors: (1) , the proportion of . We set the proportion of as , and and account for and unchanged respectively. Note that the performance of IND and FedAvg has nothing to do with the value of . (2) , as to MAFML, refers to the global mutual optimization is conducted every rounds, and the local optimization is carried out normally in each round. So when calculating for global optimization, is actually over these rounds. Here note that even if they are in the same architecture, their network parameters are different since each node updates its model independently without the share of parameters or gradients. When we mention the network personality, we are referring to the architecture and network parameters.

Method M0-LeNet M20-LeNet M40-LeNet M60-LeNet Ave.
ACC BWT FWT ACC BWT FWT ACC BWT FWT ACC BWT FWT ACC BWT FWT
MAFML(=5%,E=1) 86.17 88.67 85.33 86.33 93.33 85.11 87.50 93.33 85.78 87.17 96.00 84.22 86.79 92.83 85.11
MAFML(=5%,E=5) 87.17 93.33 85.11 88.00 94.67 85.78 86.00 90.00 84.67 85.50 90.00 84.00 86.67 92.00 84.89
MAFML(=5%,E=10) 86.50 91.33 84.89 87.83 94.00 85.78 87.83 96.00 85.11 85.83 95.33 82.67 87.00 94.16 84.61
AGG(=5%) 85.50 92.67 83.11 87.50 93.33 85.56 83.67 90.00 81.56 83.83 93.33 80.67 85.13 92.33 82.73
FedMD(=5%) 84.17 87.33 83.11 85.33 91.33 83.33 86.67 96.00 83.56 84.17 91.33 81.78 85.09 89.11 82.95
MAFML(=10%,E=1) 90.17 93.33 89.11 91.67 96.00 90.22 86.50 90.67 85.11 88.17 93.33 86.44 89.13 93.33 87.72
MAFML(=10%,E=5) 88.50 90.67 87.78 88.67 92.67 87.33 87.50 93.33 85.56 87.50 92.00 86.00 88.04 92.17 86.67
MAFML(=10%,E=10) 84.17 90.67 82.00 89.50 94.67 87.78 88.83 94.67 86.89 86.50 92.67 84.44 87.25 93.17 85.28
IND 66.39 91.33 58.08 78.11 94.00 72.82 72.39 93.11 65.48 56.89 91.78 45.48 68.45 92.56 60.47
AGG(=10%) 86.50 90.00 85.33 87.17 92.67 85.33 86.67 94.00 84.22 80.67 91.33 77.11 85.25 92.00 83.00
FedMD(=10%) 85.00 88.67 83.78 87.67 95.33 85.11 82.00 90.67 79.11 85.67 90.00 84.22 85.09 91.17 83.06
FedAvg 86.50 77.33 89.56 86.50 86.67 86.44 86.50 92.67 84.44 86.50 89.33 85.56 86.50 86.50 86.50
MAFML(unlabelled) 88.83 92.00 87.78 91.00 95.33 89.56 87.83 94.00 85.78 87.00 90.67 85.78 88.67 93.00 87.23
MAFML(=15%,E=1) 89.67 91.33 89.11 90.00 92.67 89.11 90.50 94.00 89.33 88.33 92.67 86.89 89.63 92.67 88.61
MAFML(=15%,E=5) 88.50 90.67 87.78 90.00 92.00 89.33 86.67 95.33 83.78 85.83 91.33 84.00 87.75 92.33 86.22
MAFML(=15%,E=10) 88.17 92.00 86.89 88.50 90.67 87.78 90.00 94.67 88.44 85.17 90.00 83.56 87.96 91.84 86.67
AGG(=15%) 87.83 92.00 86.44 89.67 92.10 88.44 87.83 94.00 85.78 86.00 91.33 84.22 87.83 92.47 86.22
FedMD(=15%) 88.67 89.33 88.44 89.00 93.33 87.56 85.00 90.00 83.33 84.33 92.67 81.56 86.75 91.33 85.22
Table 1: Test result (%) on three metrics on Rotated MNIST.

Results.  Table 1 shows the comparison among these methods as a function of and . We evaluate using the validation data every 50 rounds and keep the model with the maximal ACC for the final test on three metrics. Max value on each metric is bold. MAFML always outperforms other methods for a wide range of and . In general, the increase of the value of in the setting brings slight improvement to MAFML. The difference of will bring differentiated communication overload among nodes and the computation on gradients, and a smaller usually achieves better performance. IND usually outperforms AGG and some other methods on BWT, but has much lower FWT. AGG shows strong enough performance better than the designed FedMD. For FedAvg, its gradient-based communication cost is more than 1000 times that of our “awareness”-based method. We now keep its magnitude of communication the same to ours and record the results, and MAFML performs better on all metrics obviously.

In addition, to qualitatively visualize the results, we perform PCA projections of the features on all test data across all domains as Figure 1. Each dot denotes an image and their color denotes its label. We can see MAFML provides the improved overall separability on all domain data.

Figure 1: PCA projections of features on all test data using the model in domain M0 of Rotated MNIST for example. Left: MAFML. Middle: IND. Right: AGG. Color: Digit.

Evaluation on PACS dataset

Dataset and settings.  PACS Li et al. (2017) is an object recognition benchmark for domain generalization. PACS contains 9991 images from 4 different domains with 7 categories. The original PACS dataset has been fixedly split for train, validation and test, so in order to meet the needs of MAFML, we separate out 10% of its test part as the public data, and directly use the train part as our private data.

We first deploy ResNet18 as the MAFML network to all nodes. More importantly, MAFML has the natural advantage when the models are heterogeneous, and we randomly deploy ResNet18, ResNet34, AlexNet and VGG11 as their networks respectively. We use AMSGrad (lr=1e-4, weight decay=1e-5) to train 10,000 rounds and set batch_size as 32.

Method Photo-ResNet18 Art_painting-ResNet18 Cartoon-ResNet18 Sketch-ResNet18 Ave.
ACC BWT FWT ACC BWT FWT ACC BWT FWT ACC BWT FWT ACC BWT FWT
MAFML(E=1) 84.31 99.93 81.22 88.08 99.89 85.04 87.10 99.57 83.27 91.37 99.58 86.23 87.72 99.74 83.94
MAFML(E=5) 84.54 100.00 81.44 87.68 100.00 84.51 87.20 99.86 83.31 90.69 99.72 84.85 87.53 99.90 83.53
MAFML(E=10) 83.88 100.00 80.64 87.53 100.00 84.31 86.20 99.62 82.08 90.38 99.69 84.36 87.00 99.73 82.86
IND 51.45 100.00 41.70 70.53 99.95 62.94 73.48 99.95 65.36 62.95 99.89 39.05 64.60 99.95 52.26
AGG 84.52 99.93 81.42 86.30 99.89 82.79 85.46 100.00 81.00 89.35 99.89 82.53 86.41 99.93 81.94
FedMD 82.39 99.87 78.88 85.75 99.62 82.17 83.93 99.91 79.03 88.52 98.56 82.02 85.15 99.49 80.53
FedAvg 84.93 95.62 82.78 84.93 72.06 88.25 84.93 72.23 88.82 84.93 94.68 78.62 84.93 83.65 84.62
MAFML(unlabelled) 81.13 99.47 77.45 85.60 99.89 81.91 82.68 99.39 77.55 87.38 98.45 80.23 84.20 99.30 79.29
Photo-ResNet18 Art_painting-ResNet34 Cartoon-AlexNet Sketch-VGG11 Ave.
ACC BWT FWT ACC BWT FWT ACC BWT FWT ACC BWT FWT ACC BWT FWT
MAFML(E=1) 83.86 99.80 80.66 90.91 99.95 88.57 81.68 99.67 76.16 52.87 80.33 37.26 77.33 94.94 70.66
MAFML(E=5) 85.04 100.00 82.04 90.84 99.95 88.49 82.57 99.10 77.49 55.25 80.98 38.59 78.43 95.01 71.65
MAFML(E=10) 84.06 100.00 80.86 90.23 100.00 87.71 82.02 98.77 76.88 54.88 79.88 38.71 77.80 94.66 71.04
IND 51.08 99.57 41.29 77.72 99.30 72.15 68.52 99.39 59.05 44.79 78.75 22.83 60.53 94.25 48.83
AGG 84.90 100.00 81.90 89.50 100.00 86.85 80.80 98.77 75.28 52.81 78.01 36.51 77.00 94.20 70.14
FedMD 80.05 100.00 76.05 86.90 99.08 83.75 78.07 95.65 72.67 51.40 75.47 35.83 74.11 92.55 67.08
FedAvg - - - - - - - - - - - - - - -
MAFML(unlabelled) 78.38 99.93 74.05 84.48 97.40 81.14 72.63 99.86 64.27 44.26 63.20 32.01 69.94 90.10 62.87
Table 2: Test result (%) on three metrics on PACS.

Results.  We can see from Table 2: (i) MAFML generally provides a consistent improvement over other methods. The original communication bandwidth of FedAvg is 10e6 times that of ours when using ResNet18, and even if we control FedAvg’s communication to 100 times to ours by controlling and its participating fraction McMahan et al. (2017), it still cannot catch up with MAFML. (ii) Particularly, when models are in the heterogeneous type, FedAvg is not applicable to it inherently. MAFML can still beat others and =5 usually brings more benefits. Moreover, we also find that even if VGG11 network does not perform very well in the Sketch domain (see IND/AGG BWT), it will not drag down but still benefit other nodes through the mutual learning (MAFML outperforms others on these domains). This also reflects the rationality of in Eq. (3) as the “teaching confidence” via its self-assessment on its model ability.

Evaluation on Office-Home dataset

Dataset and settings.  The Office-Home Venkateswara et al. (2017b) dataset is initially proposed to evaluate domain adaptation for object recognition. It consists 4 different domains with each containing images of 65 object categories under Office and Home settings. We split the each domain data into according to the default . We randomly apply ResNet34, MobileNet, AlexNet and ResNet50 as their heterogeneous models and use the same hyper-parameters as PACS experiment.

Method Art-ResNet18 Clipart-ResNet18 Product-ResNet18 Real_world-ResNet18 Ave.
ACC BWT FWT ACC BWT FWT ACC BWT FWT ACC BWT FWT ACC BWT FWT
MAFML(E=1) 58.26 50.62 59.59 61.25 74.38 56.15 62.21 82.57 53.91 62.12 69.32 59.28 60.96 69.22 57.23
MAFML(E=5) 56.89 49.38 58.19 59.32 70.61 54.94 59.46 78.76 51.58 59.00 66.88 55.89 58.67 66.41 55.15
MAFML(E=10) 57.62 53.73 58.30 59.69 75.04 53.73 60.42 84.00 50.81 60.56 69.81 56.91 59.57 70.65 54.94
IND 35.58 48.13 33.41 44.90 76.03 32.83 47.29 80.03 33.94 51.92 68.34 45.45 44.92 68.13 36.41
AGG 54.41 50.00 55.17 57.58 72.91 51.63 60.15 82.88 50.87 58.45 68.18 54.61 57.65 68.49 53.07
FedMD 52.30 46.89 53.23 53.99 70.94 47.42 56.57 78.61 47.58 56.29 64.45 53.07 54.78 65.22 50.33
FedAvg 59.96 53.11 61.15 59.96 40.56 67.50 59.96 74.17 54.17 59.96 68.18 56.72 59.96 59.01 59.89
MAFML(unlabelled) 51.56 45.96 52.53 54.82 66.34 50.35 53.54 67.04 48.03 54.45 61.69 51.60 53.59 60.26 50.63
Art-ResNet34 Clipart-MobileNet Product-AlexNet Real_world-ResNet50 Ave.
ACC BWT FWT ACC BWT FWT ACC BWT FWT ACC BWT FWT ACC BWT FWT
MAFML(E=1) 65.52 58.70 66.70 73.55 76.52 72.40 59.64 80.82 51.00 60.97 70.29 57.30 64.92 71.58 61.85
MAFML(E=5) 63.50 58.70 64.33 72.18 79.47 69.34 59.05 81.46 49.90 59.55 72.08 54.61 63.57 72.93 59.55
MAFML(E=10) 64.65 60.87 65.70 71.63 78.16 69.09 58.59 79.71 49.97 59.14 68.99 55.25 63.50 71.93 60.00
IND 41.00 57.14 38.20 55.14 78.49 46.08 46.60 79.40 33.23 47.61 63.31 41.42 47.59 69.59 39.73
AGG 57.34 51.86 58.30 70.61 78.49 67.56 54.32 77.02 45.05 54.68 64.94 50.64 59.24 68.08 55.39
FedMD 55.46 55.59 55.44 67.49 77.50 63.61 53.17 75.59 44.02 51.74 59.42 48.72 56.97 67.03 52.95
FedAvg - - - - - - - - - - - - - - -
MAFML(unlabelled) 58.26 48.14 60.02 70.71 76.35 68.52 57.48 79.40 48.55 55.28 62.34 52.50 60.43 66.56 57.40
Table 3: Test result (%) on three metrics on Office-Home.

Results.  Similarly, MAFML still gives a clear boost on the performance of the overall accuracy, backward transfer and forward transfer performance in most cases no matter with homogeneous or heterogeneous models.

Further Analysis

Optimization and loss analysis.  We analyze the advantages of MAFML in the learning accuracy and losses as Figure 2. Taking the learning of AlexNet in domain Product of Office-Home as an example, Figure 2(left) displays the evaluation of ACC on the validation data. MAFML shows faster convergence to the higher performance. Figure 2(right) reflects the consistent utility of KL loss that helps to clearly act on the first 1000 rounds for the convergence and performance benefits as shown on ACC. Figure 2(middle) exhibits the loss during the local optimization, and we get the loss benefits locally of MAFML with the help of its global mutual learning. We discuss KL and CE losses during the global optimization in details in the follow-up ablation study.

Figure 2: Learning and loss curves on Office-Home. Left: ACC on validation data. Middle: Loss in local optimization. Right: CE and KL losses in global optimization of MAFML.
Method M0-LeNet M20-LeNet M40-LeNet M60-LeNet Ave.
ACC BWT FWT ACC BWT FWT ACC BWT FWT ACC BWT FWT ACC BWT FWT
MAFML 90.17 94.67 89.11 91.67 95.33 90.22 86.50 94.00 85.11 88.17 94.00 86.44 89.13 94.50 87.72
MAFML(no KL) 85.83 91.33 84.44 85.33 89.33 82.89 88.83 94.67 85.78 87.17 91.33 84.89 86.79 91.67 84.50
MAFML(no ) 90.00 92.00 88.44 88.17 92.67 88.67 89.17 92.67 87.78 86.50 93.33 84.89 88.46 92.67 87.45
MAFML(PCGrad) 88.83 92.67 87.56 89.17 93.33 87.78 88.67 93.33 87.11 86.67 91.33 85.11 88.34 92.67 86.89
Table 4: Ablation study on different components in global mutual optimization on Rotated MNIST.

Ablation on different design components in global mutual optimization.  In the global optimization, we introduce two primary designs: KL mimicry loss Eq. (3) besides of the conventional CE loss, and the operation for the calculation of to achieve Eq. (Global Mutual Optimization). We ablate each component here. Table 4 shows a comparison on Rotated MNIST (=10%, =1) results of MAFML to the ablated variants.

KL loss plays an important role not only in the forward transfer, but also in the backward transfer. The network robustness improvement via KL loss to find a wider minimum in the single domain has been analyzed in DML Zhang et al. (2018)

. Similarly, under our cross-domain situation, KL loss guides to match the current network’s predictions with other networks’ (teachers’) posterior predictions based on the corresponding public data, which increases the model’s generalization (FWT) to other domains. Meanwhile, it is the soft labels (for KL loss) that alleviate the domain shift interference of the only hard true labels (for CE loss) of other domains’ knowledge to the guidance of the current model’s gradient. So KL loss, as a flexible adjustment and supplement to CE loss, helps to keep the robust “stability” (BWT) in the current domain during mutual optimization.

If our operation is removed, then will be updated by directly using and there is no guarantee for Eq. (7). So the benefits that has achieved may be cut down. The results also show the performance gets worse without the calculation of . Moreover, we compare with another gradient project method PCGrad Yu et al. (2020) which deals with the conflicting gradients in a handcrafted way and is not QP-based. But PCGrad shows unsatisfactory performance even slightly worse than without the project operation. Without the constrained optimization, some useful information in the gradients may also be clipped by PCGrad.

Extension to the unlabelled public data.  MAFML can naturally extend to the situation when the public data across nodes is available but without true labels (or the true labels do not want to be made public). Then the difference from the current case is: (i) only the domain’s own data is available for the local optimization without the use of public data from other domains; (ii) CE loss is no longer available in the global optimization, so only KL loss can be achieved to calculate in Eq. (6). Table 123 also report the MAFML (=10%, =1) results with unlabelled public data. These MAFML results are still satisfactory since the KL loss has contained much information of the knowledge learning and sharing. AGG is no longer applicable, for the ground truth of public data from other domains is not available.

Conclusion

In this paper, we propose the Model Agnostic Federated Mutual Learning (MAFML): an algorithm that can “learn from and teach each other” across nodes collaboratively and communication-efficiently. MAFML does not limit the structure of model and has better plasticity and stability over the cross-domain situation. Our experiments shows MAFML provides a promising way to the industry alliance to handle the dilemma of “competition and collaboration”.

Ethics Impact

We introduced a new framework for federated learning where models are agnostic and learn collaboratively through mutual learning. We carefully notice the realistic situation that there always exists domain shifts among the participants, and we deal with the valuable “plasticity and stability” problem in the cross-domain federated learning. This study provided theoretical analysis and empirical evaluation to mitigate the concerns of participants who want to join the industry alliance.

Specifically, our “awareness”-based method is much more communication-efficient and economic than the conventional gradient-based method, and thus we can reduce a significant environment cost for the green AI Schwartz et al. (2019)

. Our “peer-to-peer” learning in the federated learning shows the fairness and equality among participants. It is more flexible and in line with the actual needs that each participant can have its customized model. Even if the participants’ model abilities are unbalanced, each one can still benefit from the mutual learning to some extent. As to the safety, our advantage of no need of the exposure to gradients and parameters reduces the risk of privacy leakage from gradients or parameters. But for the broader impact, we also think more: what if there exists malicious nodes in the alliance which provides its false “teacher knowledge” to confuse others? In our work, the proposed gradient update process in the global optimization provides a way to keep the good nodes’ robust abilities in their vanilla domain and will not make the model collapse happen. But the malicious nodes may still have the opportunity to learn the knowledge from the good ones. We also encourage more researchers to study with us to consider more about the regulation and reward-punishment mechanisms in the federated learning, and then the policies will be established to better guarantee the credibility of participants. As to the data, we consider the realistic situation of the non-iid data with domain shift across participants, and we assume the small portion of public data contains (declassified) data in different statistics from all domains. And we also generalize our work to other non-iid situations in the supplementary material: e.g., when the public data is an entirely different held out domain, or there exists the class shift instead of the domain shift among participants. MAFML still shows better performance than other methods. For the further thinking, our method can be further developed to extend to deal with the heterogeneous cross-domain data (in different label space) by separating the model’s feature extractor and classifier to optimize and boosting each participant’s feature extractor through mutual learning, and we will leave it for our future work. Besides the CV application, our framework can also be applied to other applications, e.g., the federated NLP, federated RL.

References

  • R. Anil, G. Pereyra, A. T. Passos, R. Ormandi, G. Dahl, and G. Hinton (2018)

    Large scale distributed neural network training through online distillation

    .
    In ICLR, Cited by: Related Work.
  • T. Batra and D. Parikh (2017) Cooperative learning with visual attributes. In arXiv, Cited by: Related Work.
  • K. Bonawitz, V. Ivanov, B. Kreuter, A. Marcedone, H. B. McMahan, S. Patel, D. Ramage, A. Segal, and K. Seth (2017)

    Practical secure aggregation for privacy-preserving machine learning

    .
    In ACM SIGSAC Conference on Computer and Communications Security, Cited by: Introduction, Introduction.
  • R. I. Bot, S. Grad, and G. Wanka (2009) Duality in vector optimization. Springer Science & Business Media. Cited by: Global Mutual Optimization.
  • K. Bousmalis, G. Trigeorgis, N. Silberman, D. Krishnan, and D. Erhan (2016) Domain separation networks. In NIPS, Cited by: Related Work.
  • G. A. Carpenter and S. Grossberg (1987)

    A massively parallel architecture for a self-organizing neural pattern recognition machine

    .
    Computer vision, graphics, and image processing 37 (1), pp. 54–115. Cited by: Local Optimization.
  • M. Ghifary, W. B. Kleijn, M. Zhang, and D. Balduzzi (2015)

    Domain generalization for object recognition with multi-task autoencoders

    .
    In CVPR, Cited by: 3rd item, Evaluation on Rotated MNIST.
  • S. Hardy, W. Henecka, H. Ivey-Law, R. Nock, G. Patrini, G. Smith, and B. Thorne (2017) Private federated learning on vertically partitioned data via entity resolution and additively homomorphic encryption. In arXiv, Cited by: Related Work.
  • D. He, Y. Xia, T. Qin, L. Wang, N. Yu, T. Liu, and W. Ma (2016) Dual learning for machine translation. In NIPS, Cited by: Related Work.
  • G. Hinton, O. Vinyals, and J. Dean (2015) Distilling the knowledge in a neural network. In arXiv, Cited by: Global Mutual Optimization.
  • Y. Jiang, J. Konečný, K. Rush, and S. Kannan (2019) Improving federated learning personalization via model agnostic meta learning. In NeurIPS workshop, Cited by: Related Work.
  • J. Konečný, H. B. McMahan, F. X. Yu, P. Richtárik, A. T. Suresh, and D. Bacon (2016) Federated learning: strategies for improving communication efficiency. In arXiv, Cited by: Introduction.
  • Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner (1998) Gradient-based learning applied to document recognition. In IEEE, Cited by: Evaluation on Rotated MNIST.
  • W. S. Lee (2001) Collaborative learning for recommender systems. In ICML, Cited by: Related Work.
  • D. Li, Y. Yang, Y. Song, and T. M. Hospedales (2017) Deeper, broader and artier domain generalization. In ICCV, Cited by: 3rd item, Methodology, Evaluation on PACS dataset.
  • D. Li, J. Zhang, Y. Yang, C. Liu, Y. Song, and T. M. Hospedales (2019a) Episodic training for domain generalization. In ICCV, Cited by: Related Work.
  • D. Li and J. Wang (2019) FedMD: heterogenous federated learning via model distillation. In NeurIPS Workshop, Cited by: Related Work, Local Optimization, 3rd item.
  • Y. Li, Y. Yang, W. Zhou, and T. M. Hospedales (2019b) Feature-critic networks for heterogeneous domain generalisation. In ICML, Cited by: Related Work, 2nd item.
  • Y. Liu, T. Chen, and Q. Yang (2018)

    Secure federated transfer learning

    .
    In arXiv, Cited by: Related Work.
  • M. Long, H. Zhu, J. Wang, and M. I. Jordan (2016) Unsupervised domain adaptation with residual transfer networks. In NIPS, Cited by: Related Work.
  • D. Lopez-Paz and M. Ranzato (2017) Gradient episodic memory for continual learning. In NIPS, Cited by: 3rd item, Related Work, Global Mutual Optimization.
  • M. Luca, C. Song, E. D. Cristofaro, and V. Shmatikov (2018) Inference attacks against collaborative learning. In arXiv, Cited by: Introduction.
  • M. McCloskey and N. J. Cohen (1989) Catastrophic interference in connectionist networks: the sequential learning problem. Psychology of learning and motivation 24, pp. 109–165. Cited by: 3rd item, Related Work.
  • H. B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas (2017) Communication-efficient learning of deep networks from decentralized data. In AISTATS, Cited by: Introduction, Introduction, Related Work, 4th item, Evaluation on PACS dataset.
  • M. Mohri, G. Sivek, and A. T. Suresh (2019) Agnostic federated learning. In ICML, Cited by: Related Work.
  • J. Nocedal and S. J. Wright (2006) Numerical optimization. Springer. Cited by: Global Mutual Optimization.
  • X. Peng, Z. Huang, Y. Zhu, and K. Saenko (2020) Federated adversarial domain adaptation. In ICLR, Cited by: 3rd item, Related Work.
  • J. Quiñonero-Candela, M. Sugiyama, A. Schwaighofer, and N. D. Lawrence (2009) Dataset shift in machine learning. In The MIT Press, Cited by: 3rd item.
  • S. Rebuffi, H. Bilen, and A. Vedaldi (2017) Learning multiple visual domains with residual adapters. In NIPS, Cited by: Related Work.
  • S. J. Reddi, S. Kale, and S. Kumar (2018) On the convergence of adam and beyond. In ICLR, Cited by: Evaluation on Rotated MNIST.
  • M. Riemer, I. Cases, R. Ajemian, M. Liu, I. Rish, Y. Tu, and G. Tesauro (2019) Learning to learn without forgetting by maximizing transfer and minimizing interference. In ICLR, Cited by: Related Work, Local Optimization.
  • R. Schwartz, J. Dodge, N. A. Smith, and O. Etzioni (2019) Green ai. In arXiv, Cited by: Ethics Impact.
  • T. Shen, J. Zhang, X. Jia, F. Zhang, G. Huang, P. Zhou, K. Kuang, F. Wu, and C. Wu (2020) Federated mutual learning. In arXiv, Cited by: Related Work.
  • R. Shokri and V. Shmatikov (2015)

    Privacy-preserving deep learning

    .
    In ACM SIGSAC Conference on Computer and Communications Security, Cited by: Introduction.
  • V. Smith, C. Chiang, M. Sanjabi, and A. Talwalkar (2017) Federated multi-task learning. In NIPS, Cited by: Related Work.
  • H. Venkateswara, J. Eusebio, S. Chakraborty, and S. Panchanathan (2017a) Deep hashing network for unsupervised domain adaptation. In CVPR, Cited by: 3rd item.
  • H. Venkateswara, J. Eusebio, S. Chakraborty, and S. Panchanathan (2017b) Deep hashing network for unsupervised domain adaptation. In CVPR, Cited by: Evaluation on Office-Home dataset.
  • Q. Yang, Y. Liu, T. Chen, and Y. Tong (2019) Federated machine learning: concept and applications. ACM Transactions on Intelligent Systems and Technology 10. Cited by: Related Work.
  • Y. Yang and T. M. Hospedales (2015) A unified perspective on multi-domain and multi-task learning. In ICLR, Cited by: Related Work.
  • Z. Yi, H. Zhang, P. Tan, and M. Gong (2017) Dualgan: unsupervised dual learning for image-to-image translation. In ICCV, Cited by: Related Work.
  • T. Yu, S. Kumar, A. Gupta, S. Levine, K. Hausman, and C. Finn (2020) Gradient surgery for multi-task learning. arXiv. Cited by: Further Analysis.
  • Y. Zhang, T. Xiang, T. M. Hospedales, and H. Lu (2018) Deep mutual learning. In CVPR, Cited by: 1st item, Related Work, Further Analysis.
  • W. Zhao, W. Xu, M. Yang, J. Ye, Z. Zhao, Y. Feng, and Y. Qiao (2017) Dual learning for cross-domain image captioning. In ACM Conference on Information and Knowledge Management, Cited by: Related Work.
  • Y. Zhao, M. Li, L. Lai, N. Suda, D. Civin, and V. Chandra (2018) Federated learning with non-iid data. In arXiv, Cited by: Related Work, Local Optimization.
  • L. Zhu, Z. Liu, and S. Han (2019) Deep leakage from gradients. In NeurIPS, Cited by: Introduction.