Never Go Full Batch (in Stochastic Convex Optimization)

06/29/2021 ∙ by Idan Amir, et al. ∙ Tel Aviv University 0

We study the generalization performance of full-batch optimization algorithms for stochastic convex optimization: these are first-order methods that only access the exact gradient of the empirical risk (rather than gradients with respect to individual data points), that include a wide range of algorithms such as gradient descent, mirror descent, and their regularized and/or accelerated variants. We provide a new separation result showing that, while algorithms such as stochastic gradient descent can generalize and optimize the population risk to within ϵ after O(1/ϵ^2) iterations, full-batch methods either need at least Ω(1/ϵ^4) iterations or exhibit a dimension-dependent sample complexity.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Stochastic Convex Optimization (SCO) is a fundamental problem that received considerable attention from the machine learning community in recent years

[26, 14, 3, 10, 1]. In this problem, we assume a learner that is provided with a finite sample of convex functions drawn i.i.d. from an unknown distribution. The learner’s goal is to minimize the expected function. Owing to its simplicity, it serves as an almost ideal theoretical model for studying generalization properties of optimization algorithms ubiquitous in practice, particularly first-order methods which utilize only first derivatives of the loss rather than higher-order ones.

One prominent approach for SCO—and learning more broadly—is to consider the empirical risk (the average objective over the sample) and apply a first-order optimization algorithm to minimize it. The problem of learning is then decoupled into controlling the optimization error over the empirical risk (training error) and bounding the difference between the empirical error and the expected error (generalization error).

In convex optimization, the convergence of different first-order methods has been researched extensively for many years (e.g., [25, 23, 4]), and we currently have a very good understanding of this setting in terms of upper as well lower bounds on worst-case complexity. However, in SCO where the generalization error must also be taken into account, our understanding is still lacking. In fact, this is one of the few theoretical learning models where the optimization method affects not only the optimization error but also the generalization error (distinctively from models such as PAC learning and generalized linear models). In particular, it has been shown [26, 14] that some minima of the empirical risk may obtain large generalization error, while other minima have a vanishingly small generalization error. To put differently, learning in SCO is not only a question of minimizing the empirical risk, but also a question of how one minimizes it. However, the results of [26, 14] leave open the question of whether concrete optimization also have different generalization properties.

Towards better understanding, Amir et al. [1] recently studied the generalization properties of full-batch gradient descent (GD), where each step is taken with respect to the gradient of the empirical risk. For GD (and a regularized variant thereof), they gave a lower bound on the generalization error as a function of iteration number, which is strictly larger than the well-known optimal rate obtained by stochastic gradient descent (SGD), where each step is taken with respect to the gradient at a sampled example. Notably, the lower bound of [1] precisely matches the dimension-independent stability-based upper bound recently shown for full-batch GD by Bassily et al. [3]. The separation between full-batch GD and SGD is the first evidence that not only abstract Empirical Risk Minimizers may fail to generalize in SCO, but in fact also basic methods such as GD could be prone to such overfitting. A natural question is, then, whether overfitting is inherent to full-batch algorithms, that minimize the objective only through access to the exact empirical risk, or whether this suboptimality can be remedied by adding regularization, noise, smoothing, or any other mechanism for improving the generalization of GD.

In this work we present and analyze a model of full-batch optimization algorithms for SCO. Namely, we focus on algorithms that access the empirical risk only via a first-order oracle that computes the exact (full-batch) gradient of the empirical loss, rather than directly accessing gradients with respect to individual samples. Our main result provides a negative answer to the question above by significantly generalizing and extending the result of Amir et al. [1]: we show that any optimization method that uses full-batch gradients needs at least iterations to minimize the expected loss to within error. This is in contrast with the empirical loss, which can be minimized with only steps.

Comparing SGD and GD in terms of the sample size , we see that SGD converges to an optimal generalization error of after iterations, whereas a full-batch method must perform iterations to achieve the same test error. We emphasize that we account here for the oracle complexity, which coincides with the iteration complexity in the case of gradient methods. In terms of individual gradients calculations, while SGD uses at most gradient calculations (one sample per iteration), a full-batch method will perform calculations ( samples per iteration).

The above result is applicable to a wide family of full-batch learning algorithms: regularized GD (with any data-independent regularization function), noisy GD, GD with line-search or adaptive step sizes, GD with momentum, proximal methods, coordinate methods, and many more. Taken together with upper bound of Bassily et al. [3], we obtain a sharp rate of for the generalization-complexity of full-batch methods. Surprisingly, this rate is achieved by standard GD (with an unusual step-size choice of ), and it cannot be improved by adding regularization of any sort, nor by adding noise or any other form of implicit/explicit bias.

1.1 Related work

This work extends and generalizes the results of Amir et al. [1] who proved generalization lower bounds for GD (and a specific instance of regularized GD). Our work shows that in fact any full-batch method will suffer from similar lower bounds. Our construction builds upon the one used in [1], which in turn builds upon previous constructions [3, 26]. However, our arguments and proofs here are more challenging, as we need to reason about a general family of algorithms, and not about a specific algorithm whose trajectory can be analyzed directly. Our developments also build on ideas from the literature on oracle complexity lower bounds in optimization [23, 25, 28, 7, 11, 8]. In particular, we first prove our result in the simplified setting of algorithms constrained to the span of observed gradients [23, 25] and subsequently lift it to general algorithms using a random high-dimensional embedding technique proposed by Woodworth and Srebro [28] and later refined in [7, 11]. However, while these works lower bound what we call the empirical risk, we lower bound the generalization error. This requires us to develop a somewhat different argument for how the span of the gradients evolve during the optimization: in prior work, the algorithm learns the component of the solution coordinate by coordinate, whereas in our work the true (generalizing) solution is present in the observed gradients from the first query, but spurious sampling artifacts drown it out.

Empirical studies (outside of the scope of SCO) support the claim that generalization capabilities degrade with the increase of the batch size. Specifically, Zhu et al. [31] indicates that SGD outperforms GD in terms of generalization. The works of Keskar et al. [20] and Hoffer et al. [18] exhibit a similar phenomenon in which small-batch SGD generalizes better than large-batch SGD with the same iteration budget. We provide the first theoretical evidence for this phenomenon for convex losses. Several theoretical studies explore the convergence of stochastic methods that use mini-batches [9, 21, 29]. Note that this setting differs from ours, as they assume access to mini-batches sampled without replacement whereas full-batch means we reuse the same (full) batch with each gradient step. The work of Wu et al. [30] also explores the separation between GD and SGD and interprets mini-batch SGD as a noisy version of GD. They propose a modified algorithm with noise injected to the full-batch gradients. Interestingly, the noise production requires access to the sample-points. Our work shows that (in SCO) this is unavoidable: namely, no data-independent noise can be used to improve generalization.

Several other works study the generalization performance of GD [27, 15, 19, 22]. The work of Soudry et al. [27]

, for example, examines GD on unregularized logistic regression problems. They show that, in the limit, GD converges to a well-generalizing solution by arguing about the bias of the algorithm. Interestingly, both our and their results require slow-training, beyond what is required for empirical error optimization. Another work that highlights the slow convergence of GD is that of

Bassily et al. [3]. They were the first to address uniform stability of (non-smooth) GD and SGD, and provided tight bounds. Stability entails generalization, hence our results lead to stability lower bounds for any full-batch method. Consequently, we extend the lower bounds for GD in the work of Bassily et al. [3] to a wider class. It might be thought that the instability argument of Bassily et al. [3] can be used to obtain similar generalization lower bounds—however, we note that their techniques also prove instability of SGD (which does generalize). Hence, instability does not immediately imply, in this setting, lack of generalization.

Finally, we note that under smoothness and strong convexity, it is well known that improved rates can be obtained. Specifically, using the stability bound of Bousquet and Elisseeff [5] one can show that we can achieve generalization error of after iterations if the population risk is -strongly convex. The arguments of Hardt et al. [17] imply generalization bound to instances where every sample risk is smooth. Our result implies that, even though these special families of functions enjoy appealing learning rates, in general it is impossible to obtain better rates by strong-convexifying or smoothing problem instances via first-order full-batch oracle queries.

2 Problem Setup and Main Results

We study the standard setting of stochastic convex optimization. In this setting, a learning problem is specified by a fixed domain in

-dimensional Euclidean space, and a loss function

, which is both convex and -Lipschitz with respect to its first argument (that is, for any the function is -Lipschitz and convex with respect to ). In particular, throughout the paper, our construction consists of -Lipschitz functions and we will focus on a fixed domain defined to be the unit Euclidean ball in , namely .

We also assume that there exists an unknown distribution over parameters and the goal of the learner is to optimize the true risk (or true loss, or population risk) defined as follows:

(1)

We assume that a sample is drawn from the distribution , and the learner has to output (the exact access the learner has to the sample, and how may depend on is discussed below). We require the solution to be -optimal in expectation for some parameter , i.e.,

As discussed, the standard setting assumes that the learner has direct access to the i.i.d. sample, as well as to the gradients of the loss function (i.e., a first-order oracle). In this work, though, we focus on a specific family full-batch methods. Hence, the optimization process is described as follows: First, an i.i.d. sample is drawn from . Then, the learner is provided with access only to the empirical risk via a full-batch first-order oracle which we define next.

Full-batch first-order oracle.

Consider a fixed sample of size , drawn i.i.d. from . The empirical risk over the sample is

Then, a full-batch first-order oracle is a procedure that, given input , outputs

where is an empirical risk sub-gradient of the form

(2)

and each sub-gradient is computed by the oracle as a function of and (that is, independently of for ).

We emphasize that the sample is fixed throughout the optimization, so that the oracle computes the gradient of the same empirical risk function at every call, hence the name full-batch. Note that the subgradient with respect to a single data point, i.e., , is not accessible through this oracle, which only returns the average gradient over the sample .

Notice that our definition above is slightly narrower than a general sub-gradient oracle for the empirical risk due to the requirement that the sub-gradients are chosen independently of for – since we provide here with a lower bound, this restriction strengthens our result. We make this restriction to avoid some degenerate constructions (that in fact can even be used to fail SGD if the gradient at may depend on the whole sample), which are of no practical implications.

Full-batch first-order algorithm.

A full-batch (first-order) method is naturally defined as any algorithm that has access to the optimization objective—namely the empirical risk —only via the full-batch first order oracle. In particular, if is the ’th query of the algorithm to the full-batch oracle then has to be of the form

(3)

where is a fixed (possibly randomized) mapping. At the end of the process the algorithm outputs . We study the algorithm’s oracle complexity, which is the number of iterations the algorithm performs before halting. Therefore, we assume without loss of generality that , i.e., the algorithm’s output is its ’th query.

2.1 Main result

In this section we establish our main result, which provides a generalization lower-bound for full-batch first order algorithms. The complete proof is provided in Section 5.

Theorem 1.

Let and ; there exists such that the following holds. For any full-batch first-order algorithm with oracle complexity at most , there exists a 1-Lipschitz convex function in , the unit-ball in , and a distribution over such that, for some universal constant :

(4)

An immediate consequence of Theorem 1 is that in order to obtain less than true risk we need at least iterations.

For simplicity, we state and prove the lower bound in Theorem 1 for the class of first-order full-batch algorithms defined above. However, our constructions readily generalize to local full-batch oracles that provide a complete description of in an arbitrarily small neighborhood of the query point [23, 16]. Such oracles subsume second-order oracles, and consequently our generalization lower bounds hold also for second-order full-batch algorithms.

2.2 Discussion

Theorem 1 suggests that full-batch first-order algorithms are inferior to other types of first-order algorithms that operate with access to individual examples, such as SGD. Importantly, this separation is achieved not in terms of the optimization performance but in terms of the generalization performance. In light of this result, we next discuss and revisit the role of the optimization algorithm in the context of SCO. In particular, we wish to discuss the implications to what are perhaps the two most prominent full-batch optimization methods, GD and regularzied-GD, and in turn compare them.

Gradient descent.

Perhaps the simplest example of a full-batch method is (projected) GD: GD is an iterative algorithm that at each iteration performs an update step

where is a convex set on which we project the iterated step. The output of GD is normally taken to be (or a randomly chosen ). Notice, that each step requires one call to a full batch oracle, and a single projection operation. The convergence analysis of GD to the optimal solution of the empirical risk has been widely studied. Specifically, if is the sample-size, it is known that with and , GD converges to a minimizer of that is -sub optimal. For the exact variant of GD depicted above, the generalization performance was analyzed in the work of Amir et al. [1] that showed that with steps, GD will suffer generalization error. Theorem 1 extends the above result to any variant of GD (dynamic learning-rate, noisy GD, normalized GD, etc.).

Regularized gradient descent.

We would also like to discuss the implication of Theorem 1 with respect to regularized variants of GD that operate on the regularized empirical risk

The main motivation of introducing the regularization term is to avoid overfitting, and a popular choice for is the Euclidean norm . This choice leads to the following update rule for GD:

Again, this update can be implemented using a single first-order full-batch oracle call that computes the quantity . More generally, for any data-independent , GD on is a full-batch algorithm111Note that we are not concerned with the computational cost of computing since it does not factor into oracle complexity.. When is the Euclidean norm, the minimizer of is known to enjoy (with choice ), an optimal generalization error of [5, 26]. This demonstrates the power of regularization and how it can provably induce generalization. Nevertheless, Theorem 1 still applies to any optimization method over . Since optimization of (the regularized empirical risk) to -precision can be done via a full-batch method, and with less than calls, we observe that there are methods that minimize the regularized-empirical risk but, due to Theorem 1 do not reach the optimal generalization error.

The role of regularization.

Finally, in light of Theorem 1 let us compare the different variants of GD and regularized GD that do generalize well, in order to sharpen our understanding of the role of regularization in generalization. The conclusion of Theorem 1 is that any full-batch method that generalizes well performs at least steps. For regularized GD, with regularization, are indeed sufficient. In particular, with iterations we can find a solution that has empirical error. Any such solution would enjoy a generalization error of [26]. For GD, Bassily et al. [3] showed that iterations would also suffice to achieve error. This is achieved by tuning the learning rate to . Notice that this improvement does not require any type of added regularization.

To summarize, both GD and regularized GD with optimal parameters require iterations to attain the optimal generalization error. Overall then, explicitly adding regularization is not necessary nor does it improve the convergence rate. One might be tempted to believe that tuning the learning rate in GD induces implicitly some sort of regularization. For example, one might imagine that GD can be biased towards minimal norm solution, which might explain redundancy of regularizing by this norm. However, this turns out also to be false: Dauber et al. [10] showed how GD (with any reasonable choice of learning rate) can diverge from the minimal norm solution. In fact, for any regularization term , one can find examples where GD does not converge to the regularized solution. Thus, even though GD and regularized-GD are comparable algorithms in terms of generalization and oracle complexity, they are distinct in terms of the solutions they select.

3 Technical Overview

In this section we give an overview of our construction and approach towards proving Theorem 1. For the sake of exposition, we will describe here a slightly simpler construction which proves the main result only for algorithms that remain in the span of the gradients. In more detail, let us examine the family of iterative algorithms of the form

(5)

where is the unit ball and is full-batch oracle response to query as defined in (2) above. Well-studied algorithms such as GD and GD with standard norm regularization fall into this category of algorithms.

To extend the lower bound to algorithms not restricted the gradient span we refine the simpler construction and apply well-established techniques of random embedding in high-dimensional space. We discuss these modifications briefly in the end of this section and provide the full details in Sections 5 and 4 below.

3.1 A simpler construction

Let us fix and parameters , such that , and . Define the hard instance as follows:

(6)

where and are

and is the

’th standard basis vector. The distribution we will consider is uniform over

. That is, we draw uniformly at random and pick the function .

The parameters and of the construction should be thought of as arbitrarily small. In particular, the term in Eq. 6 should be thought of as negligible, and the first term, , is roughly

Another useful property of the construction is the population risk is minimized at , with expected loss . However, as we will see, the choice of the perturbation vector and the term hinder the learner from observing this coordinate and; the first queries are constrained to a linear subspace where all the points have a high generalization error due to the expectation of the first term .

3.2 Analysis

We next state the main lemmas we use, with proofs deferred to Appendix A. Given a sample , let us denote , and

Additionally, given a fixed sample we write

for the set of coordinates such that for every in the sample , plus the coordinate .

Lemma 2.

Let , , and suppose that the sample satisfies . Then there exists a first-order full-batch oracle such that for any algorithm that adheres to

(7)

with respect to defined in Eq. 6, we have

where is the set of the largest coordinates in .

We next observe that in any span of the form such that , we cannot find a solution with better risk than . On the other hand, note that for , we have that

In other words, our lower bound stems from the following result:

Lemma 3.

For sufficiently small , and any vector , any output

satisfies

(8)
Lower bound proof sketch for span-restricted algorithms of the form (5).

First, observe that the probability of an arbitrary index

to satisfy for all is . Therefore, , the number of indexes that hold this from the possible , is distributed as a binomial with experiments and success probability . Using elementary probability arguments one can show that for sufficiently large we have with high probability; see Claim 1 in the appendix. This implies that the conditions of Lemmas 3 and 2 hold w.h.p. To conclude, we relate the LHS of Eq. 8 to the expected risk

As is convex w.r.t.  (since ) we can apply Jensen’s inequality with to obtain:

Applying the Cauchy-Schwarz inequality to the second term while also using the facts that and that is in the unit ball, we get:

For sufficiently small this term is negligible, and since we get that the expected risk is approximately the LHS term in Eq. 8. Lastly, recalling that we get that

The same lower bound (up to a constant) also hods in expectation by the the law of total expectation. Our distribution is supported on -Lipschitz convex functions, so that re-parametrizing as well as yields the claimed lower bound (4) for the case of span-restricted algorithms. ∎

3.3 Handling general full-batch algorithms

The above construction establishes an oracle complexity lower bound on any algorithm whose iterates lie in the span of the previous gradients. While this covers a large class of algorithms, techniques like preconditioning [13], coordinate methods [24] and randomized smoothing [12] do not satisfy this assumption. In fact, a trivial algorithm that always outputs will solve the hard instance (6) in a single iteration.

To address general algorithms, we employ a well-established technique in optimization lower bounds [28, 7, 11] wherein we embed a hard instance

for span-constrained algorithms in a random high-dimensional space. More concretely, we draw a random orthogonal matrix

() and consider the -dimensional instance along with its corresponding empirical objective . Roughly speaking, we show that for a general algorithm operating with the appropriate subgradient oracle for the iterate is approximately in the span of in the sense that the component of outside that span is nearly orthogonal to the columns of . Consequently, the response of the oracle to the query at iteration is, with high probability, identical to the information it would return if queried with the projection of to the span of the previously observed gradients. This reduces, in a sense, the problem back to the span-restricted setting described above.

For the embedding technique to work, we must robustify the hard instance construction so that small perturbations around points in the span of previous gradients do not “leak” additional information about the embedding . To do that we make a fairly standard modification to the component in (6) (known as Nemirovski’s function [11, 6]), replacing it with , where are small offset coefficients that go to zero as the embedding dimension tends to infinity. We provide the full construction and the proof of Theorem 1 in Sections 5 and 4.

4 The Full Construction

As explained above, the key difference between the simplified construction and the full construction with which we prove Theorem 1 is that we modify the Nemirvoski function term in order to make it robust to queries that are nearly within a certain linear subspace. In particular, we bias the different terms in the maximization defining so as to control the index of the coordinate attaining the maximum. For ease of reference, we now provide a self-contained definition of our full construction with the modified Nemirovski function.

Fix and parameters are such that , and . Define the hard instance as follows:

(9)

where and are

and is the ’th standard basis vector in . We consider a distribution over that is distributed uniformly over ; that is, we draw uniformly at random and pick the function . The rest of the parameters are set throughout the proof as follows:

(10)

With this choice of distribution as well our choice of parameters we obtain, since and by our choice of (as well as Jensen’s inequality and ):

(11)

Notice that we also have that for a choice , since :

(12)

Our development makes frequent use of the following notation from Section 3:

We begin with the following lemma, which is a robust version of Lemma 2 in Section 3. The proof is provided in Section B.1.

Lemma 4.

Suppose that . Consider with parameters as in Eq. 10. Suppose is a sample such that . Assume that is such that

where

(13)

Then,

for some , where is the set of the largest coordinates in .

The following corollary states that the gradient oracle’s answers are resilient to small perturbation of the query (as long as they are in vicinity of the “right” subspace): the proof is provided in Section B.2:

Corollary 5.

Assume that is such that

where

(14)

Then,

where is a projection onto

5 Proof of Theorem 1

To prove Theorem 1 we embed the construction of Section 4 into a random, higher-dimensional space. More formally, let be as in Eq. 9, and, for , let be an orthogonal matrix, i.e., such that . We consider then the objective function over :

Given a sample we use the notation

(15)

for the empirical error and

(16)

for the expected error, where, as before, is such that the coordinates of are i.i.d.  and the parameters are fixed as in Eq. 10. We start with the following claim:

Lemma 6.

Fix a deterministic full-batch first-order algorithm, and a sample such that . Let be a random orthogonal matrix, then for some

we have that with probability at least (over the draw of ):

(17)

where is as in Eq. 16.

Before we proceed with the proof of Lemma 6, we explain how Theorem 1 follows. Fix a full-batch algorithm and let be a distribution as in Lemma 6. Let be an indicator for the event that Eq. 17 holds, where is the random orthogonal matrix, is the sample and is the random seed of the algorithm , which is independent of and . Then by Lemma 6 we have that

(18)

The next claim follows a standard concentration inequality and shows that the event is indeed probable; the proof is provided in Section 5.2.

Claim 1.

Suppose . Then with probability at least , it holds that .

From Eqs. 18 and 1, we can conclude that:

By changing order of expectation we conclude that there exists a matrix such that w.p. at least (over the sample as well as the random bits of the algorithm ) the lower bound (17) holds. Theorem 1 now follows from

We are left with proving Lemmas 6 and 1, which we do in Sections 5.1 and 5.2.

5.1 Proof of Lemma 6

We start by defining inductively a chain of algorithms act as intermediaries between and a full-batch first order oracle for . should be thought of as an arbitrator between and the oracle, where at each iteration , it receives a query from , submits some query to the oracle, and returns some answer to (not necessarily the oracle’s answer). We will build the chain in such a way that , while forces queries to stay in the span of the gradients. We will then relate the error of to the error of by bounding the probabilities that they observe different information from the oracle.

We formally define is as follows.

  • For each , algorithm receives the query point , then the algorithm defines , where is the Euclidean projection onto

    The algorithm then inputs the query to a full-batch first-order oracle for , receives and provides with

  • For , algorithm behaves like a standard full-batch oracle. Namely, it receives a query , defines , queries the oracle with it, receives and provides with

Notice that is the algorithm interacting with a valid full-batch first order oracle for defined in Eq. 15. In particular, at each iteration provides, as required from a full-batch first-order oracle:

At the other extreme the algorithm is an algorithm that only queries points in

We obtain then, by Lemma 3 as well as Eqs. 12 and 11 that for every :

(Eq. 11)
(Lemma 3)
(Eq. 12)

Denote by the probability that algorithm outputs a sequence such that for some :

(19)

In particular, we have argued so far that .

Next, for two vectors , let us write if

Now, suppose we run and we observe at step a vector such that . Notice that in that case the output of and the output of is completely identical. Indeed, up to step the two algorithms are identical and at step they provide the same response to (and after that they again behave identically). Thus,

Rearranging terms and iteratively applying the formula above we obtain: