Federated Learning with Matched Averaging

02/15/2020 ∙ by Hongyi Wang, et al. ∙ University of Michigan University of Wisconsin-Madison ibm 0

Federated learning allows edge devices to collaboratively learn a shared model while keeping the training data on device, decoupling the ability to do model training from the need to store the data in the cloud. We propose Federated matched averaging (FedMA) algorithm designed for federated learning of modern neural network architectures e.g. convolutional neural networks (CNNs) and LSTMs. FedMA constructs the shared global model in a layer-wise manner by matching and averaging hidden elements (i.e. channels for convolution layers; hidden states for LSTM; neurons for fully connected layers) with similar feature extraction signatures. Our experiments indicate that FedMA not only outperforms popular state-of-the-art federated learning algorithms on deep CNN and LSTM architectures trained on real world datasets, but also reduces the overall communication burden.



There are no comments yet.


page 9

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Edge devices such as mobile phones, sensors in a sensor network, or vehicles have access to a wealth of data. However, due to data privacy concerns, network bandwidth limitation, and device availability, it’s impractical to gather all the data from the edge devices at the data center and conduct centralized training. To address these concerns, federated learning is emerging (McMahan et al., 2017; Li et al., 2019; Smith et al., 2017; Caldas et al., 2018; Bonawitz et al., 2019; Kairouz et al., 2019) as a paradigm that allows local clients to collaboratively train a shared global model.

The typical federated learning paradigm involves two stages: (i) clients train models with their local datasets independently, and (ii) the data center gathers the locally trained models and aggregates them to obtain a shared global model. One of the standard aggregation methods is FedAvg (McMahan et al., 2017) where parameters of local models are averaged element-wise with weights proportional to sizes of the client datasets. FedProx (Sahu et al., 2018) adds a proximal term to the client cost functions, thereby limiting the impact of local updates by keeping them close to the global model. Agnostic Federated Learning (AFL) (Mohri et al., 2019), as another variant of FedAvg, optimizes a centralized distribution that is a mixture of the client distributions.

One shortcoming of FedAvg is coordinate-wise averaging of weights may have drastic detrimental effects on the performance of the averaged model and adds significantly to the communication burden. This issue arises due to the permutation invariance of neural network (NN) parameters, i.e. for any given NN, there are many variants of it that only differ in the ordering of parameters. Probabilistic Federated Neural Matching (PFNM) (Yurochkin et al., 2019a) addresses this problem by matching the neurons of client NNs before averaging them. PFNM further utilizes Bayesian nonparametric methods to adapt to global model size and to heterogeneity in the data. As a result, PFNM has better performance and communication efficiency than FedAvg. Unfortunately, the method only works with simple architectures (e.g. fully connected feedforward networks).

Our contribution

In this work, we demonstrate how PFNM can be applied to CNNs and LSTMs, but we find that it only gives very minor improvements over weight averaging. To address this issue, we propose Federated Matched Averaging (FedMA), a new layers-wise federated learning algorithm for modern CNNs and LSTMs that appeal to Bayesian nonparametric methods to adapt to heterogeniety in the data. We show empirically that FedMA not only reduces the communcations burden, but also outperforms state-of-the-art federated learning algorithms.

(a) Homogeneous
(b) Heterogeneous
Figure 1: Comparison among various federated learning methods with limited number of communications on LeNet trained on MNIST; VGG-9 trained on CIFAR-10 dataset; LSTM trained on Shakespeare dataset over: (a) homogeneous data partition (b) heterogeneous data partition.

2 Federated Matched Averaging of neural networks

In this section we will discuss permutation invariance classes of prominent neural network architectures and establish the appropriate notion of averaging in the parameter space of NNs. We will begin with the simplest case of a single hidden layer fully connected network, moving on to deep architectures and, finally, convolutional and recurrent architectures.

Permutation invariance of fully connected architectures

A basic fully connected (FC) NN can be formulated as (without loss of generality, biases are omitted to simplify notation), where is the non-linearity (applied entry-wise). Expanding the preceding expression , where and denote th row and column correspondingly and is the number of hidden units. Summation is a permutation invariant operation, hence for any there are practically equivalent parametrizations if this basic NN. It is then more appropriate to write


Recall that permutation matrix is an orthogonal matrix that acts on rows when applied on the left and on columns when applied on the right. Suppose

are optimal weights, then weights obtained from training on two homogeneous datasets are and

. It is now easy to see why naive averaging in the parameter space is not appropriate: with high probability

and for any . To meaningfully average neural networks in the weight space we should first undo the permutation .

2.1 Matched averaging formulation

In this section we formulate practical notion of parameter averaging under the permutation invariance. Let be th neuron learned on dataset (i.e. th column of in the previous example), denote the th neuron in the global model, and be an appropriate similarity function between a pair of neurons. Solution to the following optimization problem are the required permutations:


Then and given weights provided by clients, we compute the federated neural network weights and . We refer to this approach as matched averaging due to relation of equation 2 to the maximum bipartite matching problem. We note that if

is squared Euclidean distance, we recover objective function similar to k-means clustering, however it has additional constraints on the “cluster assignments”

necessary to ensure that they form permutation matrices. In a special case where all local neural networks and the global model are assumed to have same number of hidden neurons, solving equation 2 is equivalent to finding a Wasserstein barycenter (Agueh and Carlier, 2011) of the empirical distributions over the weights of local neural networks. Concurrent work of Singh and Jaggi (2019) explores the Wasserstein barycenter variant of equation 2.

Solving matched averaging

Objective function in equation 2 can be optimized using an iterative procedure: applying the Hungarian matching algorithm (Kuhn, 1955) to find permutation corresponding to dataset , holding other permutations fixed and iterating over the datasets. Important aspect of Federated Learning that we should consider here is the data heterogeneity. Every client will learn a collection of feature extractors, i.e. neural network weights, representing their individual data modality. As a consequence, feature extractors learned across clients may overlap only partially. To account for this we allow the size of the global model to be an unknown variable satisfying where is the number of neurons learned from dataset . That is, global model is at least as big as the largest of the local models and at most as big as the concatenation of all the local models. Next we show that matched averaging with adaptive global model size remains amendable to iterative Hungarian algorithm with a special cost.

At each iteration, given current estimates of

, we find a corresponding global model (this is typically a closed-form expression or a simple optimization sub-problem, e.g. a mean if is Euclidean) and then we will use Hungarian algorithm to match this global model to neurons of the dataset to obtain a new global model with neurons. Due to data heterogeneity, local model may have neurons not present in the global model built from other local models, therefore we want to avoid “poor” matches by saying that if the optimal match has cost larger than some threshold value , instead of matching we create a new global neuron from the corresponding local one. We also want a modest size global model and therefore penalize its size with some increasing function . This intuition is formalized in the following extended maximum bipartite matching formulation:


The size of the new global model is then . We note some technical details: after the optimization is done, each corresponding is of size and is not a permutation matrix in a classical sense when . Its functionality is however similar: taking matrix product with a weight matrix

implies permuting the weights to align with weights learned on the other datasets and padding with “dummy” neurons having zero weights (alternatively we can pad weights

first and complete with missing rows to recover a proper permutation matrix). This “dummy” neurons should also be discounted when taking average. Without loss of generality, in the subsequent presentation we will ignore these technicalities to simplify the notation.

To complete the matched averaging optimization procedure it remains to specify similarity , threshold and model size penalty . Yurochkin et al. (2019b, a, c) studied fusion, i.e. aggregation, of model parameters in a range of applications. The most relevant to our setting is Probabilistic Federated Neural Matching (PFNM) (Yurochkin et al., 2019a). They arrived at a special case of equation 3

to compute maximum a posteriori estimate (MAP) of their Bayesian nonparametric model based on the Beta-Bernoulli process (BBP)

(Thibaux and Jordan, 2007), where similarity

is the corresponding posterior probability of

th client neuron generated from a Gaussian with mean , and and are guided by the Indian Buffet Process prior (Ghahramani and Griffiths, 2005)

. Instead of making heuristic choices, this formulation provides a model-based specification of equation 

3. We refer to a procedure for solving equation 2 with the setup from Yurochkin et al. (2019a) as BBP-MAP. We note that their PFNM is only applicable to fully connected architectures limiting its practicality. Our matched averaging perspective allows to formulate averaging of widely used architectures such as CNNs and LSTMs as instances of equation 2 and utilize the BBP-MAP as a solver.

2.2 Permutation invariance of key architectures

Before moving onto the convolutional and recurrent architectures, we discuss permutation invariance in deep fully connected networks and corresponding matched averaging approach. We will utilize this as a building block for handling LSTMs and CNN architectures such as VGG (Simonyan and Zisserman, 2014) widely used in practice.

Permutation invariance of deep FCs

We extend equation 1 to recursively define deep FC network:


where is the layer index, is identity indicating non-ambiguity in the ordering of input features and is identity for the same in output classes. Conventionally is any non-linearity except for

where it is the identity function (or softmax if we want probabilities instead of logits). When

, we recover a single hidden layer variant from equation 1. To perform matched averaging of deep FCs obtained from

clients we need to find permutations for every layer of every client. Unfortunately, permutations within any consecutive pair of intermediate layers are coupled leading to a NP-hard combinatorial optimization problem. Instead we consider recursive (in layers) matched averaging formulation. Suppose we have

, then plugging into equation 2 we find and move onto next layer. The recursion base for this procedure is , which we know is an identity permutation for any .

Permutation invariance of CNNs

The key observation in understanding permutation invariance of CNNs is that instead of neurons, channels define the invariance. To be more concrete, let define convolutional operation on input with weights , where , are the numbers of input/output channels and are the width and height of the filters. Applying any permutation to the output dimension of the weights and then same permutation to the input channel dimension of the subsequent layer will not change the corresponding CNN’s forward pass. Analogous to equation 4 we can write:


Note that this formulation permits pooling operations as those act within channels. To apply matched averaging for the th CNN layer we form inputs to equation 2 as , , where is the flattened dimension of . This result can be alternatively derived taking the im2col perspective. Similar to FCs, we can recursively perform matched averaging on deep CNNs. The immediate consequence of our result is the extension of PFNM (Yurochkin et al., 2019a) to CNNs. Empirically, see Figure 1, we found that this extension performs well on MNIST with a simpler CNN architecture such as LeNet (LeCun et al., 1998) (4 layers) and significantly outperforms coordinate-wise weight averaging (1 round FedAvg). However, it breaks down for more complex architecture, e.g. VGG-9 (Simonyan and Zisserman, 2014) (9 layers), needed to obtain good quality prediction on a more challenging CIFAR-10.

Permutation invariance of LSTMs

Permutation invariance in the recurrent architectures is associated with the ordering of the hidden states. At a first glance it appears similar to fully connected architecture, however the important difference is associated with the permutation invariance of the hidden-to-hidden weights , where is the number of hidden states. In particular, permutation of the hidden states affects both rows and columns of . Consider a basic RNN , where are the input-to-hidden weights. To account for the permutation invariance of the hidden states, we notice that dimensions of should be permuted in the same way for any , hence


To match RNNs, the basic sub-problem is to align hidden-to-hidden weights of two clients with Euclidean similarity, which requires minimizing over permutations . This is a quadratic assignment problem known to be NP-hard (Loiola et al., 2007). Fortunately, the same permutation appears in an already familiar context of input-to-hidden matching of . Our matched averaging RNN solution is to utilize equation 2 plugging-in input-to-hidden weights to find . Then federated hidden-to-hidden weights are computed as and input-to-hidden weights are computed as before. We note that Gromov-Wasserstein distance (Gromov, 2007) from the optimal transport literature corresponds to a similar quadratic assignment problem. It may be possible to incorporate hidden-to-hidden weights into the matching algorithm by exploring connections to approximate algorithms for computing Gromov-Wasserstein barycenter (Peyré et al., 2016). We leave this possibility for future work.

To finalize matched averaging of LSTMs, we discuss several specifics of the architecture. LSTMs have multiple cell states, each having its individual hidden-to-hidden and input-to-hidden weights. In out matched averaging we stack input-to-hidden weights into weight matrix ( is the number of cell states; is input dimension and is the number of hidden states) when computing the permutation matrices and then average all weights as described previously. LSTMs also often have an embedding layer, which we handle like a fully connected layer. Finally, we process deep LSTMs in the recursive manner similar to deep FCs.

2.3 Federated Matched Averaging (FedMA) algorithm

Defining the permutation invariance classes of CNNs and LSTMs allows us to extend PFNM (Yurochkin et al., 2019a) to these architectures, however our empirical study in Figure 1 demonstrates that such extension fails on deep architectures necessary to solve more complex tasks. Our results suggest that recursive handling of layers with matched averaging may entail poor overall solution. To alleviate this problem and utilize the strength of matched averaging on “shallow” architectures, we propose the following layer-wise matching scheme. First, data center gathers only the weights of the first layers from the clients and performs one-layer matching described previously to obtain the first layer weights of the federated model. Data center then broadcasts these weights to the clients, which proceed to train all consecutive layers on their datasets, keeping the matched federated layers frozen. This procedure is then repeated up to the last layer for which we conduct a weighted averaging based on the class proportions of data points per client. We summarize our Federated Matched Averaging (FedMA) in Algorithm 1. The FedMA approach requires communication rounds equal to the number of layers in a network. In Figure 1 we show that with layer-wise matching FedMA performs well on the deeper VGG-9 CNN as well as LSTMs. In the more challenging heterogeneous setting, FedMA outperforms FedAvg, FedProx trained with same number of communication rounds (4 for LeNet and LSTM and 9 for VGG-9) and other baselines, i.e. client individual CNNs and their ensemble.

Input : local weights of -layer architectures from clients
Output : global weights
while  do
       if  then
             = BBP-MAP() ;
              // call BBP-MAP to solve Eq. 2
             where is fraction of data points with label on worker ;
       end if
      for  do
              // permutate the next-layer weights
             Train with frozen;
       end for
end while
Algorithm 1 Federated Matched Averaging (FedMA)

FedMA with communication

We’ve shown that in the heterogeneous data scenario FedMA outperforms other federated learning approaches, however it still lags in performance behind the entire data training. Of course the entire data training is not possible under the federated learning constraints, but it serves as performance upper bound we should strive to achieve. To further improve the performance of our method, we propose FedMA with communication, where local clients receive the matched global model at the beginning of a new round and reconstruct their local models with the size equal to the original local models (e.g. size of a VGG-9) based on the matching results of the previous round. This procedure allows to keep the size of the global model small in contrast to a naive strategy of utilizing full matched global model as a starting point across clients on every round.

3 Experiments

We present an empirical study of FedMA with communication and compare it with state-of-the-art methods i.e. FedAvg (McMahan et al., 2017) and FedProx (Sahu et al., 2018); analyze the performance under the growing number of clients and visualize the matching behavior of FedMA to study its interpretability. Our experimental studies are conducted over three real world datasets. Summary information about the datasets and associated models can be found in supplement Table 3.

Experimental Setup

We implemented FedMA and the considered baseline methods in PyTorch

(Paszke et al., 2017). We deploy our empirical study under a simulated federated learning environment where we treat one centralized node in the distributed cluster as the data center and the other nodes as local clients. All nodes in our experiments are deployed on p3.2xlarge instances on Amazon EC2. We assume the data center samples all the clients to join the training process for every communication round for simplicity.

For the CIFAR-10 dataset, we use data augmentation (random crops, and flips) and normalize each individual image (details provided in the Supplement). We note that we ignore all batch normalization

(Ioffe and Szegedy, 2015) layers in the VGG architecture and leave it for future work.

For CIFAR-10, we considered two data partition strategies to simulate federated learning scenario: (i) homogeneous partition where each local client has approximately equal proportion of each of the classes; (ii) heterogeneous partition for which number of data points and class proportions are unbalanced. We simulated a heterogeneous partition into clients by sampling and allocating a proportion of the training instances of class to local client . We use the original test set in CIFAR-10 as our global test set for comparing performance of all methods. For the Shakespeare dataset, we treat each speaking role as a client (Caldas et al., 2018) resulting in a natural heterogeneous partition. We preprocess the Shakespeare dataset by filtering out the clients with less than k datapoints and sampling a random subset of clients. We allocate 80% of the data for training and amalgamate the remaining data into a global test set.

Communication Efficiency and Convergence Rate

In this experiment we study performance of FedMA with communication. Our goal is to compare our method to FedAvg and FedProx in terms of the total message size exchanged between data center and clients (in Gigabytes) and the number of communication rounds (recall that completing one FedMA pass requires number of rounds equal to the number of layers in the local models) needed for the global model to achieve good performance on the test data. We also compare to the performance of an ensemble method. We evaluate all methods under the heterogeneous federated learning scenario on CIFAR-10 with clients with VGG-9 local models and on Shakespeare dataset with clients with 1-layer LSTM network. We fix the total rounds of communication allowed for FedMA, FedAvg, and FedProx i.e. 

11 rounds for FedMA and 99/33 rounds for FedAvg and FedProx for the VGG-9/LSTM experiments respectively. We notice that number of local training epochs is a common parameter shared by the three considered methods, we thus tune this parameter (denoted

; comprehensive analysis will be presented in the next experiment) and report the convergence rate under that yields the best final model accuracy over the global test set. We also notice that there is another hyper-parameter in FedProx i.e. the coefficient associated with the proxy term, we also tune the parameter using grid search and report the best we found (0.001) for both VGG-9 and LSTM experiments. FedMA outperforms FedAvg and FedProx in all scenarios (Figure 2) with its advantage especially pronounced when we evaluate convergence as a function of the message size in Figures 2(a) and 2(c). Final performance of all trained models is summarized in Tables 1 and 2.

(a) LSTM, Shakespeare; message size
(b) LSTM, Shakespeare; rounds
(c) VGG-9, CIFAR-10; message size
(d) VGG-9, CIFAR-10; rounds
Figure 2: Convergence rates of various methods in two federated learning scenarios: training VGG-9 on CIFAR-10 with clients and training LSTM on Shakespeare dataset with clients.
Figure 3: The effect of number of local training epochs on various methods.

Effect of local training epochs

As studied in previous work (McMahan et al., 2017; Caldas et al., 2018; Sahu et al., 2018), the number of local training epochs can affect the performance of FedAvg and sometimes lead to divergence. We conduct an experimental study on the effect of over FedAvg, FedProx, and FedMA on VGG-9 trained on CIFAR-10 under heterogeneous setup. The candidate local epochs we consider are . For each of the candidate , we run FedMA for 6 rounds while FedAvg and FedProx for 54 rounds and report the final accuracy that each methods achieves. The result is shown in Figure 3. We observe that training longer benefits FedMA, supporting our assumption that FedMA performs best on local models with higher quality. For FedAvg, longer local training leads to deterioration of the final accuracy, which matches the observations made in the previous literature (McMahan et al., 2017; Caldas et al., 2018; Sahu et al., 2018). FedProx only partially alleviates this problem. The result of this experiment suggests that FedMA is the only method that local clients can use to train their model as long as they want.

Method FedAvg FedProx Ensemble FedMA
Final Accuracy(%) 87.53
Best local epoch() 20 20 N/A 150
Model growth rate
Table 1: Trained models summary for VGG-9 trained on CIFAR-10 as shown in Figure 2
Method FedAvg FedProx Ensemble FedMA
Final Accuracy(%) 49.07
Best local epoch() 2 5 N/A 5
Model growth rate
Table 2: Trained models summary for LSTM trained on Shakespeare as shown in Figure 2

Handling data bias

Real world data often exhibit multimodality within each class, e.g. 

geo-diversity. It has been shown that an observable amerocentric and eurocentric bias is present in the widely used ImageNet dataset

(Shankar et al., 2017; Russakovsky et al., 2015)

. Classifiers trained on such data “learn” these biases and perform poorly on the under-represented domains (modalities) since correlation between the corresponding dominating domain and class can prevent the classifier from learning meaningful relations between features and classes. For example, classifier trained on amerocentric and eurocentric data may learn to associate white color dress with a “bride” class, therefore underperforming on the wedding images taken in countries where wedding traditions are different

(Doshi, 2018).

Figure 4:

Performance on skewed CIFAR-10 dataset.

The data bias scenario is an important aspect of federated learning, however it received little to no attention in the prior federated learning works. In this study we argue that FedMA can handle this type of problem. If we view each domain, e.g. geographic region, as one client, local models will not be affected by the aggregate data biases and learn meaningful relations between features and classes. FedMA can then be used to learn a good global model without biases. We have already demonstrated strong performance of FedMA on federated learning problems with heterogeneous data across clients and this scenario is very similar. To verify this conjecture we conduct the following experiment.

We simulate the skewed domain problem with CIFAR-10 dataset by randomly selecting 5 classes and making 95% training images in those classes to be grayscale. For the remaining 5 we turn only 5% of the corresponding images into grayscale. By doing so, we create 5 grayscale images dominated classes and 5 colored images dominated classes. In the test set, there is half grayscale and half colored images for each class. We anticipate entire data training to pick up the uninformative correlations between greyscale and certain classes, leading to poor test performance without these correlations. In Figure 4 we see that entire data training performs poorly in comparison to the regular (i.e. No Bias) training and testing on CIFAR-10 dataset without any grayscaling. This experiment was motivated by Olga Russakovsky ’s talk at ICML 2019.

Next we compare the federated learning based approaches. We split the images from color dominated classes and grayscale dominated classes into 2 clients. We then conduct FedMA with communication, FedAvg, and FedProx with these 2 clients. FedMA noticeably outperforms the entire data training and other federated learning approach as shown in Figure 4. This result suggests that FedMA may be of interest beyond learning under the federated learning constraints, where entire data training is the performance upper bound, but also to eliminate data biases and outperform entire data training.

We consider two additional approaches to eliminate data bias without the federated learning constraints. One way to alleviate data bias is to selectively collect more data to debias the dataset. In the context of our experiment, this means getting more colored images for grayscale dominated classes and more grayscale images for color dominated classes. We simulate this scenario by simply doing a full data training where each class in both train and test images has equal amount of grayscale and color images. This procedure, Color Balanced, performs well, but selective collection of new data in practice may be expensive or even not possible. Instead of collecting new data, one may consider oversampling from the available data to debias. In Oversampling, we sample the underrepresented domain (via sampling with replacement) to make the proportion of color and grayscale images to be equal for each class (oversampled images are also passed through the data augmentation pipeline, e.g. random flipping and cropping, to further enforce the data diversity). Such procedure may be prone to overfitting the oversampled images and we see that this approach only provides marginal improvement of the model accuracy compared to centralized training over the skewed dataset and performs noticeably worse than FedMA.

(a) Raw input1
(b) FedMA filter 0
(c) Client1, filter 0
(d) Client2, filter 23
(e) FedAvg filter 0
(f) Raw input2
(g) FedMA filter 0
(h) Client 1, filter0
(i) Client 2, filter23
(j) FedAvg filter0
Figure 5: Representations generated by the first convolution layers of locally trained models, FedMA global model and the FedAvg global model.
Figure 6: Data efficiency under the increasing number of clients.

Data efficiency

It is known that deep learning models perform better when more training data is available. However, under the federated learning constraints, data efficiency has not been studied to the best of our knowledge. The challenge here is that when new clients join the federated system, they each bring their own version of the data distribution, which, if not handled properly, may deteriorate the performance despite the growing data size across the clients. To simulate this scenario we first partition the entire training CIFAR-10 dataset into 5 homogeneous pieces. We then partition each homogeneous data piece further into 5 sub-pieces heterogeneously. Using this strategy, we partition the CIFAR-10 training set into 25 heterogeneous small sub-datasets containing approximately 2k points each. We conduct a 5-step experimental study: starting from a randomly selected homogeneous piece consisting of 5 associated heterogeneous sub-pieces, we simulate a 5-client federated learning heterogeneous problem. For each consecutive step, we add one of the remaining homogeneous data pieces consisting of 5 new clients with heterogeneous sub-datasets. Results are presented in Figure

6. Performance of FedMA (with a single pass) improves when new clients are added to the federated learning system, while FedAvg with 9 communication rounds deteriorates.


One of the strengths of FedMA is that it utilizes communication rounds more efficiently than FedAvg. Instead of directly averaging weights element-wise, FedMA identifies matching groups of convolutional filters and then averages them into the global convolutional filters. It’s natural to ask “How does the matched filters look like?”. In Figure 5 we visualize the representations generated by a pair of matched local filters, aggregated global filter, and the filter returned by the FedAvg method over the same input image. Matched filters and the global filter found with FedMA are extracting the same feature of the input image, i.e. filter 0 of client 1 and filter 23 of client 2 are extracting the position of the legs of the horse, and the corresponding matched global filter 0 does the same. For the FedAvg, global filter 0 is the average of filter 0 of client 1 and filter 0 of client 2, which clearly tampers the leg extraction functionality of filter 0 of client 1.

4 Conclusion

We presented Federated Matched Averaging (FedMA), a layer-wise federated learning algorithm designed for modern CNNs and LSTMs architectures that accounts for permutation invariance of the neurons and permits global model size adaptation. Our method significantly outperforms prior federated learning algorithms in terms of its convergence when measured by the size of messages exchanged between server and the clients during training. We demonstrated that FedMA can efficiently utilize well-trained local modals, a property desired in many federated learning applications, but lacking in the prior approaches. We have also presented an example where FedMA can help to resolve some of the data biases and outperform aggregate data training.

In the future work we want to extend FedMA to improve federated learning of LSTMs using approximate quadratic assignment solutions from the optimal transport literature, and enable additional deep learning building blocks, e.g. residual connections and batch normalization layers. We also believe it is important to explore fault tolerance of FedMA and study its performance on the larger datasets, particularly ones with biases preventing efficient training even when the data can be aggregated, e.g. Inclusive Images (Doshi, 2018).


  • M. Agueh and G. Carlier (2011) Barycenters in the wasserstein space. SIAM Journal on Mathematical Analysis 43 (2), pp. 904–924. Cited by: §2.1.
  • K. Bonawitz, H. Eichner, W. Grieskamp, D. Huba, A. Ingerman, V. Ivanov, C. Kiddon, J. Konecny, S. Mazzocchi, H. B. McMahan, et al. (2019) Towards federated learning at scale: system design. arXiv preprint arXiv:1902.01046. Cited by: §1.
  • S. Caldas, P. Wu, T. Li, J. Konečnỳ, H. B. McMahan, V. Smith, and A. Talwalkar (2018) Leaf: a benchmark for federated settings. arXiv preprint arXiv:1812.01097. Cited by: §1, §3, §3.
  • T. Doshi (2018) Introducing the inclusive images competition. Cited by: §3, §4.
  • Z. Ghahramani and T. L. Griffiths (2005) Infinite latent feature models and the Indian buffet process. In Advances in Neural Information Processing Systems, pp. 475–482. Cited by: §2.1.
  • M. Gromov (2007) Metric structures for riemannian and non-riemannian spaces. Springer Science & Business Media. Cited by: §2.2.
  • S. Ioffe and C. Szegedy (2015) Batch normalization: accelerating deep network training by reducing internal covariate shift. arXiv preprint arXiv:1502.03167. Cited by: §3.
  • P. Kairouz, H. B. McMahan, B. Avent, A. Bellet, M. Bennis, A. N. Bhagoji, K. Bonawitz, Z. Charles, G. Cormode, R. Cummings, et al. (2019) Advances and open problems in federated learning. arXiv preprint arXiv:1912.04977. Cited by: §1.
  • H. W. Kuhn (1955) The Hungarian method for the assignment problem. Naval Research Logistics (NRL) 2 (1-2), pp. 83–97. Cited by: §2.1.
  • Y. LeCun, L. Bottou, Y. Bengio, P. Haffner, et al. (1998) Gradient-based learning applied to document recognition. Proceedings of the IEEE 86 (11), pp. 2278–2324. Cited by: §2.2.
  • T. Li, A. K. Sahu, A. Talwalkar, and V. Smith (2019) Federated learning: challenges, methods, and future directions. arXiv preprint arXiv:1908.07873. Cited by: §1.
  • E. M. Loiola, N. M. M. de Abreu, P. O. Boaventura-Netto, P. Hahn, and T. Querido (2007) A survey for the quadratic assignment problem. European journal of operational research 176 (2), pp. 657–690. Cited by: §2.2.
  • B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas (2017) Communication-efficient learning of deep networks from decentralized data. In Artificial Intelligence and Statistics, pp. 1273–1282. Cited by: Table 3, §1, §1, §3, §3.
  • M. Mohri, G. Sivek, and A. T. Suresh (2019) Agnostic federated learning. In

    International Conference on Machine Learning

    pp. 4615–4625. Cited by: §1.
  • A. Paszke, S. Gross, S. Chintala, G. Chanan, E. Yang, Z. DeVito, Z. Lin, A. Desmaison, L. Antiga, and A. Lerer (2017) Automatic differentiation in pytorch. In NIPS-W, Cited by: §3.
  • G. Peyré, M. Cuturi, and J. Solomon (2016) Gromov-wasserstein averaging of kernel and distance matrices. In International Conference on Machine Learning, pp. 2664–2672. Cited by: §2.2.
  • S. J. Reddi, S. Kale, and S. Kumar (2018) On the convergence of adam and beyond. In International Conference on Learning Representations, External Links: Link Cited by: Appendix F.
  • O. Russakovsky, J. Deng, H. Su, J. Krause, S. Satheesh, S. Ma, Z. Huang, A. Karpathy, A. Khosla, M. Bernstein, et al. (2015) Imagenet large scale visual recognition challenge.

    International journal of computer vision

    115 (3), pp. 211–252.
    Cited by: §3.
  • [19] O. Russakovsky Strategies for mitigating social bias in deep learning systems. Note: Invited talk at Identifying and Understanding Deep Learning Phenomena workshop, ICML 2019 Cited by: §3.
  • A. K. Sahu, T. Li, M. Sanjabi, M. Zaheer, A. Talwalkar, and V. Smith (2018) On the convergence of federated optimization in heterogeneous networks. arXiv preprint arXiv:1812.06127. Cited by: §1, §3, §3.
  • S. Shankar, Y. Halpern, E. Breck, J. Atwood, J. Wilson, and D. Sculley (2017) No classification without representation: assessing geodiversity issues in open data sets for the developing world. arXiv preprint arXiv:1711.08536. Cited by: §3.
  • K. Simonyan and A. Zisserman (2014) Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556. Cited by: §2.2, §2.2.
  • S. P. Singh and M. Jaggi (2019) Model fusion via optimal transport. arXiv preprint arXiv:1910.05653. Cited by: §2.1.
  • V. Smith, C. Chiang, M. Sanjabi, and A. S. Talwalkar (2017) Federated multi-task learning. In Advances in Neural Information Processing Systems, pp. 4424–4434. Cited by: §1.
  • R. Thibaux and M. I. Jordan (2007) Hierarchical Beta processes and the Indian buffet process. In Artificial Intelligence and Statistics, pp. 564–571. Cited by: §2.1.
  • M. Yurochkin, M. Agarwal, S. Ghosh, K. Greenewald, N. Hoang, and Y. Khazaeni (2019a) Bayesian nonparametric federated learning of neural networks. In International Conference on Machine Learning, pp. 7252–7261. Cited by: §D.2, §1, §2.1, §2.2, §2.3.
  • M. Yurochkin, M. Agarwal, S. Ghosh, K. Greenewald, and N. Hoang (2019b) Statistical model aggregation via parameter matching. In Advances in Neural Information Processing Systems, pp. 10954–10964. Cited by: §2.1.
  • M. Yurochkin, Z. Fan, A. Guha, P. Koutris, and X. Nguyen (2019c) Scalable inference of topic evolution via models for latent geometric structures. In Advances in Neural Information Processing Systems, pp. 5949–5959. Cited by: §2.1.

Appendix A Summary of the datasets used in the experiments

The details of the datasets and hyper-parameters used in our experiments are summarized in Table 3. In conducting the “freezing and retraining” process of FedMA, we notice when retraining the last FC layer while keeping all previous layers frozen, the initial learning rate we use for SGD doesn’t lead to a good convergence (this is only for the VGG-9 architecture). To fix this issue, we divide the initial learning rate by i.e. using for the last FC layer retraining and allow the clients to retrain for 3 times more epochs. We also switch off the weight decay during the “freezing and retraining” process of FedMA except for the last FC layer where we use a weight decay of . For language task, we observe SGD with a constant learning rate works well for all considered methods.

In our experiments, we use FedAvg and FedProx variants without the shared initialization since those would likely be more realistic when trying to aggregate locally pre-trained models. And FedMA still performs well in practical scenarios where local clients won’t be able to share the random initialization.

Method MNIST CIFAR-10 Shakespeare (McMahan et al., 2017)
# Data points
Model LeNet VGG-9 LSTM
# Classes
# Parameters k k k
Optimizer SGD SGD
Hyper-params. Init lr: , (last layer) lr: (const)
momentum: 0.9, weight decay:
Table 3:

The datasets used and their associated learning models and hyper-parameters.

Appendix B Details of Model Architectures and Hyper-parameters

The details of the model architectures we used in the experiments are summarized in this section. Specifically, details of the VGG-9 model architecture we used can be found in Table 4 and details of the 1-layer LSTM model used in our experimental study can be found in Table 5.

Parameter Shape Layer hyper-parameter
layer1.conv1.weight stride:;padding:
layer1.conv1.bias 32 N/A
layer2.conv2.weight stride:;padding:
layer2.conv2.bias 64 N/A
pooling.max N/A kernel size:;stride:
layer3.conv3.weight stride:;padding:
layer3.conv3.bias N/A
layer4.conv4.weight stride:;padding:
layer4.conv4.bias N/A
pooling.max N/A kernel size:;stride:
dropout N/A %
layer5.conv5.weight stride:;padding:
layer5.conv5.bias N/A
layer6.conv6.weight stride:;padding:
layer6.conv6.bias N/A
pooling.max N/A kernel size:;stride:
dropout N/A %
layer7.fc7.weight N/A
layer7.fc7.bias N/A
layer8.fc8.weight N/A
layer8.fc8.bias N/A
dropout N/A %
layer9.fc9.weight N/A
layer9.fc9.bias N/A
Table 4:

Detailed information of the VGG-9 architecture used in our experiments, all non-linear activation function in this architecture is ReLU; the shapes for convolution layers follows

Parameter Shape
Table 5: Detailed information of the LSTM architecture in our experiment

Appendix C Data augmentation and normalization details

In preprocessing the images in CIFAR-10 dataset, we follow the standard data augmentation and normalization process. For data augmentation, random cropping and horizontal random flipping are used. Each color channels are normalized with mean and standard deviation by

, . Each channel pixel is normalized by subtracting the mean value in this color channel and then divided by the standard deviation of this color channel.

Appendix D Extra Experimental Details

d.1 Shapes of Final Global Model

Here we report the shapes of final global VGG and LSTM models returned by FedMA with communication.

Parameter Shape Growth rate (#global / #original params)
Total Number of Parameters ()
Table 6: Detailed information of the LSTM architecture in our experiment
Parameter Shape Growth rate (#global / #original params)
Total Number of Parameters ()
Table 7: Detailed information of the final global VGG-9 model returned by FRB; the shapes for convolution layers follows

d.2 Hyper-parameters for BBP-MAP

We follow FPNM (Yurochkin et al., 2019a) to choose the hyper-parameters of BBP-MAP, which controls the choices of , , and the as discussed in Section 2. More specifically, there are three parameters to choose i.e. 1)

, the prior variance of weights of the global neural network; 2)

, which controls discovery of new hidden states. Increasing leads to a larger final global model; 3) is the variance of the local neural network weights around corresponding global network weights. We empirically analyze the different choices of the three hyper-parameters and find the choice of for VGG-9 on CIFAR-10 dataset and for LSTM on Shakespeare dataset lead to good performance in our experimental studies.

Appendix E Practical Considerations

Following from the discussion in PFNM, here we briefly discuss the time complexity of FedMA. For simplicity, we focus on a single-layer matching and assume all participating clients train the same model architecture. The complexity for matching the entire model follows trivially from this discussion. The worst case complexity is achieved when no hidden states are matched and is equal to for building the cost matrix and for running the Hungarian algorithm where the definitions of , and follow the discussion in Section 2. The best complexity per layer is (achieved when all hidden states are matched) . Practically, when the number of participating clients i.e.  is large and each client trains a big model, the speed of our algorithm can be relatively slow.

To seed up the Hungarian algorithm. Although there isn’t any algorithm that achieves lower complexity, better implementation improves the constant significantly. In our experiments, we used an implementation based on shortest path augmentation i.e. lapsolver 222https://github.com/cheind/py-lapsolver. Empirically, we observed that this implementation of the Hungarian algorithm leads to orders of magnitude speed ups over the vanilla implementation.

Appendix F Hyper-parameters for the Handling Data Bias Experiments

In conducting the “handling data bias” experiments. We re-tune the local epoch for both FedAvg and FedProx. The considered candidates of are . We observe that a relatively large choice of can easily lead to poor convergence of FedAvg. While FedProx tolerates larger choices of better, a smaller can always lead to good convergence. We use for both FedAvg and FedProx in our experiments. For FedMA, we choose since it leads to a good convergence. For the “oversampling” baseline, we found that using SGD to train VGG-9 over oversampled dataset doesn’t lead to a good convergence. Moreover, when using constant learning rate, SGD can lead to model divergence. Thus we use AMSGrad (Reddi et al., 2018) method for the “oversampling” baseline and train for epochs. To make the comparison fair, we use AMSGrad for all other centralized baselines to get the reported results in our experiments. Most of them converges when training for epochs. We also test the performance of the “Entire Data Training”, “Color Balanced”, and “No Bias” baselines over SGD. We use learning rate and weight decay at and train for epochs for those three baselines. It seems the “Entire Data Training” and ”No Bias” baselines converges to a slightly better accuracy i.e.  and respectively (compared to and for AMSGrad). But the “Color Balanced” doesn’t seem to converge better accuracy (we get accuracy for SGD and for AMSGrad).