Federated learning (FL) is an innovative privacy-preserving machine learning paradigm that distributes collaborative model training across participating user devices without users’ sharing their raw training samples. In the widely used federated learning algorithmFedAvg fedavg, clients jointly train a shared machine learning model by iteratively running local updates and synchronizing their intermediate local models with a central server. In spite of its success in applications such as next word prediction next_word and learning on electronic health records medical_fl, FL is known to suffer from slow model training when clients’ local data distributions are heterogeneous, or non-IID (non- independently and identically distributed). In response to this challenge, some recent works propose to bypass data heterogeneity by performing local model personalization. Instead of pursuing one universally applicable model shared by all clients, these algorithms’ training objective is to create one model for each client that fits its local data. Personalization methods include local fine tuning fine_tune
, model interpolationthree_approaches, and multi-task learning fmtl. In this paper, we focus on an alternative approach: clustered federated learning, which we generalize to train both cluster and personalized models on realistic distributions of client data.
Clustered FL relaxes the assumption of FL that each client has an unique data distribution; instead, it allows different clients to share one data distribution, with fewer source data distributions than clients. The objective of clustered FL is to train one model for every distribution. In traditional clustered FL, a client can only be associated with one data distribution. We thus call this method hard clustered federated learning. Under the hard association assumption, the non-IID problem can be easily resolved: simply group clients with the same data distribution into one cluster, then conduct conventional FL on each cluster, within which the data distribution is now IID among clients. Unlike other personalization methods, clustered FL thus produces centrally available models that can be selectively migrated to new users that are unwilling, or unable, to engage in the subsequent local adaptation process (e.g. fine tuning) due to privacy concerns or resource limitations. This convenience in model adoption is particularly valuable for the current training-testing-deployment lifecycle of FL where deployment, rather than the training itself, is the end goal advances.
However, hard clustered FL faces two fundamental problems in practice. First, multiple clients may be unlikely to possess identical data distributions. In fact, the real-world user data is more likely to follow a mixture of multiple distributions fmtl_mixture
. E.g., if each client is a mobile phone and we wish to model its user’s content preferences, we might expect the clients to be clustered into adults and children. However, adult users may occasionally view children’s content, and devices owned by teenagers (or shared by parents and children) may possess large fractions of data from both distributions. Similarly, content can be naturally grouped by users’ interests (e.g., genres of movies), each of which may have a distinct distribution. Data from users with multiple interests then reflects a mixture of these distributions. Since the mixture ratios can vary for different clients, they may have different overall distributions even though the source distributions are identical. Clustering algorithms like the Gaussian mixture modelgaussian_mixture use a similar rationale. Clients may then require models personalized to their distributions to make accurate predictions on their data, in addition to the cluster models used for new users.
Hard clustered FL’s second challenge is that it cannot effectively exploit similarities between different clusters. Though FL clients may have non-IID data distributions, two different distributions may still exhibit some similarity, as commonly assumed in personalization works fmtl. For example, young people may have more online slang terms in their chatting data, but all users (generally) follow the same basic grammar rules. Thus, the knowledge distilled through the training on one distribution could be transferred to accelerate the training of others. However, in most hard clustered FL algorithms, different cluster models are trained independently, making it difficult to leverage the potential structural likeness among distributions. Note that unlike in other personalization methods where the discussion of similarity is restricted to similarities between individual clients, here we focus on the broader similarities between source cluster distributions. Thus, we can gain better insight into the general data relationship rather than just the relationships between participating clients.
To overcome clustered FL’s first challenge of the hard association assumption, in this paper, we utilize soft clustered federated learning. In soft clustered FL, we suppose that the data of each client follows a mixture of multiple distributions. However, training cluster models using clients with mixed data raises two new challenges. First, the workload of clients can explode. When all the data of a client comes from the same distribution, as in hard clustered FL, it ideally only needs to contribute towards one training task: training that distribution’s model. However, in soft clustered FL, a client has multiple data sources. A natural extension of hard clustered FL is for the client to help train all cluster models whose distributions are included in its mixture fmtl_mixture. However, the workload of participating clients then grows linearly with the number of clusters, which can be large (though typically much smaller than the number of clients) for some applications. This multiplying of client workload can make soft clustered FL infeasible, considering the resource restrictions on typical FL user devices and the long convergence time for many FL models fedavg. Second, the training of cluster models and the local personalization are distinct. In hard clustered FL, client models are the same as cluster models since a client is uniquely bound to one cluster. In soft clustered FL, local distributions differ from individual cluster distributions, and thus training cluster models does not directly help the local personalization. Complicating things further, these local distributions and their exact relationships to the cluster models are unknown a priori. Combining the training of cluster and personalized models is then challenging.
To solve these two challenges, and handle the second disadvantage of hard clustered FL discussed above, we utilize the proximal local updating trick, which is originally developed in FedProx fedprox to grant clients the use of different local solvers in FL. During the course of proximal local updating, instead of working on fitting the local model to the local dataset, each client optimizes a proximal local objective function that both carries local information and encodes knowledge from all cluster models. We name this proposed algorithm FedSoft.
In FedSoft, since the fingerprints of all clusters are integrated into one optimization objective, clients only need to solve one single optimization problem, for which the workload is almost the same as in conventional FL. In addition, by combining local data with cluster models in the local objective, clients can perform local personalization on the fly. Eventually, the server obtains collaboratively trained cluster models that can be readily applied to new users, and each participating client gets one personalized model as a byproduct. Proximal local updating allows a cluster to utilize the knowledge of similar distributions, overcoming the second disadvantage of the hard clustered FL. Intuitively, with all clusters present in the proximal objective, a client can take as reference training targets any cluster models whose distributions take up non-trivial fractions of its data. These component distributions, co-existing in the same dataset, are similar by nature. Thus, a personalized local model integrating all its component distributions can in turn be utilized by the component clusters to exploit their similarities.
Our contributions are: We design the FedSoft algorithm for efficient soft clustered FL. We establish a convergence guarantee that relates the algorithm’s performance to the divergence of different distributions, and validate the effectiveness of the learned cluster and personalized models in experiments under various mixture patterns. Our results show the training of cluster models converges linearly to a remaining error determined by the cluster heterogeneity, and that FedSoft can outperform existing FL implementations in both global cluster models for future users and personalized local models for participating clients.
2 Related Works
The training objective of hard clustered FL is to simultaneously identify the cluster partitions and train a model for each cluster. Existing works generally adopt an Expectation-Maximization (EM) like algorithm, which iteratively alternates between the cluster identification and model training. Based on how the partition structure is discovered, these algorithms can be classified into four types:
The first type leverages the distance between model parameters, e.g., multi-center propose to determine client association based on the distances between client models and server models. Similarly, flhc
suggest to apply a distance-based hierarchical clustering algorithm directly on client models. The second type determines the partition structure based on the gradient information, e.g., theCFL cfl
algorithm splits clients into bi-partitions based on the cosine similarity of the client gradients, and then checks whether a partition is congruent (i.e., contains IID data) by examining the norm of gradients on its clients. Likewise, theFedGroup fedgroup
algorithm quantifies the similarity among client gradients with the so-called Euclidean distance of decomposed cosine similarity metric, which decomposes the gradient into multiple directions using singular value decomposition. The third type utilizes the training loss, e.g., inHyperCluster three_approaches, each client is greedily assigned to the cluster whose model yields the lowest loss on its local data. A generalization guarantee is provided for this algorithm. hard_cluster_ucb propose a similar algorithm named IFCA, for which a convergence bound is established under the assumption of good initialization and all clients having the same amount of data. The fourth type uses exogenous information about the data, e.g., patient_cluster and covid group patients into clusters respectively based on their electronic medical records and imaging modality. This information usually entails direct access to the user data and thus cannot be applied in the general case.
Recently, fmtl_mixture propose a multi-task learning framework similar to soft clustered FL that allows client data to follow a mixture of distributions. Their proposed FedEM
algorithm adopts an EM algorithm and estimates the mixture coefficients based on the training loss. However,FedEM requires every client to perform a local update for each cluster in each round, which entails significantly more training time than conventional FedAvg
. Their analysis moreover assumes a special form of the loss function with all distributions having the same marginal distribution, which is unrealistic. In contrast,FedSoft requires only a subset of clients to return gradients for only one optimization task in each round. Moreover, we show its convergence for generic data distributions and loss functions.
The proximal local updating procedure that we adopt incorporates a regularization term in the local objective, which is also used for model personalization outside clustered settings. Typical algorithms include FedAMP fedamp, which adds an attention-inducing function to the local objective, and pFedMe pfedme, which formulates the regularization as Moreau envelopes.
3 Formulation and Algorithm
Mixture of distributions. Assume that each data point at each client is drawn from one of the distinct data distributions . Similar to general clustering problems, we take
as a hyperparameter determined a priori. Data points from all clients that follow the same distribution form acluster. In soft clustered FL, a client may possess data from multiple clusters. Given a loss function , the (real) cluster risk is the expected loss for data following :
We then wish to find cluster models such that all cluster objectives are minimized simultaneously. These cluster models will be co-trained by all clients through coordination at the central server:
Suppose a client with local dataset has data points, among which data points are sampled from distribution . The real risk of a client can thus be written as an average of the cluster risks:
Here we define as the importance weight of cluster to client . In general, ’s are unknown in advance and the learning algorithm attempts to estimate their values during the learning iterations. It is worth noting that while we directly work on real risks, our formulation and analysis can be easily extended to empirical risks by introducing local-global divergences as in fedprox.
Here is a hyperparameter and denotes the estimation of at time . In the local updating step, every client searches for the optimal local model that minimizes given the current global estimation of cluster models . As in fedprox, clients may use any local solver to optimize . This design of the proximal objective entails cluster models be shared among all clients through the server, as is usual in clustered FL hard_cluster_ucb. We thus alternatively call the centers.
The regularization term in the proximal objective serves as a reference point for the local model training. It allows clients to work on their own specific dataset while taking advantage of and being guided by the globally shared knowledge of the centers. The regularization is weighted by the importance weights , so that a client will pay more attention to distributions that have higher shares in its data. To see why compounded regularization helps identify individual centers, assume we have a perfect guess of . The minimization of (4) can then be decoupled as a series of sub- optimization problems . Thus, after is minimized, the sub-problems corresponding to large will also be approximately solved. We can hence utilize the output local model to update these centers with large . Moreover, trained in this manner forges all its component distributions , which may share some common knowledge. Thus, the training of these clusters are bonded through the training of their common clients, exploiting similarities between the clusters.
The output model is itself a well personalized model that leverages both local client knowledge and the global cluster information. fmtl_mixture show that under certain conditions, the optimal client model for soft clustered FL is a mixture of the optimal cluster models. The same implication can also be captured by our proximal updating formulation. When , the gradient is
which implies that should be centered on . As a result, through the optimization of all , not only will the server obtain the trained cluster models, but also each client will obtain a sufficiently personalized local model.
Algorithm design. We formally present FedSoft in Algorithm 1. The first step of the algorithm is to estimate the importance weights for each client (lines 3-14). The algorithm obtains them by finding the center that yields the smallest loss value for every data point belonging to that client, and counting the number of points matched to every cluster . If a client has no samples matched to (), the algorithm sets , where is a pre-defined smoothing parameter.
Once the server receives the importance weights , it computes the aggregation weights as follows (line 15):
i.e., a client that has higher importance weight on cluster will be given higher aggregation weight, and vice versa. The introduction of the smoother avoids the situation where for some cluster, which could happen in the very beginning of the training when the center does not exhibit strength on any distributions. In that case, , i.e., the cluster will be updated in a manner that treats all clients equally. Otherwise, since is very small, a client with will be assigned a , and the aggregation weights of other clients will not be affected.
Though calculating and reporting is computationally trivial compared to the actual training procedure, sending centers to all clients may introduce large communication costs. FedSoft thus allows the estimations of to be used for up to iterations (line 3). In practice, a client can start computing for a cluster before it receives all other centers, the delay of transmission is thus tolerable.
Next, relevant clients run proximal local updates to find the minimizer for the proximal objective , which entails solving only one optimization problem (line 17). In the case when all clients participate, the cluster models are produced by aggregating all client models: . However, requiring full client participation is impractical in the federated setting. We thus use the client selection trick fedavg to reduce the training cost (lines 16). For each cluster , the algorithm randomly selects a small subset of clients to participate in the local updating at time , where .
Clustered FL generally entails more clients to be selected compared to conventional FL, to ensure the convergence of all cluster models. Since FedSoft clients can contribute to multiple centers, however, we select only clients instead of clients in the usual clustered FL. For example, if each distribution has the same share in every client, then in expectation only clients will be selected. This number equals when , i.e., clients are selected by both clusters.
Once the server receives the local models for selected clients, it produces the next centers by simply averaging them (line 18). After completion, the algorithm yields trained cluster models as outputs, and each client obtains a personalized local model as a byproduct.
4 Convergence Analysis
In this section, we provide a convergence guarantee for FedSoft. First, note that we can rewrite (4) as follows:
Here is only defined for , and we call optimizing each a sub-problem for client .
Our analysis relies on the following assumptions:
(-inexact solution) Each client produces a -inexact solution for the local minimization of (4):
(-similarity of sub-problems) The sub-problems of each client have similar optimal points:
for some , where .
(Strong convexity and smoothness) Cluster risks are strongly convex and smooth.
(Bounded initial error) At a certain time of the training, all centers have bounded distance from their optimal points. We begin our analysis at that point:
Assumption 1 assumes significant progress is made on the proximal minimization of , which is a natural extension from assumptions in FedProx fedprox. Assumption 2 ensures the effectiveness of the joint optimization of , i.e., solving one sub-problem can help identify the optimal points of others. Intuitively, if the sub-problems are highly divergent, we would not expect that solving them together would yield a universally good solution. This assumption quantifies our previous reasoning that different distributions co-existing in one local dataset have some similarities, which is the prerequisite for local models to converge and cluster models to be able to learn from each other. Assumption 3 is standard hard_cluster_ucb, and Assumption 4 is introduced by hard_cluster_ucb in order to bound the estimation error of (Lemma 1). Note that with Assumption 3, each sub-problem is also strongly convex and smooth, where , and the subscript indicates they increase with .
To measure the distance of different clusters, we quantify
As we will see later, soft clustered FL performs best when and are close. Intuitively, a very small indicates two clusters are almost identical, and thus might be better combined into one distribution. On the other hand, a very large implies that two clusters are too divergent, making it hard for one model to acquire useful knowledge from the other.
Next, we bound with respect to the true , for which we reply on the following lemma hard_cluster_ucb:
Based on Lemma 1, we can bound as follows
(Bounded estimation errors) The expectation of is bounded as
Here , and the expectation is taken over the randomness of samples.
Next, we seek to characterize each sub-problem at the -inexact solution that approximately minimizes . Intuitively, should perform better for sub-problems with larger . On the other hand, if , we generally cannot expect that will be close to . We summarize this intuition in Theorems 2 and 3.
Theorem 2 indicates that if the is solved with high quality (small ), and the sub-problems are sufficiently similar (small ), then sub-problems with can also be well solved by . It also justifies using as aggregation weights in (6). In the case , according to Theorem 3 (which holds for any ), approaching with will introduce an error of at most .
Finally, we show the convergence of . The following analysis does not depend on ; we will show how affects the convergence in the appendix.
Suppose . After iterations,
The choices of to make is discussed in fedprox. From Corollary 1, the gradient norm converges to a remaining error controlled by . Intuitively, when , further updating with misclassified models will inevitably move away from . This bias cannot be removed unless we have a perfect guess of . Recall that , and thus the remaining term is , which decreases as approaches . Thus, FedSoft performs better when the divergences between clusters are more homogeneous. Note that Corollary 1 seems to imply the remaining error will explode if , but Lemma 1 is only valid when . Thus when is very small, i.e., there exist two distributions that are extremely similar, the remaining error is determined by the maximum divergence of the other distributions. Furthermore, the divergence determines the degree of non-IID of a local dataset (not among clients), which also implicitly affects the accuracy of local solutions . Intuitively, a larger implies it is more difficult to exactly solve a local problem involving multiple distributions, resulting in a greater .
To see the role of cluster heterogeneity, suppose is closer than the distance of all other centers to , then the misclassified samples for cluster 1 are more likely to be matched to cluster 2. Thus, cluster 2 gets more updates from data that it does not own, producing greater remaining training error that drags its center towards cluster 1. On the other hand, if the cluster divergence is homogeneous, then the effect of mis-classification is amortized among all clusters, resulting in a universally smaller remaining error.
Theorem 4 shows the convergence of cluster models in terms of the cluster risks . For the local models , we focus on how clients integrate global knowledge into their local personalizations, which cannot be captured only with the original client risk functions . Thus, we are interested in the convergence performance of with respect to the proximal objective . Note that under Assumption 3, FedSoft is effectively a cyclic block coordinate descent algorithm on a jointly convex objective function of and , for which the convergence is guaranteed:
(Joint convergence of cluster and client models) For fixed importance weights , let , and be the outputs of FedSoft. Then linearly with .
In this section, we verify the effectiveness of FedSoft with two base datasets under various mixture patterns. For all experiments, we use clients, and the number of samples in each client is chosen uniformly at random from 100 to 200. For ease of demonstration, for every base dataset, we first investigate the mixture of distributions and then increase . In the case with two distributions, suppose the cluster distributions are named and . We evaluate the following partition patterns:
10:90 partition: 50 clients have a mixture of 10% and 90% , and 50 have 10% and 90% .
30:70 partition: Same as above except the ratio is 30:70.
Linear partition: Client has % data from and % data from , .
We further introduce the random partition, where each client has a random mixture vector generated by dividing therange into segments with points drawn from . We use all four partitions for , and only use the random partition when for simplification. Each partition produces non-IID local distributions, i.e., clients have different local data distributions. Specifically, the 10:90 and 30:70 partitions yield 2 local distributions, while the linear and random partitions yield 100. Unless otherwise noted, we choose FedSoft’s estimation interval , client selection size , counter smoother = 1e-4, and all experiments are run until both cluster and client models have fully converged. All models are randomly initialized with the Xavier normal xavier initializer without pre-training, so that the association among clients, centers, and cluster distributions is built automatically during the training process.
We compare FedSoft with two baselines: IFCA hard_cluster_ucb and FedEM fmtl_mixture. Both baseline algorithms produce one center for each cluster, but they do not explicitly generate local models as in FedSoft. Nevertheless, they also estimate the importance weights for each client, we thus use the center corresponding to the largest importance weight as a client’s local model. Since we expect cluster models will be deployed to new users, we evaluate their test accuracy/error on holdout datasets sampled from the corresponding cluster distributions. For local models, they are expected to fit the local data of participating clients, we hence evaluate their accuracy/error on local training datasets. Throughout this section, we use and to represent the average accuracy/error of the cluster and client models, not the accuracy/error of the averaged models.
We use two base datasets to generate the various cluster distributions. A different model is adopted for each dataset. Due to the space limit, we put the details of other experiment parameters in the supplementary technical report.
Synthetic Data. We generate synthetic datasets according to where , , hard_cluster_ucb. Unless otherwise noted, we use
. We use the conventional linear regression model for this dataset.
EMNIST Letters. We use the handwritten images of English letters in the EMNIST dataset to create 2 distributions for the lower and uppercase letters oneshot, each with 26 classes. Then we rotate these images counterclockwise by mnist_rotation, resulting in 4 total distributions. In the setting we compare the two distributions. A rotation variant CNN model is used for this dataset.
In general, the letter distributions share more similarities with each other, while the synthetic distributions are more divergent, e.g., letters like “O” have very similar upper and lowercase shapes and are invariant to rotations. On the other hand, data generated from and can be easily distinguished. We thus expect the mixture of synthetic data to benefit more from the personalization ability of FedSoft.
The typical convergence process of FedSoft is shown in Figure 1. In this example of the synthetic data, FedSoft is able to automatically distinguish the two cluster distributions. After around 5 global epochs, center 1 starts to exhibit strength on the first cluster distribution, and center 0 concentrates on the other, which implies a correct association between centers and cluster distributions. Similarly, the importance weight estimations , which are initially around 50:50, soon converge to the real mixture ratio 10:90.
Synthetic data: mean squared error
EMNIST letters: accuracy (%)
Table 1 lists the mean squared error (MSE) or accuracy of the output cluster models. FedSoft produces high quality centers under all mixture patterns. In particular, each center exhibits strength on one distribution, which indicates that FedSoft builds correct associations for the centers and cluster distributions. The performance gap between two distributions using the same center is larger for the synthetic data. This is because the letter distributions have smaller divergence than the synthetic distributions. Thus, letter models can more easily transfer the knowledge of one distribution to another, and a center focusing on one distribution can perform well on the other. Notably, the 30:70 mixture has the worst performance for both datasets, which is due to the degrading of local solvers when neither distribution dominates. Thus, the local problems under this partition are solved less accurately, resulting in poor local models and a large value of in Theorem 2, which then produces high training loss on cluster models according to Theorem 4.
Table 2 compares FedSoft with the baselines. Not only does FedSoft produce more accurate cluster and local models, but it also achieves better balance between the two trained centers. Similarly, Figure 2 shows the importance estimation of clients for the first cluster. FedSoft and IFCA are able to build the correct association (though the latter is a hard partition), while FedEM appears to be biased to the other center by putting less weights () on the first one.
Next, we evaluate the algorithm with the random partition for the mixture of more distributions. Tables 3 and 4 show the MSE or accuracy of cluster models for the mixture of 8 and 4 distributions on synthetic and letters data, where we still observe high-quality outcomes and good association between centers and cluster distributions.
This paper proposes FedSoft, an efficient algorithm generalizing traditional clustered federated learning approaches to allow clients to sample data from a mixture of distributions. By incorporating proximal local updating, FedSoft enables simultaneous training of cluster models for future users, and personalized local models for participating clients, which is achieved without increasing the workload of clients. Theoretical analysis shows the convergence of FedSoft for both cluster and client models, and the algorithm exhibits good performance in experiments with various mixture patterns.
This research was partially supported by NSF CNS-1909306 and CNS-2106891.
Appendix A Proof of Theorems
Proof of Theorem 1
Let be some virtual group of client ’s data points that follow the same distribution. Thus,
Proof of Theorem 2
For simplification we drop the dependency of in and . Take , we have
Reorganizing, we have
Proof of Theorem 3
Let , we have , thus
Proof of Theorem 4
We first introduce the following lemma:
, where the expectation is taken over .
Let , and note that since each client estimates its independently. Thus,
We then formally prove Theorem 4:
For , define
using Theorem 2, we have
Next we bound for ,