I Introduction
Distributed machine learning has been an important field of study for long and is becoming more and more important with time
paka03. Federated Learning is a type of distributed machine learning setting that consists of many end devices or silo organisations(clients) which have access to the data stored locally and a global server which can orchestrate the learning without accessing all the data. With the ever so rapidly growing amount of data and the concerns around data privacy, federated learning has emerged as a very promising direction, as it allows learning of a global model using the data present at each client but without explicitly sharing the data outside the client devices, thus helping in ensuring data privacy and also reducing the cost of centralised training and storage. FedAvg pmlrv54mcmahan17a is the defacto federated learning algorithm where each client performs SGD steps towards training its local model using its own data and compute resources. The client models are then periodically shared with the server and the server aggregates the client models to create a global model which is sent back to the clients. However, the solution obtained from FedAvg has been shown to diverge in presence of statistical heterogeneity across clients fedprox. Over the years several modifications have been proposed to the original algorithm addressing different aspects like data heterogeneity, availability of clients for training, modifying the aggregation mechanism, optimizing communication costs, personalised client models etc. fedprox, ditto, scaffold, semicyclic, mocha, fedma, fedrep. In all of these methods the global model parameters are obtained by appropriately aggregating the local model parameters. Yet in most practical settings where the clients are heterogeneous and differ a lot in terms of their compute resources and data distributions, these methods may face significant challenges. Several realworld FL scenarios require training over end devices which have very different hardware. In such cases, for the above methods, the clients that are incapable of training large models will never be able to take part in the training process. Similarly, if the common model architecture is kept small to accommodate all the clients, some clients will be underutilised. These negative effects might become more prominent in the crosssilo FL setting where the total number of clients is even smaller (typically less than 100). An illustrative experiment shown in Figure 1 demonstrates the drop in accuracy of the global methods like FedAvg and FedProx when a few clients are left out of the training.In this work, we propose a systematic framework for architecture agnostic federated learning called FedHeNN. This framework is able to overcome the aforementioned challenges of client heterogeneity by allowing each client to build a personalised model without any constraint on the model architecture including number and size of the hidden layers, activation functions, etc. The learning across clients is transferred by grounding the representations being learnt at each client through a proximal term. Specifically, the optimisation objective at each client comprises of two terms  the task loss, and a proximal term that pulls the final learned representations of the models together. This also makes the neural networks learn more robust representations on a wide range of data leading to better performance. We test our framework on a suite of federated learning datasets in both the settings of homogeneous and heterogeneous model architectures across clients.
Our Contributions are summarized as follows :
Our primary contribution is FedHeNN, a new framework for training deep neural networks in federated learning settings. We identify the shortcomings in joint learning of neuralnetworks and circumvent those by grounding the representations being learnt at each client through a proximal term. We empirically show that having the proximal term on the representations can deliver superior performance.

To allow transfer of learning across architectures in different output spaces, we suggest the use of a kernel based distance metric called Centered Kernel Alignment (CKA). The use of kernel based distance metric provides FedHeNN greater flexibility.

Additionally, the structure of the FedHeNN framework allows itself to be extended to the setting where different clients have different model architectures. Thus, we propose a solution that is architecture agnostic and has performance that is better or comparable to the existing methods that operate in homogeneous architecture settings.
The rest of the paper is organised as follows. Section II provides a background on Federated Learning and related developments. In Section III, we go over the preliminaries and then propose our framework, FedHeNN. We provide a thorough experimental evaluation of FedHeNN on different types of FL datasets in Section IV and conclude in Section V.
Ii Related Work
Distributed learning algorithms were extensively studied in the data mining community in the early 2000s under topics such as distributed (and privacy preserving) data mining lipi00, megh03, megh05, paka03, ghtu00. This topic was reframed as Federated Learning in an influential paper pmlrv54mcmahan17a that introduced the FedAvg algorithm, and has rapidly gained new adherents since then. The framework proposed by FedAvg pmlrv54mcmahan17a has been the standard solution for Federated Learning since 2017. However when the data across clients is noniid, averaging the local optima of the multiple clients to obtain the global solution may lead to divergence in the optimization. To solve this, FedProx fedprox proposes to modify the local training objective by adding a proximal term which penalizes the distance between the current global model weights and the local model weights thus preventing each local update from moving far away. Similarly, SCAFFOLD scaffold introduces controlvariates to correct the local updates. FedPD fedpd, FedSplit fedsplit, and FedDyn feddyn are other important works that study the problem of finding better fixed points. fedDF shows that simply averaging the local models learnt on local distributions to obtain a global model may not be ideal. FedBE fedbe also provides evidence that the best performing aggregate model need not necessarily be the average of the local models. Another important aspect of Federated Learning is thus to appropriately aggregate the local models. PFNM pfnm and FedMA fedma show that neurons in each layer are permutation invariant and these works perform a layer by layer matching of neurons and then aggregate of the matched neurons. singh2020model use an OptimalTransport based distance to perform the matching of neurons before aggregation. Different from above, FedDF fedDF uses additional data samples to distill knowledge from clients’ local models. FedNova fednova allows each client to perform variable amounts of work and calculates a regularized average for the aggregate. Several works focus on creating local personalized models instead of a single global model to cater to the heterogeneity of the data distributions across clients. In the literature, Personalised FL has been solved using various different ways like keeping the task specific component of the clients local fedrep, clustering the clients SattlerMS21, GhoshCYR20, using metalearning personalised_meta_learning, waffle, maml, KhodakBT19, multitask learning fed_mtl, ditto, mocha
local_adaption. HaoELZLCC21, FLviaSyn, luo2021fear augment the data distribution by creating additional data samples and using them for learning. MOON li2021model uses a contrastive loss to bring the representations of objects closer and utilises it to correct the local training. The need for heterogeneous client models was recently highlighted in FedProto tan2021fedproto. Our work is different from them as our formulation directly works on representations of data instances as opposed to class specific prototypes being used in FedProto for learning and inference. We think that having only one prototype per class could be limiting in case of multimodal classes.Iii FedNN Methodology
In an effort to enhance federated learning with heterogeneous clients, we propose a new framework, Federated Heterogeneous Neural Networks (FedHeNN). FedHeNN provides a systematic way to achieve the goal of joint learning of deep neural networks varying in architecture and output space across distributed clients. In this section we will first cover the preliminaries and then go to the proposed method. Preliminaries We first go over the key concepts of the federated learning methods and then elaborate on our proposed method. Consider a set of clients , the traditional federated learning algorithm learns a single global model by using the model parameters learned on individual clients.
where are the local parameters obtained at the client and is the aggregation function. For a singlelayer model, w
is the vector of weights. For an mlayer model
is the collection containing weights at each layer. The bias terms are assumed to be incorporated in the weights without loss of generality.iii.1 FedAvg
The FedAvg algorithm learns each client’s local parameters by solving
where
is some loss function. An elementwise average operation is then performed on the weight matrices of the clients for getting an aggregated model
at the server. In case of a multilayer network the average is taken layer wise and the parameters of each client are weighted based on the number of data points present at the client. For the layer, we haveWhile this aggregation mechanism is shown to have good empirical results, this method of learning involving elementwise averaging might give suboptimal results because of various reasons like permutation invariance property of the neurons, different data distributions across clients etc. Besides, FedAvg imposes a hard constraint on clients to train the models with the exact same architectures. In practice, it is highly likely that different clients may not be able to train the models of same capacity. We identify that apart from solving for subproblems like aggregation mechanism or data heterogeneity, we also need to make sure that the networks on clients are being trained towards the same goal. In order to do so we suggest to modify the objective function being optimized at each client. We consider a neural network to be composed of two components  the representation learning component that maps the input to a dimensional representation vector and the task learning component which learns the prediction function. The initial layers of the neural network are considered as the representation learning component whose output is the representation vector denoted by for a data instance under the model with parameters , while the last layer is considered to be the taskspecific layer and the output of which is the prediction corresponding to data instances denoted by , as depicted in Figure 2. The details of the method are explained in the next subsections.
Additionally, we explore two different settings for the FedHeNN framework  homogeneous and heterogeneous setting, which we define below and elaborate on in the next subsection.
Definition III.1.
We refer to a FL setting as homogeneous when each of the client in the network has the exact same architecture. A FL setting is said to be heterogeneous when each client in the network can define its own architecture.
iii.2 FedHeNN for Homogeneous Clients
For the homogeneous federated learning setting, we propose to modify the loss function on each client to additionally incorporate a proximal term on the representations. In particular, instead of just minimising the loss function , each client minimises the following objective 
(1) 
where is some distance metric, is the representation learned on the client for the set of data points X, and the is the representation learned by the global model communicated at the previous round . The is learned using the FedAvg algorithm. Because we are ensuring the similarities between the representations we do not explicitly need to address the aggregation mechanism. The proximal term helps in bringing the representations of the local and global model together and helps in correcting the learning on each client.
iii.3 FedHeNN for Heterogeneous Clients
In certain settings it might not be practical for all clients to train the neural network models of same complexity. We identify that it is nontrivial to use existing algorithms to perform joint training in such settings. Our framework allows one to perform federated learning across heterogeneous clients by collecting and aggregating the representations learnt on the local clients and pulling those representations closer. The framework is inspired by the same intuition that there is an underlying lower dimensional representation of the data that the clients can together uncover. To achieve this, we let each client train its own model but pull the representations learnt by different clients closer by adding a proximal term to the client’s loss function. The proximal term measures the distance between the representations learnt by different local models. Specifically, if is the representation matrix obtained on the client for data X, then the proximal term measuring the distances between client representations for the client can be written as 
where
The contribution of each client , could be set to reflect the capacity or the strength of the model on the client. At each iteration the server generates a set of unlabelled data instances called Representation Alignment Dataset (RAD) denoted by X and uses it to align the representations across clients. As before, the clients share the local weights with the server. Instead of aggregating the weights, the server uses each client’s weights to generate representations for the RAD and aggregates the representations over clients. The server then distributes the aggregated representations and the RAD to each client. The clients on the next iteration learn a model that minimizes the loss that is a combination of the task loss and distance to the aggregated representations. Specifically, at iteration each client is trying to minimize 
(2) 
This method helps us utilise the training from local clients but without explicitly generating a global aggregated model. The privacy of the clients data is also preserved as the clients do not need to share their data in any format. Thus the clients can keep on using the local personalised models but utilise the training from other peer clients. Later when we describe the distance function we will see that our method is robust to having output representations of clients in different spaces, i.e., ’s, , need not be aligned with each other.
iii.4 The Proximal Term
Here we give a detailed explanation about the proximal term or the distance function being used in our objective formulation in equations (1) and (2). As mentioned earlier, we consider a deep neural network model training to be performing two high level steps  learning to map the inputs of different modalities to a dimensional representation, and then learning a prediction function. We consider all but the last layer of the neural network as the representation learning module and its output as the learned representation and denote it by . So in order to match the representation learning part of the network across all clients we compare the outputs of the representation learning module of the networks on the same set of instances. In specific, we use a set of input examples encoded in matrix X aka RAD, pass them through all the local networks and capture the outputs of the representation learning module of the networks in matrices , where . stands for the activation (representation) matrix of the client, is the size of RAD and is the output dimension of the representation. Note that our method can work with representations of different dimensions on different clients, i.e., can vary with . At each communication round, we randomly select only a sample of the instances to be a part of the alignment dataset keeping . After obtaining the activation matrices ’s, we need a way to compare the representations. For this, we suggest using the representation distance matrices(RDMs). An RDM uses the distance between the instances to capture the characteristics of the representational space. Representational distance learning rdl has also been used for knowledge distillation in the past. To measure the distance, we use a distance metric proposed specifically for neural network representations called Centered Kernel Alignment(CKA) cka. CKA was originally proposed to compare the representations obtained from different neural networks to determine equivalence relationships between hidden layers of neural networks. CKA takes into input the activation(representation) matrices and outputs a similarity score between 1 (identical representations) and 0 (not similar at all). Some of the useful properties of CKA include invariance to invertible linear transformation, orthogonal transformation and isotropic scaling. CKA is also robust to training with different random initializations. Because of all these properties and the reason that CKA is able to learn similarities between layers of different widths, it extends itself naturally for use as a distance metric in our method. Let and be the representation matrices for clients and obtained for the RAD. The distance between and is obtained by first computing the kernel matrices and for any choice of kernel as follows.
and be the representational distance matrices then distance between and is given under CKA by 
where the estimator for HilbertSchmidt Independence Criterion (HSIC) could be written as 
with as the centering matrix
We try using a linear as well as an RBF kernel for computing the distances for our method. The Linear CKA can be simply written as 
(3) 
The local learning objectives for our method in homogeneous and heterogeneous settings thus respectively become
(4) 
where is the representation distance matrix obtained over the instances from the global model at previous iteration, and
(5) 
with
For , the objective function corresponding to homogeneous FedHeNN, in equation (4), becomes the FedAvg algorithm. And for the heterogeneous FedHeNN, the objective function given in equation (5) boils down to individual clients training their own local models in isolation using only the local data. On the other hand, when , the framework will try to obtain identical representations from all the local models without caring about the prediction task. In homogeneous FL setting for FedHeNN with linear CKA, we can say that, after sufficiently large number of iterations , the framework will make all the models identical.
Lemma III.2.
Given a homogeneous FL setting with clients training linear models and , then after sufficiently large number of iteration , for the FedHeNN framework with Linear CKA, we have for all .
Proof.
The optimization problem at each client is given by:
When , the problem becomes
For linear CKA, we have
If we assume that each client has a linear network, then . Then, after sufficiently large number of iterations , i.e., when each client has seen numerous RADs, optimizing equation (4) for will lead to
As a consequence, we reach the conclusion of the lemma. ∎
Iv Experiments
We now present the effectiveness of the FedHeNN framework using empirical results on different datasets and models. We simulate statistically heterogeneous and system heterogeneous FL settings by manipulating the data partitions and model architectures across clients respectively. We also discuss the effects of other variables like the size of RAD, and the choice of kernel on the performance of FedHeNN.
iv.1 Experimental Details
We evaluate FedHeNN in different settings involving varied models, tasks, heterogeneity levels, and datasets. We describe these settings below and then go on to presenting our results. Datasets We consider two different high level tasks, image classification and text classification, and use datasets corresponding to these from the popular federated learning benchmark LEAF caldas2019leaf
. For the image classification task, we use CIFAR10 and CIFAR100 datasets that contain colored images in 10 and 100 classes respectively. And for the text classification task we use a binary classification dataset called Sentiment140. We partition the entire data to generate noniid samples on each client and then split those into training and test sets at the client site.
Baselines We compare our method against three different baselines  FedAvg, FedProx and FedRep. The FedAvg and FedProx algorithms learn a centralised global model thus the reported performance metric is of the global model. On the contrary, the FedRep method learns personalised models for each client therefore FedRep’s performance is reported for the personalised models. For FedHeNN, we report the performance of local models for both homogeneous and heterogeneous settings and that of the global model for the homogeneous setting. For the FedProto algorithm, it is demonstrated in the paper that the performance gap between FedProto and FedAvg decreases when we have more samples per class. In our settings, because we do not restrict the number of samples per class and also beat the FedAvg algorithm by a significant margin, we do not directly compare with FedProto. Evaluation Metric and other ParametersWe use the average test accuracy obtained on the clients’ test datasets as the evaluation criterion. For global models the test accuracy is computed by evaluating the global model on the local test datasets and for the local models test accuracy is computed by testing personalised models on the local test datasets. The hyperparameter
that controls the contribution of representation similarity in the objective function is kept as a function of (the number of communication round). This is because the initial representations obtained from insufficiently trained models are not accurate and keeping a high in the initial rounds may mislead the training. The base value of is tuned as a hyperparameter and we find that the best performance is obtained by keeping for CIFAR10 and CIFAR100 and for Sentiment140 datasets. The size of RAD is an important parameter for our method. The performance reported in the paper is obtained by keeping this size constant at 5000 which is much smaller than the size of training or test datasets and doesn’t increase the memory footprint drastically. Implementation For the FL simulations, we keep the noniid data distribution across clients such that each client will have access to data of only certain classes, for example, with 5 classes per client, client might have access to data for classes and client might have . We also vary this number of classes to change the heterogeneity level across clients. The robustness and scalabality of our method is tested by increasing the number of clients participating in training from 100 to 500 for CIFAR10 dataset. The total number of communication rounds is kept constant at 200 for all algorithms and at each round only 10% of the clients are sampled and updated. We find that increasing the number of local epochs on clients doesn’t worsen the performance of the client models for our method, so the number of local epochs is set to 20. For the heterogeneous FedHeNN, each entry of the weight vector for aggregating the representations w is set to. In each local update, we use SGD with momentum for training. For the homogeneous FedHeNN for CIFAR datasets, we use a CNN model with 2 convolutional layers with each convolutional layer followed by a maxpooling layer followed by 3 fullyconnected layers at the end. For the heterogeneous FedHeNN, for each client we uniformly randomly sample from a set of 5 different CNNs obtained by varying the architecture size in between the simplest one that contains 1 convolutional and 1 fully connected layer and the most complex one containing 3 convolutional and 3 fully connected layers. For the Sentiment140 dataset, we use either a 1layer or a 2layer LSTM followed by 2 fully connected layers.
CKA For the CKA distance metric, we evaluate the performances by using a linear kernel as well as an RBF kernel. For the RBF kernel, we try various values for but as we will show later, the performance obtained using RBF kernel is not very different from that of the linear kernel.iv.2 Results
The performance of FedHeNN and baselines under various settings is reported in Table 2 and Table 3. The global model performances of the FedHeNN global model and related baselines is reported in Table 2. It can be observed that the FedHeNN global model outperforms the FedAvg and FedProx algorithms. The results for comparisons of the personalised models are reported in Table 3 and the results demonstrate that the FedHeNN’s performance is better than that of the baselines. We observe that the homogeneous FedHeNN has a higher gain over the baselines than the heterogeneous FedHeNN which is expected because of the varying capacity of the local models in the heterogeneous setting. Linear vs RBF Kernel For computing the CKA based distances, we try using both the Linear as well as RBF kernel. Based on the empirical analysis of FedHeNN, it is observed that both the linear and RBF kernels give comparable performances as shown in Table 1.
Dataset  Linear CKA  RBF CKA 

CIFAR10  94.47  93.03 
CIFAR100  84.37  83.03 
Sentiment140  72.6  72.8 
Data set(Setting)  FedAvg  FedProx  FedHeNN Global 

CIFAR10(100 clients, 2 cls/client)  44.29 0.5  53.8 2.3  68.8 2.1 
CIFAR10(100 clients, 5 cls/client)  58.14 0.7  63.3 2.0  70.19 2.0 
CIFAR10(500 clients, 2 cls/client)  42.7 0.4  50.46 1.4  65.4 0.8 
CIFAR10(500 clients, 5 cls/client)  56.8 0.5  55.2 1.2  64.7 0.7 
CIFAR100(100 clients, 20 cls/client)  28.6 0.8  27.3 1.1  44.2 0.7 
Sentiment140(100 clients, 2 cls/client)  52.6 0.4  52.7 1.0  52.7 0.01 
Data set(Setting)  FedRep  FedHeNN Homo  FedHeNN Hetero 

CIFAR10(100 clients, 2 cls/client)  85.7 0.4  94.7 1.1  88.9 0.35 
CIFAR10(100 clients, 5 cls/client)  72.4 1.2  84.37 1.5  73.01 0.3 
CIFAR10(500 clients, 2 cls/client)  78.9 0.6  86.5 0.9  82.02 0.8 
CIFAR10(500 clients, 5 cls/client)  58.14 0.21  73.32 1.23  61.74 0.6 
CIFAR100(100 clients, 20 cls/client)  38.85 0.9  62.89 0.8  43.36 0.2 
Sentiment140(100 clients, 2 cls/client)  69.8 0.4  72.6 0.3  71.5 0.5 
Effect of Local epochs We also analyse the effect of varying local epochs in FedHeNN. In FedAvg, increasing the number of local epochs has an adverse effect on the performance of the model. On the other hand, no such effect was observed for FedHeNN owing to the presence of the proximal term. We keep the number of local epochs for FedHeNN to be as high as 20.
Sensitivity to Changing Amount of Data We have shown through experiments that the FedHeNN framework is able to accommodate the clients with lower compute resources in an effective way. In order to show that the FedHeNN can also work with the clients with smaller data footprint, we do an experiment in which we randomly take a fraction of clients and reduce the data on those clients by 50%. We show the results of the experiment in Figure 3 where xaxis has the fraction of clients picked for shrinking the data and yaxis is the average test accuracy of all clients obtained when the framework is trained on the reduced dataset. We notice that even with the decreasing data size on the clients’ ends FedHeNN is able to maintain graceful performance. This effect could be attributed to the fact that even though the number of instances to train the prediction function is reduced, the representation learning component is still robust.
V Discussion
We present a new method for enhancing learning in Federated Learning by introducing a systematic framework called FedHeNN. The FedHeNN framework is unique because it allows the clients with heterogeneous architectures to participate in the joint learning process helping boost the performance. This could be a huge practical advancement as now the individual client devices or organisations with variable amount of resources can equally contribute and learn from each other. The empirical results indicate that FedHeNN is able to achieve better performing results while also being more inclusive. For future work, we would work on determining the solution characteristics and the convergence guarantees of both the FedHeNN algorithms.