Architecture Agnostic Federated Learning for Neural Networks

by   Disha Makhija, et al.

With growing concerns regarding data privacy and rapid increase in data volume, Federated Learning(FL) has become an important learning paradigm. However, jointly learning a deep neural network model in a FL setting proves to be a non-trivial task because of the complexities associated with the neural networks, such as varied architectures across clients, permutation invariance of the neurons, and presence of non-linear transformations in each layer. This work introduces a novel Federated Heterogeneous Neural Networks (FedHeNN) framework that allows each client to build a personalised model without enforcing a common architecture across clients. This allows each client to optimize with respect to local data and compute constraints, while still benefiting from the learnings of other (potentially more powerful) clients. The key idea of FedHeNN is to use the instance-level representations obtained from peer clients to guide the simultaneous training on each client. The extensive experimental results demonstrate that the FedHeNN framework is capable of learning better performing models on clients in both the settings of homogeneous and heterogeneous architectures across clients.


No One Left Behind: Inclusive Federated Learning over Heterogeneous Devices

Federated learning (FL) is an important paradigm for training global mod...

Federated Self-supervised Learning for Heterogeneous Clients

Federated Learning has become an important learning paradigm due to its ...

Federated Learning with Position-Aware Neurons

Federated Learning (FL) fuses collaborative models from local nodes with...

Federated Learning with Heterogeneous Architectures using Graph HyperNetworks

Standard Federated Learning (FL) techniques are limited to clients with ...

Protea: Client Profiling within Federated Systems using Flower

Federated Learning (FL) has emerged as a prospective solution that facil...

Stochastic Client Selection for Federated Learning with Volatile Clients

Federated Learning (FL), arising as a novel secure learning paradigm, ha...

FedRare: Federated Learning with Intra- and Inter-Client Contrast for Effective Rare Disease Classification

Federated learning (FL), enabling different medical institutions or clie...

I Introduction

Distributed machine learning has been an important field of study for long and is becoming more and more important with time 

paka03. Federated Learning is a type of distributed machine learning setting that consists of many end devices or silo organisations(clients) which have access to the data stored locally and a global server which can orchestrate the learning without accessing all the data. With the ever so rapidly growing amount of data and the concerns around data privacy, federated learning has emerged as a very promising direction, as it allows learning of a global model using the data present at each client but without explicitly sharing the data outside the client devices, thus helping in ensuring data privacy and also reducing the cost of centralised training and storage. FedAvg pmlr-v54-mcmahan17a is the de-facto federated learning algorithm where each client performs SGD steps towards training its local model using its own data and compute resources. The client models are then periodically shared with the server and the server aggregates the client models to create a global model which is sent back to the clients. However, the solution obtained from FedAvg has been shown to diverge in presence of statistical heterogeneity across clients fedprox. Over the years several modifications have been proposed to the original algorithm addressing different aspects like data heterogeneity, availability of clients for training, modifying the aggregation mechanism, optimizing communication costs, personalised client models etc. fedproxdittoscaffoldsemicyclicmochafedmafedrep. In all of these methods the global model parameters are obtained by appropriately aggregating the local model parameters. Yet in most practical settings where the clients are heterogeneous and differ a lot in terms of their compute resources and data distributions, these methods may face significant challenges. Several real-world FL scenarios require training over end devices which have very different hardware. In such cases, for the above methods, the clients that are incapable of training large models will never be able to take part in the training process. Similarly, if the common model architecture is kept small to accommodate all the clients, some clients will be under-utilised. These negative effects might become more prominent in the cross-silo FL setting where the total number of clients is even smaller (typically less than 100). An illustrative experiment shown in Figure 1 demonstrates the drop in accuracy of the global methods like FedAvg and FedProx when a few clients are left out of the training.

Figure 1: When a fraction of clients cannot afford to build full-blown models, FedAvg and FedProx need to drop the clients altogether resulting in poorer average performance versus FedHeNN which can accommodate clients with heterogeneous architectures.

In this work, we propose a systematic framework for architecture agnostic federated learning called FedHeNN. This framework is able to overcome the aforementioned challenges of client heterogeneity by allowing each client to build a personalised model without any constraint on the model architecture including number and size of the hidden layers, activation functions, etc. The learning across clients is transferred by grounding the representations being learnt at each client through a proximal term. Specifically, the optimisation objective at each client comprises of two terms - the task loss, and a proximal term that pulls the final learned representations of the models together. This also makes the neural networks learn more robust representations on a wide range of data leading to better performance. We test our framework on a suite of federated learning datasets in both the settings of homogeneous and heterogeneous model architectures across clients.

Our Contributions are summarized as follows :

  1. Our primary contribution is FedHeNN, a new framework for training deep neural networks in federated learning settings. We identify the shortcomings in joint learning of neural-networks and circumvent those by grounding the representations being learnt at each client through a proximal term. We empirically show that having the proximal term on the representations can deliver superior performance.

  2. To allow transfer of learning across architectures in different output spaces, we suggest the use of a kernel based distance metric called Centered Kernel Alignment (CKA). The use of kernel based distance metric provides FedHeNN greater flexibility.

  3. Additionally, the structure of the FedHeNN framework allows itself to be extended to the setting where different clients have different model architectures. Thus, we propose a solution that is architecture agnostic and has performance that is better or comparable to the existing methods that operate in homogeneous architecture settings.

The rest of the paper is organised as follows. Section II provides a background on Federated Learning and related developments. In Section III, we go over the preliminaries and then propose our framework, FedHeNN. We provide a thorough experimental evaluation of FedHeNN on different types of FL datasets in Section IV and conclude in Section V.

Ii Related Work

Distributed learning algorithms were extensively studied in the data mining community in the early 2000s under topics such as distributed (and privacy preserving) data mining lipi00megh03megh05paka03ghtu00. This topic was re-framed as Federated Learning in an influential paper pmlr-v54-mcmahan17a that introduced the FedAvg algorithm, and has rapidly gained new adherents since then. The framework proposed by FedAvg pmlr-v54-mcmahan17a has been the standard solution for Federated Learning since 2017. However when the data across clients is non-iid, averaging the local optima of the multiple clients to obtain the global solution may lead to divergence in the optimization. To solve this, FedProx fedprox proposes to modify the local training objective by adding a proximal term which penalizes the distance between the current global model weights and the local model weights thus preventing each local update from moving far away. Similarly, SCAFFOLD scaffold introduces control-variates to correct the local updates. FedPD fedpd, FedSplit fedsplit, and FedDyn feddyn are other important works that study the problem of finding better fixed points. fedDF shows that simply averaging the local models learnt on local distributions to obtain a global model may not be ideal. FedBE fedbe also provides evidence that the best performing aggregate model need not necessarily be the average of the local models. Another important aspect of Federated Learning is thus to appropriately aggregate the local models. PFNM pfnm and FedMA fedma show that neurons in each layer are permutation invariant and these works perform a layer by layer matching of neurons and then aggregate of the matched neurons.  singh2020model use an Optimal-Transport based distance to perform the matching of neurons before aggregation. Different from above, FedDF fedDF uses additional data samples to distill knowledge from clients’ local models. FedNova fednova allows each client to perform variable amounts of work and calculates a regularized average for the aggregate. Several works focus on creating local personalized models instead of a single global model to cater to the heterogeneity of the data distributions across clients. In the literature, Personalised FL has been solved using various different ways like keeping the task specific component of the clients local fedrep, clustering the clients SattlerMS21GhoshCYR20, using meta-learning  personalised_meta_learningwafflemamlKhodakBT19, multi-task learning fed_mtlditto,  mocha

and transfer-learning 

local_adaption. HaoELZLCC21FLviaSynluo2021fear augment the data distribution by creating additional data samples and using them for learning. MOON li2021model uses a contrastive loss to bring the representations of objects closer and utilises it to correct the local training. The need for heterogeneous client models was recently highlighted in FedProto tan2021fedproto. Our work is different from them as our formulation directly works on representations of data instances as opposed to class specific prototypes being used in FedProto for learning and inference. We think that having only one prototype per class could be limiting in case of multi-modal classes.

Iii FedNN Methodology

In an effort to enhance federated learning with heterogeneous clients, we propose a new framework, Federated Heterogeneous Neural Networks (FedHeNN). FedHeNN provides a systematic way to achieve the goal of joint learning of deep neural networks varying in architecture and output space across distributed clients. In this section we will first cover the preliminaries and then go to the proposed method. Preliminaries We first go over the key concepts of the federated learning methods and then elaborate on our proposed method. Consider a set of clients , the traditional federated learning algorithm learns a single global model by using the model parameters learned on individual clients.

where are the local parameters obtained at the client and is the aggregation function. For a single-layer model, w

is the vector of weights. For an m-layer model

is the collection containing weights at each layer. The bias terms are assumed to be incorporated in the weights without loss of generality.

iii.1 FedAvg

The FedAvg algorithm learns each client’s local parameters by solving


is some loss function. An element-wise average operation is then performed on the weight matrices of the clients for getting an aggregated model

at the server. In case of a multi-layer network the average is taken layer wise and the parameters of each client are weighted based on the number of data points present at the client. For the layer, we have

While this aggregation mechanism is shown to have good empirical results, this method of learning involving element-wise averaging might give sub-optimal results because of various reasons like permutation invariance property of the neurons, different data distributions across clients etc. Besides, FedAvg imposes a hard constraint on clients to train the models with the exact same architectures. In practice, it is highly likely that different clients may not be able to train the models of same capacity. We identify that apart from solving for sub-problems like aggregation mechanism or data heterogeneity, we also need to make sure that the networks on clients are being trained towards the same goal. In order to do so we suggest to modify the objective function being optimized at each client. We consider a neural network to be composed of two components - the representation learning component that maps the input to a -dimensional representation vector and the task learning component which learns the prediction function. The initial layers of the neural network are considered as the representation learning component whose output is the representation vector denoted by for a data instance under the model with parameters , while the last layer is considered to be the task-specific layer and the output of which is the prediction corresponding to data instances denoted by , as depicted in Figure 2. The details of the method are explained in the next sub-sections.

Figure 2: Depiction of a Neural Network consisting of two components - representation learning part and prediction function.

Additionally, we explore two different settings for the FedHeNN framework - homogeneous and heterogeneous setting, which we define below and elaborate on in the next sub-section.

Definition III.1.

We refer to a FL setting as homogeneous when each of the client in the network has the exact same architecture. A FL setting is said to be heterogeneous when each client in the network can define its own architecture.

iii.2 FedHeNN for Homogeneous Clients

For the homogeneous federated learning setting, we propose to modify the loss function on each client to additionally incorporate a proximal term on the representations. In particular, instead of just minimising the loss function , each client minimises the following objective -


where is some distance metric, is the representation learned on the client for the set of data points X, and the is the representation learned by the global model communicated at the previous round . The is learned using the FedAvg algorithm. Because we are ensuring the similarities between the representations we do not explicitly need to address the aggregation mechanism. The proximal term helps in bringing the representations of the local and global model together and helps in correcting the learning on each client.

iii.3 FedHeNN for Heterogeneous Clients

In certain settings it might not be practical for all clients to train the neural network models of same complexity. We identify that it is non-trivial to use existing algorithms to perform joint training in such settings. Our framework allows one to perform federated learning across heterogeneous clients by collecting and aggregating the representations learnt on the local clients and pulling those representations closer. The framework is inspired by the same intuition that there is an underlying lower dimensional representation of the data that the clients can together uncover. To achieve this, we let each client train its own model but pull the representations learnt by different clients closer by adding a proximal term to the client’s loss function. The proximal term measures the distance between the representations learnt by different local models. Specifically, if is the representation matrix obtained on the client for data X, then the proximal term measuring the distances between client representations for the client can be written as -


The contribution of each client , could be set to reflect the capacity or the strength of the model on the client. At each iteration the server generates a set of unlabelled data instances called Representation Alignment Dataset (RAD) denoted by X and uses it to align the representations across clients. As before, the clients share the local weights with the server. Instead of aggregating the weights, the server uses each client’s weights to generate representations for the RAD and aggregates the representations over clients. The server then distributes the aggregated representations and the RAD to each client. The clients on the next iteration learn a model that minimizes the loss that is a combination of the task loss and distance to the aggregated representations. Specifically, at iteration each client is trying to minimize -


This method helps us utilise the training from local clients but without explicitly generating a global aggregated model. The privacy of the clients data is also preserved as the clients do not need to share their data in any format. Thus the clients can keep on using the local personalised models but utilise the training from other peer clients. Later when we describe the distance function we will see that our method is robust to having output representations of clients in different spaces, i.e., ’s, , need not be aligned with each other.

iii.4 The Proximal Term

Here we give a detailed explanation about the proximal term or the distance function being used in our objective formulation in equations (1) and  (2). As mentioned earlier, we consider a deep neural network model training to be performing two high level steps - learning to map the inputs of different modalities to a -dimensional representation, and then learning a prediction function. We consider all but the last layer of the neural network as the representation learning module and its output as the learned representation and denote it by . So in order to match the representation learning part of the network across all clients we compare the outputs of the representation learning module of the networks on the same set of instances. In specific, we use a set of input examples encoded in matrix X aka RAD, pass them through all the local networks and capture the outputs of the representation learning module of the networks in matrices , where . stands for the activation (representation) matrix of the client, is the size of RAD and is the output dimension of the representation. Note that our method can work with representations of different dimensions on different clients, i.e., can vary with . At each communication round, we randomly select only a sample of the instances to be a part of the alignment dataset keeping . After obtaining the activation matrices ’s, we need a way to compare the representations. For this, we suggest using the representation distance matrices(RDMs). An RDM uses the distance between the instances to capture the characteristics of the representational space. Representational distance learning rdl has also been used for knowledge distillation in the past. To measure the distance, we use a distance metric proposed specifically for neural network representations called Centered Kernel Alignment(CKA) cka. CKA was originally proposed to compare the representations obtained from different neural networks to determine equivalence relationships between hidden layers of neural networks. CKA takes into input the activation(representation) matrices and outputs a similarity score between 1 (identical representations) and 0 (not similar at all). Some of the useful properties of CKA include invariance to invertible linear transformation, orthogonal transformation and isotropic scaling. CKA is also robust to training with different random initializations. Because of all these properties and the reason that CKA is able to learn similarities between layers of different widths, it extends itself naturally for use as a distance metric in our method. Let and be the representation matrices for clients and obtained for the RAD. The distance between and is obtained by first computing the kernel matrices and for any choice of kernel as follows.

and be the representational distance matrices then distance between and is given under CKA by -

where the estimator for Hilbert-Schmidt Independence Criterion (HSIC) could be written as -

with as the centering matrix

We try using a linear as well as an RBF kernel for computing the distances for our method. The Linear CKA can be simply written as -


The local learning objectives for our method in homogeneous and heterogeneous settings thus respectively become


where is the representation distance matrix obtained over the instances from the global model at previous iteration, and



For , the objective function corresponding to homogeneous FedHeNN, in equation (4), becomes the FedAvg algorithm. And for the heterogeneous FedHeNN, the objective function given in equation (5) boils down to individual clients training their own local models in isolation using only the local data. On the other hand, when , the framework will try to obtain identical representations from all the local models without caring about the prediction task. In homogeneous FL setting for FedHeNN with linear CKA, we can say that, after sufficiently large number of iterations , the framework will make all the models identical.

Lemma III.2.

Given a homogeneous FL setting with clients training linear models and , then after sufficiently large number of iteration , for the FedHeNN framework with Linear CKA, we have for all .


The optimization problem at each client is given by:

When , the problem becomes

For linear CKA, we have

If we assume that each client has a linear network, then . Then, after sufficiently large number of iterations , i.e., when each client has seen numerous RADs, optimizing equation (4) for will lead to

As a consequence, we reach the conclusion of the lemma. ∎

  Input: number of clients , number of communication rounds

, number of local epochs

, parameter
  Output: Final model
  At Server -
  for  to  do
     Generate RAD, X, by random sampling
     Select a subset of clients
     for each selected client  do
     end for
  end for
  Initialize with
  for each local epoch do
     Update using SGD for loss in equation  (4)
  end for
  Return to the server
Algorithm 1 FedHeNN Algorithm for Homogeneous clients
  Input: number of clients , number of communication rounds , number of local epochs , parameter , weight vector for clients
  Output: Final set of personalised models
  At Server -
  for  to  do
     Generate RAD, X, by random sampling
     for each client  do
     end for
     Select a subset of clients
     for each selected client  do
     end for
  end for
  Initialize with
  for each local epoch do
     Update using SGD for loss in equation  (5)
  end for
  Return to the server
Algorithm 2 FedHeNN Algorithm for Heterogeneous clients

Iv Experiments

We now present the effectiveness of the FedHeNN framework using empirical results on different datasets and models. We simulate statistically heterogeneous and system heterogeneous FL settings by manipulating the data partitions and model architectures across clients respectively. We also discuss the effects of other variables like the size of RAD, and the choice of kernel on the performance of FedHeNN.

iv.1 Experimental Details

We evaluate FedHeNN in different settings involving varied models, tasks, heterogeneity levels, and datasets. We describe these settings below and then go on to presenting our results. Datasets We consider two different high level tasks, image classification and text classification, and use datasets corresponding to these from the popular federated learning benchmark LEAF caldas2019leaf

. For the image classification task, we use CIFAR-10 and CIFAR-100 datasets that contain colored images in 10 and 100 classes respectively. And for the text classification task we use a binary classification dataset called Sentiment140. We partition the entire data to generate non-iid samples on each client and then split those into training and test sets at the client site.

Baselines We compare our method against three different baselines - FedAvg, FedProx and FedRep. The FedAvg and FedProx algorithms learn a centralised global model thus the reported performance metric is of the global model. On the contrary, the FedRep method learns personalised models for each client therefore FedRep’s performance is reported for the personalised models. For FedHeNN, we report the performance of local models for both homogeneous and heterogeneous settings and that of the global model for the homogeneous setting. For the FedProto algorithm, it is demonstrated in the paper that the performance gap between FedProto and FedAvg decreases when we have more samples per class. In our settings, because we do not restrict the number of samples per class and also beat the FedAvg algorithm by a significant margin, we do not directly compare with FedProto. Evaluation Metric and other Parameters

We use the average test accuracy obtained on the clients’ test datasets as the evaluation criterion. For global models the test accuracy is computed by evaluating the global model on the local test datasets and for the local models test accuracy is computed by testing personalised models on the local test datasets. The hyperparameter

that controls the contribution of representation similarity in the objective function is kept as a function of (the number of communication round). This is because the initial representations obtained from insufficiently trained models are not accurate and keeping a high in the initial rounds may mislead the training. The base value of is tuned as a hyperparameter and we find that the best performance is obtained by keeping for CIFAR-10 and CIFAR-100 and for Sentiment140 datasets. The size of RAD is an important parameter for our method. The performance reported in the paper is obtained by keeping this size constant at 5000 which is much smaller than the size of training or test datasets and doesn’t increase the memory footprint drastically. Implementation For the FL simulations, we keep the non-iid data distribution across clients such that each client will have access to data of only certain classes, for example, with 5 classes per client, client might have access to data for classes and client might have . We also vary this number of classes to change the heterogeneity level across clients. The robustness and scalabality of our method is tested by increasing the number of clients participating in training from 100 to 500 for CIFAR-10 dataset. The total number of communication rounds is kept constant at 200 for all algorithms and at each round only 10% of the clients are sampled and updated. We find that increasing the number of local epochs on clients doesn’t worsen the performance of the client models for our method, so the number of local epochs is set to 20. For the heterogeneous FedHeNN, each entry of the weight vector for aggregating the representations w is set to

. In each local update, we use SGD with momentum for training. For the homogeneous FedHeNN for CIFAR datasets, we use a CNN model with 2 convolutional layers with each convolutional layer followed by a max-pooling layer followed by 3 fully-connected layers at the end. For the heterogeneous FedHeNN, for each client we uniformly randomly sample from a set of 5 different CNNs obtained by varying the architecture size in between the simplest one that contains 1 convolutional and 1 fully connected layer and the most complex one containing 3 convolutional and 3 fully connected layers. For the Sentiment140 dataset, we use either a 1-layer or a 2-layer LSTM followed by 2 fully connected layers.

CKA For the CKA distance metric, we evaluate the performances by using a linear kernel as well as an RBF kernel. For the RBF kernel, we try various values for but as we will show later, the performance obtained using RBF kernel is not very different from that of the linear kernel.

iv.2 Results

The performance of FedHeNN and baselines under various settings is reported in Table 2 and Table 3. The global model performances of the FedHeNN global model and related baselines is reported in Table 2. It can be observed that the FedHeNN global model outperforms the FedAvg and FedProx algorithms. The results for comparisons of the personalised models are reported in Table 3 and the results demonstrate that the FedHeNN’s performance is better than that of the baselines. We observe that the homogeneous FedHeNN has a higher gain over the baselines than the heterogeneous FedHeNN which is expected because of the varying capacity of the local models in the heterogeneous setting. Linear vs RBF Kernel For computing the CKA based distances, we try using both the Linear as well as RBF kernel. Based on the empirical analysis of FedHeNN, it is observed that both the linear and RBF kernels give comparable performances as shown in Table 1.

Dataset Linear CKA RBF CKA
CIFAR-10 94.47 93.03
CIFAR-100 84.37 83.03
Sentiment140 72.6 72.8
Table 1: Test Accuracy for Linear vs RBF Kernel compared for homogeneous FedHeNN on CIFAR-10 dataset.

Data set(Setting) FedAvg FedProx FedHeNN Global
CIFAR-10(100 clients, 2 cls/client) 44.29 0.5 53.8 2.3 68.8 2.1
CIFAR-10(100 clients, 5 cls/client) 58.14 0.7 63.3 2.0 70.19 2.0
CIFAR-10(500 clients, 2 cls/client) 42.7 0.4 50.46 1.4 65.4 0.8
CIFAR-10(500 clients, 5 cls/client) 56.8 0.5 55.2 1.2 64.7 0.7
CIFAR-100(100 clients, 20 cls/client) 28.6 0.8 27.3 1.1 44.2 0.7
Sentiment140(100 clients, 2 cls/client) 52.6 0.4 52.7 1.0 52.7 0.01
Table 2: Average test accuracy of FedHeNN on various datasets computed for the common global model as compared to the baselines with global models.

Data set(Setting) FedRep FedHeNN Homo FedHeNN Hetero
CIFAR-10(100 clients, 2 cls/client) 85.7 0.4 94.7 1.1 88.9 0.35
CIFAR-10(100 clients, 5 cls/client) 72.4 1.2 84.37 1.5 73.01 0.3
CIFAR-10(500 clients, 2 cls/client) 78.9 0.6 86.5 0.9 82.02 0.8
CIFAR-10(500 clients, 5 cls/client) 58.14 0.21 73.32 1.23 61.74 0.6
CIFAR-100(100 clients, 20 cls/client) 38.85 0.9 62.89 0.8 43.36 0.2
Sentiment140(100 clients, 2 cls/client) 69.8 0.4 72.6 0.3 71.5 0.5
Table 3: Average test accuracy of FedHeNN on various datasets computed for the personalised models as compared to the baselines with personalised models.

Effect of Local epochs We also analyse the effect of varying local epochs in FedHeNN. In FedAvg, increasing the number of local epochs has an adverse effect on the performance of the model. On the other hand, no such effect was observed for FedHeNN owing to the presence of the proximal term. We keep the number of local epochs for FedHeNN to be as high as 20.

Figure 3: Change in test accuracy of different algorithms when the data on a fraction of clients is reduced by 50% shown for CIFAR-10 dataset.

Sensitivity to Changing Amount of Data We have shown through experiments that the FedHeNN framework is able to accommodate the clients with lower compute resources in an effective way. In order to show that the FedHeNN can also work with the clients with smaller data footprint, we do an experiment in which we randomly take a fraction of clients and reduce the data on those clients by 50%. We show the results of the experiment in Figure 3 where x-axis has the fraction of clients picked for shrinking the data and y-axis is the average test accuracy of all clients obtained when the framework is trained on the reduced dataset. We notice that even with the decreasing data size on the clients’ ends FedHeNN is able to maintain graceful performance. This effect could be attributed to the fact that even though the number of instances to train the prediction function is reduced, the representation learning component is still robust.

V Discussion

We present a new method for enhancing learning in Federated Learning by introducing a systematic framework called FedHeNN. The FedHeNN framework is unique because it allows the clients with heterogeneous architectures to participate in the joint learning process helping boost the performance. This could be a huge practical advancement as now the individual client devices or organisations with variable amount of resources can equally contribute and learn from each other. The empirical results indicate that FedHeNN is able to achieve better performing results while also being more inclusive. For future work, we would work on determining the solution characteristics and the convergence guarantees of both the FedHeNN algorithms.