Critical Learning Periods in Federated Learning

09/12/2021 ∙ by Gang Yan, et al. ∙ Binghamton University Louisiana State University 18

Federated learning (FL) is a popular technique to train machine learning (ML) models with decentralized data. Extensive works have studied the performance of the global model; however, it is still unclear how the training process affects the final test accuracy. Exacerbating this problem is the fact that FL executions differ significantly from traditional ML with heterogeneous data characteristics across clients, involving more hyperparameters. In this work, we show that the final test accuracy of FL is dramatically affected by the early phase of the training process, i.e., FL exhibits critical learning periods, in which small gradient errors can have irrecoverable impact on the final test accuracy. To further explain this phenomenon, we generalize the trace of the Fisher Information Matrix (FIM) to FL and define a new notion called FedFIM, a quantity reflecting the local curvature of each clients from the beginning of the training in FL. Our findings suggest that the initial learning phase plays a critical role in understanding the FL performance. This is in contrast to many existing works which generally do not connect the final accuracy of FL to the early phase training. Finally, seizing critical learning periods in FL is of independent interest and could be useful for other problems such as the choices of hyperparameters such as the number of client selected per round, batch size, and more, so as to improve the performance of FL training and testing.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The ever-growing attention to data privacy and the popularity of mobile computing have impelled the rise of Federated learning (FL) [23, 14, 10]

, a new distributed machine learning paradigm on decentralized data. A typical FL system consists a central server and multiple decentralized clients (e.g., smartphones and IoT devices). The central server initiates federated learning by sending a global model to clients. The clients then use their local data samples to train the received model with common deep learning algorithms and aggregate their local models to the central server. The central server updates the global model by aggregating the received local models and sends it to clients for further training. By repeating the local training and global aggregation, the central server obtains a global model jointly trained by decentralized clients without leaking any raw data. This unique distributed nature enables an extensive deployment of FL that trains deep learning models on sensitive private data, such as Google Keyboard 

[36].

The distributed nature of FL raises a series of new challenges in terms of system performance and data statistics. In FL systems, clients are typically loosely-connected mobile devices with limited communication bandwidth, computation power, and battery life. Besides, unlike traditional centralized machine learning (ML), data samples of each client in FL follow a non-identical and independent distribution (non-IID), introducing bias that slows down or even fails the training. A few recent studies have been proposed to address these challenges by model compression [16, 29, 5, 35], communication frequency optimization [32, 33, 15], and client selections [18, 30, 34].

However, existing studies have not yet explored the significance of critical learning periods

. Recent works have revealed that the first few training epochs—known as critical learning periods—determine the final quality of a deep neural network (DNN) model in traditional centralized ML 

[1, 13, 7, 11]. During a critical period, deficits such as low quality or quantity of training data will cause irreversible model degradation, no matter how much additional training is performed after the period. The existence of critical periods in FL remains an open question due to the unique distributed nature of FL.

In this paper, we seek critical learning periods in FL with systematic experiments and theoretical analysis, and we emphasize the necessity of seizing the critical learning periods to improve FL training efficiency. Specifically, through a range of carefully designed experiments on different ML models and datasets, we observe the consistent existence of critical learning periods in the FL training process. We further propose a new metric named Federated Fisher Information Matrix (FedFIM) to describe and explain this phenomenon. FedFIM is calculated based on a classical statistics notion of Fisher Information Matrix (FIM) [3] that approximates the local curvature of the loss surface in FL efficiently. We show that the phenomenon of critical learning periods in FL can be explained using the trace of FedFIM, a quantity reflecting the local curvature of each clients from the beginning of the training in FL. Our findings suggest that the initial learning phase plays a critical role in understanding the FL performance, complementing many existing studies that generally ignore the connection between the final model accuracy and the early phase training. To the best of our knowledge, this is the first work towards seizing critical learning periods in FL framework for training efficiency. Our main contributions are as follows:

  1. We discover that critical learning periods consistently exist in FL with representative models and datasets through our carefully-designed experiments.

  2. We systematically explore the impacts of critical learning periods for FL under a wide range of FL hyperparameters, including client availability, learning rates, batch size, and weight decay.

  3. We propose a new notion dubbed Federated Fisher Information Matrix (FedFIM) and analyze the phenomenon of critical learning periods in FL through the trace of FedFIM. We show that model quality during critical periods correlates strongly with the trace of FedFIM.

2 Background

2.1 Federated Learning

The goal of FL is to solve a joint optimization problem as

(1)

where w denotes the model parameters, denotes the set of clients, is the local dataset of client , the entire training dataset is , and

is the local loss function of client

. A typical solution to this optimization problem is federated averaging (FedAvg) algorithm [23]. Specifically, FedAvg initializes with a random global model and iterates the following two steps within each communication round :

  • Local training. The central server sends the goal model to a randomly selected subset of clients . Each selected client performs local training using its own dataset :

    (2)

    where is the learning rate and is the index of local iterations.

  • Global aggregation. The central server obtains a new global model by weighted-averaging the local models collected from the selected clients in round :

    (3)

There are a few variant federated learning algorithms, such as SCAFFOLD [15], FedProx [19], and FetchSGD [26]. We choose to perform observations and analysis based on FedAvg because its simplicity and generality extensively reduce the uncertainty of critical periods.

2.2 Critical Learning Periods

Critical learning periods were originally observed in early post-natal development of humans and animals that sensory deficits will cause lifelong irreversible skill impairment. Recently, researchers observed similar phenomenons in centralized deep learning that training a model with defective data such as blurred images in early epochs will decrease its final accuracy, no matter how many additional training epochs are performed [2, 1, 13, 7, 11].

However, observing and justifying critical learning periods in FL are hindered a few obstacles: (i) FL involves multiple deep learning processes across randomly selected clients with their own data; (ii) the global model aggregated by local models at the central server has no direct information about the training data decentralized across clients; and (iii) FL has far more hyperparameters (e.g., the number of selected clients and data distribution) than centralized training that make it complicated to induce critical learning periods.

Figure 1: FL exhibits critical learning periods. (top) The final accuracy achieved by ResNet-18 on both IID and Non-IID CIFAR-10 with FedAvg using partial local datasets (where indicates the ratio of local datasets) for training as a function of the communication round at which the partial training dataset is recovered to the entire training dataset. The test accuracy of FL is permanently impaired if the training dataset is not recovered to the entire training dataset early enough, no matter how many additional training rounds are performed. (bottom) Communication rounds vs. recover round (RC#). The total communication rounds required to achieve the corresponding final accuracy are significantly increased as a function of the recover round.

3 Critical Learning Periods in FL

We hypothesize that the final accuracy of FL is significantly affected by the initial learning phase, which we term as the critical learning periods in FL. Consider a model with loss function , where reaches a minimum loss with a test accuracy when optimized with FedAvg across decentralized clients on the entire training dataset . In addition, consider optimizing FedAvg across all clients only with a subset of the local training dataset , in the first communication rounds and then using the entire training dataset afterwards. Then reaches a minimum loss of with a test accuracy of . The critical learning periods articulate that there exist and such that when , i.e., the initial learning phase is critical in determining the final performance of FL, and the effect of insufficient training (i.e., only using part of the entire training dataset) during the critical learning periods cannot be overcome, no matter how much additional training is performed.

In this section, we address two key questions pertains to the phenomenon of critical learning periods in FL. We first show via an extensive set of experiments that the critical learning periods can be observed across different popular ML models and datasets. We then reveal that the critical learning periods in FL stay robust under various training schemes.

Figure 2: The existence of critical learning periods in FL: FedAvg trained on ResNet-18 using both IID and Non-IID CIFAR-10 with constant learning rates (lr).

3.1 FL Exhibits Critical Learning Periods

We perform extensive simulations using two representative ML models: ResNet-18 [8] and CNN, on popular datasets CIFAR-10 and CIFAR-100 [17]. To present the existence of critical learning periods in FL, we adopt the standard FedAvg [23] which requires the entire training dataset throughout the training process, as well as its performance when only a subset of the training dataset on each client is involved in the first communication rounds at which the training dataset is recovered to the entire training dataset. We call as the “Recover Round” and denote as the ratio of local datasets involved in training. We consider a system with clients and FedAvg randomly selects a subset of clients in each round. The batch size is of ; the initial learning rate is set to with a decay of per round; and the SGD solver is adopted using an exponential annealing scheduling for the learning rate with a weight decay of .

Figure 1 (top) reports the final performance of FL affected by the partial training datasets with different ratios as a function of the recover round . All results consistently endorse that the critical learning periods exist across all settings with different ratios of local datasets involved in the early learning phase: if the training dataset is not recovered to the entire dataset, at as early as the -th communication rounds, the final test accuracy of FL is severely degraded compared to the standard FedAvg. Comparing among different ratios of local datasets involved in early training phase, it is not too surprising to see that lower of local datasets in the early training phase makes drawing critical learning periods easier.

We further measure the total communication rounds required to achieve the corresponding final accuracy as a function of the recover round, as illustrated in Figure 1 (bottom). It is obvious that the communication rounds are significantly increased with a lower final test accuracy as a function of the recover round . This further indicates the importance of the initial learning phase in determining the final performance of FL.

Figure 3: The existence of critical learning periods in FL: FedAvg trained on ResNet-18 using both IID and Non-IID CIFAR-10 with different batch sizes (BS).

3.2 Learning Rate Annealing and Batch Size

We conduct the same experiments as in Figure 1 but using a constant learning rate rather than an annealing scheme. In particular, we set the constant learning rates to be and , respectively. From Figure 2, we still observe the existence of critical learning periods in FL even with constant learning rates. Therefore the phenomenon of critical learning periods in FL are not resultant from an annealed learning rate in later rounds, and cannot be solely explained in terms of the loss landscape of the optimization in (1). Analogous results illustrating the impact of batch size are presents in Figure 3. Again the critical learning periods consistently exist regardless of the choice of batch size. This further suggests that the phenomenon of critical learning periods in FL cannot be simply explained by the differences in batch sizes.

Figure 4: The existence of critical learning periods in FL: FedAvg trained on ResNet-18 using both IID and Non-IID CIFAR-10 with different weight decays (WD).

3.3 Weight Decay

Similarly, the results for the same experiments as in Figure 1 but with different weight decays are presented in Figure 4. We still observe the critical learning periods as in Figure 1, but surprisingly the shapes of the critical learning periods are robust to the values of weight decays, i.e., changing the weight decays does not impact the shape of the critical learning periods.

4 Federated Fisher Information

Through extensive experiments, we have shown that the initial learning phase of the training process plays a critical role in the final test accuracy of FL. Our main contribution in this section is to show that this phenomenon can be explained by the trace of the Federated Fisher Information Matrix (FedFIM), a quantity reflecting the local curvature of each clients from the beginning of the training in FL. We begin with the definition of the FIM for centralized training.

4.1 Fisher Information Matrix

Consider a probabilisitic classification model , where w is the model parameter. Let be the cross-entropy loss function calculated for input and label Denote the corresponding gradient of the loss computed for an example as . Then the Fisher Information Matrix (FIM) for centralized training is defined as

(4)

where the expectation is often approximated using the empirical distribution induced by the centralized training dataset. Note that the FIM can be viewed as a local metric on how much the perturbation of the weights affects the network output [3]. The FIM can also be seen as an approximation to the Hessian of the loss function [22], and hence of the curvature of the loss landscape at a particular point w during training. This provides a natural connection between the FIM and the optimization procedure [3].

However, the computation of FIM in (4) requires the availability of the entire training dataset for the global model at the server. Unfortunately, this is infeasible for FL since training data is decentralized across clients. Hence we cannot compute FIM for FL as in (4). We now introduce a new notion to overcome this challenge.

Figure 5: Connections between critical learning periods in FL and the Federated Fisher information achieved by ResNet-18 on IID CIFAR-10 with FedAvg using 30% of local datasets for training initially and recover to the entire datasets upon the recover round. (a) Test accuracy vs. recover rounds: the final test accuracy is permanently impaired if the training dataset is not fully recovered at as early as the 20-th round. (b) Trace of FedFIM vs. recover round. There exists a sharp increase of the trace of FedFIM in the early training phases. (c) Weighted cumulative sum of the trace of FedFIM vs. recover round.
Figure 6: Connections between critical learning periods in FL and the Federated Fisher information achieved by ResNet-18 on Non-IID CIFAR-10 with FedAvg using 30% of local datasets for training initially and recover to the entire datasets upon the recover round.

4.2 Federated Fisher Information Matrix

Given that training data resides in each client, and the training process of FL in (2) and (3), we first introduce the notation of , which represents the local FIM on client :

(5)

where is the empirical distribution induced by the local dataset of client . Note that is computed using the global model w on the local dataset , and can be considered as a local metric measuring how the perturbation of the global model affects the FL training performance from the perspective of client As a result, the overall impact of the perturbation of the global model on the final output, which we define as the Federated Fisher Information Matrix (FedFIM) for FL, can be computed using the weighted average of local FIM across all clients:

(6)

where the weight of client is the size of its dataset. The rationale is that lower local FIM often has little effect on the final performance. We denote the trace of as .

4.3 Experimental Results

We conduct similar experiments as in Figure 1 with partial local datasets involved in the initial learning phases and the training datasets recover to the entire datasets at the “recover rounds” (RC#). The test accuracy and the trace of FedFIM with different recover rounds and on IID and Non-IID CIFAR-10 are presented in Figures 5 and 6.

First of all, we again observe the existence of critical learning periods since if the training dataset is not recovered to the entire datasets, e.g., at as early as the 20-th communication rounds, the final test accuracy of FL is permanently impaired. Second, this information is fully reflected via the trace of FedFIM as shown in Figure 5 (b) for IID case and Figure 6 (b) for Non-IID case. We observe a sharp increase in the trace of the FedFIM in the early phases of the FL training process, which coincides with dramastic increase of the test accuracy in the early training phase. The information starts to decrease when the test accuracy starts to plateau. Since the training datasets are recovered from of local datasets to the entire datasets at the recover rounds, additional data further boosts the test accuracy as shown in Figure 5 (a). However, such a test accuracy boost decreases significantly as the recover rounds increase. This further suggests that the initial learning phases play a critical role in the FL performance and the permanent model degradation is irreversible no matter how much additional training is performed after the critical learning periods. Correspondingly, the accuracy boosting results in a slight increase in the trace of FedFIM, and the information decreases again when the test accuracy starts to plateau.

In general, the measures of test accuracy and trace of FedFIM are noisy, especially with Non-IID dataset as shown in Figure 6. This is because for instance the learning rate has to be adjusted in order to compensate for possible generalization issues of the training process [12, 27]. To this end, we further consider a weighted cumulative sum of the trace of FedFIM as follows

(7)

where is the learning rate at the -th round, and is the Federated Fisher Information Matrix at the -th round. The trace of FedFIM represents the degree of whether the local data is good enough to improve the model. A larger values correspond to less model information. This is exactly observed in Figure 5 (c) and Figure 6 (c), where a late recovery results in larger weight cumulative trace.

5 Seizing Critical Learning Periods

We use carefully-crafted experiments to evaluate the idea that seizes critical learning periods to improve the FL training efficiency, though existing literature largely ignore the critical learning periods in FL training process. The experiments run on PyTorch on Python 3 with NVIDIA RTX 3060 GPU. The total number of clients is

and a subset of clients are randomly selected in each round.

Specifically, we train ResNet-18 on IID and Non-IID CIFAR-10 with FedAvg under different settings as shown in Figures 7 and 8:

  • All Clients: All clients participate in federated learning.

  • Partial Clients: Only a subset of the clients (e.g., ) participate in federated learning.

  • All Clients in critical periods else Partial Clients: All clients participate in training during the critical learning periods. After that, only a subset of clients (e.g., ) remain in training.

  • All Data: Each client processes all data in local training.

  • Partial Data: Each client processes only partial local datasets (e.g., ) in local training.

  • All Clients in critical periods else Partial Clients: Each client uses its entire local dataset for training during the critical learning periods, and only uses their partial local dataset afterwards.

By seizing the critical periods in FL, we summarize the counter-intuitive experimental results as follows:

No need to involve all clients in training all along.

The conventional FedAvg requires the entire training datasets across all clients throughout the training process. However, some clients may not be available for training, e.g., due to unreliable network connection. To illustrate the impact of critical learning periods, we further consider a heuristic in which all clients are involved in the training during the critical learning periods and then only a subset of clients (e.g.,

) are involved afterwards.

Figure 7(a) and Figure 8(a) show the test accuracy v.s. wall-clock time. There exists a requirement on the number of clients involved in training which provides similar test accuracy as using all clients (FedAvg) throughput. For example, with all clients participate in the FL training during the critical learning periods, and then only 60% of clients afterwards, the final test accuracy is similar to that using all clients throughout the training process. Hence there is no need to involve all clients throughout the training process. Figure 7(b) and Figure 8(b) show the train loss v.s. wall-clock time. The participated client number requirement reduces the training time than using all clients (FedAvg) throughput. It is clear that leveraging critical learning periods for FL training, even in a heuristic manner, can significantly improve the training efficiency with a reduced training time while maintaining final test accuracy.

No need to train a model with all local data for each client. We consider the challenge that FL clients have heterogeneous system capabilities, e.g., can only process part of the local data for training. We use a heuristic with entire local datasets used for training during the critical learning periods and then only partial local datasets involved afterwards.

Figure 7(c) and Figure 8(c) show the test Accuracy v.s. wall-clock time. There exists a training dataset requirement which provides similar test accuracy as using the entire dataset (FedAvg) throughput. For example, with the entire training datasets used in the FL training during the critical learning periods, and then only 25% of local datasets afterwards, the final test accuracy is similar to that using the entire datasets throughout the training process. Hence there is no need to use the entire training datasets throughout the training process. Figure 7(d) and Figure 8(d) present the train loss v.s. wall-clock time, the training dataset requirement (the heuristic) reduces the training time than using the entire dataset (FedAvg) throughput. Again, we observe that the early learning phase plays a critical role in FL performance and leveraging it can significantly improve the training efficiency of FL.

Overall, we can save 40%-50% of the training time and 50%-65% of the total clients but achieve a close final model accuracy when training ResNet-18 on the IID and non-IID CIFAR-10 dataset.

Figure 7: Seizing the critical learning periods in FL training with ResNet-18 on IID CIFAR-10.
Figure 8: Seizing the critical learning periods in FL training with ResNet-18 on Non-IID CIFAR-10.

6 Related Work

Since the term of federated learning was introduced in the seminal work [23], there is an explosive growth in federated learning research. For example, a line of works focuses on designing algorithms to achieve higher learning accuracy and analyze their convergence performance, e.g., [28, 20, 21, 31]. Another line of works aim to improve the communication efficiency between the central server and clients through compressions or sparsification [16, 29, 5, 35], communication frequency optimization [32, 33, 15], client selections [18, 30, 34], etc. Additionally, a lot of efforts have been put on exploring the privacy and fairness of federated learning [4, 6, 9, 24, 37, 25, 31]. These studies are often under the implicit assumption that all learning phases during the training process is equally importantly. Our work focuses on showing that the initial learning phase plays a critical role in the federated learning performance, which is orthogonal to the aforementioned studies.

7 Conclusion

In this paper, we seized the existence of critical learning periods in federated learning so as to improve the federated learning training efficiency. Though a range of carefully designed experiments on different ML models and datasets, we showed that critical learning periods consistently exists in the training process of FL. To explain such a phenomenon, we further proposed a new metric called Federated Fisher Information Matrix. Our findings suggest that the initial learning phase plays a critical role in the final performance of FL.

References

  • [1] A. Achille, M. Rovere, and S. Soatto (2019) Critical Learning Periods in Deep Networks. In Proc. of ICLR, Cited by: §1, §2.2.
  • [2] S. Agarwal, H. Wang, K. Lee, S. Venkataraman, and D. Papailiopoulos (2021) ACCORDION: Adaptive Gradient Communication via Critical Learning Regime Identification. In Proc. of MLSys, Cited by: §2.2.
  • [3] S. Amari and H. Nagaoka (2000) Methods of Information Geometry. Vol. 191, American Mathematical Soc.. Cited by: §1, §4.1.
  • [4] K. Bonawitz, V. Ivanov, B. Kreuter, A. Marcedone, H. B. McMahan, S. Patel, D. Ramage, A. Segal, and K. Seth (2017) Practical Secure Aggregation for Privacy-Preserving Machine Learning. In Proc. of ACM CCS, Cited by: §6.
  • [5] S. Caldas, J. Konečny, H. B. McMahan, and A. Talwalkar (2018) Expanding the Reach of Federated Learning by Reducing Client Resource Requirements. arXiv preprint arXiv:1812.07210. Cited by: §1, §6.
  • [6] R. C. Geyer, T. Klein, and M. Nabi (2017) Differentially Private Federated Learning: A Client Level Perspective. arXiv preprint arXiv:1712.07557. Cited by: §6.
  • [7] A. S. Golatkar, A. Achille, and S. Soatto (2019) Time Matters in Regularizing Deep Networks: Weight Decay and Data Augmentation Affect Early Learning Dynamics, Matter Little Near Convergence. Proc. of NeurIPS. Cited by: §1, §2.2.
  • [8] K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep Residual Learning for Image Recognition. In Proc. of IEEE CVPR, Cited by: §A.1, §3.1.
  • [9] B. Hitaj, G. Ateniese, and F. Perez-Cruz (2017) Deep Models under the GAN: Information Leakage from Collaborative Deep Learning. In Proc. of ACM CCS, Cited by: §6.
  • [10] A. Imteaj, U. Thakker, S. Wang, J. Li, and M. H. Amini (2021) A Survey on Federated Learning for Resource-Constrained IoT Devices. IEEE Internet of Things Journal. Cited by: §1.
  • [11] S. Jastrzebski, D. Arpit, O. Astrand, G. B. Kerg, H. Wang, C. Xiong, R. Socher, K. Cho, and K. J. Geras (2021) Catastrophic Fisher Explosion: Early Phase Fisher Matrix Impacts Generalization. In Proc. of ICML, Cited by: §1, §2.2.
  • [12] S. Jastrzębski, Z. Kenton, D. Arpit, N. Ballas, A. Fischer, Y. Bengio, and A. Storkey (2017) Three Factors Influencing Minima in SGD. arXiv preprint arXiv:1711.04623. Cited by: §4.3.
  • [13] S. Jastrzebski, Z. Kenton, N. Ballas, A. Fischer, Y. Bengio, and A. J. Storkey (2019) On the Relation Between the Sharpest Directions of DNN Loss and the SGD Step Length. In Proc. of ICLR, Cited by: §1, §2.2.
  • [14] P. Kairouz, H. B. McMahan, B. Avent, A. Bellet, M. Bennis, A. N. Bhagoji, K. Bonawitz, Z. Charles, G. Cormode, R. Cummings, et al. (2019) Advances and Open Problems in Federated Learning. arXiv preprint arXiv:1912.04977. Cited by: §1.
  • [15] S. P. Karimireddy, S. Kale, M. Mohri, S. Reddi, S. Stich, and A. T. Suresh (2020) SCAFFOLD: Stochastic Controlled Averaging for Federated Learning. In Proc. of ICML, Cited by: §1, §2.1, §6.
  • [16] J. Konečnỳ, H. B. McMahan, F. X. Yu, P. Richtárik, A. T. Suresh, and D. Bacon (2016) Federated Learning: Strategies for Improving Communication Efficiency. arXiv preprint arXiv:1610.05492. Cited by: §1, §6.
  • [17] A. Krizhevsky, G. Hinton, et al. (2009) Learning Multiple Layers of Features from Tiny Images. Cited by: §A.1, §3.1.
  • [18] F. Lai, X. Zhu, H. V. Madhyastha, and M. Chowdhury (2021) Oort: Efficient Federated Learning via Guided Participant Selection. In Proc. of USENIX OSDI, Cited by: §1, §6.
  • [19] T. Li, A. K. Sahu, M. Zaheer, M. Sanjabi, A. Talwalkar, and V. Smith (2020) Federated Optimization in Heterogeneous Networks. In Proc. of MLSys, Cited by: §2.1.
  • [20] X. Li, K. Huang, W. Yang, S. Wang, and Z. Zhang (2020) On the Convergence of FedAvg on Non-IID Data. In Proc. of ICLR, Cited by: §6.
  • [21] F. Liu, X. Wu, S. Ge, W. Fan, and Y. Zou (2020) Federated Learning for Vision-and-Language Grounding Problems. In Proc. of AAAI, Cited by: §6.
  • [22] J. Martens (2014) New Insights and Perspectives on the Natural Gradient Method. arXiv preprint arXiv:1412.1193. Cited by: §4.1.
  • [23] B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas (2017) Communication-Efficient Learning of Deep Networks from Decentralized Data. In Proc. of AISTATS, Cited by: §A.1, §1, §2.1, §3.1, §6.
  • [24] L. Melis, C. Song, E. De Cristofaro, and V. Shmatikov (2019) Exploiting Unintended Feature Leakage in Collaborative Learning. In Proc. of IEEE S&P, Cited by: §6.
  • [25] M. Mohri, G. Sivek, and A. T. Suresh (2019) Agnostic Federated Learning. In Proc. of ICML, Cited by: §6.
  • [26] D. Rothchild, A. Panda, E. Ullah, N. Ivkin, I. Stoica, V. Braverman, J. Gonzalez, and R. Arora (2020) FetchSGD: Communication-Efficient Federated Learning with Sketching. In Proc. of ICML, Cited by: §2.1.
  • [27] S. L. Smith, P. Kindermans, C. Ying, and Q. V. Le (2018) Don’t Decay the Learning Rate, Increase the Batch Size. In Proc. of ICLR, Cited by: §4.3.
  • [28] V. Smith, C. Chiang, M. Sanjabi, and A. Talwalkar (2017) Federated Multi-Task Learning. In Proc. of NIPS, Cited by: §6.
  • [29] A. T. Suresh, X. Y. Felix, S. Kumar, and H. B. McMahan (2017)

    Distributed Mean Estimation with Limited Communication

    .
    In Proc. of ICML, Cited by: §1, §6.
  • [30] H. Wang, Z. Kaplan, D. Niu, and B. Li (2020)

    Optimizing Federated Learning on Non-IID Data With Reinforcement Learning

    .
    In Proc. of IEEE INFOCOM, Cited by: §1, §6.
  • [31] H. Wang, M. Yurochkin, Y. Sun, D. Papailiopoulos, and Y. Khazaeni (2020) Federated Learning with Matched Averaging. In Proc. of ICLR, Cited by: §6.
  • [32] J. Wang and G. Joshi (2019) Adaptive Communication Strategies to Achieve the Best Error-Runtime Trade-off in Local-update SGD. In Proc. of SysML, Cited by: §1, §6.
  • [33] S. Wang, T. Tuor, T. Salonidis, K. K. Leung, C. Makaya, T. He, and K. Chan (2019) Adaptive Federated Learning in Resource Constrained Edge Computing Systems. IEEE Journal on Selected Areas in Communications 37 (6), pp. 1205–1221. Cited by: §1, §6.
  • [34] G. Xiong, G. Yan, and J. Li (2021) Straggler-Resilient Distributed Machine Learning with Dynamic Backup Workers. arXiv preprint arXiv:2102.06280. Cited by: §1, §6.
  • [35] Z. Xu, Z. Yang, J. Xiong, J. Yang, and X. Chen (2019) ELFISH: Resource-Aware Federated Learning on Heterogeneous Edge Devices. arXiv preprint arXiv:1912.01684. Cited by: §1, §6.
  • [36] T. Yang, G. Andrew, H. Eichner, H. Sun, W. Li, N. Kong, D. Ramage, and F. Beaufays (2018) Applied Federated Learning: Improving Google Keyboard Query Suggestions. arXiv preprint arXiv:1812.02903. Cited by: §1.
  • [37] L. Zhu, Z. Liu, and S. Han (2019) Deep Leakage from Gradients. Proc. of NeurIPS. Cited by: §6.

Appendix A Appendix

In the appendix, we will provide our experiment details, parameter settings, and additional experimental results.

a.1 Dataset and Model

To further observe critical learning periods in FL, we conduct extensive experiments with the CIFAR-10 dataset and the CIFAR-100 dataset [17] under both IID and Non-IID settings. The CIFAR-10 dataset consists of color images in classes, where samples are for training and the other samples for testing. Unlike CIFAR-10, CIFAR-100 has classes. We use the same strategy to distribute the data over different workers as suggested by [23]. For the non-IID setting, we first divide each class of training data into ten parts, randomly assign three parts from different classes to each worker. For the IID setting, we evenly partition all training data among all workers. We consider two representative models: the ResNet-18 model [8] and a five-layer CNN model, as shown in Table 1 and Table 2. We run the experiments on PyTorch on Python3 with an NVIDIA RTX 3060 GPU.

Layer Size kernel_size
conv1 64 3
conv2.x 3
conv3.x 3
conv4.x 3
conv5.x 3
avg_pool2d 4 -
Linear 10 -
Table 1: The architecture of the ResNet-18 model.
Layer Size kernel_size
conv2d 64 3
conv2d 192 3
conv2d 384 3
conv2d 256 3
conv2d 256 3
Dropout default default
Linear 128 -
ReLU - -
Dropout default default
Linear 128 -
ReLU - -
Linear 10 -
Table 2: The architecture of the five-layer CNN model.
Figure 9: FL exhibits critical learning periods: ResNet-18 on both IID and Non-IID CIFAR-100 with FedAvg.
Figure 10: FL exhibits critical learning periods: CNN on both IID and Non-IID CIFAR-10 with FedAvg.
Figure 11: FL exhibits critical learning periods: CNN on both IID and Non-IID CIFAR-100 with FedAvg.
Figure 12: Connections between critical learning periods in FL and the Federated Fisher information achieved by CNN on IID CIFAR-10 with FedAvg using 50% of local datasets for training initially and recover to the entire datasets upon the recover round. (a) Test accuracy v.s. recover rounds: the final test accuracy is permanently impaired if the training dataset is not fully recovered at as early as the 20-th round. (b) Trace of FedFIM v.s. recover round. There exists a sharp increase of the trace of FedFIM in the early training phases. (c) Weighted cumulative sum of the trace of FedFIM vs. recover round.

a.2 Additional Results on Critical Learning Periods in FL

Complementary to Section 3 “Critical Learning Periods in FL” on running the ResNet-18 model on CIFAR-10, we provide additional experimental results by running ResNet-18 on CIFAR-100, and CNN model on both CIFAR-10 and CIFAR-100 datasets. We consider an FL system running FedAvg with clients that randomly selects a subset of clients in each round. The batch size is , the initial learning rate is set to for ResNet-18 and for CNN with a decay of per round, and the SGD solver is adopted using an exponential annealing scheduling for the learning rate with a weight decay of .

We reproduce critical learning periods in FL when run ResNet-18 on CIFAR-100 under both IID and Non-IID settings, and CNN on CIFAR-10 and CIFAR-100 under both IID and Non-IID settings with FedAvg as shown in Figures 910, and 11, respectively. We observe that the critical learning periods consistently exist across all settings with different ratios of local datasets involved in the early learning phase. For example, if the CIFAR-10 training dataset is not recovered to the entire dataset at as early as the -th communication rounds, the final test accuracy will be severely degraded compared to the standard FedAvg with the entire dataset. Comparing among different ratios of local datasets involved in early training phase, it is not too surprising to see that a lower of local datasets in the early training phase makes drawing critical learning periods easier. Similarly, it is clear that the communication rounds are significantly increased with a lower final test accuracy as a function of the recover round . This further indicates the importance of the initial learning phase in determining the FL final performance.

Figure 13: Connections between critical learning periods in FL and the Federated Fisher information achieved by CNN on Non-IID CIFAR-10 with FedAvg using 50% of local datasets for training initially and recover to the entire datasets upon the recover round.

Appendix B Additional Results on Federated Fisher Information

Complementary to Section 4 “Federated Fisher Information” on running the ResNet-18 model on CIFAR-10, we provide additional experimental results by running the five-layer CNN model on CIFAR-10 and CIFAR-100.

CNN on CIFAR-10. We conduct similar experiments as in Figure 10 with partial local datasets involved in the initial learning phases and the training datasets recover to the entire datasets at the “recover rounds” (RC#). Figures 12 and 13 present the test accuracy and the trace of FedFIM with different recover rounds when , on IID and Non-IID CIFAR-10, respectively. Again, we observe the existence of critical learning periods. Furthermore, this information can be properly reflected via the trace of FedFIM as shown in Figure 12(b) for IID case and Figure 13(b) for Non-IID case. There is a sharp increase in the trace of the FedFIM in the early phases of the FL training process, coinciding with a dramatic increase of the test accuracy in the early training phase. The information starts to decrease when the test accuracy starts to plateau. Since the training datasets are recovered from of local datasets to the entire datasets at the recover rounds, additional data further boosts the test accuracy as shown in Figure 12(b) and Figure 13(b). However, such a test accuracy boost decreases significantly as the recover rounds increase. This further suggests that the initial learning phases play a critical role in the FL and the permanent model degradation is irreversible no matter how much additional training is performed after the critical learning periods.

We further consider the weighted cumulative sum of the trace of FedFIM. The trace of FedFIM represents the degree of how good the local data is to improve the model—a larger values correspond to less model information—exactly shwon in Figure 12 (c) and Figure 13 (c), where a late recovery results in a larger weight cumulative trace.

CNN on CIFAR-100. Similar observations and analysis apply to running CNN on CIFAR-100 as shown in Figures 14 and 15, and hence are omitted here.

Figure 14: Connections between critical learning periods in FL and the Federated Fisher information achieved by CNN on IID CIFAR-100 with FedAvg using 50% of local datasets for training initially and recover to the entire datasets upon the recover round.
Figure 15: Connections between critical learning periods in FL and the Federated Fisher information achieved by CNN on Non-IID CIFAR-100 with FedAvg using 50% of local datasets for training initially and recover to the entire datasets upon the recover round.
Figure 16: Seizing the critical learning periods in FL training with CNN on IID CIFAR-10.
Figure 17: Seizing the critical learning periods in FL training with CNN on Non-IID CIFAR-10.

Appendix C Additional Results on Seizing Critical Learning Periods

Complementary to Section 5 “Seizing Critical Learning Periods,” we further evaluate the idea that seizes critical learning periods to improve the FL training efficiency. Existing literature largely ignore the critical learning periods in FL training process. In particular, we consider running the five-layer CNN on CIFAR-10. The total number of clients is and a subset of clients are randomly selected in each round. We consider the same settings as for ResNet-18 on CIFAR-10.

Figure 16(a) and Figure 17(a) show the test accuracy v.s. wall-clock time. Again, we observe that there exists a requirement on the number of clients involved in training which provides similar test accuracy as using all clients (FedAvg) throughput. For example, with all clients participating in the FL training during the critical learning periods, and then only 60% of clients afterwards, the final test accuracy is similar to that using all clients throughout the training process. Hence there is no need to involve all clients throughout the training process.

Figure 16(b) and Figure 17(b) show the train loss v.s. wall-clock time. The participated client number requirement reduces the training time than using all clients (FedAvg) throughput. It is clear that leveraging critical learning periods for FL training, even in a heuristic manner, can significantly improve the training efficiency with a reduced training time while maintaining final test accuracy.

Figure 16(c) and Figure 17(c) show the test accuracy v.s. wall-clock time. Again, we observe that there exists a training dataset requirement which provides similar test accuracy as using the entire dataset (FedAvg) throughput. For example, with the entire training datasets used in the FL training during the critical learning periods, and then only 25% of local datasets afterwards, the final test accuracy is similar to that using the entire datasets throughout the training process. Hence, there is no need to use the entire training datasets throughout the training process. Figure 16(d) and Figure 17(d) present the train loss v.s. wall-clock time, the training dataset requirement (the heuristic) reduces the training time than using the entire dataset (FedAvg) throughput. Again, we observe that the early learning phase plays a critical role in FL performance, and leveraging it can significantly improve the training efficiency of FL.