Training Dynamics of Deep Networks using Stochastic Gradient Descent via Neural Tangent Kernel

05/31/2019 ∙ by Soufiane Hayou, et al. ∙ University of Oxford 0

Stochastic Gradient Descent (SGD) is widely used to train deep neural networks. However, few theoretical results on the training dynamics of SGD are available. Recent work by Jacot et al. (2018) has showed that training a neural network of any kind with a full batch gradient descent in parameter space is equivalent to kernel gradient descent in function space with respect to the Neural Tangent Kernel (NTK). Lee et al. (2019) built on this result to show that the output of a neural network trained using full batch gradient descent can be approximated by a linear model for wide neural networks. We show here how these results can be extended to SGD. In this case, the resulting training dynamics is given by a stochastic differential equation dependent on the NTK which becomes a simple mean-reverting process for the squared loss. When the network depth is also large, we provide a comprehensive analysis on the impact of the initialization and the activation function on the NTK, and thus on the corresponding training dynamics under SGD. We provide experiments illustrating our theoretical results.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 19

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Deep neural networks have achieved state-of-the-art results on numerous tasks; see, e.g., Nguyen and Hein (2018), Du et al. (2018b), Zhang et al. (2017)

. Although the loss function is not convex, Stochastic Gradient Descent (SGD) is often used successfully to learn these models. It has been actually recently shown that for certain overparameterized deep ReLU networks, SGD converges to an optimum

(Zou et al., 2018). Similar results have also been obtained for standard batch Gradient Descent (GD) (Du et al., 2018a).

The aim of this article is to provide an analysis of the training dynamics of SGD for wide and deep neural networks which will help us to better understand the impact of the initialization and activation function. The training dynamics of full batch GD is better understood and we will build upon recent work by Jacot et al. (2018) who showed that training a neural network with full batch GD in parameter space is equivalent to a functional GD i.e. a GD in a functional space with respect to a kernel called Neural Tangent Kernel (NTK). Du et al. (2019) used a similar approach to prove that full batch GD converges to global minima for shallow neural networks and Karakida et al. (2018) linked the Fisher Information Matrix to the NTK and studied its spectral distribution for infinite width networks. The infinite width limit for different architectures was studied by Yang (2019)

who introduced a tensor formalism that can express most of the computations in neural networks.

Lee et al. (2019) studied a linear approximation of the full batch GD dynamics based on the NTK and gave an method to approximate the NTK for different architectures. Finally, Arora et al. (2019) gives an efficient algorithm to compute exactly the NTK for convolutional architectures (Convolutional NTK or CNTK). In all of these papers, the authors used full batch GD to derive their results for different neural networks architectures. However, this algorithm is far too expensive for most applications and one often uses SGD instead.

In parallel, the impact of the initialization and activation function on the performance of wide deep neural networks has been studied in Hayou et al. (2019), Lee et al. (2018), Schoenholz et al. (2017), Yang and Schoenholz (2017). These works analyze the forward/backward propagation of some quantities through the network at the initial step to select the initial parameters and the activation function so as to ensure a deep propagation of the information at initialization. While experimental results in these papers suggest that such selection also leads to overall better training procedures (i.e. beyond the initialization step), it remains unexplained why this is the case.

We extend here the results of Jacot et al. (2018) and show that the NTK also plays a major role in the training dynamics when SGD is used instead of full batch GD. Moreover, we provide a comprehensive study of the impact of the initialization and the activation function on the NTK and therefore on the resulting training dynamics for wide and deep networks. In particular, we show that an initialization known as the Edge of Chaos (Yang and Schoenholz, 2017) leads to better training dynamics and that a class of smooth activation functions discussed in (Hayou et al., 2019) also improves the training dynamics compared to ReLU-like activation functions. We illustrate these theoretical results through simulations. All the proofs are detailed in the Supplementary Material which also includes additional theoretical and experimental results.

2 Neural Networks and Neural Tangent Kernel

2.1 Setup and notations

Consider a neural network model consisting of layers , with , and let

be the flattened vector of weights and bias indexed by the layer’s index and

be the dimension of . Recall that has dimension . The output of the neural network is given by some transformation of the last layer ; being the dimension of the output (e.g. number of classes for a classification problem). For any input , we thus have . As we train the model, changes with time and we denote by the value of at time and . Let be the data set and let , be the matrices of input and output respectively, with dimension and . For any function , , we denote by the matrix of dimension .

Jacot et al. (2018) studied the behaviour of the output of the neural network as a function of the training time when the network is trained using a gradient descent algorithm. Lee et al. (2019) built on this result to linearize the training dynamics. We recall hereafter some of these results.

For a given , the empirical loss is given by . The full batch GD algorithm is given by

(1)

where is the learning rate.
Let be the training time and be the number of steps of the discrete GD (1). The continuous time system equivalent to (1) with step is given by

(2)

This differs from the result by Lee et al. (2019) since we use a discretization step of . Consider times for . The following lemma, proved in the supplementary material, controls the difference between the time-continuous system and its discretization.

Lemma 1 (Discretization Error for Full Batch Gradient Descent).

Assume is -lipschitz and let , then there exists that depends only on and such that

As in Lee et al. (2019), Equation (2) can be re-written as

where is a matrix of dimension and is the flattened vector of dimension constructed from the concatenation of the vectors . As a result, the output function

satisfies the following ordinary differential equation

(3)

The Neural Tangent Kernel (NTK) is defined as the dimensional kernel satisfying: for all ,

(4)

We also define as the matrix defined blockwise by

By applying equation (3) to the vector , one obtains

(5)

meaning that for all . Zou et al. (2018) proved that while training a wide neural network with ReLU, we do not move far from the initialization point . Following these results, Lee et al. (2018) suggested using a first order linear approximation of the dynamics of as an approximation to the real training dynamics. This linearized version is given by

Using this linearized version, the dynamics of and are given by

When tested empirically on different models and datasets, Lee et al. (2018) showed that this approximation captures remarkably well the training dynamics. However, in practice, one usually never computes the exact gradient of the empirical loss due to the high computational cost but just an unbiased version of this later. We address this issue hereafter by proving that, even with SGD, the training dynamics follow a simple Stochastic Differential Equation (SDE) that can be explicitly solved in some scenarios.

2.2 Training with Stochatic Gradient Descent

In this section, we use an approximation of the SGD dynamics by a diffusion process. We assume implicitly the existence of the triplet where

is the probability space,

is a probability measure on , and is the natural filtration of the Brownian motion. Under boundedness conditions (see the supplementary material), when using SGD, the gradient update can be seen as a GD with a Gaussian noise (Hu et al., 2018; Li et al., 2017). More precisely, let be the batchsize. The SGD update is given by

(6)

where with being a randomly selected batch of size .

Combining Hu et al. (2018) and Li et al. (2017), in the time-continuous limit, the previous dynamics can be seen as a discretization of the following SDE

(7)

with time step , and where is the square-root matrix of and a standard Brownian motion.

Since the dynamics of are described by an SDE, the dynamics of can also be described by an SDE which can be obtained from Itô’s lemma, see Section 2.1 of the supplementary material.

Proposition 1.

Under the dynamics of the SDE (16), the vector is the solution of the following SDE

(8)

where is the concatenated vector of and is the Hessian of ( component of ) with respect to .

With the quadratic loss , the SDE (17) is equivalent to

(9)

This is an Ornstein-Uhlenbeck process (mean-reverting process) with time dependent parameters. The additional term is due to the randomness of the mini-batch, it can be seen as a regularization term and could partly explain why SGD gives better generalization errors compared to GD (Kubo et al. (2019), Lei et al. (2018)).

Dynamics of for wide FeedForward neural networks :
In the case of a fully connected feedforward neural network (FFNN hereafter) of depth and widths , Jacot et al. (2018) proved that, with GD, the kernel converges to a kernel that depends only on (number of layers) for all when , where is an upper bound on the training time, under the technical assumption almost surely with respect to the initialization. For SGD, we assume that the convergence result of the NTK holds true as well, this is illustrated empirically in Section 4 but we leave the theoretical proof for future work. With this approximation, the dynamics of for wide networks is given by

where and . This is an Ornstein–Uhlenbeck process whose closed-form expression is given by

(10)

where ; see supplementary material for the proof. So for any (test) input , we have

(11)

where and .

The infinite width approximation with squared loss shows precisely how the kernel controls the training speed and the generalization of the model through equations (10) and (11). This also holds true for other loss functions (e.g. cross-entropy) as NTK is also involved in the training dynamics (17). For deep neural networks, understanding the behaviour of as goes to infinity is thus crucial to understand the training dynamics. More precisely, two concepts are crucial for good training : invertibility of since only an invertible kernel can make the training possible (equations (17) and (10)) and expressiveness of since it is directly involved in the generalization function (equation (11)).

3 Impact of the Initialization and the Activation function on the Neural Tangent Kernel

I this section we study the impact of the initialization and the activation function on the limiting NTK for Fully-connected Feed-forward neural networks (FFNN). More precisely, we prove that only an initialization on the Edge of Chaos (EOC) leads to an invertible NTK for deep neural networks. All other initializations will lead to a trivial non-invertible NTK. We also show that the smoothness of the activation function plays a major role in the behaviour of NTK. To simplify notations, we restrict ourselves to the case

and , since generalization to any function and any is straightforward.

Consider a FFNN of depth , widths , weights and bias . For some input , the forward propagation is given by

(12)

where is the activation function.

We initialize the model with and , where

denotes the normal distribution of mean

and variance

. For some , we denote by the variance of . The convergence of as increases is studied in Lee et al. (2018), Schoenholz et al. (2017) and Hayou et al. (2019). In particular under weak regularity conditions they prove that converges to a point independent of as . Also the asymptotic behaviour of the correlations between and for any two inputs and is driven by ; the authors define the EOC as the set of parameters such that where . Similarly the Ordered, resp. Chaotic, phase is defined by , resp. ; more details are recalled in Section 2 of the supplementary material. It turns out that the EOC plays also a crucial role on the NTK. Let us first define two classes of activation functions.

Definition 1.

Let be a measurable function. Then

  1. is said to be ReLU-like if there exist such that for and for .

  2. is said to be in if , is twice differentiable, and there exist , a partition of and infinitely differentiable functions such that , where is the second derivative of .

The class of ReLU-like activations includes ReLU and Leaky-ReLU, whereas the class includes, among others, Tanh, ELU and SiLU (Swish). The following proposition establishes that any initialization on the Ordered or Chaotic phase, leads to a trivial limiting NTK as the number of layers becomes large.

Proposition 2 (Limiting Neural Tangent Kernel with Ordered/Chaotic Initialization).

Let be either in the ordered or in the chaotic phase. Then, there exist such that

As a result, as goes to infinity, converges to a constant kernel for all . The training is then impossible. Indeed, we have where is the matrix with all elements equal to one, i.e. where

is an orthogonal matrix and

. Hence, is at best degenerate and asymptotically (in ) non invertible. Also, where so that does not converge to as grows, rendering the training impossible. We illustrate empirically this result in Section 4.

Recall that the (matrix) NTK for input data is given by

As shown in Schoenholz et al. (2017) and Hayou et al. (2019), an initialization on the EOC preserves the norm of the gradient as it back-propagates through the network. This means that the terms are of the same order. Hence, it is more convenient to study the average NTK (ANTK hereafter) with respect to the number of layers . The next proposition shows that on the EOC, the ANTK converges to an invertible kernel as at a polynomial rate. Moreover, by choosing an activation function in the class , we can slow the convergence of ANTK with respect to , and therefore train deeper models. This confirms the findings in (Hayou et al., 2019).

Proposition 3 (Neural Tangent Kernel on the Edge of Chaos).

Let be a non-linear activation function and .

  1. If is ReLU-like, then for all . Moreover, there exist such that

  2. If is in , then, there exists such that . Moreover, there exist such that

Since , on the EOC there exists a matrix invertible such that as . Hence, although the NTK grows linearly with , it remains asymptotically invertible. This makes the training possible for deep neural networks when initialized on the EOC, contrarywise to an initialization on the Ordered/Chaotic phase, see Proposition 2). However the limiting kernels carry (almost) no information on and have therefore little expressive power. Interestingly the convergence rate of the ANTK to is slow in ( for ReLU-like activation functions and for activation functions of type ). This means that as grows, the NTK remain expressive compared to the Ordered/Chaotic phase case (exponential convergence rate). This is particularly important for the generalization part (see equation 11). The gain obtained when using smooth activation functions of type means we can train deeper neural networks with this kind of activation functions compared to the ReLU-like activation functions and could explain why ELU and Tanh tend to perform better than ReLU and Leaky-ReLU (see Section 4).

Another important feature of deep neural network which is known to be highly influential is their architecture. The next proposition shows that adding residual connections to a ReLU network causes the NTK to explode exponentially.

Proposition 4.

Consider the following network architecture (FFNN with residual connections)

(13)

with initialization parameters and . Let be the corresponding NTK. Then for all and there exists such that

where and are given by

  • if , then and

  • if , then and

  • if , then and

Proposition 4 shows that the NTK of a ReLU FFNN with residual connections explodes exponentially with respect to . However, the normalised kernel where converges to a limiting kernel similar to with a rate for all . This could potentially explain why residual networks perform better than FFNN (RELU) in many tasks when the initialization is not on the EOC. We illustrate this result in section 4.

4 Experiments

In this section, we illustrate empirically the theoretical results obtained in the previous sections, both in terms of the existence of the convergence of to as and of its rate. Lastly, we confirm the impact of NTK on the overall performance of the model (FFNN), on MNIST and CIFAR10 datasets.

(a) t=0
(b) t=100
(c) t=1000
Figure 1: Ratio for three randomly selected pairs from MNIST dataset as a function of width for three training times , and (training time is measured by SGD updates)

4.1 Convergence of to with SGD

For NTK calculations, we use the Neural Tangent Kernel library (Schoenholz et al., 2019) which is based on JAX. We consider FFNN with equal widths (i.e. ) and we choose randomly three inputs from the MNIST dataset and track the value of the ratio for the different widths , depths and training times . The training time is measured by the number of SGD updates. In Figure 1 we observe that for both depths and the three training times, the ratio converge to 1 as the width increase. However as and grow, the convergence of slows down.

(a) EOC
(b) Ordered phase
(c) FFNN with residual connections
Figure 2: Convergence rates for different initializations and architectures. (a) Edge of Chaos. (b) Ordered phase. (c) Adding residual connections.

4.2 Convergence rate of a goes to infinity

Propositions 2, 3 and 4 give theoretical convergence rates for quantities of the form . We illustrate these results in Figure 2. Figure 3(a) illustrates a convergence rate approximately equal to for ReLU and ELU. Recall that for ELU the exact rate is but one cannot observe experimentally the logarithmic factor. However, ELU does indeed better than ReLU (see Table 1) which might be explained by this factor. Figure 3(b) demonstrates that this convergence occurs at an exponential convergence rate in the Ordered phase for both ReLU and ELU, and Figure 1(c) the convergence rate in the case of FFNN with residual connections. As predicted by Proposition 4, the convergence rate is independent of the parameter .

4.3 Impact of the initialization and smoothness of the activation on the overall performance

We train FFNN of width 300 and depths with SGD and categorical cross-entropy loss. We use SGD for training with batchsize of 64 and a learning rate for and for (this learning rate was found by a grid search of exponential step size 10). For each activation function, we use an initialization on the EOC when it exists, we add the symbol (EOC) after the activation when this is satisfied. We use for ReLU, for ELU and for Tanh. These values are all on the EOC (see Hayou et al. (2019) for more details). Table 1

displays the test accuracy for different activation functions on MNIST and CIFAR10 after 10 and 100 training epochs for depth 300 and width 300 (numerical results for other depths are provided in the supplementary material). Functions in class

(ELU and Tanh) perform much better than ReLU-like activation functions (ReLU, Leaky-Relu with

). Even with Parametric ReLU (PReLU) where the parameter of the leaky-ReLU is also learned by backpropagation, we obtain only a small improvement over ReLU. For activation functions that do not have an EOC, such as Softplus and Sigmoid, we use He initialization for MNIST and Glorot initialization for CIFAR10 (see

He et al. (2015) and Glorot and Bengio (2010)). For Softplus and Sigmoid, the training algorithm is stuck at a low test accuracy

which is the test accuracy of a uniform random classifier with 10 classes.

MNIST CIFAR10
Activation Epoch 10 Epoch 100 Epoch 10 Epoch 100
ReLU (EOC)
LReLU (EOC)
LReLU (EOC)
LReLU (EOC)
PReLU
ELU (EOC) 91.63 2.21 96.07 0.13 33.81 1.55 46.14 1.49
Tanh (EOC)
Softplus
Sigmoid
Table 1: Test accuracy for a FFNN with width 300 and depth 300 for different activation functions on MNIST and CIFAR10. We show test accuracy after 10 epochs and 100 epochs

5 Conclusion

We have shown here that the training dynamics of SGD for deep neural networks can be approximated by a SDE dependent on the NTK. This approximation sheds light on how the NTK impacts the training dynamics: it controls the training rate and the generalization function. Additionally, as the number of layers becomes very large, the NTK (resp. ANTK on the EOC) ’forgets’ the data by converging to some limiting data-independent kernel . More precisely, for an initialization in the Ordered/Chaotic phase, NTK converges exponentially fast to a non-invertible kernel as the number of layers goes to infinity, making training impossible. An initialization on the EOC leads to an invertible ANTK (and NTK) even for an infinite number of layers: the convergence rate is for ReLU-like activation functions and for a class of smooth activation functions. We believe that the NTK is a useful tool to (partially) understand wide deep neural networks even if we are aware of the limitations of such an approach; see, e.g., Chizat and Bach (2018) and Ghorbani et al. (2019).

References

Appendix A Proofs of Section 2: Neural Networks and Neural Tangent Kernel

a.1 Proofs of Subsection 2.1

Lemma 1 (Discretization Error for Full-Batch Gradient Descent).

Assume is -lipschitz, then there exists that depends only on and such that

Proof.

For , we define the stepwise constant system . Let , we have

Therefore,

Moreover, for any , we have

where we have used . Using this result, there exists a constant depending on and such that

Now we have

so we can conclude using Gronwall’s lemma. ∎

a.2 Proofs of Subsection 2.2

Recall that

(14)

where where is a randomly selected batch of size . Then for all