Federated Two-stage Learning with Sign-based Voting

12/10/2021
by   Zichen Ma, et al.
0

Federated learning is a distributed machine learning mechanism where local devices collaboratively train a shared global model under the orchestration of a central server, while keeping all private data decentralized. In the system, model parameters and its updates are transmitted instead of raw data, and thus the communication bottleneck has become a key challenge. Besides, recent larger and deeper machine learning models also pose more difficulties in deploying them in a federated environment. In this paper, we design a federated two-stage learning framework that augments prototypical federated learning with a cut layer on devices and uses sign-based stochastic gradient descent with the majority vote method on model updates. Cut layer on devices learns informative and low-dimension representations of raw data locally, which helps reduce global model parameters and prevents data leakage. Sign-based SGD with the majority vote method for model updates also helps alleviate communication limitations. Empirically, we show that our system is an efficient and privacy preserving federated learning scheme and suits for general application scenarios.

READ FULL TEXT VIEW PDF

Authors

page 7

04/14/2021

Privacy-preserving Federated Learning based on Multi-key Homomorphic Encryption

With the advance of machine learning and the internet of things (IoT), s...
01/06/2020

Think Locally, Act Globally: Federated Learning with Local and Global Representations

Federated learning is an emerging research paradigm to train models on p...
05/05/2020

Information-Theoretic Bounds on the Generalization Error and Privacy Leakage in Federated Learning

Machine learning algorithms operating on mobile networks can be characte...
02/23/2020

Practical and Bilateral Privacy-preserving Federated Learning

Federated learning, as an emerging distributed training model of neural ...
10/06/2021

Federated Learning via Plurality Vote

Federated learning allows collaborative workers to solve a machine learn...
08/05/2021

Decentralized Federated Learning with Unreliable Communications

Decentralized federated learning, inherited from decentralized learning,...
03/18/2020

Gradient Estimation for Federated Learning over Massive MIMO Communication Systems

Federated learning is a communication-efficient and privacy-preserving s...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Federated learning is a distributed machine learning mechanism where local institutions or devices collaboratively train a shared global model under the orchestration of a central server, while keeping all the sensitive private data decentralized kairouz2019advances . Challenges in federated learning include an unbalanced and non-IID (identically and independently distributed) data allocation on an enormous number of devices moreno2012unifying and limited communication bandwidth zhang2013information .

The recent deeper and larger machine learning models devlin2018bert violate the limitation of communication channels because traditional federated learning trains a shared global model via communicating parameters and its updates to each device konevcny2016federated . This needs a new paradigm other than the prototypical federated learning framework. In this paper, we design a federated two-stage learning framework that augments prototypical federated learning with a cut layer on devices and uses sign-based stochastic gradient descent with the majority vote method on model updates. Devices with a cut layer split the execution of a model on a per-layer basis, which can help learn informative and compact representations of raw data (smashed data) locally. The global model is then trained using the SIGNSGD algorithm riedmiller1993direct based on the low-dimension smashed data from devices.

Compared with the existing approaches, our proposed model has several advantages. First, the proposed federated learning scheme is highly efficient. By splitting the execution of a model between devices and the server, local devices learn informative and low-dimension representations of raw data. The global model then needs fewer parameters because of these compact inputs and thus reduces the required communicated intermediaries. Besides, SIGNSGD with the majority vote bernstein2018signsgd has been proved that can alleviate the communication bottleneck by transmitting just the sign of each minibatch stochastic gradients while preserving competitive results on publicly available real-world datasets.

Second, our designed framework suits for general application scenarios. One assumption in traditional federated learning is that all the data is from the same source (i.e., text hard2018federated or images), which may not be realistic in real-world scenarios. Note that devices may contain multiple sources of data such as videos, images, or text. A single global model may not be able to handle all of them. Besides, it is not likely to infer using the trained model when new devices contain different data modalities that have not been observed during the training process. The proposed model, on the contrary, can handle data depending on their sources and distributions. Our model is capable of handling different sources of data that have not been observed before and thus suits for general applications.

In addition, the proposed scheme is a privacy preserving federated learning architecture. Intermediaries communicated between devices and the server in the traditional federated learning may still leak some information about raw data. This leakage does harm to the constraint that raw data must remain private. Our proposed federated two-stage learning mechanism, however, reduces invertibility of intermediate representations by minimizing distance correlation szekely2007measuring between the smashed data and the raw data while still ensuring model’s prediction accuracy. Experiments show that our model can yield superior or comparable results with the state-of-the-art methods with less leakage of sensitive information.

2 Related Work

Figure 1: An overview of the proposed federated two-stage learning with sign-based voting system

Federated Learning

Federated learning enables multiple parties to collaboratively build a machine learning model while keeping their data private yang2019federated . Challenges in federated learning include unbalanced and non-IID data partition, unreliable devices and limited communication bandwidth. The most popular optimization method in federated learning is the Federated Averaging algorithm mcmahan2016communication which aggregates local models in a manner of weighted average. Li et al. li2018federated added a proximal term to the objective to improve the stability of algorithm. Fully decentralized learning tries to alleviate trust concerns about server vanhaesebrouck2017decentralized . Caldas et al. caldas2018expanding proposed an algorithm that improves the efficiency of federated learning.

Compressed Optimization

Training large models in distributed learning setting requires high communication cost liu2010distributed and thus many compressed optimization methods are proposed to alleviate this challenge. SIGNSGD tries to transmit the sign of each minibatch stochastic gradients instead of gradients themselves. It has been proved that SIGNSGD performs better than traditional SGD in efficiency bernstein2018signsgd .

3 Federated Two-stage Learning

Prototypical federated learning builds a global model on the central server and transmits model parameters to local devices. Then each local device computes gradient updates based on its own subset of the high-dimension raw data. After that, model parameters are updated using gradients. Then these updated parameters are transmitted back to the server where all updates are aggregated in a manner of weighted average. Finally, the global model is updated and a new round is started until the model converges yang2019federated . This process uses high-dimension raw data as the global model’s local input, which requires many intermediaries (e.g. model parameters and its updates) to be transmitted in multiple rounds, which poses high communication cost in the system.

Our proposed federated two-stage learning, on the contrary, divides the execution of a model on a per-layer basis between devices and the server. The designed method can be adopted for both training and inference. An overview of our system is shown in Figure1. Devices try to learn informative and compact representations of the raw data locally. Then the shared global model will directly use these representations as input rather than the raw data.

The designed system is divided into two complementary parts: local learning and global aggregation. The first process aims to extract representative and low-dimension features from the raw data that are important for training the model. Then the global model can operate on these compact features and thus reduce its parameters.

To illustrate how the system works, first we introduce local operations in the federated two-stage learning with sign-based voting. Then we demonstrate how to reduce global model parameters and eliminate potential data leakage. After showing the global model aggregation process and providing the pseudo-code of the optimization algorithm, we describe how to infer using the trained model in system.

3.1 Local Split Learning

For each device with a subset of data , it learns a representative and low-dimension feature , which is also known as smashed data. There are many discussions on the choice of smashed data sharma2019expertmatcher , but in general the following properties should be satisfied: (i) it must be representative that captures the important features in the raw data , (ii) it should be compact enough comparing with , (iii) it cannot be inverted back to easily because preserving data privacy is a key concern in federated learning.

In the paper, we find smashed data by minimizing the logarithm of distance correlation (DCOR) between the smashed data and the raw data. We show that minimizing DCOR minimizes their Kullback-Leibler divergence, which is a measure of invertibility of the smashed data in information theory. For simplicity, we use distance covariance (DCOV) which is an unnormalized DCOR

vepakomma2019reducing , Kullback-Leibler divergence and cross entropy to build the connection.

From Vepakomma et al. vepakomma2018supervised , the sample DCOV can be derived from covariance matrices , .

(1)

According to arithmetic geometric mean inequality

bhatia1993more , we have,

(2)

In equation (2), is the cross-covariance matrix and is the cross entropy . The KL divergence relates to cross entropy as

(3)

Therefore we combine equations (1)(2)(3) and build the connection between DCOR and KL divergence. Moreover, we prove that minimizing DCOR also minimizes the invertibility of the smashed data. So we have the first part of local loss,

(4)

The first loss , however, just satisfies one of properties proposed above. The smashed data should also be informative which means it extracts important features in and thus we have local supplementary loss which measures prediction ability of the smashed data,

(5)

The equations (4) and (5

) above allow us to find smashed data from local devices efficiently. Given suitable hyper-parameters, we can train the local model using sign-based stochastic gradient descent and the local training objective optimizes model parameters with respect to the local loss function,

(6)

where the system contains devices. and are hyper-parameters that balance the leakage and precision of the smashed data.

3.2 Global Model Aggregation

Traditional federated learning server directly exchanges global model parameters and its updates with local devices, which depletes communication bandwidth badly because of high-dimension raw data. Key differences between prototypical federated learning and our proposed mechanism include (i) Our global model now operates on the learnt representative and low-dimension smashed data. Therefore, the global model in federated two-stage learning can be much smaller than the traditional federated learning when training on the same objective. Moreover, the number of communicated intermediaries reduces that helps our system be more efficient than prototypical federated learning. (ii) Our sign-based optimization scheme just exchanges the sign of each minibatch stochastic gradients in the system. This method compresses client-server communication and still maintains competitive results as other methods.

Server executes:

initialize , sign
for each round = 1,2, … do
     n
      (random set of n clients)
     for each client k  do
         sign() ClientUpdate(k)
     end for
     return sign to clients
end for

ClientUpdate(k):

split local data into batches of size B
for batch b  do
      sign
     sign() =
end for
return sign() to the server
Algorithm 1 Federated Two-stage Learning with Sign-based Voting. The m devices are indexed by k; B is the local minibatch size, C is the participation level and is the learning rate.

In the paper, the global model trains iteratively. The server sends the global model parameters and initial sign to each device at the beginning of the first iteration. Each device then runs their own local model and the transmitted global model to predict labels. After that, losses and gradients can be computed between predicted label and true label on their own subset of data. Given a proper loss function, we have the global loss,

(7)

And the total loss is a combination of local loss and global loss,

(8)

where , and are hyper-parameters that need fine tuning. and control the data leakage and prediction ability of the smashed data and balances the updates to local model. In the overall loss function, global objective acts as a regularization term to synchronize the local smashed data on devices. In addition, it punishes the local objective in case of overfitting.

Method # Inter./Device Total Inter.
Federated Two-stage
FedAvg & FedProx
Table 1: Theoretical efficiency comparison between our method and other federated learning algorithms. The proposed Federated Two-stage exchanges sign of gradients in the system rather than model parameters.
Dataset Method Test Accuracy # Comm. Rounds Total Comm.
Federated Two-stage 437
MNIST FedAvg 689
FedProx 651
Federated Two-stage 191
CIFAR-10 FedAvg 279
FedProx 253
Federated Two-stage 744
Shakespeare FedAvg 936
FedProx 962
Table 2: Efficiency comparison between Federated Two-stage, FedAvg and FedProx using MNIST (non-IID), CIFAR-10 (IID) and Shakespeare(non-IID). Our method maintains competitive performance but needs less (10) communication intermediaries. we report average value of 50 runs

The most common approach to optimizing federated learning is the Federated Averaging algorithm mcmahan2016communication . Each device runs some SGD steps locally and then sends the updated local model back to the server. The coordinating server aggregates the local model in a manner of weighted average and finally transmits the updated global model to devices. However, this algorithm communicates model parameters in 32-bit floating point number that consumes lots of resources and is not an efficient method. In our system, we use 1-bit compressed gradient (sign of the gradient) which reduces communication cost a lot. In an iteration, gradients can be computed based on the loss between the predicted label and the true label. Then devices send the sign of its local gradients back to the server. After receiving the sign of gradients from local devices, the server aggregates these gradients by using the majority vote scheme and pushes 1-bit decision back to every devices. Finally, each device updates its model of the form

(9)

where is the model parameters, is the gradients of ’s device and is the learning rate. Complete pseudo-code is shown in Algorithm 1.

3.3 New Devices Inference

Prototypical federated learning mechanism often uses Federated Averaging algorithm that is described above. Since it only trains a global model, the new device with directly makes prediction using the trained model. Comparatively, our proposed method divides the original global model into two parts, so a new device with

needs to be processed first in the local model to learn the informative representations of its raw data. Then the low-dimension smashed data are passed to the trained global model to predict. In the paper, we average all trained local models’ logits (the vector of non-normalized predictions) when dealing with new devices. Because new devices may contain unknown data sources and distributions, we cannot simply pass them to any local model. The ensemble approach we proposed shows competitive results in the experiment compared with Federated Averaging algorithm.

Figure 2: Performance of Federated Two-stage, FedProx and FedAvg algorithm on MNIST, CIFAR-10 and Shakespeare dataset. Federated Two-stage maintains competitive results but uses less communication rounds.
Figure 3: Amount of communicated intermediaries (GB). Federated Two-stage reduce the amount of communicated intermediaries because it only exchanges sign of gradients (1-bit) rather than model parameters (32-bit floating number) and has a smaller global model.

4 Experiments

In this section, we first compare the communication efficiency of federated two-stage learning, Federated Averaging and FedProx algorithm li2018federated . FedProx algorithm is inspired by FedAvg and aims to tackle heterogeneity in federated networks. It adds a proximal term to the objective that helps improve the stability of the model. The key insight behind FedProx is that it punishes more on the differences between the global and the current model on devices.

Then we design three experiments to evaluate our proposed method. (i) we show that federated two-stage learning with sign-based voting can reduce the communicated intermediaries a lot on MNIST lecun1998gradient , CIFAR-10 krizhevsky2009learning and Shakespeare mcmahan2016communication datasets. By using these datasets, we can compare the efficiency of our method with FedAvg and FedProx algorithm based on their papers’ baseline model setting. (ii) we evaluate the universality of our method by using data from different distributions (rotated data in CIFAR-10). (iii)we show that the designed method is able to block critical information for reconstructing the raw data and prevents data leakage.

In Table 1, is the gradients size, is the number of devices, and stand for the number of model parameters and used rounds in proposed method, and refer to the number of model parameters and used rounds in FedAvg and FedProx algorithm. We show the required number of intermediaries that communicated per device as well as the total number of them. In an iteration, sign of gradients is communicated in our system rather than the model parameters in FedAvg and FedProx.

4.1 Efficiency

In the first experiment, we use MNIST, CIFAR-10 and Shakespeare datasets. For MNIST, we partition the dataset in a manner of Non-IID where we first sort the data by label, then divide it into 200 shards of size 300 and give 2 shards to each device. Finally, each device has at most 2 classes, which means the data is in highly non-IID setting. Our local model and hyper-parameters choice just follows the Federated Averaging algorithm paper mcmahan2016communication . We use 2 convolution layers as our global model. In addition, we average all trained local models’ logits when dealing with new devices prediction in testing process. For CIFAR-10, we partition it into 100 devices and each one contains 500 training and 100 testing examples. This sub-experiment is in balanced and IID setting. We choose a 5-layer network that contains 2 convolution layers with 2 fully connection layers and a layer to produce logits. The global model used still contains 2 convolution layers and testing procedure just follows the former one by averaging all the trained local models’ logits when new devices pass their data into the system. Shakespeare dataset is substantially unbalanced because many roles have only a few lines. We filter out the clients with data points less than 10k and get 132 clients in total. Then we partition 80% of data to training set and amalgamate the left 20% on a device as test set. The model for this dataset contains 2 LSTM layers which have 256 nodes and we want to predict the next character.

Table 2 shows results of efficiency comparison. From 50 runs’ averaged results in this experiment, our proposed federated two-stage learning with sign-based voting method maintains similar performance but communicates less(10) intermediaries. Our method is more efficient comparing with traditional federated learning algorithms. Key ideas in federated two-stage include (i)We use smashed data which is representative and low-dimension rather than the raw data as input of the global model. (ii)Sign-based SGD with the majority vote also helps reduce the communicated intermediaries.

In Figure 2, we evaluate the performance of Federated Two-stage, FedProx and FedAvg on MNIST, CIFAR-10 and Shakespeare dataset. Our proposed method performs similar results comparing with FedProx and FedAvg but uses less communication rounds. Figure 3 shows the amount of communicated intermediaries using Federated Two-stage, FedProx and FedAvg algorithm. Federated two-stage algorithm exchanges the smallest amount of intermediaries in the system because it only transmits the sign of gradients (1-bit) rather than the model parameters (32-bit floating number) and has a smaller global model comparing with other methods.

Figure 4: Performance comparison between Federated Two-stage, FedAvg and FedProx on the rotated CIFAR-10 dataset. When dealing with a new device with different data distribution, Federated Two-stage performs better than traditional federated learning algorithms
Figure 5: Reconstruction results of smashed data in Federated Two-stage. increases (0.1, 0.3, 0.5, 0.7 and 0.9) from left to right. It is hard to reconstruct from smashed data as increasing because smaller distance correlation blocks more critical information required by reconstruction.
Figure 6: Reduction comparison of distance correlation using Federated Two-stage, FedProx and FedAvg algorithm. Our method has smaller distance correlation comparing with other algorithms and protects the system from data leakage.

4.2 Universality

In this experiment, we demonstrate that the proposed method can deal with data from new sources or distributions. We spilt the CIFAR-10 dataset into 100 devices. Each device contains 500 training and 100 testing samples. In addition, we introduce a new device with rotated CIFAR-10 examples which contains 1000 training and 200 testing data. The data distribution in this new device is different from the original training dataset. The model architecture contains two convolution layers with two fully connection layers and a transformation layer. Still, the global model is a two convolution layer network. Other hyper-parameters setting follows the Federated Averaging paper.

We first train the model on 100 devices using proposed method, FedAvg and FedProx algorithm. Then the new device is introduced and trains until model converge. Figure 4 shows the accuracy of Federated Two-stage, FedAvg and FedProx algorithm. Our proposed method outperforms the traditional model because federated two-stage learning helps model learn representative patterns in different data distributions by using smashed data.

4.3 Privacy Preserving

We design this experiment to evaluate the potential data leakage risk in our federated two-stage learning with sign-based voting framework. In the third section we discussed potential data leakage in communicated intermediaries and analyzed this risk theoretically. We prove the connection between distance correlation and Kullback-Leibler divergence. Minimizing the logarithm of distance correlation between the smashed data and the raw data, we minimize Kullback-Leibler divergence which is a measure of invertibility of the smashed data. In this experiment, we show that our proposed architecture reduces this data leakage risk a lot on MNIST dataset. The local model used in the experiment contains a CNN with two convolution layers, a fully connection layer and a transformation layer. The global model still contains two convolution layers. All hyper-parameters follow former experiments. Figure 6 shows the reduction in distance correlation between the smashed data and the raw data using Federated Two-stage, FedProx and FedAvg algorithm. Because FedProx and FedAvg do not contain the cut layer, we use representations generated by the first convolution layer of model. Federated two-stage has smaller distance correlation comparing with other algorithms in training and protects the system from data leakage better. Even in the first few rounds, the distance correlation is below 0.55 which means the data leakage risk is relatively small.

In the next sub-experiment, we show reconstruction results from the smashed data in our system. Smashed data is informative and low-dimension representations of the raw data, so decoder can reconstruct the smashed data back to the raw data. From equation (6), we can block the critical information that reconstruction needs using larger .

Figure 5 shows reconstruction results from the smashed data on MNIST dataset in different . It demonstrates that larger blocks more reconstruction information and builds a more secure model. However, from equation (8), all of three hyper-parameters , and control model together which means larger may not lead to a better model.

5 Conclusion

In this paper, we propose a federated two-stage learning with sign-based voting framework that augments prototypical federated learning with a cut layer on local devices and uses sign-based stochastic gradient descent with the majority vote method on model updates. Our method maintains competitive performance while reducing communicated intermediaries, showing great universality across different data sources and protecting model from data leakage. Both empirical and theoretical analysis show that the novel mechanism offers an efficient and privacy preserving scheme which suits for general applications.

Broader Impact

While federated two-stage learning offers a number of potential benefits, providing stronger guarantees in efficiency and privacy is still an emerging direction for future researchers. We hope our insights can inspire future work on federated learning research, which is one of the most important solutions for data privacy concerns in the society.

References