Variational Federated Multi-Task Learning

06/14/2019 ∙ by Luca Corinzia, et al. ∙ ETH Zurich 0

In classical federated learning a central server coordinates the training of a single model on a massively distributed network of devices. This setting can be naturally extended to a multi-task learning framework, to handle real-world federated datasets that typically show strong non-IID data distributions among devices. Even though federated multi-task learning has been shown to be an effective paradigm for real world datasets, it has been applied only to convex models. In this work we introduce VIRTUAL, an algorithm for federated multi-task learning with non-convex models. In VIRTUAL the federated network of the server and the clients is treated as a star-shaped Bayesian network, and learning is performed on the network using approximated variational inference. We show that this method is effective on real-world federated datasets, outperforming the current state-of-the-art for federated learning.



There are no comments yet.


page 1

page 2

page 3

page 4

Code Repositories

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Large scale networks of remote devices like mobile-phones, wearables, smart-homes, self-driving cars and other IoT devices are becoming a significant source of data to train statistical models. This has generated growing interest to develop machine learning paradigms that can take into account distributed data-structure, despite of the several challenges arising in this setting:


Data generated by remote devices is often privacy-sensitive and its centralized collection and storage is governed by data protection regulations (e.g GDPR voigt2017eu and the Consumer Privacy Bill of Rights house2012consumer). Learning paradigms that do not access user-data directly are hence desired.


Remote devices in these networks have typically important storage and computational capacity constraints, limiting the complexity and the size of the model that can be used. Moreover, the communication of information between devices or between devices and a central server mostly happens on wireless networks and hence communication cost can become a significant bottleneck of the learning process.


The devices of the network typically generate samples with different user-dependent probability distributions, making the setting in general strongly non-IID. While it is a challenge to achieve high statistical accuracy for classical federated and distributed algorithms in this setting, a multi-task learning (MTL) approach can tackle heterogeneous data in a more natural way. Every device of the network requires a task-specific model, tailored for its own data distribution, to boost the performance of each individual task.

Federated learning (FL) mcmahan2017communication has emerged as the learning paradigm to address the scenario of learning models on private distributed data sources. It assumes a federation of devices called clients that both collect the data and carry out an optimization routine, and a server that coordinates the learning by receiving and sending updates from and to the clients. This paradigm has been applied successfully in many real world cases, e.g to train smart keyboards in commercial mobile devices yang2018applied and to train privacy-preserving recommendation systems ammad2019federated. Federated Averaging (FedAvg) mcmahan2017communication; nilsson2018performance is the state-of-the-art for federated learning with non-convex

models and requires all the clients to share the same model. Hence it does not address the statistical challenge of strongly skewed data distributions, and while it has been shown to work well in practice for a range of (non-federated) real world datasets, it performs poorly in heterogeneous scenarios

mcmahan2017communication. We address this problem introducing VIRTUAL (VarIational fedeRaTed mUlti tAsk Learning), a new framework for federated MTL. In VIRTUAL

, the central server and the clients form a Bayesian network and the inference is performed using variational methods. Every client has a task specific model that benefits from the server model in a transfer learning fashion with

lateral connections. Hence a part of the parameters are shared between all clients, and another part is private and tuned separately. The server maintains a posterior distribution that represent the plausibility of the shared parameters. In one step of the algorithm, the posterior is communicated to the clients before the training starts, while during training the clients update the posterior given the likelihood of their local data. Finally, the posterior update is sent back to the central server.


Our main contributions are threefold: (i) We address for the first time the problem of federated MTL for generic non convex models, and we propose VIRTUAL, an algorithm to perform federated training with strongly non-IID client data distributions. (ii) We perform extensive experimental evaluation of VIRTUAL on real world federated datasets, showing that it outperforms the current state-of-the-art in FL and (iii) we frame the federate MTL problem as an inference problem in a Bayesian network bridging the frameworks of federated and transfer/continuous learning, which opens the door to a new class of application-specific federated algorithms.

2 The Virtual algorithm

In FL, clients are associated with datasets , where is in general generated by a client dependent probability distribution function (pdf) and only accessible by the respective client. It is natural to fit different models, one for each dataset, enforcing a relationship between models using parameter sharing collobert2008unified

. This approach has been investigated extensively, and it has been shown to boost effective sample size and performance in MLT for neural networks


2.1 The Bayesian network

Assume a star-shaped Bayesian network with a server with model parameters , as well as clients with model parameters . Assume that every client is a discriminative model distribution over the input given by (a naive extension of the work could consider also generative models). Each dataset is private to client , hence it is not accessible to any other client or the central server , and has a likelihood that factorizes as . Following a Bayesian approach, we assume a prior distribution over all the network parameters . The posterior distribution over all parameters, given all datasets reads then


where we enforce that client-data is conditionally independent given server and client parameters, , and a factorization of the prior as . The Bayesian network is illustrated in Figure 0(a).

(a) Baysian network

(b) Approximated variational posterior model
Figure 1: Graphical models that describe the VIRTUAL framework for federated learning. The plates represent replicates. In both figures the outer plate replicates client over the total number of clients , while the inner plate replicates sample index over the total number of samples per client . Shadowed nodes represent observed variables and non-shadowed nodes represent latent variables. (a) Solid lines denote the discriminative model . (b) Graphical model of the approximated variational posterior. Dashed lines denote (deterministic) dependencies in the approximated variational posterior while dotted lines denote stochastic dependencies. Here we indicate as , and , the collection of all Gaussian parameters of server and client .

2.2 The optimization procedure

The posterior given in Equation 1 is in general intractable and hence we have to rely on an approximation inference scheme (e.g. variational inference, sampling, expectation propagation bishop2006pattern). Here we propose an expectation propagation (EP) like approximation algorithm minka2001expectation that has been shown to be effective and to outperform other methods when applied in the continual learning (CL) setting bui2018partitioned; nguyen2017variational. Let us denote the collection of all the client parameters by . Then we define a proxy posterior distribution that factorizes into a server and a client contribution for every client as


The fully factorization of both server and client parameters allows us in the following to perform a client update that is independent from other clients, and to perform a server update in the form of an aggregated posterior that preserves the privacy.

Given a factorization of this kind, the general EP algorithm refines one factor at each step. It first computes a refined posterior distribution where the refining factor of the proxy is replaced with the respective factor in the true posterior distribution. It then performs the update minimizing the Kullback-Leibler (KL) divergence between the full proxy posterior distribution and the refined posterior. The optimization to be performed for our particular Bayesian network and factorization is given by the following.

Proposition 1.

Assuming that at step the factor is refined, then the proxi pdf and are found minimizing the variational free energy function


where is the new posterior over server parameters.


At step the global posterior for server parameters is and analogously the client parameters distribution reads . Then the EP-like update for the model described is given by minimizing the following KL divergence w.r.t and

where the second equality comes from the normalization of client and server pdfs and from Bayes rule . Notice also that because of the factorization in Equation 2, and hence Equation 3 is proved. ∎

We can see that the variational free energy in Equation 3 decomposes naturally into two parts. The terms that involve the client parameters correspond to the variational free energy terms of Bayes by backprop blundell2015weight. Note that, except for the natural complexity cost given by the second KL term, no additional regularization is applied on the client parameters, that can hence be trained efficiently and network agnostic. The terms that involve the server posterior are instead the likelihood cost and the first KL term. This regularization restricts the server to learn an overall posterior that does not drift from a posterior distribution obtained replacing the current refining factor by the prior. This constraint effectively forces the server to progress in a CL fashion nguyen2017variational, learning from new federated datasets and avoiding catastrophic forgetting of the ones already seen.

The whole free energy in Equation 3

can be optimized using gradient descent and unbiased Monte Carlo estimates of the gradients with reparametrization trick

kingma2013auto. For simplicity we use a Gaussian mean-field approximation of the posterior, hence for server and client parameters the factorization reads respectively and , where and are the total number of parameters of the server and of the client networks. A depiction of the full graphical model of the approximated variational posterior is given in Figure 0(b). The pseudo-code of VIRTUAL is described in Algorithm 1. Notice that similarly to FedAvg, privacy is preserved since at any time the server get access only to the overall posterior distribution and never to the individual factor and , that are visible only to the respective client.

1:datasets , number of total refinements, priors
3:initialize all the pdfs and
5:for   do
6:     choose a client to be refined
7:     client computes the new server prior
8:      joint optimization of the variational free energy in eq. 3 on client
9:     client computes the new server posterior
10:     client sends to the server
11:     server sends its new posterior to all clients.
Algorithm 1 VIRTUAL

We can further notice an interesting similarity of the VIRTUAL algorithm to the Progress&Compress method for CL introduced in schwarz2018progress

, where a similar free energy is obtained heuristically by composing CL regularization terms and distillation cost functions


2.3 The server-client architecture

The model for has to be chosen in order to allow the client model to reuse knowledge extracted from the server and enlarging the effective sample size available for client . In this work we use a model inspired by Progressive Networks rusu2016progressive

that has been proved effective in CL, with lateral connections between server and client layers. Lateral connections feed the activations of the server at a given layer as an additive input to the client layer. Formally, for a multi-layer perceptron (MLP) model we can denote by

the activation of the server layer at a depth , then the subsequent activation of client , reads


where are weight matrices,

is a weight vector that act as a gate,

denotes the element-wise multiplication operator and and are trainable biases.







Client parameters

Server parameters

Figure 2: Depiction of the a Server-Client multi-layer perceptron model with two hidden layers and lateral connections. The server parameters on the right plate are shared between all the clients, while the client parameters on the left plate are private. See text for full details.

The client-server connection architecture for a 2 hidden-layers MLP is depicted in Figure 2. Note as an example that for the architecture used in Figure 2, the client parameter vector reads .

3 Related Work

We now provide a brief survey of work in the area of distributed/federated learning and of transfer/continuous learning, in light of the problem at hand described in Section 1 and of the tools used in deriving VIRTUAL.

Distributed and federated Learning

Distributed learning is a learning paradigm for which the optimization of a generic model is distributed in a parallel computing environment with centralized data mcdonald2010distributed

. Early work on this paradigm propose various learning strategies that require iterative averaging of locally trained models, typically using Stochastic Gradient Descent (

SGD) steps in the local optimization routine mcdonald2010distributed; povey2014parallel; zhang2015deep; dean2012large. These works typically consider the distributed learning to be set in a computational cluster, so with few computing devices, fast and reliable communication between devices and centralized unbalanced datasets. FL mcmahan2017communication eliminates all these constraints and it is framed as a paradigm that encompasses the new challenges and desiderata listed in Section 1. FedAvg mcmahan2017communication; konevcny2016federated has been proposed as straightforward heuristic for the FL. At each step of the algorithm a subset of the online clients is selected, and these are then updated locally using SGD. The models are then averaged to form the model at the next step, which is maintained in the server and transmitted back to all the clients. Despite working well in practice, it has been shown that the performance of FedAvg can degrade significantly for skewed non-IID data mcmahan2017communication; zhao2018federated.

Some heuristics have been proposed recently to solve the statistical challenges of FL. In particular recently it has been proposed to share part of the client-generated data zhao2018federated or a server-trained generative model jeong2018communication to the whole network of clients. These solutions are however questionable since they require significant communication effort, and do not comply with the standard privacy requirements of FL. Another solution for this problem has been proposed in sahu2018convergence where the authors extend FedAvg into FedProx

, an algorithm that prescribes clients to optimize the local loss function, further regularized with an quadratic penalty anchored on the weights of the previous step. Despite showing improvements on the

FedAvg algorithm for very data heterogeneous settings, the method is strongly inspired by early works on continuous and transfer learning (see e.g. see for example Elastic Weight consolidation (EWC) in kirkpatrick2017overcoming; zenke2017continual and the literature review in the next paragraph) and hence can be further refined.

The first contribution to highlight the possibility of naturally embedding FL in the MTL framework has been reported by MOCHA smith2017federated, that extends some early work on distributed MTL-like CoCoA and variations smith2018cocoa; jaggi2014communication; ma2015adding. In this work a federated primal-dual optimization algorithm is derived for convex models with MTL regularization, and it is shown for the first time that the MTL framework can enhance the model performance, with the MTL model outperforming global models (trained with centralized data) and local models as well, on real world federated datasets.

Transfer and continuous learning

The transfer of knowledge in neural network, from one task to another, has been used extensively and with great success since the pioneering work in hinton2006reducing of transferring information from a generative to a discriminative model using fine tuning. The application of this straightforward procedure is however difficult to apply in scenarios where multiple tasks from which to transfer from are available. Indeed a good target performance can be obtained only with a priori knowledge of task similarity, that is usually not known, while learning of sequential tasks causes knowledge of previous tasks to be abruptly erased from the network in what has been called catastrophic forgetting french1999catastrophic.

Many methods have been introduced to overcome catastrophic forgetting, and to enable models to learn multiple task sequentially retaining a good overall performance, and transferring effectively to new tasks. Many early works proposed different regularization terms of the loss function anchored to the previous solution in order to get new solutions that generalize well on old tasks kirkpatrick2017overcoming; zenke2017continual. These methods have been first introduced as heuristics, but have been found to be applications of well-known inference algorithms like Laplace Propagation smola2003laplace and Streaming Variational Bayes broderick2013streaming, which led to further generalizations lee2017overcoming; geyer2018transfer. New approaches focused on other components, like architecture innovations, introducing lateral connections that allow new models to reuse knowledge from previously trained models with layer-wise adaptors rusu2016progressive; schwarz2018progress, and memory enhanced models with generative networks wu2018memory; yoon2019oracle

. A recently introduced online Bayesian inference approach

nguyen2017variational served as inspiration for our work. It frames the continual learning paradigm in the Bayesian inference framework, establishing a posterior distribution over network parameters that is updated for any new task in light of the new likelihood function. It has been showed that this method outperformed all previously known method for CL.

4 Experiments

In this section we present an empirical evaluation of the performance of VIRTUAL on several real world federated datasets.

Dataset Number of clients Number of classes Total samples Samples per client
mean std
MNIST 10 10 60000 6000 0
P-MNIST 10 10 60000 6000 0
FEMNIST 10 62 5560 556 54
VSN 23 2 68532 3115 559
HAR 30 6 15762 543 56
Table 1: Statistics of the datasets used in the experiments.

4.1 Dataset description

MNIST: The classic MNIST dataset lecun1998gradient, randomly split into 10 different sections. Note that this dataset has not been generated in a real federated setting having IID client samples.

Permuted MNIST (P-MNIST): The MNIST dataset is randomly split into 10 sections, and a random permutation of the pixels is applied on every section. First used in zenke2017continual in the context of CL.

FEMNIST: This dataset consists of a federated version of the EMNIST dataset cohen2017emnist, maintained by the LEAF project caldas2018leaf Different clients correspond to different writers. We sub-sample 10 random writers from those with at least 300 samples.

Vehicle Sensors Network (VSN)111

: A network of 23 different sensors (including seismic, acoustic and passive infra-red sensors) are place around a road segment in order to classify vehicles driving through.

duarte2004vehicle. The raw signal is featurized in the original paper into 50 acoustic and 50 seismic features. We consider every sensor as a client and perform binary classification of assault amphibious vehicles and dragon wagon vehicles.

Human Activity Recognition (HAR)222 Recordings of 30 subjects performing daily activities are collected using a waist-mounted smart-phone with inertial sensors. The raw signal is divided into windows and featurized into a 561-length vector anguita2013public. Every individual correspond to a different client and we perform classification of 6 different activities (e.g. sitting, walking).

A comprehensive description of the statistics of the datasets used is available in Table 1.

4.2 Experiment setting

All networks employed are multilayer perceptrons (MLP) with two hidden dense flipout layers


with 100 units and ReLU activation functions. Dropout with parameter 0.3 is used before every dense layer. The Monte Carlo estimate of the gradient is performed in all the experiment using 20 samples. In all the experiments,

VIRTUAL has been evaluated training the clients in an incremental fashion, in fixed order, with 3 refinements per client.

Using the notation of the original paper mcmahan2017communication, FedAvg has been evaluated in all experiments with a fraction of updated clients per round

and a number of epochs per round

, that has been show to guarantee converge in all scenarios mcmahan2017communication; caldas2018leaf. The total number of rounds is chosen such that every client is trained on average for a number of epochs that is equal to that of the VIRTUAL experiments.

We further compare our algorithm with Local and Global baselines that are obtained respectively training one separate model per client, and training one single model on centralized data. Global does not comply with the generic federated setting and is reported only as a comparison.

Implementation of VIRTUAL is based on tensorflow abadi2016tensorflow and tensorflow distributions dillon2017tensorflow packages. 333Code and full details on the hyper-parameter used are available at The implementation of FedAvg is taken from the Leaf benchmark for federated learning settings caldas2018leaf.

4.3 Results

In Table 2

we show the advantages given by the multi-task learning framework in the federated setting. In the table we measure the average categorical accuracy over all tasks of the respective dataset. Every experiment has been repeated over 5 random 25% train test splits, and we report mean and standard deviation over the runs.

We can see that the performance of all the algorithms strongly depend on the degree of heterogeneity of the dataset considered. In particular the Global baseline is among the top performing methods only on the MNIST dataset, that has been generated in IID fashion. Strongly non-IID scenarios are depicted by the P-MNIST and VSN datasets, that have significantly dissimilar feature spaces among clients (P-MNIST features are given by random permutations, while VSN encompasses a wide spectrum of different sensors). In these scenarios the performance of both Global and FedAvg degrades while Local models enjoy high performances, being tuned to the specific data distribution.

We can see that VIRTUAL maintain the top performance in the whole sprectrum of federated scenarios, being on par with Global and FedAvg on IID datasets, and with Local models on strongly non-IID settings. It also outperforms other methods on FEMNIST and HAR, that are datasets that best represent the multi-task learning setting, as they encompass different users gathering data in very similar but distinct conditions.

Global 0.96780.0007 0.440.02 0.9190.004 0.9260.005 0.7970.006
Local 0.95110.0020 0.510.01 0.9500.003 0.9600.003 0.9400.001
FedAvg 0.96750.0004 0.450.05 0.9050.002 0.9160.007 0.9440.004
Virtual 0.96660.0017 0.560.01 0.9490.001 0.9600.002 0.9440.001
Table 2: Multi-task average test accuracy. Mean and standard deviation over 5 train-test splits.

5 Conclusion

In this work we introduced VIRTUAL, an algorithm for federated learning that tackles the well known statistical challenges of the federated learning framework using a multi-task setting. We consider the federation of central server and clients as a Bayesian network and perform training using approximated variational inference. The algorithm naturally comply with the federated setting desiderata, giving access to the central server only to an aggregated parameter update in the form of an overall posterior distribution over shared parameters. The algorithm is shown to outperform the state-of-the-art in non-IID real world federated datasets, and to be on par with the state-of-the-art in other scenarios.

One possible direction for further developments is to consider synchronous updates of multiple clients studying empirically the effect of using outdated priors during client training or theoretically developing a new Bayesian model of synchronous updates. Another interesting direction is the exploration of other design choices. Indeed the general method can be tuned for a particular application by modifying e.g. the architecture of the lateral connections between devices (Block-Modular NN terekhov2015knowledge, NinN architecture lin2013network), the topology of the Bayesian network (star shape, hierarchical etc.), the choice of the variational inference algorithm. Finally, it is possible to study VIRTUAL under memory constraints, for which on optimal strategy can store chunks of data for further refinements or discard them, in the line of coresets theory bachem2015coresets.