LAG: Lazily Aggregated Gradient for Communication-Efficient Distributed Learning

05/25/2018 ∙ by Tianyi Chen, et al. ∙ NetEase, Inc 0

This paper presents a new class of gradient methods for distributed machine learning that adaptively skip the gradient calculations to learn with reduced communication and computation. Simple rules are designed to detect slowly-varying gradients and, therefore, trigger the reuse of outdated gradients. The resultant gradient-based algorithms are termed Lazily Aggregated Gradient --- justifying our acronym LAG used henceforth. Theoretically, the merits of this contribution are: i) the convergence rate is the same as batch gradient descent in strongly-convex, convex, and nonconvex smooth cases; and, ii) if the distributed datasets are heterogeneous (quantified by certain measurable constants), the communication rounds needed to achieve a targeted accuracy are reduced thanks to the adaptive reuse of lagged gradients. Numerical experiments on both synthetic and real data corroborate a significant communication reduction compared to alternatives.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In this paper, we develop communication-efficient algorithms to solve the following problem

(1)

where

is the unknown vector,

and are smooth (but not necessarily convex) functions with . Problem (1) naturally arises in a number of areas, such as multi-agent optimization (Nedic and Ozdaglar, 2009), distributed signal processing (Giannakis et al., 2016; Schizas et al., 2008), and distributed machine learning (Dean et al., 2012). Considering the distributed machine learning paradigm, each is also a sum of functions, e.g., , where

is the loss function (e.g., square or the logistic loss) with respect to the vector

(describing the model) evaluated at the training sample ; that is, . While machine learning tasks are traditionally carried out at a single server, for datasets with massive samples , running gradient-based iterative algorithms at a single server can be prohibitively slow; e.g., the server needs to sequentially compute gradient components given limited processors. A simple yet popular solution in recent years is to parallelize the training across multiple computing units (a.k.a. workers) (Dean et al., 2012). Specifically, assuming batch samples distributedly stored in a total of workers with the worker associated with samples , a globally shared model will be updated at the central server by aggregating gradients computed by workers. Due to bandwidth and privacy concerns, each worker will not upload its data to the server, thus the learning task needs to be performed by iteratively communicating with the server.

We are particularly interested in the scenarios where communication between the central server and the local workers is costly, as is the case with the Federated Learning paradigm (McMahan et al., 2017; Smith et al., 2017), and the cloud-edge AI systems (Stoica et al., 2017). In those cases, communication latency is the bottleneck of overall performance. More precisely, the communication latency is a result of initiating communication links, queueing and propagating the message. For sending small messages, e.g., the -dimensional model or aggregated gradient, this latency dominates the message size-dependent transmission latency. Therefore, it is important to reduce the number of communication rounds, even more so than the bits per round. In short, our goal is to find that minimizes (1) using as low communication overhead as possible.

1.1 Prior art

To put our work in context, we review prior contributions that we group in two categories.

Large-scale machine learning.

Solving (1

) at a single server has been extensively studied for large-scale learning tasks, where the “workhorse approach” is the simple yet efficient stochastic gradient descent (SGD)

(Robbins and Monro, 1951; Bottou, 2010; Bottou et al., 2016). For learning beyond a single server, distributed parallel machine learning is an attractive solution to tackle large-scale learning tasks, where the parameter server architecture is the most commonly used one (Dean et al., 2012; Li et al., 2014). Different from the single server case, parallel implementation of the batch gradient descent (GD) is a popular choice, since SGD that has low complexity per iteration requires a large number of iterations thus communication rounds (McMahan and Ramage, 2017). For traditional parallel learning algorithms however, latency, bandwidth limits, and unexpected drain on resources, that delay the update of even a single worker will slow down the entire system operation. Recent research efforts in this line have been centered on understanding asynchronous-parallel algorithms to speed up machine learning by eliminating costly synchronization; e.g., (Cannelli et al., 2016; Sun et al., 2017; Peng et al., 2016; Recht et al., 2011; Liu et al., 2015).

Communication-efficient learning.

Going beyond single-server learning, the high communication overhead becomes the bottleneck of the overall system performance (McMahan and Ramage, 2017). Communication-efficient learning algorithms have gained popularity (Jordan et al., 2018; Zhang et al., 2013). Distributed learning approaches have been developed based on quantized (gradient) information, e.g., (Suresh et al., 2017), but they only reduce the required bandwidth per communication, not the rounds. For machine learning tasks where the loss function is convex and its conjugate dual is expressible, the dual coordinate ascent-based approaches have been demonstrated to yield impressive empirical performance (Smith et al., 2017; Jaggi et al., 2014; Ma et al., 2017). But these algorithms run in a double-loop manner, and the communication reduction has not been formally quantified. To reduce communication by accelerating convergence, approaches leveraging (inexact) second-order information have been studied in (Shamir et al., 2014; Zhang and Lin, 2015). Roughly speaking, algorithms in (Smith et al., 2017; Jaggi et al., 2014; Ma et al., 2017; Shamir et al., 2014; Zhang and Lin, 2015) reduce communication by increasing local computation (relative to GD), while our method does not increase local computation. In settings different from the one considered in this paper, communication-efficient approaches have been recently studied with triggered communication protocols (Liu et al., 2017; Lan et al., 2017). Except for convergence guarantees however, no theoretical justification for communication reduction has been established in (Liu et al., 2017). While a sublinear convergence rate can be achieved by algorithms in (Lan et al., 2017), the proposed gradient selection rule is nonadaptive and requires double-loop iterations.

1.2 Our contributions

Before introducing our approach, we revisit the popular GD method for (1) in the setting of one parameter server and workers: At iteration , the server broadcasts the current model to all the workers; every worker computes and uploads it to the server; and once receiving gradients from all workers, the server updates the model parameters via

(2)

where is a stepsize, and is an aggregated gradient that summarizes the model change. To implement (2), the server has to communicate with all workers to obtain fresh .

In this context, the present paper puts forward a new batch gradient method (as simple as GD) that can skip communication at certain rounds, which justifies the term Lazily Aggregated Gradient (LAG). With its derivations deferred to Section 2, LAG resembles (2), given by

(3)

where each is either , when , or an outdated gradient that has been computed using an old copy . Instead of requesting fresh gradient from every worker in (2), the twist is to obtain by refining the previous aggregated gradient ; that is, using only the new gradients from the selected workers in , while reusing the outdated gradients from the rest of workers. Therefore, with , LAG in (3) is equivalent to

(4)

where is the difference between two evaluations of at the current iterate and the old copy . If is stored in the server, this simple modification scales down the number of communication rounds per iteration from GD’s to LAG’s .

We develop two different rules to select . The first rule is adopted by the parameter server (PS), and the second one by every worker (WK). At iteration ,

LAG-PS: the server determines and sends to the workers in ; each worker computes and uploads ; workers in do nothing; the server updates via (4);

LAG-WK: the server broadcasts to all workers; every worker computes , and checks if it belongs to ; only the workers in upload ; the server updates via (4).

See a comparison of two LAG variants with GD in Table 1.

Figure 1: LAG in a parameter server setup.

Naively reusing outdated gradients, while saving communication per iteration, can increase the total number of iterations. To keep this number in control, we judiciously design our simple trigger rules so that LAG can: i) achieve the same order of convergence rates (thus iteration complexities) as batch GD under strongly-convex, convex, and nonconvex smooth cases; and, ii) require reduced communication to achieve a targeted learning accuracy, when the distributed datasets are heterogeneous (measured by certain quantity specified later). In certain learning settings, LAG requires only communication of GD. Empirically, we found that LAG can reduce the communication required by GD and other distributed parallel learning methods by several orders of magnitude.

Notation. Bold lowercase letters denote column vectors, which are transposed by . And denotes the -norm of . Inequalities for vectors is defined entrywise.

Metric Communication Computation Memory
Algorithm PSWK  WK PS     PS   WK     PS  WK 
GD (2)
LAG-PS , if , if (4), (15b) , if
LAG-WK , if (4)
Table 1: A comparison of communication, computation and memory requirements. PS denotes the parameter server, WK denotes the worker, PSWK is the communication link from the server to the worker , and WK PS is the communication link from the worker to the server.

2 LAG: Lazily Aggregated Gradient Approach

In this section, we formally develop our LAG method, and present the intuition and basic principles behind its design. The original idea of LAG comes from a simple rewriting of the GD iteration (2) as

(5)

Let us view as a refinement to , and recall that obtaining this refinement requires a round of communication between the server and the worker . Therefore, to save communication, we can skip the server’s communication with the worker if this refinement is small compared to the old gradient; that is, .

Generalizing on this intuition, given the generic outdated gradient components with for a certain , if communicating with some workers will bring only small gradient refinements, we skip those communications (contained in set ) and end up with

(6a)
(6b)

where and are the sets of workers that do and do not communicate with the server, respectively. It is easy to verify that (6) is identical to (3) and (4). Comparing (2) with (6b), when includes more workers, more communication is saved, but is updated by a coarser gradient.

Key to addressing this communication versus accuracy tradeoff is a principled criterion to select a subset of workers that do not communicate with the server at each round. To achieve this “sweet spot,” we will rely on the fundamental descent lemma. For GD, it is given as follows (Nesterov, 2013).

[GD descent in objective] Suppose is -smooth, and is generated by running one-step GD iteration (2) given and stepsize . Then the objective values satisfy

(7)

Likewise, for our wanted iteration (6), the following holds; its proof is given in the Supplement.

[LAG descent in objective] Suppose is -smooth, and is generated by running one-step LAG iteration (4) given . The objective values satisfy (cf. in (4))

(8)

Lemmas 2 and 2estimate the objective value descent by performing one-iteration of the GD and LAG methods, respectively, conditioned on a common iterate . GD finds by performing rounds of communication with all the workers, while LAG yields by performing only rounds of communication with a selected subset of workers. Our pursuit is to select to ensure that LAG enjoys larger per-communication descent than GD; that is

(9)

If we choose the standard in Lemmas 2 and 2, it follows that

(10a)
(10b)

Plugging (10) into (9), and rearranging terms, (9) is equivalent to

(11)

Note that since we have

(12)

if we can further show that

(13)

then we can prove that (11) holds thus (9) also holds.

However, directly checking (13) at each worker is expensive since i) obtaining requires information from all the workers; and ii) each worker does not know . Instead, we approximate in (13) by

(14)

where are constant weights. The rationale here is that, as is smooth, cannot be very different from the recent gradients or the recent iterate lags.

Building upon (13) and (14), we will include worker in of (6) if it satisfies

(15a)
Condition (15a) is checked at the worker side after each worker receives from the server and computes its . If broadcasting is also costly, we can resort to the following server side rule:
(15b)

The values of and admit simple choices, e.g., with used in the simulations.

Algorithm 1 LAG-WK 1:Input: Stepsize , and . 2:Initialize: . 3:for  do 4:     Server broadcasts to all workers. 5:     for worker  do 6:          Worker computes . 7:          Worker checks condition (15a). 8:          if worker violates (15athen 9:               Worker uploads . 10:     Save ) 11:          else 12:               Worker uploads nothing. 13:          end if 14:     end for 15:     Server updates via (4). 16:end for Algorithm 2 LAG-PS 1:Input: Stepsize , , and . 2:Initialize: . 3:for  do 4:     for worker  do 5:          Server checks condition (15b). 6:          if worker violates (15bthen 7:               Server sends to worker . 8:       Save at server 9:               Worker computes . 10:               Worker uploads . 11:          else 12:               No actions at server and worker . 13:          end if 14:     end for 15:     Server updates via (4). 16:end for
Table 2: A comparison of LAG-WK and LAG-PS.

LAG-WK vs LAG-PS. To perform (15a), the server needs to broadcast the current model , and all the workers need to compute the gradient; while performing (15b), the server needs the estimated smoothness constant for all the local functions. On the other hand, as it will be shown in Section 3, (15a) and (15b) lead to the same worst-case convergence guarantees. In practice, however, the server-side condition is more conservative than the worker-side one at communication reduction, because the smoothness of readily implies that satisfying (15b) will necessarily satisfy (15a), but not vice versa. Empirically, (15a) will lead to a larger than that of (15b), and thus extra communication overhead will be saved. Hence, (15a) and (15b) can be chosen according to users’ preferences. LAG-WK and LAG-PS are summarized as Algorithms 1 and 2.

Regarding our proposed LAG method, two remarks are in order.

R1) With recursive update of the lagged gradients in (4) and the lagged iterates in (15), implementing LAG is as simple as GD; see Table 1. Both empirically and theoretically, we will further demonstrate that using lagged gradients even reduces the overall delay by cutting down costly communication.

R2) Compared with existing efforts for communication-efficient learning such as quantized gradient, Nesterov’s acceleration, dual coordinate ascent and second-order methods, LAG is not orthogonal to all of them. Instead, LAG can be combined with these methods to develop even more powerful learning schemes. Extension to the proximal LAG is also possible to cover nonsmooth regularizers.

3 Iteration and communication complexity

In this section, we establish the convergence of LAG, under the following standard conditions.

Assumption 1: Loss function is -smooth, and is -smooth.

Assumption 2: is convex and coercive.

Assumption 3: is -strongly convex, or generally, satisfies the Polyak-Łojasiewicz (PL) condition with the constant ; that is, .

Note that the PL condition in Assumption 3 is strictly weaker than the strongly convexity (or even convexity), and it is satisfied by a wider range of machine learning problems such as least squares for underdetermined linear systems and logistic regression; see details in

(Karimi et al., 2016). While the PL condition is sufficient for the subsequent linear convergence analysis, we will still use the strong convexity for the ease of understanding by a wide audience.

The subsequent analysis critically builds on the following Lyapunov function:

(16)

where is the minimizer of (1), and are constants that will be determined later.

We will start with the sufficient descent of our in (16).

[descent lemma] Under Assumption 1, if and are chosen properly, there exist constants such that the Lyapunov function in (16) satisfies

(17)

which implies the descent in our Lyapunov function, that is, . Lemma 3 is a generalization of GD’s descent lemma. As specified in the supplementary material, under properly chosen , the stepsize including guarantees (17), matching the stepsize region of GD. With and in (16), Lemma 3 reduces to Lemma 2.

3.1 Convergence in strongly convex case

We first present the convergence under the smooth and strongly convex condition.

1[strongly convex case] Under Assumptions 1 and 3, the iterates generated by LAG-WK or LAG-PS satisfy

(18)

where is the minimizer of in (1), and is a constant depending on and as well as the condition number that are specified in the supplementary material.

Iteration complexity. The iteration complexity in its generic form is complicated since depends on the choice of several parameters. Specifically, if we choose the parameters as follows

(19)

then, following Theorem 3.1, the iteration complexity of LAG in this case is

(20)

The iteration complexity in (20) is on the same order of GD’s iteration complexity , but has a worse constant. This is the consequence of using a smaller stepsize in (19) (relative to in GD) to simplify the choice of other parameters. Empirically, LAG with can achieve almost the same empirical iteration complexity as GD; see Section 4. Building on the iteration complexity, we study next the communication complexity of LAG. In the setting of our interest, we define the communication complexity as the total number of uploads over all the workers needed to achieve accuracy . While the accuracy refers to the objective optimality error in the strongly convex case, it is considered as the gradient norm in general (non)convex cases.

Figure 2: Communication events of workers over iterations. Each stick is an upload. An example with .

The power of LAG is best illustrated by numerical examples; see an example of LAG-WK in Figure 2. Clearly, workers with a small smoothness constant communicate with the server less frequently. This intuition will be formally treated in the next lemma. [lazy communication] Define the importance factor of every worker as . If the stepsize and the constants in (15) satisfy and worker satisfies

(21)

then, until iteration , worker communicates with the server at most rounds.

Lemma 2 asserts that if the worker has a small (a close-to-linear loss function) such that , then under LAG, it only communicates with the server at most rounds. This is in contrast to the total of communication rounds involved per worker under GD. Ideally, we want as many workers satisfying (21) as possible, especially when is large.

To quantify the overall communication reduction, we will rely on what we term the heterogeneity score function, given by

(22)

where the indicator equals when holds, and otherwise. Clearly, is a nondecreasing function of , that depends on the distribution of smoothness constants

. It is also instructive to view it as the cumulative distribution function of the

deterministic quantity , implying . Putting it in our context, the critical quantity lower bounds the fraction of workers that communicate with the server at most rounds until the -th iteration.

We are now ready to present the communication complexity.

1[communication complexity] Under the same conditions as those in Theorem 3.1, with defined in (21) and the function defined in (22), the communication complexity of LAG denoted as is bounded by

(23)

where the constant is defined as .

The communication complexity in (23) crucially depends on the iteration complexity as well as what we call the fraction of reduced communication per iteration . Simply choosing the parameters as (19), it follows from (20) and (23) that (cf. )

(24)

where the GD’s complexity is . In (24), due to the nondecreasing property of , increasing the constant yields a smaller fraction of workers that are communicating per iteration, yet with a larger number of iterations (cf. (20)). The key enabler of LAG’s communication reduction is a heterogeneous environment associated with a favorable ensuring that the benefit of increasing is more significant than its effect on increasing iteration complexity. More precisely, for a given , if guarantees , then we have . Intuitively speaking, if there is a large fraction of workers with small , LAG has lower communication complexity than GD. An example follows to illustrate this reduction.

Example. Consider , and , where we have , implying that , if . Choosing and in (19) such that in (21), we have (cf. (24))

(25)

Due to technical issues in the convergence analysis, the current condition on to ensure LAG’s communication reduction is relatively restrictive. Establishing communication reduction on a broader learning setting that matches the LAG’s intriguing empirical performance is in our research agenda.

3.2 Convergence in (non)convex case

LAG’s convergence and communication reduction guarantees go beyond the strongly-convex case. We next establish the convergence of LAG for general convex functions. 2[convex case] Under Assumptions 1 and 2, if and are chosen properly, then the iterates generated by LAG-WK or LAG-PS satisfy

(26)

For nonconvex objective functions, LAG can guarantee the following convergence result. 3[nonconvex case] Under Assumption 1, if and are chosen properly, then the iterates generated by LAG-WK or LAG-PS satisfy

(27)

Theorems 3.2 and 3.2 assert that with the judiciously designed lazy gradient aggregation rules, LAG can achieve order of convergence rate identical to GD for general convex and nonconvex smooth objective functions. Furthermore, we next show that in these general cases, LAG still requires fewer communication rounds than GD, under certain conditions on the heterogeneity function .

In the general smooth (possibly nonconvex) case however, we define the communication complexity in terms of achieving -gradient error; e.g., . Similar to Proposition 2, we present the communication complexity as follows.

2[communication complexity] Under Assumption 1, with defined as in Proposition 2, the communication complexity of LAG denoted as is bounded by

(28)

where is the communication complexity of GD. Choosing the parameters as (19), if the heterogeneity function satisfies that there exists such that , then we have that

(29)

Along with Proposition 2, we have shown that for strongly convex, convex, and nonconvex smooth objective functions, LAG enjoys provably lower communication overhead relative to GD in certain heterogeneous learning settings. In fact, the LAG’s empirical performance gain over GD goes far beyond the above worst-case theoretical analysis, and lies in a much broader distributed learning setting, which is confirmed by the subsequent numerical tests.

4 Numerical tests

To validate the theoretical results, this section evaluates the empirical performance of LAG in linear and logistic regression tasks. All experiments were performed using MATLAB on an Intel CPU @ 3.4 GHz (32 GB RAM) desktop. By default, we consider one server, and nine workers. Throughout the test, we use the optimality error in objective as figure of merit of our solution. To benchmark LAG, we consider the following approaches.

Cyc-IAG is the cyclic version of the incremental aggregated gradient (IAG) method (Blatt et al., 2007; Gurbuzbalaban et al., 2017) that resembles the recursion (4), but communicates with one worker per iteration in a cyclic fashion.

Num-IAG also resembles the recursion (4

), but it randomly selects one worker to obtain a fresh gradient per iteration with the probability of choosing worker

equal to .

Batch-GD is the GD iteration (2) that communicates with all the workers per iteration.

For LAG-WK, we choose with , and for LAG-PS, we choose more aggressive with . Stepsizes for LAG-WK, LAG-PS, and GD are chosen as ; to optimize performance and guarantee stability, stepsizes for Cyc-IAG and Num-IAG are chosen as

. For the linear regression task, no regularization is added; for the logistic regression task, the

-regularization parameter is set to .

Figure 3: Iteration and communication complexity in synthetic datasets with increasing .
Figure 4: Iteration and communication complexity in synthetic datasets with uniform .
Figure 5: Iteration and communication complexity for linear regression in real datasets.
Figure 6: Iteration and communication complexity for logistic regression in real datasets.

We consider two synthetic data tests: a) linear regression with increasing smoothness constants, e.g., ; and, b) logistic regression with uniform smoothness constants, e.g., . For each worker, we generate 50 samples

from the standard Gaussian distribution, and rescale the data to mimic the increasing and uniform smoothness constants. For the case of increasing

, it is not surprising that both LAG variants need fewer communication rounds; see Figure 3. Interesting enough, for uniform , LAG-WK still has marked improvements on communication, thanks to its ability of exploiting the hidden smoothness of the loss functions; that is, the local curvature of may not be as steep as ; see Figure 4.

Dataset # features () # samples () worker index
Housing 1,2,3
Body fat 4,5,6
Abalone 7,8,9
Table 3: A summary of real datasets used in the linear regression tests.
Dataset # features () # samples () worker index
Ionosphere 1,2,3
Adult fat 4,5,6
Derm 7,8,9
Table 4: A summary of real datasets used in the logistic regression tests.

Performance is also tested on the real datasets (Lichman, 2013): a) linear regression using Housing, Body fat, Abalone datasets; and, b) logistic regression using Ionosphere, Adult, Derm datasets; see Figures 5-6. Each dataset is evenly split into three workers with the number of features used in the test equal to the minimal number of features among all datasets; see the summaries of datasets in Tables 3 and 4, while the details are deferred to Appendix I. In all tests, LAG-WK outperforms the alternatives in terms of both metrics, especially reducing the needed communication rounds by several orders of magnitude. Its needed communication rounds can be even smaller than the number of iterations, if none of workers violate the trigger condition (15) at certain iterations. Additional tests on real datasets under different number of workers are listed in Table 5. Under all the tested settings, LAG-WK consistently achieves the lowest communication complexity, which corroborates the effectiveness of LAG when it comes to communication reduction.

Linear regression Logistic regression
Algorithm                        
Cyclic-IAG
Num-IAG
LAG-PS
LAG-WK
Batch GD
Table 5: Communication complexity to achieve accuracy under different number of workers.
Figure 7: Iteration and communication complexity in Gisette dataset.

Similar performance gain has also been observed in the test on a larger dataset Gisette. The Gisette dataset was constructed from the MNIST data (LeCun et al., 1998). After random selecting subset of samples and eliminating all-zero features, it contains samples . We randomly split this dataset into nine workers. The performance of all the algorithms is reported in Figure 7 in terms of the iteration and communication complexity. Clearly, LAG-WK and LAG-PS achieve the same iteration complexity as GD, and outperform Cyc- and Num-IAG. Regarding communication complexity, two LAG variants reduce the needed communication rounds by several orders of magnitude compared with the alternatives.

5 Conclusions

Confirmed by the impressive empirical performance on both synthetic and real datasets, this paper developed a promising communication-cognizant method for distributed machine learning that we term Lazily Aggregated Gradient (LAG) approach. LAG can achieve the same convergence rates as batch gradient descent (GD) in smooth strongly-convex, convex, and nonconvex cases, and requires fewer communication rounds than GD given that the datasets at different workers are heterogeneous. To overcome the limitations of LAG, our future work consists of incorporating smoothing techniques to handle nonsmooth loss functions, and robustifying our aggregation rules to deal with cyber attacks.

Appendix A Proof of Lemma 2

Using the smoothness of in Assumption 1, we have that

(30)

Plugging (6) into leads to (cf. )

(31)

Using , we can re-write the inner product in (A) as

(32)

where (a) follows from the LAG update (6).

Combining (A) and (A), and plugging into (30), the claim of Lemma 2 follows.

Appendix B Proof of Lemma 3

Using the definition of in (16), it follows that

(33)

where (a) uses (8) in Lemma 2.

Decomposing the square distance as