SCAFFOLD: Stochastic Controlled Averaging for On-Device Federated Learning

10/14/2019 ∙ by Sai Praneeth Karimireddy, et al. ∙ 7

Federated learning is a key scenario in modern large-scale machine learning. In that scenario, the training data remains distributed over a large number of clients, which may be phones, other mobile devices, or network sensors and a centralized model is learned without ever transmitting client data over the network. The standard optimization algorithm used in this scenario is Federated Averaging (FedAvg). However, when client data is heterogeneous, which is typical in applications, FedAvg does not admit a favorable convergence guarantee. This is because local updates on clients can drift apart, which also explains the slow convergence and hard-to-tune nature of FedAvg in practice. This paper presents a new Stochastic Controlled Averaging algorithm (SCAFFOLD) which uses control variates to reduce the drift between different clients. We prove that the algorithm requires significantly fewer rounds of communication and benefits from favorable convergence guarantees.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

A learning scenario playing a key role in modern large-scale applications is that of federated learning. Unlike standard settings where models are trained using large datasets stored in a central server (dean2012large; iandola2016firecaffe; goyal2017accurate), in federated learning, the training data remains distributed over a large number of clients, which may be phones, other mobile devices, network sensors, or alternative local information sources (konevcny2016federated; konecny2016federated2; mcmahan2017communication; mohri2019agnostic). A centralized model is thus learned without ever transmitting client data over the network, thereby ensuring a basic level of privacy and limiting network communication.

This centralized model benefits from all client data and can often result in a beneficial performance, as reported in several tasks, including next word prediction (hard2018federated; yang2018applied), emoji prediction (ramaswamy2019federated), decoder models (chen2019federatedb)

, vocabulary estimation

(chen2019federated), low latency vehicle-to-vehicle communication (samarakoon2018federated), and predictive models in health (brisimi2018federated). Nevertheless, federated learning raises several types of issues and has been the topic of multiple research efforts studying the learning and generalization properties in that scenario (mohri2019agnostic), systems, networking and communication bottleneck problems due to frequent exchanges between the central server and the clients, with unreliable or relatively slow network connections (McMahanMooreRamageHampsonAguera2017), and many others.

This paper deals with the key question of the optimization task in federated learning and specifically that of designing an efficient optimization solution with convergence guarantees. The optimization task in federated learning has been the topic of multiple research work. That includes the design of more efficient communication strategies (konevcny2016federated; konecny2016federated2; suresh2017distributed; stich2018sparsified; karimireddy2019error; basu2019qsparse), the study of lower bounds for parallel stochastic optimization with a dependency graph (WoodworthWangSmithMcMahanSrebro2018), devising efficient distributed optimization methods benefiting from differential privacy guarantees (agarwal2018cpsgd), stochastic optimization solutions for the agnostic formulation (mohri2019agnostic), and incorporating cryptographic techniques (bonawitz2017practical), see (li2019federated) for an in-depth survey of recent work in federated learning.

Training in federated learning typically involves alternating rounds of communication and local updates. At each such round, a subset

of clients are sampled and each sampled client receives the shared global model. Clients then perform local updates to this model, which involve only their local training data. Then, at the end of the round, the clients sampled send their updates to the server, which aggregates the updates to form the new global model. There are three key aspects which differentiate federated learning from parallel or distributed training: (1) the data, and thus the loss function, on the different clients may be very heterogeneous and this far from being representative of the joint data; (2) only a small subset of the devices selected by a central server participate in each round; (3) the server never keeps track of any individual client information and only uses

aggregates to ensure privacy.

The standard optimization algorithm for federated learning is Federated Averaging (FedAvg) (mcmahan2017communication). For this algorithm, the subset of clients participating in the current round receive the global parameters . Each client performs a fixed (say ) steps of SGD using its local data and outputs the update

. The updates are then aggregated to update the global parameters. However, FedAvg does not benefit from a favorable convergence guarantee and can be quite slow when client data is heterogeneous, which is typical in applications. Empirically, FedAvg is known to be sensitive to its hyperparameters and tends to diverge if not chosen carefully. This, along with its slow convergence, can make it hard to use out of the box

(li2019federated, Sec 2.3). This is due to the key problem of drifting of client updates, which we now briefly discuss.

We distinguish between the server optimum , parameters which work well for the combined data, and client ’s optimum , parameters which work well on the client data for client . Since client data is heterogeneous, the server optimum is usually quite different from the client optima . Suppose we run FedAvg starting close to the server optimum . Each client updates its local model towards since is not the client optimum. This drift away from the true server optimum suggests that, to ensure convergence, FedAvg requires a carefully decreasing step-size sequence.

The FedProx algorithm by li2019federated seeks to minimize the local drift by imposing additional regularization on each client. While this can slightly reduce the effect, it does not eliminate it. This argument can be formalized to prove that FedAvg (and FedProx) are necessarily significantly slower than standard SGD, even without any stochasticity and with all clients participating at every round. This is because standard SGD ensures that the local clients are always in sync through frequent communication.

The main idea behind the design of our Stochastic Controlled Averaging algorithm is to use control variates to reduce client drift and ensure that the client updates are aligned with each other. Each client is assigned a control variate and the global control variate is defined to be their uniform average . The control variate represents the direction of the local update we expect to see from client and to be the aggregate direction in which the server updates. Given access to and , the client can perform the following correction to its local update to better align itself with the server update

Assume that is set to be equal to . Then, the corrected update on every client is exactly the server update, thereby removing all drift. While this fixes the issue, we are left with a chicken-and-egg problem: we need to know the client update direction in order to set and we need in order to compute . We break the cycle by using only an estimate for in order to set . After performing the actual update, this estimate can be further refined. This leads to our new optimization algorithm SCAFFOLD, which we describe more formally and analyze in detail in the next sections.

Related work. For identical clients, FedAvg coincides with parallel SGD analyzed by zinkevich2010parallelized who proved identical asymptotic convergence. stich2018local and, more recently stich2019error and patel2019communication, gave a sharper analysis of the same method, under the name of local SGD, also for identical functions. However, there still remains a gap between their upper bounds and the lower bound of woodworth2018graph.

The analysis of FedAvg for heterogeneous clients is more delicate since it faces the local client drift issue discussed earlier. Several analyses bound this drift by using a very small step-size and assuming that the local updates admit bounded magnitude (wang2019adaptive; yu2019parallel; li2019convergence). Some other analyses view the drift as a second source of stochastic noise and provide guarantees asymptotically worse than standard SGD (khaled2019first). Similarly, li2019federated prove convergence under an assumption which effectively implies that the client optima are -close, and therefore that the drift is negligible. Finally, zhao2018federated propose global sharing of the clients’ data. While this does address the drifting issue, sending client data defeats the framework and the main purpose of federated learning.

The use of control variates

is a classical technique to reduce variance in Monte Carlo sampling methods (cf.

glasserman2013monte). In optimization, they were used for finite-sum minimization by SVRG (johnson2013accelerating; zhang2013linear) and then in SAGA (defazio2014saga) to simplify the linearly convergent method SAG (schmidt2017minimizing). Numerous variations and extensions of the technique are studied in (hanzely2019one). In a very similar vein, control variates were used to obtain linearly converging decentralized algorithms under the guise of ‘gradient-tracking’ in (shi2015extra; nedich2016geometrically) and for gradient compression as ‘compressed-differences’ in (mishchenko2019distributed). Our method can be viewed as seeking to remove the ‘variance’ in the gradients across the clients, though there still remains additional stochasticity.

The problem of drifting we described is a common phenomenon in distributed optimization. In fact, classic techniques such as ADMM mitigate this drift, though they are not applicable in federated learning. For well structured convex problems, CoCoA uses the dual variable as the control variates, enabling flexible distributed methods (smith2016cocoa). DANE by shamir2014communication obtain a closely related primal only algorithm, which was later accelerated by reddi2016aide. Stochastic Controlled Averaging can be viewed as an improved version of DANE where, instead of solving a proximal sub-problem at every iteration, a fixed number of (stochastic) gradient steps are taken.

The rest of the paper is organized as follows. In Section 2, we describe the optimization problem we consider, describe the assumptions adopted about client functions and specify the notation used. In Section 3, we describe our Stochastic Controlled Averaging algorithm in the simpler case where there is no sampling of clients. The convergence analysis of the algorithm is presented in Section 4. In Section 5, we discuss the more general setup of our algorithm relevant to federated learning where, at each round, a subset of clients is sampled. The convergence analysis is presented in Section 6.

2 Problem setup

2.1 Optimization problem

The problem we consider is that of minimizing a sum of stochastic functions, with only access to stochastic samples:

The functions are present on separate clients which can intermittently communicate amongst themselves. Our results also extend to the case when functions are weighted with respect to the number of samples .

2.2 Assumptions

We will adopt the following standard assumptions:

  1. Each function is -smooth and for any and satisfies

    In particular, using and the convexity of this implies that

    (1)

    Further, by the Cauchy-Schwarz inequality, the smoothness of implies that the gradient of is Lipschitz and for and gives

    (2)
  2. Each function is -strongly convex and for any and satisfies

    (3)
  3. We are given an independent unbiased stochastic gradient with and bounded variance

    (4)

Note that we do not make any assumptions regarding the similarity between functions .

2.3 Notation

Here, we summarize the notation used throughout the paper:

  • denotes the euclidean norm, and .

  • we have clients, rounds of communication, and (local) client update steps between two communication rounds.

  • represents the model parameters of the server after round of communication.

  • for and represents the model of client after performing local update steps in round .

  • and represent the control variates at the server and client respectively computed after round . We always maintain the invariant that .

1:server input initial parameters , control variate , and global step-size
2: client input for each local control variate , and local step-size
3:for each communication round  do
4:     select a subset of clients
5:     communicate to all clients
6:     on each client do

7:         initialize local parameters
8:         for each local step  do
9:              compute a stochastic gradient of
10:               local updates with correction
11:         end for
12:          (i) , or (ii) compute new control variate
13:         communicate
14:          update control variate
15:     end client

16:      and aggregate client outputs
17:      and update parameters and control
18:end for
Algorithm 1 SCAFFOLD: Stochastic Controlled Averaging for federated learning

3 Stochastic Controlled Averaging algorithm – without sampling of clients

Our algorithm (with client sampling) is presented in Algorithm 1. As a warm-up, we will study the case when all the clients participates every round. We will use the following notation: represent the client models, is the aggregate server model, and and are the client and server control variates. Client in round performs the following updates

  • Starting from the shared global parameters , we update the local parameters for

    (5)
  • Update the control iterates using any of the following options:

    (6)
  • Compute the new global parameters and global control variate

    (7)

Note that if we remove the correction in (5) or equivalently always set in (6), we recover the standard FedAvg algorithm. As we discussed previously, the main issue with FedAvg is that the updates of the clients and may be very different from each other leading to ‘drift’. The correction is introduced to exactly reduce this drift. For example, suppose that we can set the control variate every step to be in (5), then the update becomes identical for all clients

Unfortunately, we cannot set since computing the corresponding global control iterate requires all the clients communicating with each other every step. We will instead use the easily computable with for the whole round (option I in (6)). Since the gradient of is Lipschitz, we can hope that as long as our local updates are not too large and . This idea of control iterates is inspired by (and is similar to) those used for variance reduction in SVRG (johnson2013accelerating) and SAGA (defazio2014saga).

There are other choices of the control variate which are correlated with and hence also suffice. One could use an update similar to that of SARAH (nguyen2017stochastic) and continuously perform local updates to instead of keeping it fixed. In another approach, option II in (6) which is known as gradient-tracking in decentralized algorithms (shi2015extra; nedich2016geometrically) uses

By using an average of many stochastic gradients, option II has lower variance at the cost of slightly higher bias. This option results in a method similar to a very recent independent work of anon2019variance. However, in general i) their algorithm and proof can not be extended to support client sampling, and ii) they do not use a global step-size and hence their rates have a worse dependence on the number of clients .

The final output of the algorithm is a weighted average for some positive weights for

(8)

4 Convergence analysis – without sampling

We show the following rate of convergence for strongly-convex functions. Similar extensions can be derived for the general convex, and non-convex settings. Suppose that each of the functions satisfies assumptions 2.22.2. Then, there exist weights and local step-sizes such that for any the output (8) generated using (5)–(7) for any satisfies111We use to suppress constant factors and to hide logarithmic terms.:

Thus by setting , we get a communication complexity of . When the variance is large (which is usually the case) or if the required accuracy is small, the rate of convergence is typically dominated by . The result above shows that increasing the number of local steps as well as number of clients can decrease the number of communication rounds required.

Theorem 4 improves upon the best-know upper bounds even when all functions are identical by a factor —in comparison, stich2019error show a communication complexity of for identical functions. When our communication complexity becomes and nearly matches the lower bound of for distinct functions by arjevani2015communication. The improved square-root dependence on condition number can be achieved via acceleration (nesterov2018lectures), a direction which we do not explore here.

Finally, note that we are free to choose an which is even larger than while retaining the same rate. However, the local step-size also correspondingly becomes smaller. In the limit when , and we recover SGD with a large batch size of . Thus, we fail to show a strict improvement due to picking a small . This is not surprising since the lower-bound by arjevani2015communication rules out the possibility of improvement over SGD for general convex functions. In fact, even when all functions are identical (in which case the lower bound of arjevani2015communication does not apply), showing a strict advantage of taking local steps remains an open question.

4.1 Usefulness of control variates

Let us examine how our correction using control variates might help us. We first show that by using the control variates, the server update direction does not change. We drop the indices for the round and local step whenever obvious from context.

  1. The local update (5) of SCAFFOLD aggregated across clients is similar to that of FedAvg

While the aggregate update direction remains unchanged, the drift across the clients is reduced by our use of control variates. For simplicity, in this section we only examine (option I) choice of in (6) and delay the general case for Section 6.

  1. The drift from the starting point due to the local updates of SCAFFOLD is bounded as follows for any and

The proof of (4.1) can be found in Appendix D.1. Suppose we start from the optimum point and , and further . Then we can set in (4.1) to get

Thus, SCAFFOLD overcomes the problem of client drift (at least close to the optimum). Further, the above argument proves that, unlike for FedAvg, the optimum is a fixed point of SCAFFOLD (up to the noise in the stochastic gradients).

4.2 Proof summary

We can show the following progress for our algorithm between two communication rounds.

Lemma 1

(one round progress) Suppose our updates satisfy (4.1) and assumptions 2.22.2. For any step-size satisfying and effective step-size ,

where is the drift caused by the local updates on the clients

Here, is the expectation over all randomness in round , and conditioned on . The lemma above is valid for any algorithm which satisfies (4.1) and is hence also valid for FedAvg. The difference between the two algorithms is bound on the drift-term which becomes smaller as we approach the optimum, and does not depend on the heterogeneity across the functions.

Lemma 2

(bounded drift) Suppose our updates satisfy (4.1) and assumptions 2.22.2. For any step-size satisfying , then we can bound the drift as

Note that (4.1) was shown only for (option I) choice of in (6). Option II and other variations are analyzed in Section 6.

Proof of Theorem 4

Combining Lemmas 1 and 2 gives the following recurrence

Rearranging the equation above gives the following one step progress

with the notation that , , and . Now, Lemma 5 is applicable with . Using the step-size and weights as defined in Lemma 5, the following holds for all :

Using convexity of completes the proof of the theorem.

Note that the local step-size we need to take to get optimal rates is bounded as . The reason why we need to scale by is because if the functions are completely unrelated to each other, then taking multiple local-steps may not really help the optimization of the average function. In practice, larger step-sizes can be used since the functions are typically more closely related to each other.

5 Stochastic Controlled Averaging algorithm – with sampling

Control variates can also be used even when only a small subset of devices (say ) participate each round. We will describe the algorithm using notation here which is convenient for the proofs: represent the client models, is the aggregate server model, and and are the client and server control variates. For an equivalent description which is easier to implement, we refer to Algorithm 1. The server maintains a global control variate as before and each client maintains its own control variate . In round , a subset of clients of size are sampled uniformly from . Suppose that every client performs the following updates

  • Starting from the shared global parameters , we update the local parameters for

    (9)
  • Update the control iterates using (option II):

    (10)

    We update the local control variates only for clients

    (11)
  • Compute the new global parameters and global control variate using only updates from the clients :

    (12)

Note that the clients are agnostic to the sampling and their updates are identical to when all clients are participating. Also note that the control variate choice (10) corresponds to (option II) in step 12 of Algorithm 1. Further, the updates of the clients is forgotten and is defined only to make the proofs easier. While actually implementing the method, only clients participate and the rest remain inactive (see Algorithm 1).

After running for rounds of communication, the final output of the algorithm is, as before, a weighted average for some positive weights for

(13)

To get some intuition about the new method, examine what happens when the number of machines () is large and . Let , making the local and global iterates are the same i.e. . Also suppose that implying . Then the update in round can be written for a randomly chosen to be

In this setting, SCAFFOLD turns out to be equivalent to SAGA (defazio2014saga). In the other extreme, if all the data is on a single machine () and , the method reduces to the standard SGD updates as expected. Thus, SCAFFOLD captures a wide range of algorithms and their corresponding rates as special cases, while simultaneously generalizing to many new useful settings of the parameters , , and .

6 Convergence analysis – with sampling of clients

We prove the following rate of convergence for strongly convex functions. One can extend our technique to derive rates for general convex functions and non-convex functions. Suppose that each of the functions satisfies assumptions 2.22.2. Then, there exist weights and local step-sizes such that for any the output (13) of Algorithm 1 using option (ii) in step 12 satisfies for all :

where . Setting gives a communication complexity of . If all clients participate and , we recover the communication complexity of given by Theorem 4. This proves Theorem 4 even when option (II) is used for the update of the control variate. If , then the additional term is necessary since we would need to communicate with every device at least once. Also note that when , Theorem 6 recovers the linear rate which matches that of SAGA. In fact, when we obtain an interesting generalization of SAGA with additional local steps.

Instead of counting each round of communication as a single unit of cost, we can count the total amount of communication received from clients. This is an important metric since it represents the amount of work done by the clients and also captures the cost to privacy. With this metric the communication cost with all devices participating is . In contrast, with sampling the algorithm has a cost . Since typically , this represents a significant reduction.

6.1 Overcoming sampling with control variates

The main question is what variant of properties (4.1) and (4.1) still hold when we sample a small number of clients. Consider the accumulated local updates of the clients starting from local models :

  1. It is easy to see that where the expectation is over the sampling

Thus, in expectation over the sampling of , our update matches that of the usual FedAvg. We now examine what we can say about the drift of the local client parameters. The challenge with bounding the drift with client sampling is that even with re syncing every step (i.e. ) SGD may drift at the optimum due to the variance across the clients. Thus, the use of control variates is critical here. We now analyze the general case with any choice of control variate.

  1. The drift from the starting point due to the local update of SCAFFOLD in local step or rounds is bounded as below for any and

The proof of (6.1) can be found in Appendix E.1. By comparing (6.1) with (4.1) we see that the final term which depends on the control iterates to be extra. Thus any choice of which ‘learns’ (up to the noise in the stochastic gradients) as the algorithm progresses would suffice. E.g. suppose we start at the optimum . Then, by setting and assuming , we see once again that using in (6.1) proves that there is no drift at the optimum

The challenge here is then to bound the term as the algorithm progresses.

6.2 Proof summary

Just like in the full sampling case, we can prove the following progress between two communication rounds. For notational convenience, assume that for all and .

Lemma 3

(one round progress) Suppose our updates satisfy (4.1) and assumptions 2.22.2. Then the following holds for any step-size satisfying , effective step-size , and control variates updated using (10),

where is the error in our control variate defined as

and is the drift caused by the local updates on the clients

In addition to keeping track of distance from the optimum as we did in Lemma 1, we also need to keep track how far our control variate is from its value at the optimum using . This is because there is some ‘lag’ in our updates of the control variates since only a small subset of them are updated each round. We can also bound the drift term .

Lemma 4

(bounded drift) Suppose our step-sizes satisfy and satisfies assumptions 2.22.2. Then, for any global we can bound the drift as

Here again, the optimal step-size should not scale as but should be much larger depending on the similarity between the functions. There is an additional bound on the step-size depending on the number of clients sampled . However, typically is very small making reasonably large. Hence the condition that can be safely ignored while setting the learning rate in practice.

Proof of Theorem 6

Combining Lemmas 3 and 4 and rearranging the terms we can show that for any weights

where , , and for any . The rest proceeds exactly as in the proof of Theorem 4.

7 Conclusion

We observe that FedAvg may experience ‘drift’ due to the updates of heterogeneous local clients, leading to slow convergence and necessitating careful learning rate scheduling. We instead propose SCAFFOLD, a new method which uses control variates to overcome this issue and prove that it has excellent theoretical properties. We believe that the increased stability of SCAFFOLD to heterogeneity of the clients even with sampling would make it easy to tune in practice. This, along with the ease of implementation of SCAFFOLD, we believe facilitates easy adoption.

References

Appendix A Algorithm without sampling

Here, we outline our algorithm when all devices participate every round.

1:input initial parameters , control variate , global and local step-sizes
2:initialize for each client control variates
3:for each communication round  do
4:     communicate to all clients
5:     on each client do
6:         initialize local parameters
7:         for each local step  do
8:              compute a stochastic gradient of
9:               local updates with correction
10:         end for
11:          (i) , or (ii) compute new control variate
12:         communicate
13:          update control variate
14:     end client
15:      and aggregate client outputs
16:      and update parameters and control
17:end for
Algorithm 2 Stochastic Controlled Averaging (without sampling)

Appendix B Some technical lemmas

In this section we cover some technical lemmas which are useful for computations later on.

The lemma below is useful to unroll recursions and derive convergence rates.

Lemma 5 (convergence rate)

For every non-negative sequence and any parameters , , , , there exists a constant step-size and weights such that for ,

By substituting the value of , we observe that we end up with a telescoping sum and estimate