Multi-Center Federated Learning

05/03/2020 ∙ by Ming Xie, et al. ∙ University of Technology Sydney University of Washington 10

Federated learning has received great attention for its capability to train a large-scale model in a decentralized manner without needing to access user data directly. It helps protect the users' private data from centralized collecting. Unlike distributed machine learning, federated learning aims to tackle non-IID data from heterogeneous sources in various real-world applications, such as those on smartphones. Existing federated learning approaches usually adopt a single global model to capture the shared knowledge of all users by aggregating their gradients, regardless of the discrepancy between their data distributions. However, due to the diverse nature of user behaviors, assigning users' gradients to different global models (i.e., centers) can better capture the heterogeneity of data distributions across users. Our paper proposes a novel multi-center aggregation mechanism for federated learning, which learns multiple global models from the non-IID user data and simultaneously derives the optimal matching between users and centers. We formulate the problem as a joint optimization that can be efficiently solved by a stochastic expectation maximization (EM) algorithm. Our experimental results on benchmark datasets show that our method outperforms several popular federated learning methods.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The widespread use of mobile phones and the Internet-of-Things have witnessed a huge volume of data generated by end-users with mobile devices. Generally, service providers collect users’ data to a central server to train machine learning models, including deep neural networks. The centralized machine learning approach causes severe practical issues, such as high communication costs, high consumption of device battery, and risks of violating the privacy and security of user data.

Federated learning [23] is a decentralized machine learning framework that learns models collaboratively using the training data distributed on remote devices to boost the communication efficiency. Basically, it learns a shared pre-trained model by aggregating the locally-computed updates derived from the training data distributed across participating devices. An aggregation algorithm is responsible for averaging the many local models’ parameters, weighted by the size of training data on each device. Compared with conventional distributed machine learning, federated Learning is robust against unbalanced and non-IID (Independent and Identically Distributed) data distributions, which is the defining characteristic of modern AI products for mobile devices.

The vanilla federated learning proposed a very practical scenario, where 1) the central machine learning model did not need to access distributed data, which protected users’ private data, and 2) all distributed data are non-IID, which is a natural assumption of real-world applications. However, early federated learning approaches [23, 28]

use only one global model as a single-center to aggregate the information of all users. The stochastic gradient descent (SGD) for single-center aggregation is designed for IID data, and therefore, conflicts with the non-IID setting in federated learning.

Recently, the non-IID or heterogeneity challenge of federated learning has been studied with the purpose of improving the robustness of the global model against outlier/adversarial users or devices 

[9, 17, 18]. Moreover,  [26]

proposed an idea of clustered federated learning (FedCluster) that addresses the non-IID issue by dividing the users into multiple clusters. However, the hierarchical clustering in FedCluster is achieved by multiple rounds of bipartite separation, each requiring to run the federated SGD algorithm till convergence. Hence, its computation and communication efficiency will become bottlenecks when applied to a large-scale federated learning system.

In this work, we propose a novel multi-center federated learning framework that updates multiple global models to aggregate information from multiple user groups. In particular, the datasets of the users in the same group are likely to be generated or derived from the same or similar distribution. We formulate the problem of multi-center federated learning as jointly optimizing the clustering of users and the global model for each cluster such that 1) each user’s local model is assigned to its closest global model, and 2) each global model leads to the smallest loss over all the users in the associated cluster. The optimization algorithm, which we use to solve the above problem, can be described as an EM algorithm. The proposed multi-center federated learning not only inherits the communication efficiency of federated SGD but also keeps the capability of handling non-IID data on heterogeneous datasets.

We summarise our main contributions in the following:

  • We propose a novel multi-center aggregation approach (Section 4.1) to address the non-IID challenge of federated learning.

  • We design an objective function, namely multi-center federated loss (Section 4.2), for user clustering in our problem.

  • We propose Federated Stochastic Expectation Maximization (FeSEM) (Section 4.3) to solve the optimization of the proposed objective function.

  • We present the algorithm (Section 4.4) as an easy-to-implement and strong baseline for federated learning111 The codes are anonymously released at Its effectiveness is evaluated on benchmark datasets. (Section 5)

2 Related Work

Federated learning enables users to leverage rich data machine learning models while not compromising their data. This emerging concept has attracted lots of research interests since 2017, with many research papers investigating federated learning from several aspects, e.g., system perspective, personalized models, scalable  [3], boosting the communication efficiency  [13], and privacy issue  [8]. Most of the related works address a particular concern: improving security and privacy  [25, 20, 5].

Federated learning is designed for specific scenarios that can be further expanded to a standard framework to preserve data privacy in large-scale machine learning systems. For example, Yang et al.  [28]

expanded the federated learning by introducing a comprehensive, secure federated learning framework that includes horizontal federated learning, vertical federated learning, and federated transfer learning. The work in  

[16] surveyed the federated learning systems in relation to their functions on data and privacy protection.  [12] discussed the advances and open problems in federated learning.  [4] proposed LEAF – a benchmark for federated settings with multiple datasets.  [22] proposed an object detection-based dataset for federated learning.

The heterogeneous challenge in the federated setting has been widely studied from various perspectives.  [10] conducted theoretical convergence analysis for federated learning with heterogeneous data.  [11] measured the effects of non-IID data for federated visual classification.  [19] discussed that the local representations allow us to process the data from new devices in different ways depending on their source modalities instead of using a single global model. The single global model might not generalize to unseen modalities and distributions. Li and Wang  [14] proposed a new federated setting composed of a shared global dataset and many heterogeneous datasets from devices.

To solve the problem caused by non-IID or heterogeneity data in federated setting [4],  [26] proposed clustered federated learning (FedCluster) by integrating federated learning and bi-partitioning-based clustering into an overall framework.  [9] proposed robust federated learning composed of three steps: 1) learning a local model on each device, 2) clustering model parameters to multiple groups, each being a homogeneous dataset, and 3) running a robust distributed optimization  [15] in each cluster.  [18] proposed FedDANE by adapting the DANE  [27] to federated setting. In particular, FedDANE is a federated Newton-Type optimization method.  [17] proposed FedProx for a generalization and re-parameterization of FedAvg  [23]

. It adds a proximal term to the objective function of each device’s supervised learning task, and the proximal term is to measure the parameter-based distance between the server and the local model.  

[1] added a personalized layer for each local model, i.e., FedPer, to tackle heterogeneous data.

Figure 1: Framework of a federated learning.

3 Background

Federated learning is an approach to learn a centralized model by collecting information from distributed devices222We assume that each user is associated with one device that stores the user’s data and maintains/updates an user-specific model.. As illustrated in Fig. 1, it consists of four steps in a loop: 1) fine-tuning a pre-trained model with users’ private data on device, 2) uploading the local model to a central server, 3) aggregating the collected local models to the global model on server, 4) sending the aggregated global model to replace each device’s local model.

In particular, each -th device has a private dataset , where and denote input features and corresponding gold labels respectively. Each dataset will be used to train a local supervised learning model . denotes a deep neural model parameterized by weights . It is built to solve a specific task, and all devices share the same model architecture. For the -th device, given a private training set , the training procedure of is briefly represented as



is a general definition of loss function for any supervised task, and its arguments are model structure, training data and learnable parameters respectively, and

denotes the parameters after training.

In general, the data from one device is insufficient to train a data-driven neural network with satisfactory performance. A federated learning framework is thus designed to minimize the total loss of all devices, namely supervised learning-based federated loss. Typically, the vanilla federated learning will use a weighted average loss by considering the data size of each device. Therefore, the supervised learning-based federated loss is defined as


where denotes the number of devices.

For model aggregation, the vanilla federated learning uses a central server to aggregate all distributed models into the global one that is parameterized by . In particular, the aggregation mechanism is a weighted average of the local model parameters collected from the devices, which is defined as


4 Methodology

Given the diverse nature of user behaviors, we argue one global model is not enough to capture the shared knowledge effectively. To overcome this challenge, we propose a novel model aggregation method with multiple centers, each representing a cluster of users with IID data.

4.1 Multi-center Aggregation

The vanilla federated learning uses one central model to store the aggregation result in a server, which is known as single-center aggregation. Users’ behavior data are usually diverse, which means the optimal local models will vary. In our proposed method, all local models will be grouped to clusters, which is denoted as . Each cluster consists of a set of local model parameters , and corresponding center model .

Figure 2: Comparison between single-center aggregation in vanilla federated learning (left) and multi-center aggregation in the proposed one (right). Each represents the local model’s parameters collected from -th device, which is denoted as a node in the space. represents the aggregation result of multiple local models.

An intuitive comparison between the vanilla federated learning and our proposed one is illustrated in Fig. 2. As shown in the left figure, there is only one center model in vanilla federated learning. In contrast, the multi-center federated learning shown in the right has two centers, and , and each center represents a cluster of devices with similar data distributions and models. Obviously, the right one has a smaller intra-cluster distance than the left one. As discussed in the following Section 4.2, intra-cluster distance directly reflects the possible loss of the federated learning algorithm. Hence, a much smaller intra-cluster distance indicates our proposed approach potentially reduces the loss of the federated learning.

4.2 Objective Function

In general setting of federated learning, the learning objective is to minimize the total loss of all supervised learning tasks on devices, as described in Eq. 2. The model aggregation mechanism is a stochastic gradient descent (SGD) procedure that adjusts the central model’s parameters to approach the parameters of the local models. However, the SGD training process is based on the assumption that all the distributed datasets are extracted from one IID source, while training on non-IID data is the most attractive characteristic of federated learning. To tackle the non-IID challenge in federated learning, we propose 1) distance-based federated loss – a new objective function using a distance between parameters from the global and local models, and 2) multi-center federated loss – the total distance-based loss to aggregate local models to multiple centers.

Figure 3: Optimization of a supervised learning task. There are two initialization points, and , where is closer to global optimal . Therefore, the model initialized with is more likely to converge at the global optimum.

Distance-based Federated Loss (DF-Loss)

The DF-Loss uses a distance-based loss to replace the supervised learning-based federated loss (Eq. 2) in the federated setting. In particular, the new loss is based on an empirical assumption: the model with better initialization is more likely to converge at the global optimum. This assumption is illustrated in Fig. 3. Moreover, considering the limited computation power and insufficient training data on each device, a “good” initialization is vitally important to train a supervised learning model on the device.

According to the above assumption, the overall goal of a federated learning task aims to find a shared initialization close to all models parameters on distributed devices. This is illustrated in the left of Fig. 1. If the given initialization is close to the optimal parameters of all local models, then the trained model with this initialization is more likely to efficiently converge at the global optimum, which will decrease the supervised learning-based loss. Therefore, we can argue that the supervised learning-based federated loss is consistent with the proposed distance-based federated loss when optimizing the federated learning model.

The learning objective of the federated learning is replaced with minimizing the total distance between the global model and local models. Formally, this new objective function is to minimize the total loss , which is denoted as


where denotes a function to measure the dissimilarity between server model parameters and local model . Note that a direct macro average is used here regardless of the weight of each device, which treats every device equally. The weights used in Eq. 2 can easily be incorporated for a micro average. In addition, any distance metric can be integrated into this framework, and in this paper we only explore a simplest one, i.e., L2 distance between two models, i.e.,


Multi-center DF-Loss

As elaborated above, our new learning objective for federated learning is to minimize the distance between the global and local models. Furthermore, according to the non-IID assumption, the datasets in different devices can be grouped into multiple clusters where the on-device datasets in the same cluster are likely to be generated from one distribution. As illustrated in the right of Fig. 2 we can use intra-cluster distance to measure their distance-based loss, and the new loss, namely multi-center distance-based loss (MD-Loss), could be defined as below.


where cluster assignment , as defined in Eq. 7, indicates whether the device belongs to cluster , and is the parameters of the aggregated model in cluster .

4.3 Optimization Method

In general, Expectation-Maximization (EM) [2]

can be used to solve the distance-based objective function of clustering, e.g., K-Means. However, in contrast to the general objective of clustering, our proposed objective, as described in Eq.

6, has a dynamically changing during optimization. Therefore, we adapt the Stochastic Expectation Maximization (SEM) [6] optimization framework by adding one step, i.e., updating . In the modified SEM optimization framework, named federated SEM (FeSEM), we sequentially conduct: 1) E-step – updating cluster assignment with fixed , 2) M-step – updating cluster centers , and 3) updating local models by providing new initialization .

1 Initialize
2 while stop condition is not satisfied do
3       E-Step:
4       Calculate distance
5       Update using (Eq. 7)
6       M-Step:
7       Group devices into using
8       Update using and (Eq. 8)
9       for each cluster  do
10             for   do
11                   Send to device
14             end for
16       end for
18 end while
Algorithm 1 FeSEM – Federated Stochastic EM

Firstly, for the E-Step, we calculate the distance between the cluster center and nodes – each node is a model’s parameters , then update the cluster assignment by


Secondly, for the M-Step, we update the cluster center according to the and , i.e.,


Thirdly, for updating the local models, the global model’s parameters are sent to each device in cluster to update its local model, and then we can fine-tune the local model’s parameters using a supervised learning algorithm on its own private training data. The local training procedure is similar to Eq. 1, except that instead of random initialization, the corresponding is used as initialization.

Lastly, we repeat the three stochastic updating steps above until convergence. The sequential executions of the three updates compose the iterations in FeSEM’s optimization procedure. In particular, we sequentially update three variables , , and while fixing other factors. These three variables will be jointly used to calculate the objective of our proposed multi-center federated learning in Eq. 6.

4.4 Algorithm

Input: -- device index
    -- the model parameters from server
Output: -- updated local model
1 Initialization:
2 for  local training steps do
3       Update with training data (Eq. 1)
5 end for
6Return to server
Algorithm 2 Local_update

We implement FedSEM in Algorithm 1 that is an iterative procedure. As elaborated in Section 4.3, each iteration is comprised of three steps to update the cluster assignment, the cluster center, and the local models respectively. In particular, in the third step for updating the local model, we need to fine-tune the local model by implementing Algorithm 2.

5 Experiments

As a proof-of-concept scenario to demonstrate the effectiveness of the proposed method, we experimentally evaluate and analyze the proposed FeSEM on federated benchmarks.

5.1 Training Setups


We employed two publicly-available federated benchmarks datasets introduced in LEAF [4]. LEAF is a benchmarking framework for learning in federated settings. The used datasets are Federeated Extended MNIST (FEMNIST)333  [7] and Federated CelebA (FedCelebA)444  [21]

. In FEMNIST, we split the handwritten images according to the writers. For FedCelebA, we extracted face images for each person and developed an on-device classifier to recognize whether the person smiles or not. A statistical description of the datasets is described in Table 


Dataset FEMNIST FedCelebA
content type Handwritten Face
# of instances 805,263 202,599
# of classes 62 2
# of devices 3,400 9,343
# of inst. per dev. 227 21
Table 1: Statistics of datasets. “# of inst. per dev.” represents the average number of instances per device.

Local Model

For FEMNIST, we set a three-layer CNN as the local model – a standard architecture introduced in a tutorial of Tensorflow

555 For the FedCelebA, we used a CNN with the same architecture from  [21].


We compared the proposed federated learning method with five baseline methods below.

  1. NonFed: We will conduct the supervised learning task at each device without federated learning framework.

  2. FedSGD uses SGD to optimise the global model.

  3. FedAvg: the SGD-based federated learning that considers the averaging weight using the data size of each device [23] .

  4. Clustered is to enclose FedAvg into a hierarchical clustering framework [26].

  5. Robust design a framework run in modular manner, namely, robust clustering model, and a communication efficient, distributed, robust optimization over each cluster separately [9].

  6. FedDANE is a federated learning framework with Newton-type optimization method [18].

  7. FedProx adds a proximal term onto objective function of the learning task on device [17].

  8. FedDist: We adapt a distance based-objective function in Reptile meta-learning [24] to federated setting.

  9. FedDWS: a variation of FedDist by changing the aggregation to weighted averaging, and the weight depends on the data size of each device.

  10. FeSEM(): our proposed multi-center federated learning method solved via federated SEM (FeSEM) with different number of clusters .

Training Settings

For each device’s data, we used 80% for training and 20% for testing. The federated learning algorithm was to train an optimal global model using all devices’ training data. The learned global model was sent to each device for further fine-tuning with local training data, and then the fine-tuned local model was tested on each device. For the initialization of cluster centers in FeSEM, we conducted 20 times of pure clustering with different randomized initialization, and then the “best” initialization, which has the minimal intra-cluster distance, was selected as the initial centers for FeSEM. For the local update procedure of FeSEM, we set to 1, meaning we only updated the for one time in each local update.

Evaluation Metrics

Given numerous devices, we evaluated the overall performance of the federated learning methods. We used classification accuracy and F1 score as the metrics for the two benchmarks. In addition, due to the multiple devices involved, we explored two ways to calculate the metrics, i.e., micro and macro. The only difference is that when computing an overall metric, “micro” calculates a weighted average of the metrics from devices where the weight is proportional to the data amount, while “macro” directly calculates an average over the metrics from devices.

Datasets FEMNIST FedCelebA
Metrics(%) Micro- Micro- Macro- Macro- Micro- Micro- Macro- Macro-
Acc F1 Acc F1 Acc F1 Acc F1
NoFed 79.4 67.6 81.3 51.0 83.8 66.0 83.9 67.2
FedSGD 70.1 61.2 71.5 46.7 75.7 60.7 75.6 55.6
FedAvg 84.9 67.9 84.9 45.4 86.1 78.0 86.1 54.2
FedDist 79.3 67.5 79.8 50.5 71.8 61.0 71.6 61.1
FedDWS 80.4 67.2 80.6 51.7 73.4 59.3 73.4 50.3
Robust(TKM) 78.4 53.1 77.6 53.6 90.1 68.0 90.1 68.3
Clustered 84.1 64.3 84.2 64.4 86.7 67.8 87.0 67.8
FedDane 40.0 31.8 41.7 31.7 76.6 61.8 75.9 62.1
FedProx 51.8 34.2 52.3 34.4 83.4 60.9 84.3 65.2
FeSEM(2) 84.8 65.5 84.8 52.0 89.1 64.6 89.0 56.0
FeSEM(3) 87.0 68.5 86.9 41.7 88.1 64.3 87.5 55.9
FeSEM(4) 90.3 70.6 91.0 53.4 93.6 74.8 94.1 69.5
Table 2: Comparison of our proposed FeSEM() algorithm with the baselines on FEMNIST and FedCelebA datasets. Note the number in parenthesis following “FeSEM” denotes the number of clusters, .

5.2 Experimental Study

Comparison Study

As shown in Table 2, we compared our proposed FeSEM with the baselines, and found that our proposed federated learning framework can achieve the best performance in the most cases. But, it is observed that the proposed model achieves an inferior performance for Micro F1 score on FedCelebA dataset, and a possible reason is that our objective function defined in Eq. 6 did not take into account the device weights. Hence our model is able to deliver a significant improvement in terms of “macro” metrics. Furthermore, at the last three columns in Table 2, we found that FeSEM with a larger number of clusters empirically achieves a better performance, which verifies the correctness of non-IID assumption of the data distribution.

Figure 4: Convergence analysis for the proposed FeSEM with different cluster number (in parenthesis) in terms of micro-accuracy.

Convergence Analysis

To verify convergence of the proposed approach, we conducted a convergence analysis by running FeSEM with different cluster numbers (from 2 to 4) within 100 iterations. As shown in Fig. 4, FeSEM can efficiently converge on both datasets; and it can achieve the best performance when cluster number .

Clustering Analysis

To check the effectiveness of our proposed optimization method and whether the devices grouped into one cluster have similar model, we conducted a clustering analysis via an illustration. We used two-dimensional figures to display the clustering results of local models derived from FeSEM(4) on the two benchmark datasets. In particular, we randomly chose 100 devices from each benchmark dataset, and ploted each device’s local model as one point in the 2D space after PCA-based dimension reduction. As shown in Fig. 

5, both datasets are suitable for four clusters and each cluster are distinguishable to each other.

Figure 5: Clustering analysis for the models on different devices using T-SNE dimention reduction. Note the model data is derived from Femnist and Celeba FeSEM(4) respectively.
Figure 6: Case study for clustering from FeSEM(2) on FEMNIST (top) and FedCelebA (bottom).

Case Study on Clustering

To intuitively judge whether the devices grouped into one cluster have similar data distribution, we conducted case studies on a case of two clusters that are extracted from a trained FeSEM(2) model. For FMNIST, as shown in the top of Fig. 6, the cluster1 is likely to recognize hand-writings with regular font, and the cluster2 is to tackle the freestyle font for hand-writing letters. For FedCelebA, as shown in the bottom of Fig. 6

, the face recognition task in cluster1 is likely to handle the smile faces with relatively simple background.

6 Conclusion

In this work, we propose a novel multi-center federated learning framework to tackle the non-IID challenge of the federated setting. This proposed method can efficiently capture the multiple hidden distributions of numerous devices or users. And an optimization approach, federated SEM, is then proposed to solve the multi-center federated learning problem effectively. The experimental results show the effectiveness of our algorithm, and several analyses are further provided for an insight understanding of the proposed approach.