Statistical Inference for Model Parameters in Stochastic Gradient Descent

10/27/2016 ∙ by Xi Chen, et al. ∙ National University of Singapore NYU college 0

The stochastic gradient descent (SGD) algorithm has been widely used in statistical estimation for large-scale data due to its computational and memory efficiency. While most existing work focuses on the convergence of the objective function or the error of the obtained solution, we investigate the problem of statistical inference of the true model parameters based on SGD. To this end, we propose two consistent estimators of the asymptotic covariance of the average iterate from SGD: (1) an intuitive plug-in estimator and (2) a computationally more efficient batch-means estimator, which only uses the iterates from SGD. As the SGD process forms a time-inhomogeneous Markov chain, our batch-means estimator with carefully chosen increasing batch sizes generalizes the classical batch-means estimator designed for time-homogenous Markov chains. The proposed batch-means estimator is of independent interest, which can be potentially used for estimating the covariance of other time-inhomogeneous Markov chains. Both proposed estimators allow us to construct asymptotically exact confidence intervals and hypothesis tests. We further discuss an extension to conducting inference based on SGD for high-dimensional linear regression. Using a variant of the SGD algorithm, we construct a debiased estimator of each regression coefficient that is asymptotically normal. This gives a one-pass algorithm for computing both the sparse regression coefficient estimator and confidence intervals, which is computationally attractive and applicable to online data.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Estimation of model parameters by minimizing an objective function is a fundamental idea in statistics. Let be the true -dimensional model parameters555We follow the standard notation in optimization literature, which usually denotes the global optimum of an optimization problem (corresponding to the true model parameters) by .. In many common models, is the minimizer of a convex objective function from to defined as follows,

(1)

where

denotes the random sample from the probability distribution

and

is the loss function. In Section

2.2

, we provide two classical estimation problems (i.e., linear regression and logistic regression) as motivating examples.

One fundamental problem of stochastic optimization is to compute the minimizer of the population function,

(2)

A widely used optimization method for minimizing is the stochastic gradient descent (SGD), which has a long history in optimization (see, e.g., Robbins and Monro (1951); Polyak and Juditsky (1992); Nemirovski et al. (2009)). In particular, let denote the starting point. SGD is an iterative algorithm, where the -th iterate takes the following form,

(3)

The step size is a decreasing sequence in , is the -th sample randomly drawn from the distribution , and denotes the gradient of with respect to at . The algorithm outputs either the last iterate or the average iterate

(4)

as the solution to the optimization problem in (2). When is adopted as the solution, the algorithm is referred to as averaged SGD (ASGD) and such an averaging step is known as the Polyak-Ruppert averaging (Ruppert, 1988; Polyak and Juditsky, 1992). SGD has many computational and storage advantages over traditional deterministic optimization methods. For example, SGD only uses one pass over the data and the per-iteration time complexity of SGD is , which is independent of the sample size. There is no need for SGD to store the dataset, and thus SGD naturally fits in the online setting, where each sample arrives sequentially (e.g., search queries or transactional data). In addition, ASGD is known to achieve the optimal convergence rate in when is strongly convex with the rate of (Rakhlin et al., 2012)666When is non-smooth, Rakhlin et al. (2012) proposed a slight modification of the averaging step.

. It has become the prevailing optimization method for many machine learning tasks

(Srebro and Tewari, 2010)

, e.g., training deep neural networks.

Based on the simple SGD template in (3), there are a large number of variants developed in the optimization and statistical learning literature. Most existing works only focus on the convergence in terms of the objective function or the distance between the obtained solution and the true minimizer of (2). However, the statistical inference of based on SGD has largely remained unexplored. In this paper, we propose computationally efficient methods to conduct the statistical inference of . The proposed method are based on a classical result of ASGD, which characterizes the limiting distribution of in (4). In particular, let be the Hessian matrix of at and be the covariance matrix of , i.e.,

(5)

Note that . Ruppert (1988) and Polyak and Juditsky (1992) showed that when is strongly convex with a Lipschitz gradient, by choosing appropriate step sizes,

converges in distribution to a multivariate normal random vector, i.e.,

(6)

This result shows that the solution

from ASGD is asymptotically efficient and attains the Cramér-Rao lower bound. To construct an asymptotically valid confidence interval (or equivalently, an asymptotically valid hypothesis test that controls the type I error at the nominal level), we only need to obtain a consistent estimator of the asymptotic covariance of

, i.e., .

We propose two approaches to estimate the asymptotic covariance of based on SGD. An intuitive approach is the plug-in estimator. In particular, we propose a thresholding estimator of based on the sample estimate . Together with the sample estimate of , the asymptotic covariance is estimated by , which is proven to be a consistent estimator. However, there are several drawbacks of the plug-in estimator:

  1. The plug-in estimator requires the computation of the Hessian matrix of the loss function , which could be computationally expensive for some nonlinear models. Further, it requires the inverse of the matrix , which has a large computational cost of .

  2. The plug-in estimator requires prior knowledge of a lower bound on the minimum eigenvalue of

    to construct the threshold estimator . Further, to establish the consistency result, it requires an additional Lipschitz condition over the Hessian matrix of the loss function (see Assumption 3.1).

  3. To form confidence intervals for coordinates of , we are only interested in estimating diagonal elements of . However, the plug-in estimator directly computes the entire estimator , which is computationally wasteful. Furthermore when is large, storing and requires more than bits, which is prohibitive and more than necessary since the output only requires the diagonals.

Given these drawbacks of the intuitive plug-in estimator, a natural question arises: can we estimate the asymptotic covariance only using the iterates from SGD without requiring any additional information (e.g., Hessian matrix of )? We provide an affirmative answer to this question by proposing a computationally and memory efficient batch-means estimator. Basically, we split the sequence of iterates into batches with the sizes :

Here and are the starting and ending point of the -th batch with , ; and is the batch size. The -th batch is discarded since the iterates in the -th batch are far from the optimum. For each batch , we compute the mean of the iterates in the -th batch and denote it by . The batch-means estimator is a “weighted” sample covariance matrix, which treats each batch mean as a sample:

(7)

where is the mean of all the iterates except for those from the -th batch. It is clear that the batch-means estimator does not require computing the Hessian matrix of , nor any matrix inversion, thus is computationally more attractive than the plug-in estimator. In addition, to form confidence intervals for coordinates of , one only needs to compute and store the diagonal elements of batch-means estimator in (7).

The idea of batch-means estimator can be traced to Markov Chain Monte Carlo (MCMC), where the batch-means method with equal batch size (see, e.g., Glynn and Iglehart (1990); Damerdji (1991); Geyer (1992); Fishman (1996); Jones et al. (2006); Flegal and Jones (2010)

) is widely used for variance estimation in a time-homogeneous Markov chain. The SGD iterates in (

3) indeed form a Markov chain as only depends on . However, since the step size sequence is a diminishing sequence, it is a time-inhomogenous Markov chain. Moreover, the asymptotic behavior of SGD and MCMC are fundamentally different, while the former converges to the optimum, the latter travels ergodically inside the state space. As a consequence of these two differences, the previous literature on batch-means methods are not applicable to our analysis. To address this challenge, our new batch-means method constructs batches of increasing size. The sizes of batches are chosen to ensure that the correlation decays appropriately among far apart batches, so that far apart batch means can be roughly treated as independent. In Theorem 3.2 below, the proposed batch-means method is shown to be a consistent estimator of the asymptotic covariance . Further, we believe this new batch-means algorithm with increasing batch size is of independent interest since it can be used to estimate the covariance structure of other time-inhomogeneous Markov chains.

We further study the problem of confidence interval construction for in high dimensional linear regression based on SGD, where the dimensionality can be much larger than the sample size (or the number of iterates) . It is noteworthy that the sample size is also the number of iterates in SGD since each iteration of SGD samples a new data. In a high-dimensional setup, it is natural to solve an -regularized problem,

(8)

where is defined in (1). A popular approach to solve (8) is the proximal stochastic gradient approach (see, e.g., Ghadimi and Lan (2012) and references therein) . However, due to the proximal operator (i.e., the soft-thresholding operator for -regularized problem), the distribution of the average iterate

will no longer converge to a multivariate normal distribution. To address this challenge, we use the recently proposed RADAR algorithm

(Agarwal et al., 2012), which is a variant of SGD, together with the debiasing approach (Zhang and Zhang, 2014; van de Geer et al., 2014; Javanmard and Montanari, 2014). The standard debiasing method relies on solving convex optimization problems (e.g., node-wise Lasso in van de Geer et al. (2014)) to construct an approximation of the inverse of covariance matrix of the design. Each optimization problem requires the computational cost of , where is the number of iterations of the adopted deterministic optimization algorithm and generally in the hundreds to thousands. In contrast, we adopt the stochastic RADAR algorithm to solve these convex optimization problems, where each problem only requires one pass of the data with the computational cost . Moreover, since the resulting approximate inverse covariance matrix from the stochastic RADAR is not an exact solution of the corresponding optimization problem, the analysis of van de Geer et al. (2014), which heavily relies on the KKT condition, is no longer applicable. We provide a new analysis to establish the asymptotic normality of the obtained estimator of from the stochastic optimization algorithm.

1.1 Some related works on SGD

There is a large body of literature on stochastic gradient approaches and their applications to statistical learning problems (see, e.g., Zhang (2004); Nesterov and Vial (2008); Xiao (2010); Ghadimi and Lan (2012); Roux et al. (2012); Agarwal et al. (2012); Xiao and Zhang (2014) and references therein). Most works on SGD focus on the convergence rate of the objective function instead of the asymptotic distribution of the obtained solution. Thus, we only review a few closely related works with results on distributions.

Back in 1960’s, Fabian (1968) studied the distribution of SGD iterates. However, without averaging, the asymptotic variance is inflated and thus the resulting statistical inference would have a reduced power even if the asymptotic is known. Ruppert (1988); Polyak and Juditsky (1992); Bach and Moulines (2011); Toulis and Airoldi (2016) studied the averaged SGD (ASGD) and its variant (e.g., implicit SGD in Toulis and Airoldi (2016)) and established the asymptotic normality and efficiency of the estimators. However these works do not discuss the estimation of the asymptotic covariance.

A few works in the SGD literature (e.g., Nesterov and Vial (2008); Nemirovski et al. (2009)) show large deviation results of by combining the Markov inequality with the expected deviation of to . However, the -norm deviation bounds are not useful for constructing confidence intervals. Namely, the error bounds are generally (where is the variance of the norm of the stochastic gradient), whereas a confidence interval for any single coordinate should be . Therefore, using an

-norm error bound results in an extremely conservative confidence interval. Instead, we will use the central limit theorem that shows

, where is the

-quantile of the standard normal distribution. This allows us to construct an asymptotically exact confidence interval that is also asymptotically efficient, in the sense that it attains the Cramér-Rao lower bound.

We also note that Toulis and Airoldi (2016) established the asymptotic normality for the averaged implicit SGD procedure, which has the same limiting distribution as ASGD. Therefore, as long as our key Lemma 2.6 also holds for averaged implicit SGD, our estimators (i.e., plug-in or batch-means) would also provide a consistent covariance estimate for averaged implicit SGD. We leave the verification of Lemma 2.6 for averaged implicit SGD as a future work direction.

1.2 Notations and organization of the paper

As a summary of notations, throughout the paper, we will use to denote the vector -norm of , the matrix operator norm of and the element-wise -norm of (i.e., ). For a square matrix , we denote its trace by . For a positive semi-definite (PSD) matrix , let and be the maximum and minimum eigenvalue of and when is strictly positive definite, we use to denote any lower bound on . For a vector of length and any index subset , we denote by the sub-vector of with the elements indexed by and the sub-vector of with the elements indexed by . Similarly for a matrix and two index subsets and , we denote by the sub-matrix of with elements in rows in and columns in . When or , we denote by or , respectively. We use

to denote the identity matrix. For any

, . The function denotes the CDF of the standard normal distribution.

Let be the universal constants which could change from line to line. For any sequences and of positive numbers, we write if holds for all and some absolute constant , if holds, and if both and hold.

The rest of the paper is organized as follows. In Section 2, we discuss the problem setup and necessary assumptions, and also provide the convergence rate of iterates in Lemma 2.6. In Section 3, we propose the plug-in estimator and batch-means estimator for estimating the asymptotic covariance of from ASGD. In Section 4, we discuss how to conduct inference for high-dimensional linear regression. In Section 5, we demonstrate the proposed methods by simulated experiments. All the proofs are given in the Appendix.

2 Problem Setup

Let us first introduce the problem setup and the assumptions made in this work.

2.1 Assumptions

In the classical work of Polyak and Juditsky (1992), the SGD method was introduced in a different formulation from (3): In particular, the iteration was given by

(9)

This formulation decomposes the descent into two parts: represents the direction of population gradient which is the major driving force behind the convergence of SGD, and is a sequence of martingale difference. In fact, the iteration in (9) is a more general formulation which includes (3) as a special case. To see this, we write (3) as

(10)

where is a martingale sequence due to Fubini’s theorem:

Here and in the sequel, denotes the conditional expectation , where is the sigma-algebra generated by (note: is the -th sample). Let be the error of the -th iterate. It it noteworthy that by subtracting from both sides of (10), the recursion (10) is equivalent to

(11)

which will be extensively used throughout the paper.

We first make the following standard assumptions on the population loss function in (1).

Assumption 2.1 (Strong Convexity and Lipschitz continuity of the gradient).

Assume that the objective function is continuously differentiable and strongly convex with parameter , that is for any and

Further, assume that is Lipschitz continuous with the constant , i.e., for any and ,

(12)

We also assume has a Hessian matrix at :

(13)

The strongly convexity of guarantees the uniqueness of the minimizer . Moreover, due to the strong convexity, we immediately find that .

Let be the error of the -th iterate. We further assume that the martingale difference satisfies the following conditions.

Assumption 2.2.

The followings hold for the sequence :

  1. It is a martingale difference sequence: .

  2. The conditional covariance of has an expansion around of polynomial order or :

    (14)

    and there exists a constant such that for any ,

    (15)

    Note that is the covariance matrix of defined in (5).

  3. There exists a constant

    such that the fourth conditional moment of

    is bounded by

The constant in (15) plays an important role in the convergence rate of the covariance estimation. For example, when the loss is quadratic (e.g., the loss of linear regression, see Example 2.4 in Section 2.2), Eq. (15) holds for , which produces a faster convergence rate in both theory and numerical simulation.

Assumption 2.2 is a mild condition over the regularity and boundedness of the loss function. In fact, one can easily verify Assumption 2.2 using the following lemma.

Lemma 2.3.

If there is a function with bounded fourth moment, such that the Hessian of is bounded by

(16)

for all , and have a bounded fourth moment, then Assumption 2.2 holds with .

In the sequel, we will impose Assumptions 2.1 and 2.2. We note that under these two assumptions, the classical works by Ruppert (1988) and Polyak and Juditsky (1992) established the asymptotic normality and efficiency of the .

2.2 Two motivating examples

In this section, we illustrate Assumptions 2.1 and 2.2 on applications of SGD to two popular statistical estimation problems.

Example 2.4 (Linear Regression).

Under the classical linear regression setup, let the -th sample be , where the input is a sequence of random vectors independently drawn from the same multivariate distribution and the response follows a linear model,

(17)

Here represents the true parameters of the linear model, and are independently and identically distributed (i.i.d.

) centered random variables, which are uncorrelated with

. For simplicity, we assume and have all moments being finite. Given , loss function at is a quadratic one:

and the true parameters . Given the loss function, the SGD iterates in (3) become,

This can as well be written in the form of (9) as

(18)

where is the population gram matrix of .

It is easy to find that

which implies that and . Therefore, Assumption 2.1 holds as long as for some constants . This is commonly assumed in linear regression models.

Lastly, we turn to the martingale difference sequence . First, we notice it is indeed a martingale sequence, since and . Second, the conditional covariance takes the following form:

where

(19)

Assumption 2.2 holds with because

and by Hölder’s inequality:

Example 2.5 (Logistic Regression).

One of the most popular applications for general loss in statistics is the logistic regression for binary classification problems. In particular, the logistic model assumes that the binary response is generated by the following probabilistic model,

while is an i.i.d. sequence generated by a fixed distribution. The population objective function is given by

Let

denote the sigmoid function, we have

Using the fact that for any , is Lipschitz continuous with the Lipschitz constant . Moreover,

(20)

is clearly positive semi-definite. Additional non-degeneracy condition on (e.g., assuming has a strictly positive density with respect to Lebesgue) will guarantee that has strictly positive minimum eigenvalue, see Lemma A.2 in Appendix.

For the martingale difference assumption in Assumption 2.2, note that

(21)

which implies that . As long as has bounded eighth moment, Lemma 2.3 applies, which establishes Assumption 2.2.

2.3 Inferences on SGD iterates

One of the fundamental theoretical results for SGD is that under Assumptions 2.1 and 2.2, the iterates converge to the true minimizer . Namely we have the following error moments bound for :

Lemma 2.6.

Under Assumptions 2.1 and 2.2, the iterates of error satisfy,

  1. There exist universal constants and such that for , the followings hold:

  2. Moreover, when the step size is chosen to be with , the followings hold:

The proof of Lemma 2.6 is provided in Appendix. A similar result of in claim 1) of Lemma 2.6 has been shown in Theorem 3 in Bach and Moulines (2011) (see Theorem 3 and Appendix C). Here, we provide much simpler bounds on conditional moments of . The claim 2) tells us how does the error decorrelates in terms of the number of iterations. By Lemma 2.6, we set the sequence of step sizes to be

(22)

throughout the paper.

The main objective of this paper is to construct a confidence interval of based on the estimator from ASGD. To this end, recall the Theorem 2 of Polyak and Juditsky (1992), which shows that

where . It is worth noting that this result does not require that the model is well-specified. In a mis-specified case, the asymptotic distribution of is centered at , where is the unique point such that . Due to possible model mis-specification, the asymptotic covariance is of the sandwich covariance form (e.g., see Buja et al. (2013)); however, the sandwich covariance simplifies to the fisher information matrix if the model is indeed well-specified.

In order to constract a confidence interval for each coordinate of , it suffices to estimate the asymptotic covariance matrix , and form the interval , where is an estimator of and is the -quantile of the standard normal distribution (i.e., and is the CDF of the standard normal distribution).

3 Estimators for Asymptotic Covariance

Following the inference procedures illustrated above, it is essential to estimate the asymptotic covariance matrix . In this section, we will propose two estimators, the plug-in estimator and the batch-means estimator.

3.1 Plug-in estimator

The idea of the plug-in estimator is to separately estimate and by some and and use as an estimator of . Since converges to by Lemma 2.6, according to the definitions of and in (13) and (5), an intuitive way to construct and is to use the sample estimate

as long as the information of is available. Note that according to (3), is the -th stochastic gradient. Since we are interested in estimating , it is necessary to avoid the possible singularity of from statistical randomness. Therefore, we propose to use thresholding estimator , which has a strictly positive minimum eigenvalue. In particular, let be any lower bound on so that . Further, let be the eigenvalue decomposition of , where is a non-negative diagonal matrix. We construct the thresholding estimator :

(23)

The threshold level is set to for the ease of theoretical analysis. By the construction, it is guaranteed that , and is bounded from above by , while clearly if .

With the construction of and in place, we propose the plug-in estimator as . Our goal is to establish the consistency of the plug-in estimator, i.e.,

Since this estimator relies on the Hessian matrix of the loss function, we need an additional assumption to establish the consistency,

Assumption 3.1.

There are constants and such that for all ,

We note that it is easy to verify that Assumption 3.1 holds for the two motivating examples in Section 2.2, because for quadratic loss, the Hessian matrix at any is itself, and (21) gives the Hessian for the logistic loss, which is Lipschitz in and also bounded.

With this additional assumption, we first establish the consistency of the sample estimate and in the following lemma (its proof is provided in Appendix).

Lemma 3.2.

Under Assumptions 2.1, 2.2 and 3.1, the followings hold

where is given in (22).

Using Lemma 3.2 and a matrix perturbation inequality for the inverse of a matrix (see Lemma B.1 in Appendix), we obtain the consistency result of the proposed plug-in estimator :

Theorem 3.1 (Error rate of the plug-in estimator).

Under Assumptions 2.1, 2.2 and 3.1, the plug-in estimator initialized from any bounded converges to the asymptotic covariance matrix,

where is given in (22).

We also note that since the element-wise -norm is bounded from above by the matrix operator norm, so as ,

and thus can be estimated by for the construction of confidence intervals.

3.2 Batch-means estimator

Although the plug-in estimator is intuitive, as we discussed in the introduction, there are several drawbacks from both computational and theoretical perspectives, such as requiring the computation of the Hessian matrix, the inverse of , the prior knowledge of and an additional Assumption 3.1.

In this section, we develop the batch-means estimator, which only uses the iterates from SGD without requiring computation of any additional quantity. Intuitively, if all iterates are independent and share the same distribution, the asymptotic covariance can be directly estimated by the sample covariance,

Unfortunately, the SGD iterates are far from independent. To understand the correlation between two consecutive iterates, we note that for sufficiently large such that is close to , by the Taylor expansion of at , we have . Combining this with the recursion in (11), we have for sufficiently large ,

(24)

Based on (24), the strength of correlation between and can be estimated by , which is very close to as . To address the challenge of strong correlation among neighboring iterates, we split the entire sequence of iterates into batches with carefully chosen sizes and construct the batch-means estimator.

Batch-means estimators have been widely used in estimating the variance in time-homogeneous Markov chains and MCMC simulations (e.g., Damerdji (1991); Glynn and Iglehart (1990); Geyer (1992); Fishman (1996); Flegal and Jones (2010); Jones et al. (2006)). However, the traditional batch means estimator, see analysis in Jones et al. (2006), fails for the SGD process because the Markov chain formed by SGD iterates is not geometrically ergodic, and has a degenerate point mass for the invariant distribution. Moreover, due to the diminishing step sizes, the SGD iterates in (3) form a time-inhomogeneous Markov chain. We extend the classical batch-means estimator to the case of time-inhomogeneous Markov chains by using increasing batch sizes to ensure that the batches are sufficiently decorrelated. In particular, we split iterates of SGD into batches with the sizes :

Here and are the starting and ending index of -th batch with , , , and . We treat the -th batch as the “burn-in stage”. More precisely, the iterates will not be used for constructing the batch-means estimator since the step sizes are not small enough and the corresponding iterates in the -th batch are far away from the optimum. Let us denote the mean of the iterates for the -th batch by and the mean of all the iterates expect for by , i.e.,

(25)

The batch-means estimator is given by the following:

(26)

The batch-means estimator has several advantages over the plug-in estimator. It only uses the iterates from SGD and does not require computing or any matrix inversion, and thus is computationally more attractive.

Intuitively, the reason why our batch-means estimator with increasing batch size can overcome strong dependence between iterates is as follows. Although the correlation between neighboring iterates is strong, it decays exponentially fast for far apart iterates. Roughly speaking, from (24), for large and , the strength of correlation between and can be estimated by

(27)

Therefore, the correlations between the batch means are close to zero if the batch sizes are large enough, in which case different batch means can be roughly treated as independent. As a consequence, the sample covariance gathered from the batch means will serve as a good estimator of the true asymptotic covariance.

The remaining difficulty is how to determine the batch sizes. The approximation of correlation given by (27) provides us a clear clue. If we want the correlation between two neighboring batches to be on the order of , where is a parameter controlling the amount of decorrelation and is a constant, we need

to hold for every batch . When , , which leads to the following batch size setting:

(28)

Given the total number of iterates and noting that , the decorrelation strength factor takes the following form,

(29)

where is the number of batches.

Under this setting, the batch-means covariance estimator (26) is consistent as shown in the following theorem.

Theorem 3.2.

Under Assumptions 2.1 and 2.2, the batch-means estimator initialized at any bounded is a consistent estimator for . In particular, for sufficiently large and , we have