When deploying AI models on edge devices (e.g., smart phone, self-driving cars, etc.), training the models in a distributed manner while maintaining user data privacy is a great challenge. Since direct access to user data is not desired, regular training techniques will not satisfy the model’s target performance. As such, federated learning (FL) was proposed in [mcmahan2017fedavg], to enable decentralized model training while maintaining local data privacy (restricting direct access to private user data in cloud server).
However, FL faces three key challenges: First, the imbalance/non-independent identically distributed (Non-IID) local data usually causes training failure in the decentralized environment. Second, frequent sharing of model weights between edge devices and the server incurs an excessive communication overhead. Lastly, the increasing demand for computing, memory and storage for AI models (e.g., deep neural networks – DNNs) makes it hard to deploy on resource-limited edge devices. This suggests that designing efficient FL models and deploying them effectively will be critical and important to achieve higher performance on future computing systems. Recent research in FL and its variants[karimireddy2020scaffold, wang2020fednova, li2020fedprox] mainly focuses on learning efficiency, i.e., improving training stability and using the minimum training rounds to get the target accuracy. However, these solutions still incur extra communication costs. As such, there is no superior solution to address the three key issues jointly. Furthermore, the above methods aim to learn a uniform shared model for all the heterogeneous clients, and there is no guarantee for the model’s performance on each Non-IID local data.
Deep learning models are generally over-parameterized, and easily over-fit during local FL updates. But since only a subset of salient parameters decide the final prediction outputs in neural networks, it is unnecessary to aggregate all the parameters. Additionally, we observed that typical deep learning models (e.g., CNNs) usually consist of an encoder part (embed the input instance) and a predictor head (predict based on the embedding). Transfer learning [torrey2010translearn] shows that a well-trained encoder can be easily transferred to non-IID datasets. We argue that there is no need to aggregate the whole model when using federated learning, instead just training the encoder, and transferring its knowledge to heterogeneous clients by deploying a local predictor on each client is sufficient. This not only reduces the communication overhead, but also improves the model’s local performance on the heterogeneous clients.
Based on these observations, we propose an efficient FL method through Salient Parameter Aggregation and Transfer Learning (SPATL). Specifically, we train the model’s encoder in a distributed manner through federated learning, and transfer its knowledge to each heterogeneous client via locally deployed predictor heads. Additionally, we implemented a local salient parameter selection agent using deep reinforcement learning to select the encoder’s salient parameters. We reduce communication overhead by only uploading the selected salient parameters to the aggregating server. Finally, we leverage the gradient control mechanism to correct the encoder’s gradient heterogeneity and guide the gradient towards a generic global direction among all clients. This further stabilizes the training process and speeds up the convergence.
Overall, the major contribution of this paper are as follows:
Reducing communication overhead in federated learning by introducing salient parameter selection and aggregation for over-parameterized models.
Addressing data heterogeneity in FL via knowledge transfer of the trained model to heterogeneous clients.
Accelerating the model’s local inference by salient parameter selection.
Continual online-learning of local salient parameter selection agent through federated learning process and deep reinforcement learning.
2 Related work
2.1 Federated Learning
With increasing concerns over user data privacy, federated learning was proposed in [mcmahan2017fedavg], to train a shared model in a distributed manner without direct access to private data. The algorithm FedAvg [mcmahan2017fedavg] is simple and quite robust in many practical settings. However, the local updates may lead to divergence due to heterogeneity in the network, as demonstrated in previous works [hsu2019measuring, karimireddy2020scaffold, li2020convergence_noniid]. To tackle these issues, numerous variants have been proposed [li2020fedprox, wang2020fednova, karimireddy2020scaffold]. For example, FedProx [li2020fedprox] adds a proximal term to the local loss, which helps to restrict deviations between the current local model and the global model. FedNova [wang2020fednova] introduces weight modification to avoid gradient biases by normalizing and scaling the local updates. SCAFFOLD [karimireddy2020scaffold]
corrects update direction by maintaining drift variates, which are used to estimate the overall update direction of the server model. Nevertheless, these variants incur extra communication overhead to maintain stable training. Particularly, in FedNova and SCAFFOLD, the average communication cost in each communication round is approximatecompared to FedAvg.
Numerous research papers have addressed data heterogeneity (i.e. non-IID data among local clients) in FL [zhao2018federated, hsieh2020quagmire, lim2020federated, zhang_fedufo, gong_ensemblefl, caldarola_graphfl, sun_soteria], such as improve client sampling fairness [nishio2019clientselection], adaptive optimization [zhang2021adaptive_noniid, han2020adaptive_FLquantization, reddi2021fedopt, yu2021adaptive], and correct the local updation [karimireddy2020scaffold, li_model-contrastive, wang2020fednova]. Also, federated learning had been extended in real life applications [liu_feddg, guo_multi-institutional]. One promising solution is to personalize the federated learning [dinh2021personalized_moreau, huang2021personalized_crosssilo, zhang2021personalized_modelopt, fallah2020personalized_meta, hanzely2020lowerbound_personazlied], which tries to learn personalized local models among clients to address the data heterogeneity. These works, however, fail to address high communication overhead.
2.2 Neural Network Pruning
As AI models are typically over-parameterized, only a subset of parameters decides the output. Network pruning is among the most widely used model compression techniques to slim and accelerate AI models. Pruning has achieved outstanding results and is a proven technique to drastically shrink the model size [gao_network_2021, liu_learnable_2021, wang_convolutional_2021]. Recently, AutoML pruning algorithms [li2020eagleeye, chin2020legr] offered better results with higher versatility, particularly the reinforcement learning (RL)-based methods [yu2021gnnrl, he2018amc, yu2021agmc], which model the neural network as graphs and use GNN-based reinforcement learning agent to search for pruning policy. Inspired by RL-based pruning method, in this paper, we proposed a salient parameter selection method to reduce the communication cost of federated learning,
SPATL consists of three main components including knowledge transfer learning, salient parameter selection agent, and gradient control federated learning. Figure 1 shows the overview of SPATL. Unlike mainstream FL solutions which attempt to train complete deep learning models distributively, SPATL only trains the encoder part of the model distributively and transfers the knowledge to heterogeneous clients. In each round of federated learning, the client first downloads the encoder from the cloud aggregator, and transfers its knowledge using a local predictor through local updates. Then before uploading back the encoder, the salient parameter selection agent will evaluate the training results of the current model based on the model performance (such as accuracy), and communicate the salient parameters of the encoder through deep reinforcement learning. Additionally, both clients and server maintain a gradient control variate to correct the heterogeneous gradient, which helps stabilize and smooth out the training process.
3.1 Heterogeneous Knowledge Transfer Learning Through Client Local Updates
Inspired by transfer learning [torrey2010translearn], SPATL aims to train an encoder through Federated Learning (FL) and addresses the FL heterogeneity issue through transferring the encoder’s knowledge to heterogeneous clients. Formally, we formulate our deep learning model as an encoder and a predictor , where and are encoder and predictor parameters respectively, and, and are input instances to the encoder and predictor.
SPATL shares the encoder with the cloud aggregator, while the predictor for the client is kept private on the client. The forward propagation of the model in the local client is formulated as follows:
During local updates of a communication round, the selected client first downloads the shared encoder parameter, , from the cloud server, and optimizes it with the local predictor head, , through back propagation. Equation 3 shows the optimization function.
Here, refers to the loss when fitting the label for data , and is the constant coefficient.
In federated learning, not all clients are involved in communication during each round. In fact, there is a possibility a client might never be selected for any communication round. Before deploying the trained encoder on such a client, the client will download the encoder from the aggregator and apply local updates to its local predictor. Equation 4 shows the optimization function.
3.2 GNN-based Salient Parameter Selection using Reinforcement Learning
One key issue of FL is the high communication overhead caused by frequently sharing parameters between clients and the cloud aggregator server. We observed that deep learning models (e.g., VGG [simonyan2015vgg] and ResNet [he2016resnet]) are usually over-parameterized. As such, only a sub-set of salient parameters decide the final output. Therefore, in order to reduce the communication cost, we implemented a local salient parameter selection agent for selecting salient parameters for communication. Figure 2 shows the idea of using reinforcement learning (RL) to search for the salient parameter policy. Specifically, inspired by topology-aware network pruning task [yu2021agmc, yu2021gnnrl], we model the neural network (NN) as a simplified computational graph, and use it to represent the NN’s states. Since NNs are essentially computational graphs, their parameters and operations correspond to nodes and edges of the computational graph. We then introduced a GNN-based RL agent, which takes the graph as input (RL’s environment state) and produces a parameter selection policy from the topology through GNN embedding. Additionally, the RL agent uses the selected sub-model’s accuracy as reward to guide its search for the optimal policy.
3.2.1 Reinforcement Learning Task Definition
Defining environment states, action space, reward function, and RL policy are essential for specifying an RL task. In this section, we will discuss these components in more detail. Algorithm 1 shows the RL search process. For the search step, we first initialize the target encoder with the input encoder , and convert it to a graph. If the size of does not satisfy the constraints, the proximal policy optimization (PPO) [schulman2017ppo] RL agent will produce a parameter selection policy (i.e. the action of the RL), to update . If satisfies the size constraint, the RL agent will use its accuracy as reward to update the policy. Finally, the parameter and corresponding parameter index of the target encoder with the best reward, will be uploaded to the cloud server.
Environment States. We use a simplified computational graph to represent the NN model [yu2021agmc]
. In a computational graph, nodes represent hidden features (feature maps), and edges represent primitive operations (such as ’add’,’minus’, and ’product’). Since the NN model involves billions of operations, it’s unrealistic to use primitive operations. Instead, we simplified the computational graph by replacing the primitive operations with machine learning operations (e.g., conv 3x3, Relu, etc.).
Action Space. The actions are the sparsity ratios for encoder’s hidden layers. The action space is defined as , where
is the number of encoder’s hidden layers. The actor network in the RL agent projects the NN’s computational graph to action vector, as shown in equations5 and 6.
is the environment state,
is the graph representation, and MLP is a multi-layer perceptron neural network. The graph encoder learns the topology embedding, and the MLP projects the embedding into hidden layers’ sparsity ratios.
Reward Function. The reward function is the accuracy of selected sub-network on validation dataset.
3.2.2 Update Policy
The RL agent is updated end-to-end through the Proximal Policy Optimization (PPO) algorithm. The RL agent trains on the local clients through continual online-learning over each FL round. Equation 8 shows the objective function we used for the PPO update policy.
Here, is the policy parameter (the actor-critic network’s parameter), denotes the empirical expectation over time steps,
is the ratio of the probability under the new and old policies, respectively,is the estimated advantage at time t, and is a clip hyper-parameter, usually 0.1 or 0.2.
3.3 Generic Parameter Gradient Controlled Federated Learning
Inspired by stochastic controlled averaging federated learning [karimireddy2020scaffold], we propose a generic parameter gradient controlled federated learning to correct the heterogeneous gradient. Due to client heterogeneity, local gradient update directions will move towards local optima and may diverge across all clients. To correct overall gradient divergence by estimating gradient update directions, we maintain control variates both on clients and the cloud aggregator. However, controlling the entire model’s gradients will hurt the local model’s performance on Non-IID data. To compensate for performance loss, SPATL only corrects the generic parameter’s gradients (i.e., the encoder’s gradients) while maintaining a heterogeneous predictor. Specifically in equation 9, during local updates of the encoder, we correct gradient drift by adding the estimate gradient difference .
Here, control variate is the estimate of the global gradient direction maintained on the server side, and is the estimate of the update direction for local heterogeneous data maintained on each client. In each round of communication, the is updated as equation 10
Here, E is the number of local epochs, andis the local learning rate is updated by equation 11
is the difference between new and old local control variates of client , is the set of clients, and is the set of selected clients. Algorithm 2 shows SPATL with gradient controlled FL. In each update round, the client downloads the global encoder’s parameter and update direction from server, and performs local updates. When updating the local encoder parameter , is applied to correct the gradient drift. The predictor head’s gradient remains heterogeneous. Before uploading, the local control variate is be updated by estimating the gradient drift.
3.3.1 Aggregation with Salient Parameters
Due to the Non-IID local training data in heterogeneous clients, salient parameter selection policy varies among the heterogeneous clients after local updates. Since the selected salient parameters have different matrix sizes and/or dimensions, directly aggregating them will cause a matrix dimension mismatch. To prevent this, as Figure 3 shows, we only aggregate partial parameters according to the current client’s salient parameter index on the server side. Equation 12 shows the mathematical representation of this process.
Here, is the global parameter, is the client’s salient parameter, is the ’s index corresponding to the original weights, and is the update step size. By only aggregating the salient parameter and its corresponding index (negligible burdens), we can significantly reduce the communication overhead and avoid matrix dimension mismatches.
We conducted several experiments to examine SPATL’s performance. Overall, we divided our experiments into three categories: learning efficiency, communication cost, and inference acceleration. We also performed an ablation study and compared SPATL with state-of-the-art FL algorithms.
4.1 Implementation and Hyper-parameter Setting
Datasets and Models. The experiments were conducted with FEMNIST [caldas2019leaf]
and CIFAR-10[krizhevsky2009cifar]. In FEMNIST, we follow the LEAF benchmark Non-IID settings [li2021federated]. In CIFAR-10, each client is allocated a proportion of the samples of each label according to Dirichlet distribution (with concentration ). Specifically, we sample and allocate a proportion of the instances to client . Here we choose the . The deep learning models we use in the experiment are VGG-11 [simonyan2015vgg] and ResNet-20/32 [he2016resnet].
Federated Learning Setting. Unlike common FL experiment settings that evaluate the model on the server side after each round of communication, in SPATL, the models in each client are different. Thus, we evaluate the average performance of models in heterogeneous clients. We set a total of 10 clients, and a sample ratio of 1 in each communication round. During local updates, each client updates 10 rounds locally. The learning rate is 0.01 in VGG-11 and 0.005 in ResNet-32/20. The control variates and
have been initialized as a zero matrix. When training on CIFAR-10, each client was allocated 6k images (5K images local training set and 1k image for local test set).
RL Agent Settings. The RL agent updates the policy every search steps and updates 20 epochs in each updating round. The discount factor is , the clip parameter is
, and the standard deviation of actions is. Adam optimizer is applied to update the RL agent, where the learning rate is and the .
Baseline. We compare SPATL with the state-of-the-art FL algorithms, such as FedNova [wang2020fednova], FedAvg [mcmahan2017fedavg], FedProx [li2020fedprox], and SCAFFOLD [karimireddy2020scaffold].
4.2 Learning Efficiency
In this section, we aim to evaluate the learning efficiency of SPATL by investigating the relationship between communication rounds and the average Top-1 accuracy of the model. Since SPATL learns a shared encoder, each local client has a heterogeneous predictor, and the model’s performance is different among clients. Instead of evaluating a global test accuracy on the server side, we allocate each client with a local Non-IID training dataset and a validation dataset and evaluate the average Top-1 accuracy of the model among heterogeneous clients. We train the VGG-11 [simonyan2015vgg] and ResNet-20/32 [he2016resnet] on CIFAR-10 [krizhevsky2009cifar], and 2-layer CNN on FEMNIST [caldas2019leaf] separately with 100 rounds of communication and compare the results with SoTA (i.e., FedNova [wang2020fednova], FedAvg [mcmahan2017fedavg], FedProx [li2020fedprox], and SCAFFOLD [karimireddy2020scaffold]).
Figure 4 shows the results. On CIFAR-10, in all ResNet-20, ResNet-32, and VGG-11, SPTAL outperformed the state-of-the-art (SoTA) FL methods, where the accuracies achieved by SPATL are 84.27%, 83.67%, and 84.07%. Compared to the SCAFFOLD, although the SPATL did not produce the desired results in the warm-up phase of training (the first 30 rounds), it catches up after 60 rounds of communication. The learning curves of SPATL are stable and smooth, and SPATL generates a higher accuracy than the SCAFFOLD after training is complete. SPATL outperforms FedAvg, FedNova, and FedProx with a large margin, a significantly higher final accuracy and a much more stable training process. However, the result on 2-layer CNN is an exception, the model trained by SPATL has a slightly lower accuracy than SoTA. We suspect it has to do with the size of the 2-layer CNN. The assumption of over-parameterized model no longer holds, thus the salient parameter selection makes it hard to fit the training data.
We further evaluate the performance of knowledge transfer learning on heterogeneous clients by evaluating each local client’s model accuracy. Figure 5 shows ResNet-20 on each client’s accuracy after the training is complete (total 10 clients trained by SPATL and SCAFFOLD with 100 rounds). SPATL produces better performance on all heterogeneous clients. Since SPATL uses heterogeneous predictors to transfer the encoder’s knowledge, the model is much more robust when dealing with the Non-IID data. However, our baseline methods (such as SCAFFOLD) share the entire model when training among Non-IID clients and the model’s performance may vary on Non-IID clients causing poor performance on some clients.
4.3 Communication Cost
|Method||Model||R||Avg. cost/ (R*C)||Total|
Although the FedNova [wang2020fednova] and SCAFFOLD [karimireddy2020scaffold] have a stable training process by maintaining gradient control or gradient normalization variates, their average communication cost per round is double compared to FedAvg. One key contribution of SPATL is that it can significantly reduce the communication cost through salient parameter aggregation. In each round of communication, SPATL only communicates the selected salient parameters while baseline models communicate the entire model. To compare communication overheads, we trained all models to a target accuracy of 80% and calculated the communication bandwidth consumption through the training process. We trained ResNet-20/32 [he2016resnet] and VGG-11 [simonyan2015vgg] on CIFAR-10 [krizhevsky2009cifar] with 10 clients. Table 1 shows the detailed information. The communication cost is calculated as:
Results show SPATL outperforms baseline models. In ResNet-20, SPATL reduced communication costs by up to . (FedNova costs 8.32GB while SPATL only costs 1.1GB). Moreover, SPATL reduced communication by up to 108GB when training VGG-11 compared to FedNova.
4.4 Inference Acceleration
|Model||Avg. sparsity ratio||Avg. FLOPs||Highest FLOPs|
We further evaluate the inference acceleration of SPATL. The salient parameter selection agent selects salient parameters of the client’s model. When doing inference in the local clients, we forward the input instance to the selected salient sub-model, accelerating the inference time and reducing computational consumption. Table 2 shows the inference acceleration status after training is finished. SPATL notably reduced the FLOPs (floating point operations per second) in all the evaluated models. For instance, in ResNet-32, the average FLOPs reduction among 10 clients is , and the client with the highest FLOPs reduction achieves fewer FLOPs than the original model, while the client models have a relatively low sparsity ratio (the sparsity ratio represents the ratio of salient parameters compared to the entire model parameters).
4.5 Transferbility of Learned Model
To verify whether SPATL can successfully train a transferable encoder, we train an encoder through SPATL and transfer it to new data. We split the CIFAR-10 [krizhevsky2009cifar] into two separate datasets, one with 50K images and another with 10K images. We train VGG-11 for 100 rounds on 50K images with 10 clients (where each client has 4k image local training data and 1k validation set). SPATL trained encoder achieves average validation accuracy among clients. After transferring this encoder to the 10K images (9K training data and 1k validation data) with a predictor for 50 epochs with batch size 64, it achieves validation accuracy. This shows SPATL can successfully learn the encoder that can be transferred to another heterogeneous dataset.
4.6 Ablation Study
4.6.1 Impact of Gradient Control
This section investigates the impact of gradient control mechanics. We maintain control variates in the local and cloud, which help correct the local update directions and guide the encoder’s local gradient towards the global optimum. Figure 6 shows the results of SPATL with and without gradient control. We train VGG-11 [simonyan2015vgg] on CIFAR-10 [krizhevsky2009cifar] with 10 clients. The results show gradient control remarkably improves the training performance of SPATL. Since SPATL only aggregates salient parameters, simply averaging the salient parameters without gradient correction is hard to optimize the target model.
4.6.2 Burdens of Reinforcement Learning
One concern of SPATL could be that its reinforcement learning process may add additional computational burdens on local clients. In this section, we investigate the feasibility and applicability of the salient parameter selection reinforcement learning agent. We perform the salient parameter selection task on pre-trained ResNet-18 and ResNet-56 separately and record the average reward of the RL agent. As shown in Figure 7, the salient parameter selection agent begins converging around round 40 of RL policy updating, and the agent can yield a parameter selection policy with the desired reward. The 40 rounds of RL policy updating is within acceptable resources and computational budgets in local clients (this process requires 45 minutes of a single GPU usage). Furthermore, the RL task has a limited search space, which is far less resource-hungry than popular RL tasks (such as gaming, self-driving, etc.). Moreover, in the FL task, the RL agent is not trained from the start in every round of communication. Rather, it is trained continuously via online-learning. Since the RL agent will inherit policy made in previous rounds, the desired policy will be produced much quicker.
The proposed approach may have poor performance on simple models. As Figure 4 shows, our approach works well on over-parameterized neural networks, such as ResNet [he2016resnet] and VGG [simonyan2015vgg] net. However, when it turns to a less-parameterized model, such as 2-layer CNN, the salient parameter selection may degrade its performance, making the model converge slower than the baselines. Generally, the less-parameterized models are rarely being used in real-world applications.
In this paper, we proposed a federated learning method called SPATL for efficient federated learning using salient parameter aggregation and transfer learning. To address the data heterogeneity in federated learning, we introduced a knowledge transfer local predictor that transfers the shared encoder to each client. We further introduced a salient parameter selection agent that selects the salient parameter of the over-parameterized model and communicates with the server. The proposed method notably decreases the communication overhead. We also leveraged the gradient control mechanism to stabilize the training process and make it robust. Our experiments show SPATL has a stable training process and achieves promising results. Moreover, SPATL significantly reduces the communication cost and accelerates the inference time.