1 Introduction
Federated learning is a distributed machine learning mechanism where local institutions or devices collaboratively train a shared global model under the orchestration of a central server, while keeping all the sensitive private data decentralized kairouz2019advances . Challenges in federated learning include an unbalanced and non-IID (identically and independently distributed) data allocation on an enormous number of devices moreno2012unifying and limited communication bandwidth zhang2013information .
The recent deeper and larger machine learning models devlin2018bert violate the limitation of communication channels because traditional federated learning trains a shared global model via communicating parameters and its updates to each device konevcny2016federated . This needs a new paradigm other than the prototypical federated learning framework. In this paper, we design a federated two-stage learning framework that augments prototypical federated learning with a cut layer on devices and uses sign-based stochastic gradient descent with the majority vote method on model updates. Devices with a cut layer split the execution of a model on a per-layer basis, which can help learn informative and compact representations of raw data (smashed data) locally. The global model is then trained using the SIGNSGD algorithm riedmiller1993direct based on the low-dimension smashed data from devices.
Compared with the existing approaches, our proposed model has several advantages. First, the proposed federated learning scheme is highly efficient. By splitting the execution of a model between devices and the server, local devices learn informative and low-dimension representations of raw data. The global model then needs fewer parameters because of these compact inputs and thus reduces the required communicated intermediaries. Besides, SIGNSGD with the majority vote bernstein2018signsgd has been proved that can alleviate the communication bottleneck by transmitting just the sign of each minibatch stochastic gradients while preserving competitive results on publicly available real-world datasets.
Second, our designed framework suits for general application scenarios. One assumption in traditional federated learning is that all the data is from the same source (i.e., text hard2018federated or images), which may not be realistic in real-world scenarios. Note that devices may contain multiple sources of data such as videos, images, or text. A single global model may not be able to handle all of them. Besides, it is not likely to infer using the trained model when new devices contain different data modalities that have not been observed during the training process. The proposed model, on the contrary, can handle data depending on their sources and distributions. Our model is capable of handling different sources of data that have not been observed before and thus suits for general applications.
In addition, the proposed scheme is a privacy preserving federated learning architecture. Intermediaries communicated between devices and the server in the traditional federated learning may still leak some information about raw data. This leakage does harm to the constraint that raw data must remain private. Our proposed federated two-stage learning mechanism, however, reduces invertibility of intermediate representations by minimizing distance correlation szekely2007measuring between the smashed data and the raw data while still ensuring model’s prediction accuracy. Experiments show that our model can yield superior or comparable results with the state-of-the-art methods with less leakage of sensitive information.
2 Related Work

Federated Learning
Federated learning enables multiple parties to collaboratively build a machine learning model while keeping their data private yang2019federated . Challenges in federated learning include unbalanced and non-IID data partition, unreliable devices and limited communication bandwidth. The most popular optimization method in federated learning is the Federated Averaging algorithm mcmahan2016communication which aggregates local models in a manner of weighted average. Li et al. li2018federated added a proximal term to the objective to improve the stability of algorithm. Fully decentralized learning tries to alleviate trust concerns about server vanhaesebrouck2017decentralized . Caldas et al. caldas2018expanding proposed an algorithm that improves the efficiency of federated learning.
Compressed Optimization
Training large models in distributed learning setting requires high communication cost liu2010distributed and thus many compressed optimization methods are proposed to alleviate this challenge. SIGNSGD tries to transmit the sign of each minibatch stochastic gradients instead of gradients themselves. It has been proved that SIGNSGD performs better than traditional SGD in efficiency bernstein2018signsgd .
3 Federated Two-stage Learning
Prototypical federated learning builds a global model on the central server and transmits model parameters to local devices. Then each local device computes gradient updates based on its own subset of the high-dimension raw data. After that, model parameters are updated using gradients. Then these updated parameters are transmitted back to the server where all updates are aggregated in a manner of weighted average. Finally, the global model is updated and a new round is started until the model converges yang2019federated . This process uses high-dimension raw data as the global model’s local input, which requires many intermediaries (e.g. model parameters and its updates) to be transmitted in multiple rounds, which poses high communication cost in the system.
Our proposed federated two-stage learning, on the contrary, divides the execution of a model on a per-layer basis between devices and the server. The designed method can be adopted for both training and inference. An overview of our system is shown in Figure1. Devices try to learn informative and compact representations of the raw data locally. Then the shared global model will directly use these representations as input rather than the raw data.
The designed system is divided into two complementary parts: local learning and global aggregation. The first process aims to extract representative and low-dimension features from the raw data that are important for training the model. Then the global model can operate on these compact features and thus reduce its parameters.
To illustrate how the system works, first we introduce local operations in the federated two-stage learning with sign-based voting. Then we demonstrate how to reduce global model parameters and eliminate potential data leakage. After showing the global model aggregation process and providing the pseudo-code of the optimization algorithm, we describe how to infer using the trained model in system.
3.1 Local Split Learning
For each device with a subset of data , it learns a representative and low-dimension feature , which is also known as smashed data. There are many discussions on the choice of smashed data sharma2019expertmatcher , but in general the following properties should be satisfied: (i) it must be representative that captures the important features in the raw data , (ii) it should be compact enough comparing with , (iii) it cannot be inverted back to easily because preserving data privacy is a key concern in federated learning.
In the paper, we find smashed data by minimizing the logarithm of distance correlation (DCOR) between the smashed data and the raw data. We show that minimizing DCOR minimizes their Kullback-Leibler divergence, which is a measure of invertibility of the smashed data in information theory. For simplicity, we use distance covariance (DCOV) which is an unnormalized DCOR
vepakomma2019reducing , Kullback-Leibler divergence and cross entropy to build the connection.From Vepakomma et al. vepakomma2018supervised , the sample DCOV can be derived from covariance matrices , .
(1) | ||||
According to arithmetic geometric mean inequality
bhatia1993more , we have,(2) |
In equation (2), is the cross-covariance matrix and is the cross entropy . The KL divergence relates to cross entropy as
(3) |
Therefore we combine equations (1)(2)(3) and build the connection between DCOR and KL divergence. Moreover, we prove that minimizing DCOR also minimizes the invertibility of the smashed data. So we have the first part of local loss,
(4) |
The first loss , however, just satisfies one of properties proposed above. The smashed data should also be informative which means it extracts important features in and thus we have local supplementary loss which measures prediction ability of the smashed data,
(5) |
) above allow us to find smashed data from local devices efficiently. Given suitable hyper-parameters, we can train the local model using sign-based stochastic gradient descent and the local training objective optimizes model parameters with respect to the local loss function,
(6) |
where the system contains devices. and are hyper-parameters that balance the leakage and precision of the smashed data.
3.2 Global Model Aggregation
Traditional federated learning server directly exchanges global model parameters and its updates with local devices, which depletes communication bandwidth badly because of high-dimension raw data. Key differences between prototypical federated learning and our proposed mechanism include (i) Our global model now operates on the learnt representative and low-dimension smashed data. Therefore, the global model in federated two-stage learning can be much smaller than the traditional federated learning when training on the same objective. Moreover, the number of communicated intermediaries reduces that helps our system be more efficient than prototypical federated learning. (ii) Our sign-based optimization scheme just exchanges the sign of each minibatch stochastic gradients in the system. This method compresses client-server communication and still maintains competitive results as other methods.
Server executes:
ClientUpdate(k):
In the paper, the global model trains iteratively. The server sends the global model parameters and initial sign to each device at the beginning of the first iteration. Each device then runs their own local model and the transmitted global model to predict labels. After that, losses and gradients can be computed between predicted label and true label on their own subset of data. Given a proper loss function, we have the global loss,
(7) |
And the total loss is a combination of local loss and global loss,
(8) | ||||
where , and are hyper-parameters that need fine tuning. and control the data leakage and prediction ability of the smashed data and balances the updates to local model. In the overall loss function, global objective acts as a regularization term to synchronize the local smashed data on devices. In addition, it punishes the local objective in case of overfitting.
Method | # Inter./Device | Total Inter. |
---|---|---|
Federated Two-stage | ||
FedAvg & FedProx |
Dataset | Method | Test Accuracy | # Comm. Rounds | Total Comm. |
---|---|---|---|---|
Federated Two-stage | 437 | |||
MNIST | FedAvg | 689 | ||
FedProx | 651 | |||
Federated Two-stage | 191 | |||
CIFAR-10 | FedAvg | 279 | ||
FedProx | 253 | |||
Federated Two-stage | 744 | |||
Shakespeare | FedAvg | 936 | ||
FedProx | 962 |
The most common approach to optimizing federated learning is the Federated Averaging algorithm mcmahan2016communication . Each device runs some SGD steps locally and then sends the updated local model back to the server. The coordinating server aggregates the local model in a manner of weighted average and finally transmits the updated global model to devices. However, this algorithm communicates model parameters in 32-bit floating point number that consumes lots of resources and is not an efficient method. In our system, we use 1-bit compressed gradient (sign of the gradient) which reduces communication cost a lot. In an iteration, gradients can be computed based on the loss between the predicted label and the true label. Then devices send the sign of its local gradients back to the server. After receiving the sign of gradients from local devices, the server aggregates these gradients by using the majority vote scheme and pushes 1-bit decision back to every devices. Finally, each device updates its model of the form
(9) |
where is the model parameters, is the gradients of ’s device and is the learning rate. Complete pseudo-code is shown in Algorithm 1.
3.3 New Devices Inference
Prototypical federated learning mechanism often uses Federated Averaging algorithm that is described above. Since it only trains a global model, the new device with directly makes prediction using the trained model. Comparatively, our proposed method divides the original global model into two parts, so a new device with
needs to be processed first in the local model to learn the informative representations of its raw data. Then the low-dimension smashed data are passed to the trained global model to predict. In the paper, we average all trained local models’ logits (the vector of non-normalized predictions) when dealing with new devices. Because new devices may contain unknown data sources and distributions, we cannot simply pass them to any local model. The ensemble approach we proposed shows competitive results in the experiment compared with Federated Averaging algorithm.
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
4 Experiments
In this section, we first compare the communication efficiency of federated two-stage learning, Federated Averaging and FedProx algorithm li2018federated . FedProx algorithm is inspired by FedAvg and aims to tackle heterogeneity in federated networks. It adds a proximal term to the objective that helps improve the stability of the model. The key insight behind FedProx is that it punishes more on the differences between the global and the current model on devices.
Then we design three experiments to evaluate our proposed method. (i) we show that federated two-stage learning with sign-based voting can reduce the communicated intermediaries a lot on MNIST lecun1998gradient , CIFAR-10 krizhevsky2009learning and Shakespeare mcmahan2016communication datasets. By using these datasets, we can compare the efficiency of our method with FedAvg and FedProx algorithm based on their papers’ baseline model setting. (ii) we evaluate the universality of our method by using data from different distributions (rotated data in CIFAR-10). (iii)we show that the designed method is able to block critical information for reconstructing the raw data and prevents data leakage.
In Table 1, is the gradients size, is the number of devices, and stand for the number of model parameters and used rounds in proposed method, and refer to the number of model parameters and used rounds in FedAvg and FedProx algorithm. We show the required number of intermediaries that communicated per device as well as the total number of them. In an iteration, sign of gradients is communicated in our system rather than the model parameters in FedAvg and FedProx.
4.1 Efficiency
In the first experiment, we use MNIST, CIFAR-10 and Shakespeare datasets. For MNIST, we partition the dataset in a manner of Non-IID where we first sort the data by label, then divide it into 200 shards of size 300 and give 2 shards to each device. Finally, each device has at most 2 classes, which means the data is in highly non-IID setting. Our local model and hyper-parameters choice just follows the Federated Averaging algorithm paper mcmahan2016communication . We use 2 convolution layers as our global model. In addition, we average all trained local models’ logits when dealing with new devices prediction in testing process. For CIFAR-10, we partition it into 100 devices and each one contains 500 training and 100 testing examples. This sub-experiment is in balanced and IID setting. We choose a 5-layer network that contains 2 convolution layers with 2 fully connection layers and a layer to produce logits. The global model used still contains 2 convolution layers and testing procedure just follows the former one by averaging all the trained local models’ logits when new devices pass their data into the system. Shakespeare dataset is substantially unbalanced because many roles have only a few lines. We filter out the clients with data points less than 10k and get 132 clients in total. Then we partition 80% of data to training set and amalgamate the left 20% on a device as test set. The model for this dataset contains 2 LSTM layers which have 256 nodes and we want to predict the next character.
Table 2 shows results of efficiency comparison. From 50 runs’ averaged results in this experiment, our proposed federated two-stage learning with sign-based voting method maintains similar performance but communicates less(10) intermediaries. Our method is more efficient comparing with traditional federated learning algorithms. Key ideas in federated two-stage include (i)We use smashed data which is representative and low-dimension rather than the raw data as input of the global model. (ii)Sign-based SGD with the majority vote also helps reduce the communicated intermediaries.
In Figure 2, we evaluate the performance of Federated Two-stage, FedProx and FedAvg on MNIST, CIFAR-10 and Shakespeare dataset. Our proposed method performs similar results comparing with FedProx and FedAvg but uses less communication rounds. Figure 3 shows the amount of communicated intermediaries using Federated Two-stage, FedProx and FedAvg algorithm. Federated two-stage algorithm exchanges the smallest amount of intermediaries in the system because it only transmits the sign of gradients (1-bit) rather than the model parameters (32-bit floating number) and has a smaller global model comparing with other methods.

![]() |
![]() |
![]() |
![]() |
![]() |

4.2 Universality
In this experiment, we demonstrate that the proposed method can deal with data from new sources or distributions. We spilt the CIFAR-10 dataset into 100 devices. Each device contains 500 training and 100 testing samples. In addition, we introduce a new device with rotated CIFAR-10 examples which contains 1000 training and 200 testing data. The data distribution in this new device is different from the original training dataset. The model architecture contains two convolution layers with two fully connection layers and a transformation layer. Still, the global model is a two convolution layer network. Other hyper-parameters setting follows the Federated Averaging paper.
We first train the model on 100 devices using proposed method, FedAvg and FedProx algorithm. Then the new device is introduced and trains until model converge. Figure 4 shows the accuracy of Federated Two-stage, FedAvg and FedProx algorithm. Our proposed method outperforms the traditional model because federated two-stage learning helps model learn representative patterns in different data distributions by using smashed data.
4.3 Privacy Preserving
We design this experiment to evaluate the potential data leakage risk in our federated two-stage learning with sign-based voting framework. In the third section we discussed potential data leakage in communicated intermediaries and analyzed this risk theoretically. We prove the connection between distance correlation and Kullback-Leibler divergence. Minimizing the logarithm of distance correlation between the smashed data and the raw data, we minimize Kullback-Leibler divergence which is a measure of invertibility of the smashed data. In this experiment, we show that our proposed architecture reduces this data leakage risk a lot on MNIST dataset. The local model used in the experiment contains a CNN with two convolution layers, a fully connection layer and a transformation layer. The global model still contains two convolution layers. All hyper-parameters follow former experiments. Figure 6 shows the reduction in distance correlation between the smashed data and the raw data using Federated Two-stage, FedProx and FedAvg algorithm. Because FedProx and FedAvg do not contain the cut layer, we use representations generated by the first convolution layer of model. Federated two-stage has smaller distance correlation comparing with other algorithms in training and protects the system from data leakage better. Even in the first few rounds, the distance correlation is below 0.55 which means the data leakage risk is relatively small.
In the next sub-experiment, we show reconstruction results from the smashed data in our system. Smashed data is informative and low-dimension representations of the raw data, so decoder can reconstruct the smashed data back to the raw data. From equation (6), we can block the critical information that reconstruction needs using larger .
Figure 5 shows reconstruction results from the smashed data on MNIST dataset in different . It demonstrates that larger blocks more reconstruction information and builds a more secure model. However, from equation (8), all of three hyper-parameters , and control model together which means larger may not lead to a better model.
5 Conclusion
In this paper, we propose a federated two-stage learning with sign-based voting framework that augments prototypical federated learning with a cut layer on local devices and uses sign-based stochastic gradient descent with the majority vote method on model updates. Our method maintains competitive performance while reducing communicated intermediaries, showing great universality across different data sources and protecting model from data leakage. Both empirical and theoretical analysis show that the novel mechanism offers an efficient and privacy preserving scheme which suits for general applications.
Broader Impact
While federated two-stage learning offers a number of potential benefits, providing stronger guarantees in efficiency and privacy is still an emerging direction for future researchers. We hope our insights can inspire future work on federated learning research, which is one of the most important solutions for data privacy concerns in the society.
References
- (1) Jeremy Bernstein, Yu-Xiang Wang, Kamyar Azizzadenesheli, and Anima Anandkumar. signsgd: Compressed optimisation for non-convex problems. arXiv preprint arXiv:1802.04434, 2018.
- (2) Rajendra Bhatia and Chandler Davis. More matrix forms of the arithmetic-geometric mean inequality. SIAM Journal on Matrix Analysis and Applications, 14(1):132–136, 1993.
- (3) Sebastian Caldas, Jakub Konečny, H Brendan McMahan, and Ameet Talwalkar. Expanding the reach of federated learning by reducing client resource requirements. arXiv preprint arXiv:1812.07210, 2018.
- (4) Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
- (5) Andrew Hard, Kanishka Rao, Rajiv Mathews, Swaroop Ramaswamy, Françoise Beaufays, Sean Augenstein, Hubert Eichner, Chloé Kiddon, and Daniel Ramage. Federated learning for mobile keyboard prediction. arXiv preprint arXiv:1811.03604, 2018.
- (6) P. Kairouz, H. McMahan, B. Avent, A. Bellet, et al. Advances and open problems in federated learning. arXiv preprint arXiv:1912.04977, 2019.
- (7) Jakub Konečnỳ, H Brendan McMahan, Felix X Yu, Peter Richtárik, Ananda Theertha Suresh, and Dave Bacon. Federated learning: Strategies for improving communication efficiency. arXiv preprint arXiv:1610.05492, 2016.
- (8) Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images. 2009.
- (9) Yann LeCun, Léon Bottou, Yoshua Bengio, Patrick Haffner, et al. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.
- (10) Tian Li, Anit Kumar Sahu, Manzil Zaheer, Maziar Sanjabi, Ameet Talwalkar, and Virginia Smith. Federated optimization in heterogeneous networks. arXiv preprint arXiv:1812.06127, 2018.
- (11) Keqin Liu and Qing Zhao. Distributed learning in multi-armed bandit with multiple players. IEEE Transactions on Signal Processing, 58(11):5667–5681, 2010.
- (12) H Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, et al. Communication-efficient learning of deep networks from decentralized data. arXiv preprint arXiv:1602.05629, 2016.
- (13) Jose G Moreno-Torres, Troy Raeder, RocíO Alaiz-RodríGuez, Nitesh V Chawla, and Francisco Herrera. A unifying view on dataset shift in classification. Pattern Recognition, 45(1):521–530, 2012.
-
(14)
Martin Riedmiller and Heinrich Braun.
A direct adaptive method for faster backpropagation learning: The rprop algorithm.
InProceedings of the IEEE international conference on neural networks
, volume 1993, pages 586–591. San Francisco, 1993. - (15) Vivek Sharma, Praneeth Vepakomma, Tristan Swedish, Ken Chang, Jayashree Kalpathy-Cramer, and Ramesh Raskar. Expertmatcher: Automating ml model selection for clients using hidden representations. arXiv preprint arXiv:1910.03731, 2019.
- (16) G. Székely, M. Rizzo, N. Bakirov, et al. Measuring and testing dependence by correlation of distances. The Annals of Statistics, 35(6):2769–2794, 2007.
- (17) Paul Vanhaesebrouck, Aurélien Bellet, and Marc Tommasi. Decentralized collaborative learning of personalized models over networks. 2017.
- (18) Praneeth Vepakomma, Otkrist Gupta, Abhimanyu Dubey, and Ramesh Raskar. Reducing leakage in distributed deep learning for sensitive health data. arXiv preprint arXiv:1812.00564, 2019.
- (19) Praneeth Vepakomma, Chetan Tonde, Ahmed Elgammal, et al. Supervised dimensionality reduction via distance correlation maximization. Electronic Journal of Statistics, 12(1):960–984, 2018.
- (20) Qiang Yang, Yang Liu, Tianjian Chen, and Yongxin Tong. Federated machine learning: Concept and applications. ACM Transactions on Intelligent Systems and Technology (TIST), 10(2):12, 2019.
-
(21)
Yuchen Zhang, John Duchi, Michael I Jordan, and Martin J Wainwright.
Information-theoretic lower bounds for distributed statistical estimation with communication constraints.
In Advances in Neural Information Processing Systems, pages 2328–2336, 2013.