FedMAX
Source code for ECML-PKDD (2020) paper: FedMAX: Mitigating Activation Divergence for Accurate and Communication-Efficient Federated Learning
view repo
In this paper, we identify a new phenomenon called activation-divergence which occurs in Federated Learning (FL) due to data heterogeneity (i.e., data being non-IID) across multiple users. Specifically, we argue that the activation vectors in FL can diverge, even if subsets of users share a few common classes with data residing on different devices. To address the activation-divergence issue, we introduce a prior based on the principle of maximum entropy; this prior assumes minimal information about the per-device activation vectors and aims at making the activation vectors of same classes as similar as possible across multiple devices. Our results show that, for both IID and non-IID settings, our proposed approach results in better accuracy (due to the significantly more similar activation vectors across multiple devices), and is more communication-efficient than state-of-the-art approaches in FL. Finally, we illustrate the effectiveness of our approach on a few common benchmarks and two large medical datasets.
READ FULL TEXT VIEW PDFSource code for ECML-PKDD (2020) paper: FedMAX: Mitigating Activation Divergence for Accurate and Communication-Efficient Federated Learning
Large amounts of data are increasingly generated nowadays on edge devices, such as phones, tablets, and wearable devices. If properly used, machine learning (ML) models trained using this data can significantly improve the intelligence of such devices
[5]. However, since data on such personal devices is highly sensitive, training ML models by sending the users’ local data to a centralized server clearly involves significant privacy risks. Other examples of private datasets include personal medical records which must not be shared with third parties. Hence, in order to enable intelligence for these privacy-critical applications, Federated Learning (FL) has become the de facto paradigm for training ML models on local devices without sending data to the cloud [7, 8].As the state-of-the-art approach for FL, Federated Averaging (FedAvg) [1] simply runs several local
training epochs on a randomly selected subset of devices; these training epochs utilize only local data available on any user’s device. After local training, the models (not the local data!) are sent over to a server via a
communication round; the server then averages all the parameters of these local models to update a global model. Unfortunately, FedAvg is not designed to handle the statistical heterogeneity in federated settings, i.e., when data is not independent and identically distributed (non-IID) across the different devices. Not surprisingly, it has been recently reported that FedAvg can incur significant loss of accuracy when data is non-IID [2, 3].To deal with such non-IID settings, one approach called “data-sharing strategy” distributes global data across the local devices, such that the test accuracy can increase by making data look more IID [2, 9]. However, obtaining this common global data is usually problematic in practice. Another approach called FedProx [4] targets the weight-divergence problem, i.e., the local-weights diverge from the global model due to non-IID data at local devices (hence, the updates can go in different directions at different local devices).
In this paper, we first identify a new phenomenon called activation-divergence and argue that the activation vectors in FL can diverge even if a subset of users share a few common classes of data. Since the activation vectors directly contribute to the model’s accuracy, making them as similar as possible across all devices should become an important objective in FL. To this end, we propose FedMAX, a new FL approach that introduces a new prior for local training. Specifically, our prior maximizes the entropy of local activation vectors across all devices. We show that our new prior:
Makes activation vectors across multiple devices more similar (for the same classes); in turn, this improves the classification accuracy of our approach;
Significantly reduces the number of total communication rounds needed (as one can perform more local training without losing accuracy). This is particularly important to save energy when training on edge devices.
Extensive experiments on five non-IID FL datasets demonstrate that our approach significantly outperforms both FedAvg [1] and FedProx [4] (e.g. better accuracy on CIFAR-10 dataset). We also observe up to reduction in communication rounds compared to FedAvg and FedProx.
The remainder of this paper is organized as follows. In Section 2, we provide some background information on FL and discuss the maximum entropy principle. We then present our proposed approach FedMAX in Section 3. In Section 4, we provide a thorough evaluation of FedMAX, under both IID and non-IID settings, using three digit/object recognition and two medical datasets. Our results demonstrate the applicability of our idea and the practical benefits of FedMAX over other approaches.
In FedAvg, after training on device’s own data, the updated local models are averaged at a central server in order to get a new global model. For non-IID data, the performance of FedAvg reduces significantly as the weights of different models often diverge [2, 3]. To address this non-IID issue, several approaches propose to use some globally shared data to improve the accuracy by making the local data look more IID [2, 9]. However, in practice, collecting this global data may be problematic (or even infeasible) due to privacy concerns; additionally, dealing with this global data can use up critical resources like the local storage space or network bandwidth. Consequently, another approach called FedProx [4]
has been proposed to solve the weight-divergence problem by introducing a new loss function which constrains the local models to stay close to the global model.
In contrast to these prior approaches, we aim at constraining the activation-divergence across multiple devices. More precisely, our approach is based on the principle of maximum entropy which states that when there is no a priori information about a problem, the prior distribution should be chosen to maximize entropy [11]. The core idea behind maximizing entropy is to obtain a prior which assumes the least amount of information about a given problem^{1}^{1}1Making needless or unfounded prior assumptions about a problem can reduce the accuracy of the model, hence it is better to make minimal assumptions. For more information on maximum entropy, please refer to [11, 10]
. We note that, while this principle has been exploited to solve traditional natural language processing problems
[12, 13], it has never been used in the context of FL. In the next section, we explain the intuition behind using of this principle when dealing with non-IID data in FL and describe our newly proposed approach in detail.We note that other studies exploit ML models [17] and aim at addressing differential privacy [15] of medical datasets. In practice, such samples of medical datasets are usually unbalanced and non-IID. Therefore, evaluating FL with medical datasets is necessary, especially when privacy issues are at stake [15]. To this end, we perform multiple experiments on such two different medical datasets. The Chest X-ray dataset [17]
is one of the accessible medical image datasets for developing automated methods to identify and classify pneumonia. The APTOS dataset
[18] is also a well-known dataset for detecting the blindness with retina images taken using fundus photography. Our results show the effectiveness of our approach on these non-IID datasets.FL aims to solve the learning task without explicitly sharing local data. More precisely, a central server coordinates the global learning across a network where each node is a device collecting data and performing a local learning task (as shown in Fig. 1(a)). The objective of FL is to minimize:
(1) |
where is the local objective which is typically the loss function of the prediction made with model parameters ; is the number of devices selected at any given communication round, where is the proportion of selected devices and is the number of total devices; , and is the number of samples available at the device , is the total number of samples.
In FedAvg, any local model is updated with its own data as , where is the learning rate, represents the gradient of ; the global model is then formed by the averaging the parameters of all these local models, i.e., . For non-IID datasets, different local models will have different data. Although optimized with the same learning rate and the same number of local training epochs, the weights of these local models will likely diverge. Consequently, the accuracy of the global model decreases when its parameters are weight-averaged across these different local models. One possible solution to this problem is to constrain the local updates within a reasonable range, as FedProx proposed [4].
Since activation vectors directly contribute to model accuracy, our new idea is to reduce the activation-divergence for the same classes across multiple devices. To this end, we propose a new prior for the local training that can help us achieve the above goal. More precisely, we use a Convolutional Neural Network (CNN) model consisting of five convolutional layers and two fully-connected layers (see Fig.
1(b)) for 3 digit/object recognition datasets and ResNet50 [19] for two medical datasets, i.e., APTOS and Chest X-ray. We also refer to the inputs of the last fully-connected layer as the activation-vector; for the 5-layer CNN, this activation-vector is 512-dimensional tensor which passes through the final fully-connected layer to yield logits (the unnormalized class probabilities). Hence, we propose a prior distribution that achieves similar activation vectors across all different devices.
We initially consider the norm to constrain the activation vectors and argue that by preventing the activation vectors from taking large values, the norm should reduce the activation-divergence across different devices. We formulate the norm regularization as follows:
(2) |
where is the cross-entropy loss on local data (same as the cost function of FedAvg [1] which tries to distinguish the various labels from each other), denotes to any local device in Fig. 1(a), is norm, and refers to the activation vectors at the input of the last fully-connected layer (as shown in Fig. 1(b)) for sample on device . Further, is a hyper-parameter used to control the scale of the norm regularization.
Intuitively, this norm regularization constrains the activation vectors and indirectly affects the parameters of other layers except the last fully-connected layer. However, reducing the activation to zero can lead to model underfitting, which results in poor performance. Therefore, we further propose another form of regularization to ensure more similar activation vectors across different devices.
The activation-divergence problem is more complex in the non-IID settings where different users deal with data from different classes. As such, we do not have any prior information about which users have data from which classes. Hence, in non-IID settings, we do not have any prior information about how the activation vectors at different users (for the given classes) should be distributed. Consequently, we propose to use the principle of maximum entropy [11] and select a distribution for activation vectors that maximizes their entropy. Using such a prior, the local loss function for our FL problem is given by:
(3) |
where is a mini-batch size of local training data, and denotes the entropy of activation vectors. Also, is a hyper-parameter that is used to control the scale of the entropy loss. Compared with (2), equation (3) maximizes the entropy (hence it minimizes the negative entropy) of activation vectors instead of minimizing the norm of activation vectors ; therefore, we call this approach FedMAX.
Further, (3) can be written using the Kullback-Leibler (KL) divergence as:
(4) |
where denotes the KL divergence, and is uniform distribution over the activation vectors. Since equation (4) is equivalent to equation (3) up to a constant term, the new formulation does not affect the optimization process, thus resulting a maximum entropy too. As we shall see shortly, FedMAX is more stable than the norm-based regularization.
The training process of FedMAX is similar to FedAvg (see Algorithm 1). The initial model and weights are generated on a remote server. After selecting a subset of devices ( represents the proportion of selected devices, as shown in Fig. 1(a)), the server sends the model (and the corresponding weights) only to these devices. The devices train the model for local epochs using their local data and then send the trained model back to the server. After averaging the models on the server, sending back the updated model to the newly selected devices finishes one communication round () – see Algorithm 1, where represents the number of devices, is the local training batch size, and represents the total number of communication rounds^{2}^{2}2We note that this approach reduces to FedAvg if .. This completes the newly proposed FedMAX; we next show its effectiveness on multiple datasets.
We perform multiple experiments on five different datasets: FEMNIST* [6], CIFAR-10, CIFAR-100 [16], APTOS [15] and Chest X-ray [14]. The first three datasets are trained with the five layer CNN in Fig. 1(b), while the last two medical datasets are fine-tuned with ResNet50 [19]. We consider a FL setting where we have a central server and a total of 100 local devices (i.e., ), each device containing only a subset of the entire dataset. At each communication round, only (i.e., ) of these devices are randomly selected by the server for local training. With different ways to separate data at the local devices, we can get either IID or non-IID of each dataset. In what follows, we show results for both IID and non-IID datasets.
We first use synthetic data generated as in [4] to verify that the maximum entropy regularization leads to similar activations at different local devices. Samples for
th device are drawn from a normal distribution
, which has two parameters: the mean vector and the covariance matrix . Each element in the mean vector is generated from , and here . A larger will lead to more varied mean vectors of the data distribution at each device, thus more non-IID data; the covariance matrix is a diagonal matrix where (similar to that used in [20]).Following the data-generation strategy presented in [4]
, we use a two-layer perceptron
to generate the labels w.r.t the input samples^{3}^{3}3Once initialized, these models remain fixed ., where , , , and . Each element in , , , and is drawn from the normal distribution , where . The controls the differences among the local models, thus indirectly influences the generated labels.We use three different sets to generate the non-IID synthetic data. We train both FedAvg and FedMAX on the synthetic data with a two-layer perceptron which has the same structure as the model used to generate the labels. The training process lasts 200 communication rounds (i.e., ), with one local training epoch (i.e., ). For each communication round, the average activation of each local model is collected and the similarity between the local activation and the global activation is calculated with KL-divergence . The global activation is calculated from the averages of all local activations , where is the total number of devices. The overall similarity per communication round is represented by the mean of the local similarity .
As we can see from Fig. 2, the maximum entropy regularization (FedMAX) can result in relatively lower KL-divergence of global and local activations, which means the activations from the model with maximum entropy regularization are similar to each other. Moreover, the values for synthetic data lead to a higher KL-divergence for both FedAvg and FedMAX during the first few epochs, which means that the more heterogeneous data distributions can cause activations very dissimilar from each other. Thus, constraining the activation within a reasonable range, or making the activations more similar to each other, can be benefit FL, especially for the non-IID case.
We first compare our proposed FedMAX against the norm-based regularization on a non-IID CIFAR-10 dataset. For each regularization, we train a CNN like in Fig. 1 consisting of about 0.6 million parameters. The hyper-parameter of norm regularization varies from to , and the of maximum entropy regularization varies from to . Since the maximum entropy regularization is averaged over the activation, it has larger hyper-parameters than the norm.
The results are shown in Fig. 3. As we can see, both norm and maximum entropy regularization outperform the FedAvg, which is because that both methods enable more similar activation vectors across the devices. However, when compared against the norm, the accuracy of the maximum entropy regularization is more robust to hyper-parameter variation. Specifically, we found that for certain values, the norm results in extremely low accuracies (see Fig. 3(b)); this, in turn, can result in a much more time consuming hyper-parameter search for different datasets. Since FedMAX results in a significantly more stable behavior (see Fig. 3(a)), in the remaining of this paper, the experimental results are reported only for FedMAX using the maximum entropy regularization.
We first verify our approach on three different datasets: FEMNIST* [6], CIFAR-10 and CIFAR-100. For each dataset, we train a CNN like in Fig. 1(b) consisting of about 0.6 million parameters.
The training process lasts 3000 communication rounds (i.e., ) with a single local training epoch (i.e., ); the mini-batch size at each selected device is 100. The learning rate is initialized to 0.1 and decays by 0.9992 at each round. For reference, the decay rate in [2] is 0.992^{4}^{4}4For our setup, since this decay rate results in an extremely small learning rate after thousands of epochs, we increase our learning rate decay to 0.9992.. We also test the communication efficiency by setting the global communication rounds , learning rate decay of 0.996, five local training epochs, and keep all other parameters the same; this way, the experimental settings remain consistent with the 3000 communication rounds setup.
For FedProx, the results are reported for the hyper-parameter [4]. We did try other values like , but found that the results are very similar. Also, for our approach, we set = 1500. To split the datasets into the non-IID parts, we randomly assign 2 out of 10 classes (20 out of 100 classes) for CIFAR-10 (CIFAR-100) to each device. For FEMNIST*, we follow the same setting as in [4], where data from 20 out of 26 classes are given to each device. For the IID case of all three datasets, labels are distributed uniformly across all users. In what follows, we present two sets of results: (i) Accuracy improvements and (ii) Communication-efficiency of FedMAX.
The test accuracy of the 3000 communication round experiment is shown in Fig. 4. As evident, our approach outperforms the other approaches for all three datasets. The test accuracy decreases accordingly as the datasets change from FEMNIST* to CIFAR-100, where our CNN models become relatively smaller for the dataset. Since each device for CIFAR-10 has only two out of ten labels, this is an extreme non-IID case; this is why the test accuracy on CIFAR-10 varies much more rapidly (for all three approaches) compared to the other datasets. For the CIFAR-10 dataset, our model also converges significantly faster than the other approaches. The final accuracies across five runs for all experiments are shown in Table 1. As shown, our approach outperforms existing techniques for both IID and non-IID cases; the best results are highlighted with bold.
non-IID | 3000 communication rounds | 600 communication rounds | ||||
---|---|---|---|---|---|---|
Approach | FEMNIST* | CIFAR-10 | CIFAR-100 | FEMNIST* | CIFAR-10 | CIFAR-100 |
FedAvg [1] | 92.240.08% | 67.261.50% | 42.170.49% | 92.090.14% | 58.913.55% | 34.290.52% |
FedProx [4] | 92.140.16% | 67.461.78% | 41.990.58% | 92.090.08% | 58.632.98% | 34.420.33% |
FedMAX | 94.050.13% | 73.101.20% | 47.150.75% | 93.780.10% | 65.641.49% | 43.150.99% |
Improvement | 1.81% | 5.64% | 4.98% | 1.69% | 6.73% | 8.73% |
IID | 3000 communication rounds | 600 communication rounds | ||||
Approach | FEMNIST* | CIFAR-10 | CIFAR-100 | FEMNIST* | CIFAR-10 | CIFAR-100 |
FedAvg | 92.240.08% | 81.140.49% | 43.560.26% | 92.090.14% | 75.940.96% | 32.670.39% |
FedProx | 92.140.16% | 81.160.29% | 43.220.30% | 92.090.08% | 75.911.09% | 32.670.44% |
FedMAX | 94.050.13% | 83.660.38% | 53.130.58% | 93.780.10% | 82.390.26% | 47.380.47% |
Improvement | 1.81% | 2.50% | 9.57% | 1.69% | 6.45% | 14.71% |
The test accuracy of the 600 communication rounds experiment is shown in Fig. 5. With more local training, the weights of the models on different devices are expected to diverge more from the global model, which explains the loss of accuracy. However, FedMAX significantly outperforms the test accuracy of FedAvg [1] and FedProx [4] by up to (see Table 1, the better results are highlighted with bold.).
Another observation worth noting from Table 1 is that for all three datasets, FedMAX with 600 communication rounds achieves comparable or even better accuracy than FedAvg and FedProx with 3000 communication rounds. This shows that, by relying on more local training, FedMAX significantly reduces communication rounds (by up to ) compared to prior techniques, without losing accuracy. This is particularlly important for edge computing where communication costs reduction is crucial for energy savings.
The APTOS dataset includes 38,788 samples, five labels describing the severity of blindness, and each class contains different numbers of retina images taken using fundus photography. The Chest X-ray dataset has 5,856 samples and two image categories (Pneumonia/Normal) graded by expert physicians. Each dataset is randomly split into 85% training data and 15% test data. Since these are unbalanced datasets, we use F1 macro score to measure the performance of the model.
The experiment setting is the same, but instead of training a five-layer CNN, we fine-tune a ResNet50 which is pre-trained on the ImageNet dataset. The activation of ResNet50 is the output of final average-pool layer, where the activation-vector is a 2048-dimensional tensor.
The training process lasts 300 communication rounds (i.e., ) with a single local training epoch; the mini-batch size at each selected device is 32. The learning rate is initialized to 0.001 and decays by 0.992 at each round. To split the datasets into non-IID parts, we randomly assign different proportions of 5 classes (2 classes) for APTOS (Chest X-ray) to each device. For our approach, we set = 10,000 for APTOS dataset and 1,000 for Chest X-ray dataset.
APTOS | Chest X-ray | |||
---|---|---|---|---|
Approach | IID | non-IID | IID | non-IID |
FedAvg | 0.33620.0040 | 0.27070.0135 | 0.82430.0296 | 0.70940.0338 |
FedMAX | 0.34510.0062 | 0.27060.0121 | 0.81470.0286 | 0.71830.0383 |
Improvement | 0.0089 | -0.0001 | -0.0096 | 0.008 |
The test accuracy of the IID and non-IID cases for the 300 communication-round experiment is shown in Fig. 6. As evident, our approach FedMAX outperforms FedAvg on the APTOS IID case. On the non-IID case, our method yields similar results as FedAvg. The F1 score of the non-IID case varies more rapidly than the IID case. This is because the medical datasets are highly imbalanced; the non-IID partition by randomly separating the samples can lead to devices with only one class, which exacerbates the impact of the training process.
Compared with other datasets, the results of FedMAX on the Chest X-ray dataset are close to FedAvg. One possible reason is that since the Chest X-ray dataset has only two classes, it cannot really make the activations more similar among different labels across different devices. Besides, with fewer samples in the Chest X-ray dataset, after partitioning, each device contains only a small amount of data; this leads to a short local training process and comparably high frequent communication. As a result, the activation divergence may already be constrained, so that the FedAvg has a similar performance when compared against FedMAX. Final accuracy comparisons across the five runs for all our experiments are shown in Table 2. The better results are highlighted with bold.
We now analyze the impact of our proposed FedMAX on the activation-divergence that can happen in non-IID FL. We show 2-dimensional (2D) t-SNE plots of our 512-dimensional (512D) activation vectors for different devices (each device has two random classes from the CIFAR-10 dataset). Specifically, the t-SNE plots embed each 512D activation vector with a 2D point in such a way that similar objects are modeled by nearby points and dissimilar objects are modeled by distant points. We expect the activation vectors of the same class (even from different devices) to share more similarities, thus, their corresponding 2D points should be closer to each other and form a cluster on the t-SNE plots. To keep it simple, we perform the experiment on a total of 10 local devices, with all the devices training at every communication round.
In Fig. 7, Fig. 8, and Fig. 9, the plots on the left show the activation vectors for FedAvg, and the ones on the right show those for FedMAX. Various colors represent the activation vectors for different classes, while the letters denote the device IDs. As the number of local epochs increases, we observe that: (i) FedMAX starts to gain accuracy, (ii) Activation vectors for FedMAX start to cluster together - see highlighted portions in Fig. 8 and Fig. 9 where the activation vectors from same classes (i.e., the same color) come closer to each other across different devices (i.e., letters A-J). In contrast, for FedAvg, clustering happens much more slowly and, hence, its accuracy is significantly lower than FedMAX.
In this paper, we have identified the activation-divergence phenomenon in FL and proposed FedMAX, a new approach for accurate and communication-aware FL in non-IID and IID settings. By exploiting the norm regularization and the principle of maximum entropy, we have introduced a new prior which assumes minimal information about the activation vectors at different devices.
With extensive experiments, we have shown that FedMAX improves the test accuracy and is significantly more communication-efficient than the state-of-the-art approaches running on FEMNIST*, CIFAR-10, and CIFAR-100 for both non-IID and IID settings. Besides, we have presented experiments on two medical datasets, APTOS and Chest X-ray, and have shown the improvement of FedMAX on the APTOS IID case. We attribute the better performance of FedMAX to improving the similarity across the devices while regularizing the activation vectors. Finally, we note that FedAvg and FedMAX perform similarly on the Chest X-ray dataset due to the smaller number of samples which may hardly lead to activation divergence.
In future work, we plan to evaluate the FedMAX approach using different datasets which contain more classes and samples. Moreover, with the increasing need of multitasks learning, we also plan to implement FedMAX for different learning tasks such as language modeling.
Kermany, Daniel S and Goldbaum, Michael and Cai, Wenjia and Valentim, Carolina CS and Liang, Huiying and Baxter, Sally L and McKeown, Alex and Yang, Ge and Wu, Xiaokang and Yan, Fangbing and et al.: Identifying medical diagnoses and treatable diseases by image-based deep learning. Cell, vol. 172, pp. 1122–1131. Elsevier (2018)
Wang, Xiaosong and Peng, Yifan and Lu, Le and Lu, Zhiyong and Bagheri, Mohammadhadi and Summers, Ronald M: Chestx-ray8: Hospital-scale chest x-ray database and benchmarks on weakly-supervised classification and localization of common thorax diseases. Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2097–2106. (2017)