A popular application of language models is virtual keyboard applications, where the goal is to predict the next word, given the previous words (hard2018federated). For example, given “I live in the state of”, ideally, it should guess the state the user intended to type. Suppose we train a single model on all the user data and deploy it, then the model would predict the same state for all users and would not be a good model for most users. Similarly, in many practical applications, the distribution of data across clients is highly non-i.i.d. and training a single global model for all clients may not be optimal.
Thus, we study the problem of learning personalized models, where the goal is to train a model for each client, based on the client’s own dataset and the datasets of other clients. Such an approach would be useful in applications with the natural infrastructure to deploy a personalized model for each client, which is the case with large-scale learning scenarios such as federated learning (FL) (mcmahan2017communication).
Before we proceed further, we highlight one of our use cases in FL. In FL, typically a centralized global model is trained based on data from a large number of clients, which may be mobile phones, other mobile devices, or sensors (mcmahan2017communication; konevcny2016federated; konecny2016federated2)
using a varaint of stochastic gradient descent calledFedAvg. This global model benefits from having access to client data and can often perform better on several learning problems, including next word prediction (hard2018federated; yang2018applied) and predictive models in health (brisimi2018federated). We refer to Appendix A for more details on FL.
Personalization in FL was first studied by wang2019federated. They showed that federated models can be fine-tuned based on local data. They proposed methods to find the best hyper-parameters for fine-tuning and showed that it improves the next word prediction of language models in virtual keyboard applications. Recently, jiang2019improving drew interesting connections between FedAvg and first order model agnostic meta learning (finn2017model) and showed that FedAvg is already a meta-learning algorithm. Apart from these, personalization in the context of FL has not been well studied theoretically.
We provide a learning-theoretic framework, generalization guarantees, and computationally efficient algorithms for personalization. Since FL is one of the main frameworks where personalized models can be used, we propose efficient algorithms that take into account computation and communication bottlenecks.
Before describing the mathematical details of personalization, we highlight two related models. The first one is the global model trained on data from all the clients. This can be trained using either standard empirical risk minimization (vapnik1992principles) or other methods such as agnostic risk minimization (mohri2019agnostic). The second baseline model is the purely local model trained only on the client’s data.
The global model is trained on large amounts of data and generalizes well; however it does not perform well for clients whose data distributions are very different from the global train data distribution. On the other hand, the train data distributions of local models match the ones at inference time, but they do not generalize well due to the scarcity of data.
Personalized models can be viewed as intermediate models between pure-local and global models. Thus, the hope is that they incorporate the generalization properties of the global model and the distribution matching property of the local model. Before we proceed further, we first introduce the notation used in the rest of the paper.
We start with some general notation and definitions used throughout the paper. Let denote the input space and the output space. We will primarily discuss a multi-class classification problem where is a finite set of classes, but much of our results can be extended straightforwardly to regression and other problems. The hypotheses we consider are of the form , where stands for the simplex over . Thus,
is a probability distribution over the classes that can be assigned to. We will denote by a family of such hypotheses . We also denote by
a loss function defined overand taking non-negative values. The loss of for a labeled sample is given by . Without loss of generality, we assume that the loss is bounded by one. We will denote by the expected loss of a hypothesis with respect to a distribution over :
and by its minimizer: . Let denote the Rademacher complexity of class over the distribution with samples.
Let be the number of clients. The distribution of samples of client is denoted by . Clients do not know the true distribution, but instead, have access to samples drawn i.i.d. from the distribution . We will denote by the corresponding empirical distribution of samples and by the total number of samples.
2.2 Local model
We first ask when it is beneficial for a client to participate in global model training. Consider a canonical user with distribution . Suppose we train a purely local model based on the client’s data and obtain a model . By standard learning-theoretic tools (MohriRostamizadehTalwalkar2012), the performance of this model can be bounded as follows: with probability at least , the minimizer of empirical risk satisfies
where is the VC-dimension of the hypothesis class . From (1), it is clear that local models perform well when the number of samples is large. However, this is often not the case. In many realistic settings, such as virtual keyboard models, the average number of samples per user is in the order of hundreds, whereas the VC-dimension of the hypothesis class is in millions (hard2018federated). In such cases, the above bound becomes vacuous.
2.3 Uniform global model
The global model is trained by minimizing the empirical risk on the concatenation of all the samples. For , the weighted average distribution is given by . The global model is trained on the concatenated samples from all the users and hence is equivalent to minimize the loss on the distribution , where . Since the global model is trained on data from all the clients, it may not match the actual underlying client distribution and thus may perform worse.
The divergence between distributions is often measured by a Bregman divergence such as KL-divergence or unnormalized relative entropy. However, such divergences do not consider the task at hand. To obtain better bounds, we use the notion of label-discrepancy between distributions (mansour2009domain). For two distributions over features and labels, and , and a class of distributions , discrepancy is given by
If the loss of all the hypotheses in the class is the same under both and , then the discrepancy is zero and models trained on generalize well on and vice versa. Thus, discrepancy takes into account both the hypothesis set and the loss functions, and hence the structure of the learning problem.
With the above definitions, it can be shown that the uniform global model generalizes as follows: with probability at least
, the minimizer of empirical risk on the uniform distribution satisfies
Since the global model is trained on the concatenation of all users’ data, it generalizes well. However, due to the distribution mismatch, the model may not perform well for a specific user. If , the difference between local and global models depends on the discrepancy between and , the number of samples from the domain , and the total number of samples . While in most practical applications is small and hence a global model usually performs better, this is not guaranteed. We provide the following simple example, which shows that global models can be a constant worse compared to the local model.
Let and . Suppose there are two clients with distributions and defined as follows. and if and zero otherwise. Similarly, only if and zero otherwise. Let
be the class of threshold classifiers indexed by a thresholdand sign such that is given by . Further, suppose we are interested in zero-one loss and the number of samples from both domains is very large and equal.
The optimal classifier for is and the optimal classifier for is , and they achieve zero error in their respective clients. Since the number of samples is the same from both clients, is the uniform mixture of the two domains, . Note that for all , and hence the global objective cannot differentiate between any of the hypotheses in . Thus, with high probability, any globally trained model incurs a constant loss on both clients.
Since the uniform global model assigns weight to client , clients with larger numbers of samples receive higher importance. This can adversely affect clients with small amounts of data. Furthermore, by (2), the model may not generalize well for clients whose distribution is different than the uniform distribution. Thus, (1) and (2) give some guidelines under which it is beneficial for clients to participate in global model training.
Instead of using uniform weighting of samples, one can use agnostic risk, which is less risk averse. We refer to Appendix B for details about the agnostic risk minimization.
2.4 Learning theory for personalization
We ask if personalization can be achieved by an intermediate model between the local and global models. This gives rise to three natural algorithms, which are orthogonal and can be used separately or together.
Train a model for subsets of users: we can cluster users into groups and train a model for each group. We refer to this as user clustering, or more refinely hypothesis-based clustering.
Train a model on interpolated data: we can combine the local and global data and train a model on their combination. We refer to this as data interpolation.
Combine local and global models: we can train a local and a global model and use their combination. We refer to this as model interpolation.
In the rest of the paper, we study each of the above methods.
3 User clustering
Instead of training a single global model, a natural approach is to cluster clients into groups and train a model for each group. This is an intermediate model between a purely local and global model and provides a trade-off between generalization and distribution mismatch. If we have a clustering of users, then we can naturally find a model for each user using standard optimization techniques. In this section, we ask how to define clusters. Clustering is a classical problem with a broad literature and known algorithms (jain2010data). We argue that, since the subsequent application of our clustering is known, incorporating it into the clustering algorithm will be beneficial.
If we have meta-features about the data samples and clients, such as location or type of device, we can use them to find clusters. This can be achieved by algorithms such as -means or variants. This approach depends on the knowledge of the meta-features and their relationship to the set of hypotheses under consideration. While it may be reasonable in many circumstances, it may not be always feasible.
If there are no meta-features, a natural approach is to cluster using a Bregman divergence defined over the distributions (banerjee2005clustering)
. However, it is likely that we would overfit as the generalization of the density estimation depends on the VC-dimension of the class, which in general can be much larger than that of the class of hypotheses . To overcome this, we propose an approach based on hypotheses under consideration which we discuss next.
3.2 Hypothesis-based clustering
Consider the scenario where we are interested in finding clusters of images for a facial recognition task. Suppose we are interested in finding clusters of users for each skin-color and find a good model for each cluster. If we naively use the Bregman divergence clustering, it may focus on clustering based on the image background e.g., outdoor or indoors to find clusters instead of skin color.
To overcome this, we propose to incorporate the task at hand to obtain better clusters. We refer to this approach as hypothesis-based clustering and show that it admits better generalization bounds than the Bregman divergence approach. We partition users into clusters and find the best hypothesis for each cluster. In particular, we use the following optimization:
where is the importance of client . The above loss function trains best hypotheses and naturally divides into partitions, where each partition is associated with a particular hypothesis .
In practice, we only have access to the empirical distributions . To simplify the analysis, we use the fraction of samples from each user as . An alternative approach is to use for all users, which assigns equal weight to all clients. The analysis is similar and we omit it to be concise. In particular, we propose to solve for
3.3 Generalization bounds
We now analyze the generalization properties of this technique. First we state a lemma about generalization bounds. [Appendix C.1] Let be the models obtained by solving (3) and be the models obtained by solving (4). Then,
Thus it suffices to provide bounds on this last term. Since we are bounding the maximum difference between true cluster based loss and empirical cluster based loss for all hypotheses, this bound holds for any clustering algorithm. Let be the clusters and let be the number of samples from cluster . Let and be the empirical and true distributions of cluster . With these definitions, we now bound the generalization error of this technique.
[Appendix C.2] With probability at least ,
[Appendix C.3] Let be the VC-dimension of . Then with probability at least , the following holds:
The above learning bound can be understood as follows. For good generalization, the average number of samples per user should be larger than the logarithm of the number of clusters, and the average number of samples per cluster should be larger than the VC-dimension of the overall model. Somewhat surprisingly, these results do not depend on the minimum number of samples per clients and instead depend only on the average statistics.
where is the mapping from users to clusters. Thus, the generalization bound is in between that of the local and global model. For , it yields the global model, and for , it yields the local model. As we increase , the generalization decreases and the discrepancy term gets smaller. Allowing a general lets us choose the best clustering scheme and provides a smooth trade-off between the generalization and the distribution matching. In practice, we choose small values of . We further note that we are not restricted to using the same value of for all clients. We can find clusters for several values of and use the best one for each client separately using a hold-out set of samples.
3.4 Algorithm : HypCluster
We provide an expectation-maximization (EM)-type algorithm for minimizing (4). A naive EM modification may require heavy computation and communication resources. To overcome this, we propose a stochastic EM algorithm in HypCluster. In the algorithm, we denote clusters via a mapping , where denotes the cluster of client . Similar to -means, HypCluster is not guaranteed to converge to the true optimum, but, as stated in the beginning of the previous section, the generalization guarantee of Theorem 3.3 still holds here.
4 Data interpolation
From the point of view of client , there is a small amount of data with distribution and a large amount of data from the global or clustered distribution . How are we to use auxiliary data from to improve the model accuracy on ? This relates the problem of personalization to domain adaptation. In domain adaptation, there is a single source distribution, which is the global data or the cluster data, and a single target distribution, which is the local client data. As in domain adaptation with target labels (blitzer2008learning), we have at our disposal a large amount of labeled data from the source (global data) and a small amount of labeled data from the target (personal data). We propose to minimize the loss on the concatenated data,
where is a hyper-parameter and can be obtained by either cross validation or by using the generalization bounds of blitzer2008learning. can either be the uniform distribution or one of the distributions obtained via clustering.
Personalization is different from most domain adaptation works as they assume they only have access to unlabeled target data, whereas in personalization we have access to labeled target data. Secondly, we have one target domain per client, which makes our problem computationally expensive, which we discuss next. Given the known learning-theoretic bounds, a natural question is if we can efficiently estimate the best hypothesis for a given . However, note that naive approaches suffer from the following drawbacks. If we optimize for each client separately, the time complexity of learning per client is and the overall time complexity is .
In addition to the computation time, the algorithm also admits a high communication cost in FL. This is because, to train the model with a -weighted mixture requires the client to admit access to the entire dataset , which incurs communication cost . One empirically popular approach to overcome this is the fine-tuning approach, where the central model is fine-tuned on the local data (wang2019federated). However, to the best of our knowledge, there are no theoretical guarantees and the algorithm may be prone to catastrophic forgetting (goodfellow2013empirical).
We propose Dapper, a theoretically motivated and efficient algorithm to overcome the above issues. The algorithm first trains a central model on the overall empirical distribution . Then for each client, it subsamples to create a smaller dataset of size of size , where is a constant. It then minimizes the weighted combination of two datasets as in (7) for several values of . Finally, it chooses the best using cross-validation. The algorithm is efficient both in terms of its communication complexity which is and its computation time, which is at most . Hence, the overall communication and computation time is .
4.1 Convergence analysis
We analyze Dapper when the loss function is strongly convex in the hypothesis parameters and show that the model minimizes the intended loss to the desired accuracy. To the best of our knowledge, this is the first fine-tuning algorithm with provable guarantees that ensures that there is no catastrophic forgetting.
To prove convergence guarantees, we need to ask what the desired convergence guarantee is. Usually, models are required to converge to the generalization guarantee and we use the same criterion. To this end, we first state a known generalization theorem. Let and . [ (blitzer2008learning)] If the VC-dimension of the hypothesis class is , then with probability at least ,
Since the generalization bound scales as , the same accuracy in convergence is desired. Let denote the desired convergence guarantee. For strongly convex functions, we show that one can achieve this desired accuracy using Dapper, furthermore the amount of additional data is a constant multiple of , independent of and . [Appendix D] Assume that the loss function is -strongly convex and assume that the gradients are -smooth. Let admit diameter at most . Let a constant independent of . Let the learning rate . Then after steps of SGD, the output satisfies,
4.2 Practical considerations
While the above algorithm reduces the amount of data transfer and is computationally efficient, it may be vulnerable to privacy issues in applications such as FL. To overcome that, we propose several alternatives:
Sufficient statistics: in many scenarios, instead of the actual data, we only need some sufficient statistics. For example in regression with loss, we only need the covariance matrix of the dataset from .
Generative models: for problems such as density estimation and language modelling, we can use the centralized model to generate synthetic samples from and use that as an approximation to . For other applications, one can train a GAN and send the GAN to the clients and the clients can sample from the GAN to create the dataset (augenstein2019generative).
Proxy public data: if it is not feasible to send the actual user data, one could send proxy public data instead. While this may not be theoretically optimal, it will still avoid overfitting to the local data.
5 Model interpolation
The above approaches assume that the final inference model belongs to class . In practice, this may not be the case. One can learn a central model from a class , and learn a local model from , and use their interpolated model
during inference. Such interpolated models are routinely used in applications such as virtual keyboards.
More formally, let be the central or cluster model and let , where is the local model for client . Let be the interpolated weight for client and let . If one has access to the true distributions, then learning the best interpolated models can be formulated as the following optimization,
Since, the learner does not have access to the true distributions, we propose the following optimization,
5.1 Generalization bounds
We now show a generalization bound for the above optimization. [Appendix E] Assume that the loss is Lipschitz. Let be the hypotheses class for the central model and be the hypotheses class for the local models. Let be the optimal values and be the optimal values for the empirical estimates. Then with probability at least ,
Standard upper bounds on Rademacher complexity by the VC-dimension, combined with Jensen’s inequality yields the following result. Assume that is L Lipschitz. Let be the optimal values and be the optimal values for the empirical estimates. Then with probability at least , the following holds:
where is the VC-dimension of and is the VC-dimension of . Hence for models to generalize well, it is desirable to have and the average number of samples to be much greater than , i.e., . Similar to Corollary 3.3, this bound only depends on the average number of samples and not the minimum number of samples.
A common approach for model interpolation in practice is to first train the central model and then train the local model separately and find the best interpolation coefficients, i.e.,
We first show that this method of independently finding the local models is sub-optimal with an example. Consider the following discrete distribution estimation problem. Let be the set of distributions over values and let be the set of distributions with support size . For even , let
and for odd, let for all . Let the number of clients be very large and the number of samples per client a constant, say ten. Suppose we consider the log-loss.
The intuition behind this example is that since we have only one example per domain, we can only derive good estimates for the local model for even and we need to estimate the global model jointly from the odd clients. With this approach, the optimal solution is as follows. For even , and . For odd , and the optimal is given by, . If we learn the models separately, observe that, for each client be the empirical estimate and would be . Thus, for any , the algorithm would incur at least a constant loss more than optimal for any for odd clients.
Since training models independently is sub-optimal in certain cases, we propose a joint-optimization algorithm. First observe that the optimization can be rewritten as
Notice that for a fixed the function is convex in both and . But with the minimization over , the function is no longer convex. We propose algorithm Mapper for minimizing the interpolation models. At each round, the algorithm randomly selects a client. It then finds the best local model and interpolation weight for that client using the current value of the global model. It then updates the global model using the local model and the interpolation weight found in the previous step. In practice, dividing into three parts: , , and and using each of these separately for (9), (10), and (11) leads to better performance.
6.1 Synthetic dataset
We first demonstrate the proposed algorithms on a synthetic dataset for density estimation. Let , , and . Let be cross entropy loss and the number of users . We create client distributions as a mixture with a uniform component, a cluster component, and an individual component. The details of the distributions are in Appendix F.1.
We evaluate the algorithms as we vary the number of samples per user. The results are in Table 2. Note that Mapper performs uniformly well across all values of . However, the performance difference between HypCluster and Dapper, depends on the number of samples per user.
In order to understand the effect of clustering, we evaluate various clustering algorithms as a function of when for all clients, , and the results are in Table 1. Since the clients are naturally divided into four clusters, as we increase , the test loss steadily decreases till the number of clusters reaches and then remains constant.
|algorithm||seen acc.||unseen acc.|
|algorithm||seen acc.||unseen acc.|
6.2 EMNIST dataset
We evaluate the proposed algorithms on the federated EMNIST-62 dataset (caldas2018leaf)
provided by TensorFlow Federated (TFF). The dataset consists of 3400 users’ examples that are each labeled as one of 62 classes (lower and upper case letters and digits). The original TFF dataset is split only into train and test, so we further split train into train and eval such that all users have at least one example in each split. Additionally, within each split, we use the first 2500 users to train the global or clustered models and leave the remaining 900 as new unseen clients. The unseen clients do not participate in central model training and are reserved for evaluation only. The reported metrics are uniformly averaged across all clients similar to previous works(jiang2019improving).
For model architecture, we use a two-layer convolutional neural net with hyper-parameters tuned using the eval dataset. We refer to Appendix F.2 for more details. We train the model for 1000 communication rounds with 20 clients per round and use server side momentum. One can use different optimizers as proposed by jiang2019improving. Evaluating the combined effect of our approach and adaptive optimizers remains an interesting open direction. For Finetune and Dapper, we use the best baseline model as the pretrained starting global model since its training is independent of client fine-tuning. For Mapper we use the same model architecture for both local and global models and for the ease of training, at each optimization step, we initialize the local model using the parameters of the global model at that step.
We first split the seen and unseen clients using the original ordering. For this case, Table 3 reports the accuracy of each algorithm on the seen and unseen client test data averaged over 20 trials. However, there is a distinct difference in seen and unseen accuracy in Baseline, which possibly indicates a natural ordering of clients. HypCluster further supports this as the best trials have two clusters and the unseen clients all map to only one. This experiment models the scenario where there is a distribution shift over the clients.
We then randomly shuffle the client before splitting into seen and unseen. The results for this case are in Table 4. After shuffling the clients, the Baseline seen and unseen accuracy is much closer and the client cluster distribution in HypCluster is much more balanced.
We observe that in both the shuffled and unshuffled cases, Mapper performs the best in seen and unseen accuracy, followed by Dapper, Finetune, and HypCluster respectively. Additionally, all three novel approaches provide better and more balanced performance compared to Baseline. This is especially pronounced in the unseen clients for the unshuffled scenarios.
We presented a systematic learning-theoretic study of personalization in learning and proposed and analyzed three algorithms: user clustering, data interpolation, and model interpolation. For all three approaches, we provided learning theoretic guarantees and efficient algorithms. Finally, we empirically demonstrated the usefulness of the proposed approaches on synthetic and EMNIST datasets.
Authors thank Rajiv Mathews, Brendan Mcmahan, and Ke Wu for helpful comments and discussions.
Appendix A Federated learning
FL was introduced by mcmahan2017communication as an efficient method for training models in a distributed way. They proposed a new communication-efficient optimization algorithm called FedAvg. They also showed that the training procedure provides additional privacy benefits. The introduction of FL has given rise to several interesting research problems, including the design of more efficient communication strategies (konevcny2016federated; konecny2016federated2; suresh2017distributed; stich2018local; karimireddy2019scaffold), the study of lower bounds for parallel stochastic optimization with a dependency graph (woodworth2018graph), devising efficient distributed optimization methods benefiting from differential privacy guarantees (agarwal2018cpsgd), stochastic optimization solutions for the agnostic formulation (mohri2019agnostic), and incorporating cryptographic techniques (bonawitz2017practical), see (li2019federatedsurvey; kairouz2019advances) for an in-depth survey of recent work in FL.
Federated learning often results in improved performance, as reported in several learning problems, including next word prediction (hard2018federated; yang2018applied), vocabulary estimation (chen2019federated), emoji prediction (ramaswamy2019federated), decoder models (chen2019federatedb), low latency vehicle-to-vehicle communication (samarakoon2018federated), and predictive models in health (brisimi2018federated).
Appendix B Agnostic global model
Instead of assigning weights proportional to the number of samples as in the uniform global model, we can weight them according to any . For example, instead of uniform sample weights, we can weight clients uniformly corresponding to , for all . Let denote the -weighted empirical distribution and let be the minimizer of loss over . Instead of the uniform global model described in the previous section, we can use the agnostic loss, where we minimize the maximum loss over a set of distributions. Let . Agnostic loss is given by
Let be the minimizer. Let . Let be an -cover of . Let denote the empirical distribution of samples
. The skewness between the distributionsand is defined as where . With these definitions, the generalization guarantee of (mohri2019agnostic, Theorem 2) for client one can be expressed as follows:
where is the mixture weight where the trained model has the highest loss. Hence, this approach would personalize well for hard distributions and can be considered as a step towards ensuring that models work for all distributions. In this work, we show that training a different model for each client would significantly improve the model performance.
Appendix C Proofs for user clustering
c.1 Proof of Lemma 3.3
where the inequality follows by observing that , by the definition of .
c.2 Proof of Theorem 3.3
For any set of real numbers and , observe that
We first prove the theorem for one side. Let be a mapping from clients to clusters. Applying the above result yields,
Since changing one sample changes the above function by at most , for a given , by the McDiarmid’s inequality, with probability at least , the following holds:
The number of possible functions is . Hence, by the union bound, for all , with probability at least , the following holds: