The rapid growth of machine learning model complexity and demand for large training data has stimulated the interest in distributing the learning task across multiple machines. As an emerging distributed learning paradigm, federated learning (FL) [mcmahan2017communication] allows multiple clients (e.g., mobile devices) to collaboratively learn a shared model in a privacy-preserving way. In contrast to conventional machine learning methods that require all training data to be exposed to a central server, FL allows privacy-sensitive data to be retained on each client. In particular, each client computes an update of the model on their local dataset, and a central server (e.g., service provider) coordinates the learning process by aggregating the clients’ updates to maintain a global model. This strong privacy guarantee of FL has spurred a broad spectrum of real-world applications in areas like mobile computing [hard2018federated] and telemedicine [sheller2020federated].
Despite its favorable characteristics, FL still faces challenges from non-malicious failures (e.g., noisy data) as well as adversarial attacks (e.g., Byzantine attacks [blanchard2017machine] and backdoor attacks [bagdasaryan2020backdoor, sun2019can]). Moreover, the strong emphasis on clients’ privacy prevents the server from accessing and inspecting clients’ data directly, which makes detecting these failures and attacks a challenging task [kairouz2019advances]. The aggregation rule adopted by the central server acts as the most crucial component in ensuring the amount of robustness of FL systems. By default, the server aggregates the local model updates by taking the average value as the global model update [mcmahan2017communication]. However, it has been shown that a single faulty/malicious client can impede the convergence of the jointly learned model under this setting [blanchard2017machine].
Recently several theoretical approaches based on gradient similarity [blanchard2017machine] or robust statistics [chen2017distributed, yin2018byzantine] have been proposed to achieve Byzantine-resilient learning. Although offering provable guarantees, in practice these methods only provide a weak level of tolerance to attacks and the resulting model could still be significantly influenced by malicious clients. To address this, Bulyan [mhamdi2018hidden]
proposes to execute another robust aggregation rule for multiple iterations to provide a stricter convergence guarantee at the cost of high computational burden. Other methods attempt to detect and remove malicious clients by estimating each client’s reliability through calculating the descendant of the loss function[xie2019zeno]
or projecting clients’ updates into a latent space using a variational autoencoder[li2020learning]. However, these methods require prior knowledge on clients’ data distributions for loss descendent estimation or autoencoder training, which is hard to satisfy in practice, especially for cross-device FL where clients’ data are private and extremely heterogeneous. Another method [munoz2019byzantine]
proposes to adaptively estimate the quality of client updates using a hidden Markov model. However, in order to keep track of and update the reliability score of each client, the server needs to know the mapping between the submitted updates and the clients’ identities, which may lead to serious risks of privacy breaches (e.g., data inference attacks[wang2019beyond]).
In this work, we seek to release these constraints by proposing a new aggregation strategy that can resist strong adversarial attacks for achieving Byzantine-resilient federated learning. Different from existing studies, we propose to examine the local model updates from a spatial-temporal perspective. From the spatial perspective, we show that at each round of communication, the updates from faulty/malicious clients exhibit certain distinguishable geometric patterns in the parameter space. Leveraging this observation, we can assess the integrity of a client’s model update by inspecting its cosine similarity with all updates and utilize a clustering-based approach to detect and filter out malicious updates. Moreover, to handle malicious clients with time-varying behavior, we propose to adaptively adjust the learning rate at each communication round by comparing the received updates with the speculated update according to historical data from a temporal perspective. This enables our method to tolerant abrupt and uncertain adversarial activities in cross-device FL setting with highly unreliable clients. In addition, our method does not rely on the prior knowledge on the client’s data distribution or clients’ identities and therefore can be applied along with techniques such as secure shuffling[bittau2017prochlo] and differential privacy [geyer2017differentially] to ensure user’s privacy.
We conduct extensive experiments to evaluate the proposed method under both cross-silo and cross-device settings on public datasets. The results demonstrate that our method achieve better robustness in the presence of noisy, faulty or malicious clients comparing to the current state-of-the-art aggregation methods such as Krum [blanchard2017machine], Median, and Trimmed Mean [yin2018byzantine].
Ii Background and Related Work
Ii-a Federated Learning
Federated Learning (FL) (or Collaborative Learning) is a distributed learning framework that allows multiple clients to collaboratively train a machine learning model under the coordination of a central server, while keeping their private training data locally on the device without being shared or revealed to the server or other clients. Federated learning can be conducted among a small set of reliable clients (cross-silo) or among a large number of mobile and edge devices (cross-device). Let denote the set of participating clients, each of which holds a local dataset of data samples. is the joint training dataset and is the total number of data samples. represents the empirical loss over a model and dataset . The objective of federated learning can be formulated as:
Initially, the central server randomly initializes a global model . Then at each communication round, the following steps are performed to achieve the learning objective, as shown in Figure 1:
Step I: Broadcast Latest Model. The central server broadcasts the latest global model to all the clients (usually in cross-silo FL) or a subset of clients () that are selected to participate in this round of training (usually in cross-device FL).
Step II: Clients Compute Local Updates. Each client computes an update of the model on its local dataset by performing several iterations of gradient descent: , with being the learning rate.
Step III: Aggregate Client Updates. The server updates the global model by aggregating the local updates according to a certain aggregation rule : .
Ii-B Byzantine-resilient Aggregation Rules
The most widely-used aggregation rule for communication-efficient FL is Federated Averaging (FedAvg) [mcmahan2017communication], which aggregates the client updates by computing a weighted average: . However, FedAvg is not fault-tolerant and even a single faulty/malicious client can prevent the global model from converging [blanchard2017machine, yin2018byzantine]. To address this, several robust aggregation techniques have been proposed:
Krum [blanchard2017machine]. At each communication round, Krum selects one of the local model updates as the global model update by comparing the similarity between the provided local updates. Suppose out of clients are malicious. Krum assigns a score for each local model update by computing the sum of Euclidean distances between and neighboring local updates that are closest to . The local model updates with the smallest scores will be selected and the average will be computed as the global model update.
Median [yin2018byzantine]. Median is a coordinate-wise aggregation rule that considers each model parameters independently. Specifically, to decide the th parameter of the global model update, the server sorts the th parameter of the submitted local model updates and takes the median value. Median aggregation can achieve order-optimal statistical error rate if the loss function is strongly convex.
Trimmed Mean [yin2018byzantine]. Trimmed Mean is another coordinate-wise aggregation rule. At each round of communication, given a trim rate (), the server first sorts the th parameter of the submitted local model updates, removes the smallest and largest values , and then computes the mean of the remaining values as the th parameter of the global model update. It is proven that trimmed mean can achieve order-optimal error rate for strongly convex losses if , where is the ratio between the number of byzantine clients over the total number of clients.
Other methods. Bulyan [mhamdi2018hidden] iteratively executes another byzantine-resilient aggregation rule (e.g., Krum) multiple times to achieve enhanced robustness, but is not scalable due to high computational cost. Zeno [xie2019zeno] computes the descendant score for each update and only aggregates the top updates with the highest scores, where is the total number of clients and
is a hyperparameter that needs to be specified in advance and should be no less than the number of malicious clients. A more recent study[li2020learning] proposes to use a variational autoencoder to project client updates into a latent space where malicious updates can be detected. However, this method is based on the assumption that the server has access to data that are drawn from the same distribution as the client’s private data to train the autoencoder, which is hard to satisfy in practice. Other studies aim to achieve robust federated learning by identifying and blocking the malicious clients through adaptive model quality estimation [munoz2019byzantine] or clustered federated learning [sattler2020byzantine]. However, these methods require the server to keep track of the identity of each client to maintain a trustworthiness score or to establish the cluster structure and therefore cannot be applied to the scenarios where privacy-preserving techniques (e.g., secure shuffling [bittau2017prochlo]) are applied.
Differently, in this work, we aim to design an aggregation scheme that can tolerant attacks or failures in a more dynamic FL scenario while achieving privacy preservation, i.e., without requiring prior knowledge on the number of faulty/malicious clients, the distribution of the client’s data, or the mapping between the submitted model updates and the clients’ identities. Moreover, different from existing Byzantine-resilient aggregators (e.g., Krum, Median, and Trimmed Mean), our method can tolerant stronger attacks that have large negative impact on the joint model with few malicious clients, such as targeted data poisoning attacks [tolpegin2020data].
Our algorithm inspects the client updates from two critical perspectives. (1) Spatial perspective: We leverage geometric patterns to filter out malicious updates within each round of communication; and (2) Temporal perspective
: We utilize historical data from previous communication rounds to detect temporal outliers.
Iii-a Spatial Perspective
Iii-A1 Geometric Property of Malicious Updates
We first perform a preliminary study to compare the distributions of the model updates computed by benign clients and the updates from faulty/malicious clients. We simulate a simple federated learning task with clients, of which are either faulty clients that contain noisy data or malicious clients that perform byzantine or label-flipping attack (detailed settings are described in Section IV-A
). The learning objective is to jointly train a simple multi-layer perceptron model with one hidden layer ofneurons on the MNIST dataset [lecun1998mnist]. We let each client perform iterations of gradient descent with a learning rate of on its local dataset and report the model update. Figure 2 shows the visualization of the clients’ updates in a 2-dimensional space using t-SNE [maaten2008visualizing]. From the plots we can observe that these malicious updates diverge from benign updates, causing the aggregated global update to be biased and deviate from the direction of the true gradient, which in turn results in degraded performance of the learned model. However, on the other hand, the divergent model updates produce identifiable patterns that can potentially be utilized for detecting and removing these anomalous model updates to improve the robustness of the aggregation rule.
Iii-A2 Clustering-based Anomalous Update Detection
Motivated by the geometric property of the malicious updates, we thus propose to adopt a clustering-based method for achieving unsupervised anomalous model update detection. Since it has been shown that different underlying data distribution of clients can be distinguished by inspecting the cosine similarity between their model updates [sattler2020clustered]
, we use cosine similarity as the metric for computing the affinity matrix. Different from conventional clustered federated learning framework[sattler2020clustered, sattler2020byzantine], we construct clusters per each communication round and the cluster structure is not carried over to the consecutive rounds after each partition. This disentangles the mapping between the model update and the client’s identity to prevent data inference attacks [wang2019beyond] and ensures that our method is scalable to cross-device scenario with a large crowd of clients. Specifically, at each communication round , we first construct the affinity matrix prior to the aggregation by computing the pairwise cosine similarities between the different clients’ updates:
where . We then apply agglomerative clustering with complete linkage [mullner2011modern] to partition the clients’ updates into clusters of singleton nodes and iteratively merge the currently most closest pair of clusters into a new cluster, until there are only two candidate clusters left:
Then we compute the largest similarity between the two candidate clusters as the criteria for partitioning:
The partition process will be proceeded if is less than a preset threshold . Based on the assumption that the majority of clients are not faulty/malicious, we consider the larger cluster of the two as the benign cluster . If , we consider all client updates in this round to be benign. We aggregate the updates that are decided to be benign according to a certain aggregation rule :
In our experiment, we choose to use Median as the default aggregation rule for the proposed algorithm as it does not require prior knowledge on the quantity of malicious clients. The subsequent operations will only be performed on the aggregated benign updates until the next communication round when a new clustering structure is formed.
Iii-B Temporal Perspective
Different from the cross-silo FL where the clients are almost always available, in the cross-device FL scenario, the participating clients are usually a large number of mobile or edge devices that are highly unreliable due to their varying battery, usage, or network conditions. To ensure training speed and avoid impacting the user of the device, the server usually only selects a fraction of clients that are available for computing the global update at each communication round. As a result, the number of faulty clients selected in each communication round is dynamic and highly variable. In addition, a client may continue to send genuine updates until some point in the learning process when it is compromised by an adversary. Thus solely relying on spatial patterns is insufficient, especially when facing a sudden violent perturbation.
Iii-B1 Adaptive Learning Rate Adjustment via Momentum-based Update Speculation
To cope with these time-varying behaviors and achieve temporal robustness, we propose to assess the quality of the aggregated update by comparing it with a speculated value of update that is predicted according to historical statistics. The intuition is that if the current update significantly deviates from previous results, this can indicate an abrupt change in the state of the participating clients (e.g., in extreme case all clients involved in the current round are malicious).
To make a speculation of the update using historical data, we take inspiration from momentum, which utilizes the past gradients to smooth out the current update to achieve fast and stable convergence. Specifically, we first estimate the gradient using the aggregated updates: . Then we compute an exponential moving average of the gradient according to:
where is the decay factor, and can be seen as a speculated value of the gradient from past updates. The cosine similarity between the gradient and the averaged value can be obtained. If , all updates in the current round will be discarded. Otherwise, we update the global model according to
where is the learning rate which is adaptively adjusted according to based on the initial learning rate : . This indicates that our algorithm will take a small step if and disagrees. A complete procedure of our algorithm is described in Algorithm 1.
Iv-a Experimental Setup
Iv-A1 Federated Learning Scenarios
Real-world FL systems in production are usually optimized to keep the number of faulty/malicious clients at a low level () using a variety of system-level protections. In this paper, in order to better show the superiority of the proposed algorithm, unless mentioned otherwise, we consider an extreme malicious case where around fraction of clients are malicious. More specifically, we consider the following two federated learning scenarios:
Cross-silo FL: There are clients that continuously participate in every round of communication. We assume that of them are faulty/malicious. This simulates the federated learning scenario that involves a small number of reliable clients such as different organizations.
Cross-device FL: We assume that a total number of clients are involved and of which are faulty/malicious. At each round of communication, only clients are selected randomly to compute the model update. This simulates the federated learning scenario which involves a large number of mobile and edge devices that are unreliable due to varying battery or network conditions.
Iv-A2 Baseline Aggregations and Parameter Selection
In each FL scenario, we compare our proposed spatial-temporal pattern analysis (STPA) algorithm with representative baseline methods: Krum [blanchard2017machine], Median [yin2018byzantine], and Trimmed Mean [yin2018byzantine]. For fair comparison, we carefully choose the parameters for baseline methods: for Krum, we assume the number of Byzantine updates is known to the server and set to be within the range of to be Byzantine-resilient; for Trimmed Mean, we set the trim ratio to be within -, which is the percentage of the simulated Byzantine clients over total clients. For our STPA algorithm, we set the to , to , and to for the label-flipping setting and for other settings.
Iv-A3 Datasets and Models
We conduct our experiments on public datasets: MNIST [lecun1998mnist], Fashion-MNIST (Fashion) [xiao2017fashion], Spambase [hopkins1999spambase], and Cifar-10 [krizhevsky2009learning]. The MNIST and Fashion-MNIST datasets both contain gray-scale images from classes, of which are used for training and the rest are used for testing. The Spambase dataset is a binary classification problem with instances to decide whether an email is spam or not. We keep the first attributes which indicate whether a particular word was frequently occurring in the e-mail. The dataset is randomly split into training and test sets with a ratio of to . The Cifar-10 dataset contains colour images, with
of them being used for training and the rest for testing. For the MNIST and Fashion dataset, we train a convolutional neural network (CNN) withconvolutional layers and
fully-connected (FC) layers. For Spambase, we train a simple Logistic regression (LR) model. For Cifar-10, we train a CNN withconvolutional layer, max-pooling layer, and FC layers. A summary of the dataset and model configurations is presented in Table I.
|Dataset||# Train||# Test||# Feature||# Class||Model|
Iv-A4 Adversary Model
For each scenario, we consider the following settings in the experiments:
Normal. In each communication round, all selected clients perform steps of gradient descent on their local datasets at a learning rate of , and report the genuine local update to the central server.
Byzantine clients send model updates that are significantly different from genuine clients. In our experiment, instead of performing gradient descent on their local datasets, the faulty/malicious clients compute model updates drawn from a Gaussian distribution with
mean and isotropic covariance matrix with a standard deviation of.
Noisy. For the MNIST, Fashion, and Cifar-10 datasets, prior to the training procedure, we normalize the image data to . When computing update, a uniform noise is added to the data of the selected noisy clients: and is then clipped to the interval. For the Spambase dataset, a uniform noise is added to the noisy clients, and the value is then clipped to the interval.
Label-flipping. All the training labels of the malicious clients are set to zero. This simulates a strong targeted data poisoning attack scenario, where the adversary’s goal is to cause bias in the global model towards a specific class.
|Scenario||Method||5% Faulty/Malicious Clients||10% Faulty/Malicious Clients||20% Faulty/Malicious Clients||33% Faulty/Malicious Clients|
Iv-B Experimental Results
Cross-silo FL Scenario. Figure 3 illustrates the experiments results in cross-silo scenario. From the results, we can observe that the proposed STPA not only achieves comparable convergence speed in the normal setting, but also remains robust in all faulty/attack settings. Although Krum, Median, and Trimmed Mean are able to achieve satisfactory performance in the Byzantine and noisy settings, they fail to resist the stronger label-flipping attack. Krum has the highest test error in the label-flipping on the MNIST and Fashion datasets (
) in the cross-silo scenario. This is because Krum selects the most reliable updates by calculating and comparing the Euclidean distance. In the case of label-flipping attack, all malicious updates are biased towards the same class, which increases the probability of being falsely selected by Krum.
Cross-device FL Scenario. The experiment results in cross-device scenario are shown in Figure 4. As we can see, similar to the cross-silo scenario, Median, Krum, and Trimmed Mean are still susceptible to the label-flipping in the cross-device scenario. Due to the varying number of selected malicious clients, the test errors of these methods are highly variable between communication rounds. Our proposed STPA, however, is able to cope with this dynamic scenario by utilizing the temporal patterns to obtain a stabilized gradient and provide guaranteed convergence. Most noticeably in the case of Cifar-10 dataset, we can achieve test errors of and against label-flipping attack in the cross-silo and cross-device scenario, respectively, which is comparable to the performance when there is no attack, while other methods cannot converge at all.
Fraction of Faulty/Malicious Clients. To study the impact of the quantity of faulty/Byzantine clients, we vary the fraction of faulty/Byzantine clients from to for both FL scenarios on the MNIST dataset. For each adversary model per scenario, we record the average test errors and their standard deviations of the last communication rounds. The results are summarized in Table II, with best results being marked as bold. As we can see, our proposed method achieves the lowest test error in out of cases when comparing to the baseline methods. In some cases where other methods have better performance, we observe that our test errors are almost at the same level with them (i.e., the difference is typically within ). It is also worth notice that our method can resist the strong label-flipping attack even with a large number of malicious clients whereas all other methods fail. These results further confirm the effectiveness, robustness, and generalization of our proposed aggregation method.
Non-IID setting. We conduct experiments in the non-IID federated learning setting where each client is assigned shards with each containing image samples from a single class. Figure 5 shows the results on the MNIST dataset with clients ( of which are malicious). We observe that all methods show worse performance comparing to the IID setting: FedAvg cannot converge in the Byzantine scenario, Median converges slowly in the noisy scenario, and Krum and median both cannot converge in the label flipping scenario. Though the proposed algorithm also converges slowly in some scenarios, it can still achieve the stablest convergence curve comparing to other methods.
In this work, we propose a new method to achieve Byzantine-resilient FL through analyzing the spatial-temporal patterns of the clients’ updates. By utilizing a clustering-based method, we can detect and exclude incorrect updates in each round of communication. Moreover, to further handle malicious clients with time-varying behaviors, we perform a momentum-based update speculation and adaptive learning rate adjustment. Different from existing methods, our method does not rely on prior knowledge of the client’s data distribution or clients identities, thereby preserving the user’s privacy. We conducted extensive experiments on public datasets with one normal setting and three faulty/attack settings under both cross-silo and cross-device scenario. The results show that our method is able to achieve enhanced robustness across all settings comparing to baseline methods.