Learning ReLUs via Gradient Descent

05/10/2017 ∙ by Mahdi Soltanolkotabi, et al. ∙ 0

In this paper we study the problem of learning Rectified Linear Units (ReLUs) which are functions of the form max(0,<w,x>) with w denoting the weight vector. We study this problem in the high-dimensional regime where the number of observations are fewer than the dimension of the weight vector. We assume that the weight vector belongs to some closed set (convex or nonconvex) which captures known side-information about its structure. We focus on the realizable model where the inputs are chosen i.i.d. from a Gaussian distribution and the labels are generated according to a planted weight vector. We show that projected gradient descent, when initialization at 0, converges at a linear rate to the planted model with a number of samples that is optimal up to numerical constants. Our results on the dynamics of convergence of these very shallow neural nets may provide some insights towards understanding the dynamics of deeper architectures.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Nonlinear data-fitting problems are fundamental to many supervised learning tasks in signal processing and machine learning. Given training data consisting of

pairs of input features and desired outputs we wish to infer a function that best explains the training data. In this paper we focus on fitting Rectified Linear Units (ReLUs) to the data which are functions of the form

A natural approach to fitting ReLUs to data is via minimizing the least-squares misfit aggregated over the data. This optimization problem takes the form


with denoting a regularization function that encodes prior information on the weight vector.

Fitting nonlinear models such as ReLUs have a rich history in statistics and learning theory [11] with interesting new developments emerging [6] (we shall discuss all these results in greater detail in Section 4

). Most recently, nonlinear data fitting problems in the form of neural networks (a.k.a. deep learning) have emerged as powerful tools for automatically extracting interpretable and actionable information from raw forms of data, leading to striking breakthroughs in a multitude of applications

[12, 13, 4]

. In these and many other empirical domains it is common to use local search heuristics such as gradient or stochastic gradient descent for nonlinear data fitting. These local search heuristics are surprisingly effective on real or randomly generated data. However, despite their empirical success the reasons for their effectiveness remains mysterious.

Focusing on fitting ReLUs, a-priori it is completely unclear why local search heuristics such as gradient descent should converge for problems of the form (1.1

), as not only the regularization function maybe nonconvex but also the loss function! Efficient fitting of ReLUs in this high-dimensional setting poses new challenges: When are the iterates able to escape local optima and saddle points and converge to global optima? How many samples do we need? How does the number of samples depend on the a-priori prior knowledge available about the weights? What regularizer is best suited to utilizing a particular form of prior knowledge? How many passes (or iterations) of the algorithm is required to get to an accurate solution? At the heart of answering these questions is the ability to predict convergence behavior/rate of (non)convex constrained optimization algorithms. In this paper we build up on a new framework developed by the author in

[17] for analyzing nonconvex optimization problems to address such challenges.

2 Precise measures for statistical resources

We wish to characterize the rates of convergence for the projected gradient updates (3.2) as a function of the number of samples, the available prior knowledge and the choice of the regularizer. To make these connections precise and quantitative we need a few definitions. Naturally the required number of samples for reliable data fitting depends on how well the regularization function can capture the properties of the weight vector . For example, if we know that the weight vector is approximately sparse, naturally using an norm for the regularizer is superior to using an regularizer. To quantify this capability we first need a couple of standard definitions which we adapt from [14, 15, 17].

Definition 2.1 (Descent set and cone)

The set of descent of a function at a point is defined as

The cone of descent is defined as a closed cone that contains the descent set, i.e. . The tangent cone is the conic hull of the descent set. That is, the smallest closed cone obeying .

We note that the capability of the regularizer in capturing the properties of the unknown weight vector depends on the size of the descent cone . The smaller this cone is the more suited the function is at capturing the properties of . To quantify the size of this set we shall use the notion of mean width.

Definition 2.2 (Gaussian width)

The Gaussian width of a set is defined as:

where the expectation is taken over . Throughout we use to denote the the unit ball/sphere of .

We now have all the definitions in place to quantify the capability of the function in capturing the properties of the unknown parameter . This naturally leads us to the definition of the minimum required number of samples.

Definition 2.3 (minimal number of samples)

Let be a cone of descent of at . We define the minimal sample function as

We shall often use the short hand with the dependence on implied.

We note that is exactly the minimum number of samples required for structured signal recovery from linear measurements when using convex regularizers [3, 1]. Specifically, the optimization problem


succeeds at recovering an unknown weight vector

with high probability from

observations of the form if and only if .111We would like to note that only approximately characterizes the minimum number of samples required. A more precise characterization is where . However, since our results have unspecified constants we avoid this more accurate characterization. While this result is only known to be true for convex regularization functions we believe that also characterizes the minimal number of samples even for nonconvex regularizers in (2.1). See [14] for some results in the nonconvex case as well as the role this quantity plays in the computational complexity of projected gradient schemes for linear inverse problems. Given that with nonlinear samples we have less information (we loose some information compared to linear observations) we can not hope to recover the weight vector from when using (1.1). Therefore, we can use as a lower-bound on the minimum number of observations required for projected gradient descent iterations (3.2) to succeed at finding the right model.

3 Theoretical results for learning ReLUs

A simple heuristic for optimizing (1.1) is to use gradient descent. One challenging aspect of the above loss function is that it is not differentiable and it is not clear how to run projected gradient descent. However, this does not pose a fundamental challenge as the loss function is differentiable except for isolated points and we can use the notion of generalized gradients to define the gradient at a non-differentiable point as one of the limit points of the gradient in a local neighborhood of the non-differentiable point. For the loss in (1.1) the generalized gradient takes the form


Therefore, projected gradient descent takes the form


where is the step size and is the constraint set with denoting the Euclidean projection onto this set.

Theorem 3.1

Let be an arbitrary weight vector and be a proper function (convex or nonconvex). Suppose the feature vectors are i.i.d. Gaussian random vectors distributed as with the corresponding labels given by

To estimate

, we start from the initial point and apply the Projected Gradient (PGD) updates of the form


with and defined via (3.1). Also set the learning parameter sequence for all and let , defined by 2.3, be our lower bound on the number of measurements. Also assume


holds for a fixed numerical constant . Then there is an event of probability at least such that on this event the updates (3.3) obey


Here is a fixed numerical constant.

The first interesting and perhaps surprising aspect of this result is its generality: it applies not only to convex regularization functions but also nonconvex ones! As we mentioned earlier the optimization problem in (1.1) is not known to be tractable even for convex regularizers. Despite the nonconvexity of both the objective and regularizer, the theorem above shows that with a near minimal number of data samples, projected gradient descent provably learns the original weight vector without getting trapped in any local optima.

Another interesting aspect of the above result is that the convergence rate is linear. Therefore, to achieve a relative error of the total number of iterations is on the order of . Thus the overall computational complexity is on the order of (in general the cost is the total number of iterations multiplied by the cost of applying the feature matrix and its transpose). As a result, the computational complexity is also now optimal in terms of dependence on the matrix dimensions. Indeed, for a dense matrix even verifying that a good solution has been achieved requires one matrix-vector multiplication which takes time.

4 Discussions and prior art

There is a large body of work on learning nonlinear models. A particular class of such problems that have been studied are the so called idealized Single Index Models (SIMs) [8, 9]. In these problems the inputs are labeled examples which are guaranteed to satisfy for some and nondecreasing (Lipchitz continuous) . The goal in this problem is to find a (nearly) accurate such and . An interesting polynomial-time algorithm called the Isotron exists for this problem [11, 10]. In principle, this approach can also be used to fit ReLUs. However, these results differ from ours in term of both assumptions and results. On the one had, the assumptions are slightly more restrictive as they require bounded features , outputs and weights. On the other hand, these result hold for much more general distributions and more general models than the realizable model studied in this paper. These results also do not apply in the high dimensional regime where the number of observations is significantly smaller than the number of parameters (see [5] for some results in this direction). In the realizable case, the Isotron result require iterations to achieve error in objective value. In comparison, our results guarantee convergence to a solution with relative error () after iterations. Focusing on the specific case of ReLU functions, an interesting recent result [6] shows that reliable learning of ReLUs is possible under very general but bounded distributional assumptions. To achieve an accuracy of the algorithm runs in poly time. In comparison, as mentioned earlier our result rquires iterations for reliable parameter estimation. We note however we study the problem in different settings and a direct comparison is not possible between the two results.

5 Proofs

5.1 Preliminaries

In this section we gather some useful results on concentration of stochastic processes which will be crucial in our proofs. These results are mostly adapted from [17, 2, 16]. We begin with a lemma which is a direct consequence of Gordon’s escape from the mesh lemma [7].

Lemma 5.1

Assume is a cone and is the unit sphere of . Also assume that

for a fixed numerical constant . Then for all

holds with probability at least .

We also need a generalization of the above lemma stated below.

Lemma 5.2 ([17])

Assume is a cone (not necessarily convex) and is the unit sphere of . Also assume that

for a fixed numerical constant . Then for all

holds with probability at least .

We next state a generalization of Gordon’s escape through the mesh lemma also from [17].

Lemma 5.3 ([17])

Let be fixed vector with nonzero entries and construct the diagonal matrix . Also, let have i.i.d.  entries. Furthermore, assume and define

where is distributed as . Define

then for all

holds with probability at least

The previous lemma leads to the following Corollary.

Corollary 5.4

Let be fixed vector with nonzero entries and assume . Furthermore, assume

Then for all ,

holds with probability at least .

5.2 Convergence proof (Proof of Theorem 3.1)

In this section we shall prove Theorem 3.1. Throughout, we use the shorthand to denote the descent cone of at , i.e. . We begin by analyzing the first iteration. Using we have

We use the argument of [17][Page 25, inequality (7.34)] which shows that


Using ReLU we have


We proceed by bounding the first term in the above equality. To this aim we decompose in the direction parallel/perpendicular to that of and arrive at


with and are independent random Gaussian random vectors distributed as and

. By concentration of Chi-squared random variables


holds with probability at least . Also,


holds with probability at least . Plugging (5.4) with and (5.5) with into (5.2), as long as



holds with probability at least .

We now focus on bounding the second term in (5.2). To this aim we decompose in the direction parallel/perpendicular to that of and arrive at


with . Now note that is sub-exponential with norm bounded by

with fixed numerical constant. Thus by Bernstein’s type inequality ([18][Proposition 5.16])


holds with probability at least with a fixed numerical constant. Also note that


holds with probability at least and

holds with probability at least . Combining the last two inequalities we conclude that


holds with probability at least . Plugging (5.8) and (5.9) with , , and into (5.2)


holds with probability at least as long as

Thus pluggin (5.6) and (5.10) into (5.1) we conclude that for

holds with probability at least as long as

for a fixed numerical constant .

To introduce our general convergence analysis we begin by defining

To prove Theorem 3.1 we use the argument of [17][Page 25, inequality (7.34)] which shows that if we apply the projected gradient descent update

the error obeys


To complete the convergence analysis it is then sufficient to prove


We will instead prove that the following stronger result holds for all and


The equation (5.13) above implies (5.12) which when combined with (5.11) proves the convergence result of the Theorem (specifically equation (3.5)).

The rest of this section is dedicated to proving (5.13). To this aim note that . Therefore, the loss function can alternatively be written as


Now defining we conclude that

Now define . Using this we can rewrite the previous expression in the form