The widespread use of mobile phones and the Internet-of-Things have witnessed a huge volume of data generated by end-users with mobile devices. Generally, service providers collect users’ data to a central server to train machine learning models, including deep neural networks. The centralized machine learning approach causes severe practical issues, such as high communication costs, high consumption of device battery, and risks of violating the privacy and security of user data.
Federated learning  is a decentralized machine learning framework that learns models collaboratively using the training data distributed on remote devices to boost the communication efficiency. Basically, it learns a shared pre-trained model by aggregating the locally-computed updates derived from the training data distributed across participating devices. An aggregation algorithm is responsible for averaging the many local models’ parameters, weighted by the size of training data on each device. Compared with conventional distributed machine learning, federated Learning is robust against unbalanced and non-IID (Independent and Identically Distributed) data distributions, which is the defining characteristic of modern AI products for mobile devices.
The vanilla federated learning proposed a very practical scenario, where 1) the central machine learning model did not need to access distributed data, which protected users’ private data, and 2) all distributed data are non-IID, which is a natural assumption of real-world applications. However, early federated learning approaches [23, 28]
use only one global model as a single-center to aggregate the information of all users. The stochastic gradient descent (SGD) for single-center aggregation is designed for IID data, and therefore, conflicts with the non-IID setting in federated learning.
Recently, the non-IID or heterogeneity challenge of federated learning has been studied with the purpose of improving the robustness of the global model against outlier/adversarial users or devices[9, 17, 18]. Moreover, 
proposed an idea of clustered federated learning (FedCluster) that addresses the non-IID issue by dividing the users into multiple clusters. However, the hierarchical clustering in FedCluster is achieved by multiple rounds of bipartite separation, each requiring to run the federated SGD algorithm till convergence. Hence, its computation and communication efficiency will become bottlenecks when applied to a large-scale federated learning system.
In this work, we propose a novel multi-center federated learning framework that updates multiple global models to aggregate information from multiple user groups. In particular, the datasets of the users in the same group are likely to be generated or derived from the same or similar distribution. We formulate the problem of multi-center federated learning as jointly optimizing the clustering of users and the global model for each cluster such that 1) each user’s local model is assigned to its closest global model, and 2) each global model leads to the smallest loss over all the users in the associated cluster. The optimization algorithm, which we use to solve the above problem, can be described as an EM algorithm. The proposed multi-center federated learning not only inherits the communication efficiency of federated SGD but also keeps the capability of handling non-IID data on heterogeneous datasets.
We summarise our main contributions in the following:
We propose a novel multi-center aggregation approach (Section 4.1) to address the non-IID challenge of federated learning.
We design an objective function, namely multi-center federated loss (Section 4.2), for user clustering in our problem.
We propose Federated Stochastic Expectation Maximization (FeSEM) (Section 4.3) to solve the optimization of the proposed objective function.
2 Related Work
Federated learning enables users to leverage rich data machine learning models while not compromising their data. This emerging concept has attracted lots of research interests since 2017, with many research papers investigating federated learning from several aspects, e.g., system perspective, personalized models, scalable , boosting the communication efficiency , and privacy issue . Most of the related works address a particular concern: improving security and privacy [25, 20, 5].
Federated learning is designed for specific scenarios that can be further expanded to a standard framework to preserve data privacy in large-scale machine learning systems. For example, Yang et al. 
expanded the federated learning by introducing a comprehensive, secure federated learning framework that includes horizontal federated learning, vertical federated learning, and federated transfer learning. The work in surveyed the federated learning systems in relation to their functions on data and privacy protection.  discussed the advances and open problems in federated learning.  proposed LEAF – a benchmark for federated settings with multiple datasets.  proposed an object detection-based dataset for federated learning.
The heterogeneous challenge in the federated setting has been widely studied from various perspectives.  conducted theoretical convergence analysis for federated learning with heterogeneous data.  measured the effects of non-IID data for federated visual classification.  discussed that the local representations allow us to process the data from new devices in different ways depending on their source modalities instead of using a single global model. The single global model might not generalize to unseen modalities and distributions. Li and Wang  proposed a new federated setting composed of a shared global dataset and many heterogeneous datasets from devices.
To solve the problem caused by non-IID or heterogeneity data in federated setting ,  proposed clustered federated learning (FedCluster) by integrating federated learning and bi-partitioning-based clustering into an overall framework.  proposed robust federated learning composed of three steps: 1) learning a local model on each device, 2) clustering model parameters to multiple groups, each being a homogeneous dataset, and 3) running a robust distributed optimization  in each cluster.  proposed FedDANE by adapting the DANE  to federated setting. In particular, FedDANE is a federated Newton-Type optimization method.  proposed FedProx for a generalization and re-parameterization of FedAvg 
. It adds a proximal term to the objective function of each device’s supervised learning task, and the proximal term is to measure the parameter-based distance between the server and the local model. added a personalized layer for each local model, i.e., FedPer, to tackle heterogeneous data.
Federated learning is an approach to learn a centralized model by collecting information from distributed devices222We assume that each user is associated with one device that stores the user’s data and maintains/updates an user-specific model.. As illustrated in Fig. 1, it consists of four steps in a loop: 1) fine-tuning a pre-trained model with users’ private data on device, 2) uploading the local model to a central server, 3) aggregating the collected local models to the global model on server, 4) sending the aggregated global model to replace each device’s local model.
In particular, each -th device has a private dataset , where and denote input features and corresponding gold labels respectively. Each dataset will be used to train a local supervised learning model . denotes a deep neural model parameterized by weights . It is built to solve a specific task, and all devices share the same model architecture. For the -th device, given a private training set , the training procedure of is briefly represented as
is a general definition of loss function for any supervised task, and its arguments are model structure, training data and learnable parameters respectively, anddenotes the parameters after training.
In general, the data from one device is insufficient to train a data-driven neural network with satisfactory performance. A federated learning framework is thus designed to minimize the total loss of all devices, namely supervised learning-based federated loss. Typically, the vanilla federated learning will use a weighted average loss by considering the data size of each device. Therefore, the supervised learning-based federated loss is defined as
where denotes the number of devices.
For model aggregation, the vanilla federated learning uses a central server to aggregate all distributed models into the global one that is parameterized by . In particular, the aggregation mechanism is a weighted average of the local model parameters collected from the devices, which is defined as
Given the diverse nature of user behaviors, we argue one global model is not enough to capture the shared knowledge effectively. To overcome this challenge, we propose a novel model aggregation method with multiple centers, each representing a cluster of users with IID data.
4.1 Multi-center Aggregation
The vanilla federated learning uses one central model to store the aggregation result in a server, which is known as single-center aggregation. Users’ behavior data are usually diverse, which means the optimal local models will vary. In our proposed method, all local models will be grouped to clusters, which is denoted as . Each cluster consists of a set of local model parameters , and corresponding center model .
An intuitive comparison between the vanilla federated learning and our proposed one is illustrated in Fig. 2. As shown in the left figure, there is only one center model in vanilla federated learning. In contrast, the multi-center federated learning shown in the right has two centers, and , and each center represents a cluster of devices with similar data distributions and models. Obviously, the right one has a smaller intra-cluster distance than the left one. As discussed in the following Section 4.2, intra-cluster distance directly reflects the possible loss of the federated learning algorithm. Hence, a much smaller intra-cluster distance indicates our proposed approach potentially reduces the loss of the federated learning.
4.2 Objective Function
In general setting of federated learning, the learning objective is to minimize the total loss of all supervised learning tasks on devices, as described in Eq. 2. The model aggregation mechanism is a stochastic gradient descent (SGD) procedure that adjusts the central model’s parameters to approach the parameters of the local models. However, the SGD training process is based on the assumption that all the distributed datasets are extracted from one IID source, while training on non-IID data is the most attractive characteristic of federated learning. To tackle the non-IID challenge in federated learning, we propose 1) distance-based federated loss – a new objective function using a distance between parameters from the global and local models, and 2) multi-center federated loss – the total distance-based loss to aggregate local models to multiple centers.
Distance-based Federated Loss (DF-Loss)
The DF-Loss uses a distance-based loss to replace the supervised learning-based federated loss (Eq. 2) in the federated setting. In particular, the new loss is based on an empirical assumption: the model with better initialization is more likely to converge at the global optimum. This assumption is illustrated in Fig. 3. Moreover, considering the limited computation power and insufficient training data on each device, a “good” initialization is vitally important to train a supervised learning model on the device.
According to the above assumption, the overall goal of a federated learning task aims to find a shared initialization close to all models parameters on distributed devices. This is illustrated in the left of Fig. 1. If the given initialization is close to the optimal parameters of all local models, then the trained model with this initialization is more likely to efficiently converge at the global optimum, which will decrease the supervised learning-based loss. Therefore, we can argue that the supervised learning-based federated loss is consistent with the proposed distance-based federated loss when optimizing the federated learning model.
The learning objective of the federated learning is replaced with minimizing the total distance between the global model and local models. Formally, this new objective function is to minimize the total loss , which is denoted as
where denotes a function to measure the dissimilarity between server model parameters and local model . Note that a direct macro average is used here regardless of the weight of each device, which treats every device equally. The weights used in Eq. 2 can easily be incorporated for a micro average. In addition, any distance metric can be integrated into this framework, and in this paper we only explore a simplest one, i.e., L2 distance between two models, i.e.,
As elaborated above, our new learning objective for federated learning is to minimize the distance between the global and local models. Furthermore, according to the non-IID assumption, the datasets in different devices can be grouped into multiple clusters where the on-device datasets in the same cluster are likely to be generated from one distribution. As illustrated in the right of Fig. 2 we can use intra-cluster distance to measure their distance-based loss, and the new loss, namely multi-center distance-based loss (MD-Loss), could be defined as below.
where cluster assignment , as defined in Eq. 7, indicates whether the device belongs to cluster , and is the parameters of the aggregated model in cluster .
4.3 Optimization Method
In general, Expectation-Maximization (EM) 
can be used to solve the distance-based objective function of clustering, e.g., K-Means. However, in contrast to the general objective of clustering, our proposed objective, as described in Eq.6, has a dynamically changing during optimization. Therefore, we adapt the Stochastic Expectation Maximization (SEM)  optimization framework by adding one step, i.e., updating . In the modified SEM optimization framework, named federated SEM (FeSEM), we sequentially conduct: 1) E-step – updating cluster assignment with fixed , 2) M-step – updating cluster centers , and 3) updating local models by providing new initialization .
Firstly, for the E-Step, we calculate the distance between the cluster center and nodes – each node is a model’s parameters , then update the cluster assignment by
Secondly, for the M-Step, we update the cluster center according to the and , i.e.,
Thirdly, for updating the local models, the global model’s parameters are sent to each device in cluster to update its local model, and then we can fine-tune the local model’s parameters using a supervised learning algorithm on its own private training data. The local training procedure is similar to Eq. 1, except that instead of random initialization, the corresponding is used as initialization.
Lastly, we repeat the three stochastic updating steps above until convergence. The sequential executions of the three updates compose the iterations in FeSEM’s optimization procedure. In particular, we sequentially update three variables , , and while fixing other factors. These three variables will be jointly used to calculate the objective of our proposed multi-center federated learning in Eq. 6.
We implement FedSEM in Algorithm 1 that is an iterative procedure. As elaborated in Section 4.3, each iteration is comprised of three steps to update the cluster assignment, the cluster center, and the local models respectively. In particular, in the third step for updating the local model, we need to fine-tune the local model by implementing Algorithm 2.
As a proof-of-concept scenario to demonstrate the effectiveness of the proposed method, we experimentally evaluate and analyze the proposed FeSEM on federated benchmarks.
5.1 Training Setups
We employed two publicly-available federated benchmarks datasets introduced in LEAF . LEAF is a benchmarking framework for learning in federated settings. The used datasets are Federeated Extended MNIST (FEMNIST)333http://www.nist.gov/itl/products-and-services/emnist-dataset  and Federated CelebA (FedCelebA)444http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html 
. In FEMNIST, we split the handwritten images according to the writers. For FedCelebA, we extracted face images for each person and developed an on-device classifier to recognize whether the person smiles or not. A statistical description of the datasets is described in Table1.
|# of instances||805,263||202,599|
|# of classes||62||2|
|# of devices||3,400||9,343|
|# of inst. per dev.||227||21|
For FEMNIST, we set a three-layer CNN as the local model – a standard architecture introduced in a tutorial of Tensorflow555https://www.katacoda.com/basiafusinska/courses/tensorflow-getting-started/tensorflow-mnist-expert. For the FedCelebA, we used a CNN with the same architecture from .
We compared the proposed federated learning method with five baseline methods below.
NonFed: We will conduct the supervised learning task at each device without federated learning framework.
FedSGD uses SGD to optimise the global model.
FedAvg: the SGD-based federated learning that considers the averaging weight using the data size of each device  .
Clustered is to enclose FedAvg into a hierarchical clustering framework .
Robust design a framework run in modular manner, namely, robust clustering model, and a communication efficient, distributed, robust optimization over each cluster separately .
FedDANE is a federated learning framework with Newton-type optimization method .
FedProx adds a proximal term onto objective function of the learning task on device .
FedDist: We adapt a distance based-objective function in Reptile meta-learning  to federated setting.
FedDWS: a variation of FedDist by changing the aggregation to weighted averaging, and the weight depends on the data size of each device.
FeSEM(): our proposed multi-center federated learning method solved via federated SEM (FeSEM) with different number of clusters .
For each device’s data, we used 80% for training and 20% for testing. The federated learning algorithm was to train an optimal global model using all devices’ training data. The learned global model was sent to each device for further fine-tuning with local training data, and then the fine-tuned local model was tested on each device. For the initialization of cluster centers in FeSEM, we conducted 20 times of pure clustering with different randomized initialization, and then the “best” initialization, which has the minimal intra-cluster distance, was selected as the initial centers for FeSEM. For the local update procedure of FeSEM, we set to 1, meaning we only updated the for one time in each local update.
Given numerous devices, we evaluated the overall performance of the federated learning methods. We used classification accuracy and F1 score as the metrics for the two benchmarks. In addition, due to the multiple devices involved, we explored two ways to calculate the metrics, i.e., micro and macro. The only difference is that when computing an overall metric, “micro” calculates a weighted average of the metrics from devices where the weight is proportional to the data amount, while “macro” directly calculates an average over the metrics from devices.
5.2 Experimental Study
As shown in Table 2, we compared our proposed FeSEM with the baselines, and found that our proposed federated learning framework can achieve the best performance in the most cases. But, it is observed that the proposed model achieves an inferior performance for Micro F1 score on FedCelebA dataset, and a possible reason is that our objective function defined in Eq. 6 did not take into account the device weights. Hence our model is able to deliver a significant improvement in terms of “macro” metrics. Furthermore, at the last three columns in Table 2, we found that FeSEM with a larger number of clusters empirically achieves a better performance, which verifies the correctness of non-IID assumption of the data distribution.
To verify convergence of the proposed approach, we conducted a convergence analysis by running FeSEM with different cluster numbers (from 2 to 4) within 100 iterations. As shown in Fig. 4, FeSEM can efficiently converge on both datasets; and it can achieve the best performance when cluster number .
To check the effectiveness of our proposed optimization method and whether the devices grouped into one cluster have similar model, we conducted a clustering analysis via an illustration. We used two-dimensional figures to display the clustering results of local models derived from FeSEM(4) on the two benchmark datasets. In particular, we randomly chose 100 devices from each benchmark dataset, and ploted each device’s local model as one point in the 2D space after PCA-based dimension reduction. As shown in Fig.5, both datasets are suitable for four clusters and each cluster are distinguishable to each other.
Case Study on Clustering
To intuitively judge whether the devices grouped into one cluster have similar data distribution, we conducted case studies on a case of two clusters that are extracted from a trained FeSEM(2) model. For FMNIST, as shown in the top of Fig. 6, the cluster1 is likely to recognize hand-writings with regular font, and the cluster2 is to tackle the freestyle font for hand-writing letters. For FedCelebA, as shown in the bottom of Fig. 6
, the face recognition task in cluster1 is likely to handle the smile faces with relatively simple background.
In this work, we propose a novel multi-center federated learning framework to tackle the non-IID challenge of the federated setting. This proposed method can efficiently capture the multiple hidden distributions of numerous devices or users. And an optimization approach, federated SEM, is then proposed to solve the multi-center federated learning problem effectively. The experimental results show the effectiveness of our algorithm, and several analyses are further provided for an insight understanding of the proposed approach.
-  Arivazhagan, M.G., Aggarwal, V., Singh, A.K., Choudhary, S.: Federated learning with personalization layers. arXiv preprint arXiv:1912.00818 (2019)
Bishop, C.M.: Pattern recognition and machine learning. springer (2006)
-  Bonawitz, K., Eichner, H., et.al.: Towards federated learning at scale: System design. ArXiv abs/1902.01046 (2019)
-  Caldas, S., Wu, P., Li, T., Konečnỳ, J., McMahan, H.B., Smith, V., Talwalkar, A.: Leaf: A benchmark for federated settings. arXiv preprint arXiv:1812.01097 (2018)
-  Cao, T.D., Truong-Huu, T., Tran, H., Tran, K.: A federated learning framework for privacy-preserving and parallel training (2020)
-  Cappé, O., Moulines, E.: On-line expectation–maximization algorithm for latent data models. Journal of the Royal Statistical Society: Series B (Statistical Methodology) 71(3), 593–613 (2009)
-  Cohen, G., Afshar, S., Tapson, J., van Schaik, A.: Emnist: an extension of mnist to handwritten letters. arXiv preprint arXiv:1702.05373 (2017)
-  Geyer, R.C., Klein, T., Nabi, M.: Differentially private federated learning: A client level perspective. arXiv preprint arXiv:1712.07557 (2017)
-  Ghosh, A., Hong, J., Yin, D., Ramchandran, K.: Robust federated learning in a heterogeneous environment. arXiv preprint arXiv:1906.06629 (2019)
-  Haddadpour, F., Mahdavi, M.: On the convergence of local descent methods in federated learning. arXiv preprint arXiv:1910.14425 (2019)
-  Hsu, T.M.H., Qi, H., Brown, M.: Measuring the effects of non-identical data distribution for federated visual classification. arXiv preprint arXiv:1909.06335 (2019)
-  Kairouz, P., McMahan, H.B., et.al.: Advances and open problems in federated learning. arXiv preprint arXiv:1912.04977 (2019)
-  Konecný, J., McMahan, H.B., Yu, F.X., Richtárik, P., Suresh, A.T., Bacon, D.: Federated learning: Strategies for improving communication efficiency. CoRR abs/1610.05492 (2018)
-  Li, D., Wang, J.: Fedmd: Heterogenous federated learning via model distillation. arXiv preprint arXiv:1910.03581 (2019)
-  Li, L., Xu, W., Chen, T., Giannakis, G.B., Ling, Q.: Rsa: Byzantine-robust stochastic aggregation methods for distributed learning from heterogeneous datasets. In: AAAI. vol. 33, pp. 1544–1551 (2019)
-  Li, Q., Wen, Z., He, B.: Federated learning systems: Vision, hype and reality for data privacy and protection. arXiv preprint arXiv:1907.09693 (2019)
-  Li, T., Sahu, A.K., Zaheer, M., Sanjabi, M., Talwalkar, A., Smith, V.: Federated optimization in heterogeneous networks. arXiv preprint arXiv:1812.06127 (2018)
-  Li, T., Sahu, A.K., Zaheer, M., Sanjabi, M., Talwalkar, A., Smith, V.: Feddane: A federated newton-type method. arXiv preprint arXiv:2001.01920 (2020)
-  Liang, P.P., Liu, T., Ziyin, L., Salakhutdinov, R., Morency, L.P.: Think locally, act globally: Federated learning with local and global representations. arXiv preprint arXiv:2001.01523 (2020)
-  Liu, Y., Ma, Z., Liu, X., Wang, Z., Ma, S., Ren, K.: Revocable federated learning: A benchmark of federated forest. arXiv preprint arXiv:1911.03242 (2019)
Liu, Z., Luo, P., Wang, X., Tang, X.: Deep learning face attributes in the wild. In: Proceedings of the IEEE ICCV. pp. 3730–3738 (2015)
-  Luo, J., Wu, X., Luo, Y., Huang, A., Huang, Y., Liu, Y., Yang, Q.: Real-world image datasets for federated learning. arXiv preprint arXiv:1910.11089 (2019)
-  McMahan, H.B., Moore, E., Ramage, D., Hampson, S., y Arcas, B.A.: Communication-efficient learning of deep networks from decentralized data (2016)
-  Nichol, A., Schulman, J.: Reptile: a scalable metalearning algorithm. arXiv preprint arXiv:1803.02999 2 (2018)
-  Rouhani, B.D., Riazi, M.S., Koushanfar, F.: Deepsecure: Scalable provably-secure deep learning. In: The 55th Annual Design Automation Conference. p. 2. ACM (2018)
-  Sattler, F., Müller, K.R., Samek, W.: Clustered federated learning: Model-agnostic distributed multi-task optimization under privacy constraints. arXiv preprint arXiv:1910.01991 (2019)
-  Shamir, O., Srebro, N., Zhang, T.: Communication-efficient distributed optimization using an approximate newton-type method. In: ICML. pp. 1000–1008 (2014)
-  Yang, Q., Liu, Y., Chen, T., Tong, Y.: Federated machine learning: Concept and applications. ACM Transactions on Intelligent Systems and Technology (TIST) 10(2), 12 (2019)