Consider a multilayer perceptron (MLP) with fully connected layers and data pairs , where is the label of . Given input data point , the output of the MLP is computed via the forward pass:
is a nonlinear activation function. Without loss of generality, Eq. (1) does not have bias parameters. Otherwise, bias can be included in the weight matrix
, and vectoris appended with an additional homogeneous coordinate of value one. For ease of presentation, we assume constant layer size, i.e., , for . Define the weight vector consisting of all weight parameters concatenated together as , where and vec is the operator vectorizing matrices.
Given a loss function, which measures the misfit between the network output and the true label, we define as the total loss of the MLP with respect to the weight vector . Note is a function of the weights .
Definition 1 ((Generalized) Gauss-Newton Hessian)
Let be the Hessian of the loss function , and be a block diagonal matrix with being the block,. Let be the Jacobian of with respect to the weights , and be the vertical concatenation of all . The (generalized) Gauss-Newton Hessian (GNH) matrix associated with the total loss with respect to the weights is defined as
The GNH matrix is closely related to the Hessian matrix and, importantly, it is always (symmetric) positive semi-definite when the loss function is convex in ( is positive semi-definite), a useful property in many applications. For several standard choices of the loss function, the GNH matrix is mathematically equivalent to the Fisher matrix as used in the natural gradient method. The GNH matrix is also equivalent to the Hessian matrix of a particular approximation of constructed by replacing with its first-order approximation with respect to the weights [martens16].
This paper is concerned with fast entry-wise evaluation of the GNH matrix. Such an algorithmic primitive can be used in constructing approximations of the GNH matrix for solving linear systems and eigenvalue problems, which are useful for training and analyzing neural networks [byrd-e11, martens16, bottou-nocedal18, o2019inexact]
, for selecting training data to minimize the inference variance[cohn94]
, for estimating learning rates[lecun-bottou98], for network pruning [hassibi-stork93], for robust training [yao2018hessian], for probabilistic inference [hennequin-e14], for designing fast solvers [carmon-duchi18, triburaneni-jordan18, gower-roux-bach17] and so on.
1.1 Previous work
We classify related work into two groups. One group avoids entry-wise evaluation of the GNH matrix and relies on the matrix-vector multiplication (matvec) with the Hessian or the GNH that is matrix-free[martens2010deep, martens2011learning, martens16]
. For example, the matrix-free matvec can be used to construct low-rank approximations of the GNH matrix through the randomized singular value decomposition (RSVD)[halko-martinsson-tropp11], but the numerical rank may not be small [keutzer-e17, dinh-bengio-e17]. Other examples are the following: [dauphin-bengio-e14] introduces a low-rank approximation using the Lanczos algorithm to tackle saddle points; [leroux-e08] maintains a low-rank approximation of the inverse of the Hessian based on rank-one updates at each optimization step; [gower-roux-bach17] uses a quasi-Newton-like construction of the low-rank approximation; [ye-zhang-luo17, mahoney16] study the convergence of stochastic Newton methods combined with a randomized low-rank approximation; [yao2018hessian] uses a matrix-free method with only the layers near the output layer.
The other group of methods are based on evaluating or approximating entries on or close to the diagonal of the GNH matrix [lafond-bottou17]. For example, [zhang-socher-e17] introduces a recursive fast algorithm to construct block-diagonal approximations. As another example, [grosse-martens15, martens16] introduce the K-FAC algorithm, which is based on an entry-wise approximation of the Fisher matrix (mathematically equivalent to the GNH for several standard choices of the loss function). The Fisher matrix is given by , where is the gradient evaluated for the training point , and is sampled from the network’s predictive distribution . In practice, an extra step of block-diagonal approximation or block-tridiagonal approximation is used for fast inversion purpose. The method has been tested within optimization frameworks on modern supercomputers and has been shown to perform well [osawa-yokota-satoshi-e18]. However, the sampling in the K-FAC algorithm converges slowly, and block-diagonal approximations do not account for off-diagonal information.
In this paper, we introduce a fast algorithm for entry-wise evaluation of the GNH matrix , i.e., computing
where and are two canonical bases. With the fast evaluation, we propose the hierarchical-matrix (-matrix) approximation [bebendorf08, hackbusch15] of the GNH matrix for the MLP network, which has applications in autoencoders, long-short memory networks, and is often used to study the potential of second-order training methods. Notice if the matrix-free matvec is used to evaluate , the computational cost would be .
Our fast algorithm includes a precomputation step and a sampling step, which reduces the cost to work (independent of ), where is the output dimension of the network. Specifically, suppose the network employs the mean squared loss ( is the identity), and therefore, the GNH matrix is , where is the Jacobian of the network output with respect to the weights. Then
, and only columns in the Jacobian are required to be computed. Our precomputation algorithm exploits the structure of a feed-forward neural network, where the gradient is back propagated layer by layer, so the intermediate results effectively form a compressed format of the Jacobian withmemory. As a result, every column can be retrieved in only time (note every column has entries).
To accelerate the computation of , we introduce a fast Monte Carlo sampling algorithm. Let denote the sub-vector in the Jacobian’s column corresponding to the data point, and therefore, . In the sampling, we draw (independent of ) independent samples from
with a carefully designed probability distributionand compute an estimator
We prove with high probability. Note it requires only work to compute as an approximation, where is the output dimension of the network.
With the fast evaluation algorithm, we are able to take advantage of the existing GOFMM method [chenhan-biros-e17, yu-reiz-biros18, gofmm-home-page] to construct the -matrix approximation of the GNH matrix through evaluating entries in the matrix. The -matrix approximation is a multi-level scheme that stores diagonal blocks and employs low-rank approximations for off-diagonal blocks in the input matrix. So previous work on the (global) low-rank approximation and the block-diagonal approximation can be viewed as the two extremes in the spectrum of our -matrix approximation, which effectively works for a broader range of problems. -matrices are algebraic generalizations of the well-known fast -body calculation algorithms [barnes-hut-86, greengard94]
in computational physics, and they have been applied to kernel methods in machine learning[lee-gray08, march-xiao-yu-biros-sisc16]. An -matrix can be formulated as
where and are tall-and-skinny matrices, is a block-sparse matrix, and is a block-diagonal matrix with the blocks being either smaller -matrices at the next level or dense blocks at the last level. Figure 1 shows the structure of a low-rank matrix and the hierarchically low-rank structure of -matrices.
Given an -matrix approximation, the memory footprint is 111Generally speaking, there may be a or prefactor, as for other complexity results related to -matrix approximations. But here we focus on the case without such prefactors.., where is the matrix size or the number of weights in a network and is the maximum off-diagonal rank. Compared to the storage for the entire matrix, an -matrix approximation leads to significant memory savings. Once constructed, an -matrix can be factorized with only work, and there exists an entire class of well-established numerical techniques [martinsson2005fast, xia2010fast, ho2013hierarchical, ghysels-li-e16, aminfar2016fast, chen2018distributed, takahashi2019parallelization]. The factorization can be applied to a vector with work and be used as either a fast direct solver or a preconditioner depending on the approximation accuracy.
To summarize, our work makes the following two major contributions:
a fast algorithm that requires storage and requires work to evaluate an arbitrary entry in the GNH matrix, where and are the number of parameters and the output dimension of the MLP, respectively, is the data size, and is a prescribed accuracy.
a framework to construct the -matrix approximation of the GNH matrix, analysis and demonstration of the corresponding accuracy and the cost, as well as comparison with the RSVD and the K-FAC methods.
In this section, we review the importance of the GNH matrix and the associated computational challenge.
2.1 Neural network training
The GNH matrix is useful in training and analyzing neural networks, selecting training data, estimating learning rate, and so on. Here we focus on its use in second-order optimization to show the challenge that is common in other applications. In the MLP, the weight vector is obtained via solving the following constrained optimization problem (regularization on could be added):
Recall that , where is the loss function, is the network output corresponding to input , which has label .
To solve for in problem (4), a second-order optimization method solves a sequence of local quadratic approximations of , which requires solving the following linear systems repeatedly:
where is the curvature matrix (the Hessian of in the standard Newton’s method), is the gradient, and is the update direction. Generally speaking, second-order optimization methods are highly concurrent and could require much less number of iterations to converge than first-order methods, which imply potentially significant speedup on modern distributed computing platforms.
In the Gauss-Newton method, a popular second-order method, the GNH matrix is employed (with a small regularization) as the curvature matrix in Eq. (5), which can be solved using the Conjugate Gradient method. Since the GNH is mathematically equivalent to the Fisher matrix for several standard choices of the loss, and then the solution of Eq. (5) becomes the natural gradient, a efficient steepest descent direction in the space of probability distribution with an appropriately defined distance measure [martens2014new].
2.2 Back-propagation & matrix-free matvec
|Evaluate gradient||Matvec with GNH:|
Table 1 shows the back-propagation algorithm for evaluating the gradient in Eq. (5) and the matrix-free matvec with the GNH matrix, both of which have complexity . Note a direct matvec with the full GNH matrix would require work.
Based on the two basic ingredients, iterative solvers such as Krylov methods can be used to solve Eq. (5) as in Hessian-free methods [martens2010deep, martens2011learning]. However, the iteration count for convergence can grow rapidly in the presence of ill-conditioning, in which case fast solvers or preconditioners for Eq. (5) are necessary [axelsson-94, knoll-keyes-04, martens16].
3 Fast computation of entries in GNH
This section presents a precomputation algorithm and a fast Monte Carlo algorithm for fast computation of arbitrary entries in the GNH matrix of an MLP network.
A naive method
Consider a GNH matrix , where an entry can be written as
where and are the and the columns in the
-dimensional identity matrix. We can take advantage of the matrix-free matvec with the GNH matrix inTable 1 to compute , which costs the same as one pass of forward propagation plus one pass of backward propagation, i.e., work.
In the following, we introduce a precomputation algorithm that reduces the cost of evaluating an entry in the GHN to work with memory, and a fast Monte Carlo algorithm that further reduces the cost to work.
3.1 Precomputation algorithm
The motivation of our precomputation algorithm is to exploit the sparsity of and plus the symmetry of in Eq. (6). Recall the definition of in Eq. (2), and let be a symmetric factorization, which can be computed via, e.g., the eigen-decomposition or the LDLT factorization with pivoting. We have
where and are two -dimensional vectors:
for and .
For an MLP network that has fully connected layers with constant layer size (-by- weight matrices), every entry in the GNH matrix can be computed in time with a precomputation that requires storage and work.
We precompute and store
where and is the derivative of the activation function. Notice that computing the symmetric factorizations for cost , which is negligible compared to other parts of the computation. Moreover, a forward pass of the network (step (a) of gradient evaluation in Table 1) computes with work.
Since every is a matrix, the total storage cost is , where is the total number of weights. In addition, notice that , so they can be computed from to iteratively, which requires work in total.
Let , , and . Since has only one nonzero entry, for because are all zeros except for . The matrix has only one nonzero at position (column-major ordering) as the following:
Following step (a) of the matvec in Table 1, we have at layer . Denote , and we have
Notice that the only nonzero entry in is the element, which equals to the element in . Therefore,
where should be interpreted as a scaling of the column in by the element in , which costs work.
3.2 Fast Monte Carlo algorithm
Recall Eq. (7), which sums over a large number of data points, and the idea is to sample a subset with judiciously chosen probability distribution and scale the (partial) sum appropriately to approximate . It is important to note that the computation of the probabilities is fast based on the previous precomputation. The fast sampling algorithm is given in Algorithm 1.
Define and as two vectors in , and Eq. (7) can be written as the inner product of the two vectors:
The following theorem shows that our sampling algorithm returns a good estimator of , where the error is measured using , an upper bound on .
Theorem 2 (Sampling error)
Consider an MLP network that has fully connected layers with constant layer size (-by- weight matrices). For every entry in the GNH matrix, Algorithm 1 returns an estimator that
is an unbiased estimator of, i.e., .
its variance or mean squared error (MSE) satisfies
where is the number of random samples.
with probability at least , where , its absolute error satisfies
where and is the number of random samples.
Our proof consists of the following three parts.
The variance or MSE error of the estimator is the following:
|(Drop the last term)|
Notice that with Jensen’s inequality, we also obtain a bound of the absolute error in expectation:
We will use the McDiarmid’s (a.k.a., Hoeffding-Azuma or Bounded Differences) inequality to obtain Eq. (14). Define function , where are random samples, and we show that changing one sample at a time does not affect too much. Consider changing a sample to while keeping others the same. The new estimator differs from by only one term. Thus,
where we have used Cauchy-Schwarz inequality twice. Then, define ; using the triangle inequality we see
Finally, let , and we use the McDiarmid’s inequality to obtain Eq. (14) as follows
The error in the approximation of depends on only the number of random samples (but not ) and can be made arbitrarily small as needed. In particular, if , we have
and if , then with probability at least , where
Furthermore, the error of the entire matrix in the Frobenius norm is
The estimator is exact using at most one sample when . The (trivial) case is implied by the situation that for all ; otherwise, we have , and the sampling probability becomes
Therefore, with any random sample .
Theorem 3 (Computational cost of sampling)
4 -matrix approximation
This section introduces the -matrix approximation of the GNH matrix for the MLP. While the low-rank and the block-diagonal approximations focus on the global and the local structure of the problem, respectively. The -matrix approximation handles both as they may be equally important.
4.1 Overall algorithm
Here we take advantage of the existing GOFMM method [chenhan-biros-e17, yu-reiz-biros18, gofmm-home-page], which evaluates entries in a symmetric positive definite (SPD) matrix to construct the -matrix approximation such that
where is a prescribed tolerance.
Since GOFMM requires only entry-wise evaluation of the input matrix, we apply it with our fast evaluation algorithm to the regularized GNH matrix (note the GNH matrix is symmetric positive semi-definite, so we always add a small regularization of times the identity matrix, where is the unit roundoff). The overall algorithm that computes the -matrix approximation (and approximate factorization) of the GNH matrix using the GOFMM method is shown in Algorithm 2.
The error analysis of Algorithm 2 is the following. Let be computed by Algorithm 1 and is a regularization, and be the approximation of computed by GOFMM. Then the error between the output from Algorithm 2 and the (regularized) GNH matrix is
where the first term is the sampling error from Algorithm 1 and the second term is the GOFMM approximation error. For simplicity, we drop the regularization parameter for the rest of this paper.
4.2 Gofmm overview
Given an SPD matrix , the GOFMM takes two steps to construct the -matrix approximation as follows. First of all, a permutation matrix is computed to reorder the original matrix, which often corresponds to a hierarchical domain decomposition for applications in two- or three-dimensional physical spaces. The recursive domain partitioning is often associated with a tree data structure . Unlike methods targeting applications in physical spaces, the GOFMM does not require the use of geometric information (thus its name“geometry-oblivious fast multipole method”), which does not exist for neural networks. Instead of relying on geometric information, the GOFMM exploits the algebraic distance measure that is implicitly defined by the input matrix . As a matter of fact, any SPD matrix is the Gram matrix of unknown Gram vectors [hofmann2008kernel]. Therefore, the distance between two row/column indices and can be defined as
With either definition, the GOFMM is able to construct the permutation and a balanced binary tree .
The second step is to approximate the reordered matrix by
where and are two diagonal blocks that have the same structure as unless their sizes are small enough to be treated as dense blocks, which occurs at the leaf level of the tree ; and are block-sparse matrices, and and are low-rank approximations of the remaining off-diagonal blocks in . These bases are computed recursively with a post-order traversal of
using the interpolative decomposition[halko-martinsson-tropp11] and a nearest neighbor-based fast sampling scheme. There is a trade-off here: while the so-called weak-admissibility criteria sets and to zero and obtains relatively large ranks, the so-called strong-admissibility criteria selects and to be certain subblocks in corresponding to a few nearest neighbors/indices of every leaf node in and achieves smaller (usually constant) ranks.
Here we focus on the hierarchical semi-separable (HSS) format among other types of hierarchical matrices. Technically speaking, the HSS format means and are both zero and the bases / and / of a node in are recursively defined through the bases of the node’s children, i.e., the so-called nested bases.
4.3 Summary & comparison of complexity
We summarize the storage and computational complexity of our -matrix approximation method (HM), and compare HM with three reference methods, namely, the Hessian-free method (HF) [martens2010deep, martens2011learning], the Kronecker-factorization (K-FAC) [martens16, grosse-martens15] and the randomized singular value decomposition (RSVD, Algorithm 5.1 in [halko-martinsson-tropp11]). As before, we assume the MLP network has layers of constant layer sizes , so the number of weights is . We also assume the number of data points is . The four algorithms of interest are as follows.
The algorithm is given in Algorithm 2. Suppose the rank is in the -matrix approximation; The GOFMM needs to call Algorithm 1 times ( work), and it requires work to compute the approximation and its factorization, which has storage and can be applied to a vector with work. These are standard results in the HSS literature [martinsson2005fast, xia2010fast, ho2013hierarchical].
The (iterative) Hessian-free methods [martens2010deep, martens2011learning] combine the conjugate gradient (CG) method and the matrix-free matvec for solving linear systems and eigenvalue problems. The method is based on the two primitives in Table 1, where every iteration costs work and storage. The number of CG iteration is generally upper bounded by , where is a prescribed accuracy and is the condition number of the (regularized) GNH matrix.
Recall the definition in Eq. (2). Assume is an identity to simplify the presentation. The algorithm is to compute an approximate SVD of , which leads to an approximate eigenvalue decomposition of . The algorithm is the following. First, we apply the back-propagation in Table 1
with a random Gaussian matrix as input. Second, the QR decomposition of the result is used to estimate the row space of. Third, the linearized forward is applied to project onto the approximate row space, and finally, the SVD is computed on the projection. Overall, the storage is , and the work required is , where is the numerical rank from the QR decomposition.
K-Fac [grosse-martens15, martens16]
The algorithm computes an approximation of the Fisher matrix . Let a column vector ( defined in Table 1) be the network gradient, and be a -by- block matrix with block size -by-. Note the expectation here is taken with respect to both the empirical input data distribution and the network’s predictive distribution . In particular, the -th block () is given by