Dynamical Isometry and a Mean Field Theory of LSTMs and GRUs

01/25/2019 ∙ by Dar Gilboa, et al. ∙ 12

Training recurrent neural networks (RNNs) on long sequence tasks is plagued with difficulties arising from the exponential explosion or vanishing of signals as they propagate forward or backward through the network. Many techniques have been proposed to ameliorate these issues, including various algorithmic and architectural modifications. Two of the most successful RNN architectures, the LSTM and the GRU, do exhibit modest improvements over vanilla RNN cells, but they still suffer from instabilities when trained on very long sequences. In this work, we develop a mean field theory of signal propagation in LSTMs and GRUs that enables us to calculate the time scales for signal propagation as well as the spectral properties of the state-to-state Jacobians. By optimizing these quantities in terms of the initialization hyperparameters, we derive a novel initialization scheme that eliminates or reduces training instabilities. We demonstrate the efficacy of our initialization scheme on multiple sequence tasks, on which it enables successful training while a standard initialization either fails completely or is orders of magnitude slower. We also observe a beneficial effect on generalization performance using this new initialization.



There are no comments yet.


page 6

page 7

page 15

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

A common paradigm for research and development in deep learning involves the introduction of novel network architectures followed by experimental validation on a selection of tasks. While this methodology has undoubtedly generated significant advances in the field, it is hampered by the fact that the full capabilities of a candidate model may be obscured by difficulties in the training procedure. It is often possible to overcome such difficulties by carefully selecting the optimizer, batch size, learning rate schedule, initialization scheme, or other hyperparameters. However, the standard strategies for searching for good values of these hyperparameters are not guaranteed to succeed, especially if the trainable configurations are constrained to a low-dimensional subspace of hyperparameter space, which can render random search, grid search, and even Bayesian hyperparameter selection methods unsuccessful.

Figure 1: Critical initialization improves trainability of recurrent networks.

Test accuracy for peephole LSTM trained to classify sequences of MNIST digits after 8000 iterations. As the sequence length increases, the network is no longer trainable with standard initialization, but still trainable using critical initialization.

In this work, we argue that, for long sequence tasks, the trainable configurations of initialization hyperparameters

for LSTMs and GRUs lie in just such a subspace, which we characterize theoretically. In particular, we identify precise conditions on the hyperparameters governing the initial weight and bias distributions that are necessary to ensure trainability. These conditions derive from the observation that in order for a network to be trainable, (a) signals from the relevant parts of the input sequence must be able to propagate all the way to the loss function and (b) the gradients must be stable (i.e. they must not explode or vanish exponentially).

As shown in Figure 1, training of recurrent networks with standard initialization on certain tasks begins to fail as the sequence length increases and signal propagation becomes harder to achieve. However, as we shall show, a suitably-chosen initialization scheme can dramatically improve trainability on such tasks.

We study the effect of the initialization hyperparameters on signal propagation for a very broad class of recurrent architectures, which includes as special cases many state-of-the-art RNN cells, including the GRU (Cho et al., 2014), the LSTM (Hochreiter & Schmidhuber, 1997), and the peephole LSTM (Gers et al., 2002). The analysis is based on the mean field theory of signal propagation developed in a line of prior work (Schoenholz et al., 2016; Xiao et al., 2018; Chen et al., 2018; Yang et al., 2019), as well as the concept of dynamical isometry (Saxe et al., 2013; Pennington et al., 2017, 2018)

that is necessary for stable gradient backpropagation and which was shown to be crucial for training simpler RNN architectures 

(Chen et al., 2018). We perform a number of experiments to corroborate the results of the calculations and use them to motivate initialization schemes that outperform standard initialization approaches on a number of long sequence tasks.

2 Background and related work

2.1 Mean field analysis of neural networks

Signal propagation at initialization can be controlled by varying the hyperparameters of fully-connected (Schoenholz et al., 2016; Yang & Schoenholz, 2017) and convolutional (Xiao et al., 2018) feed-forward networks, as well as for simple gated recurrent architectures (Chen et al., 2018)

. In all these cases, such control was used to obtain initialization schemes that outperformed standard initializations on benchmark tasks. In the feed-forward case, this enabled the training of very deep architectures without the use of batch normalization or skip connections.

By forward signal propagation, we specifically refer to the persistence of correlations between the hidden states of networks with different inputs as a function of time (or depth in the feed-forward case), as will be made precise in Section 4.2. Backward signal propagation depends not only on the norm of the gradient, but also on its stability, which is governed by the state-to-state Jacobian matrix, as discussed in (Bengio et al., 1994)

. In our context, the goal of the backward analysis is to enhance the conditioning of the Jacobian by controlling the first two moments of its squared singular values. Forward signal propagation and the spectral properties of the Jacobian at initialization can be studied using mean field theory and random matrix theory

(Poole et al., 2016; Schoenholz et al., 2016; Yang & Schoenholz, 2017, 2018; Xiao et al., 2018; Pennington et al., 2017, 2018; Chen et al., 2018; Yang et al., 2019). More generally, mean field analysis is also emerging as a promising tool for studying the dynamics of learning in neural networks and even obtaining generalization bounds in some settings (Mei et al., 2018).

While extending an analysis at initialization to a trained network might appear hopeless due to the complexity of the training process, intriguingly, it was recently shown that in the infinite width, continuous time limit neural networks exhibit certain invariances during training (Jacot et al., 2018), further motivating the study of networks at initialization. In fact, the general strategy of proving the existence of some property beneficial for training at initialization and controlling it during the training process is the core idea behind a number of recent global convergence results for over-parametrized networks (Du et al., 2018; Allen-Zhu et al., 2018), some of which (Allen-Zhu et al., 2018) also rely explicitly on control of forward and backward signal propagation (albeit not defined in the exact sense as in this work).

As neural network training is a nonconvex problem, using a modified initialization scheme could lead to convergence to different points in parameter space in a way that adversely affects the generalization error. We provide some empirical evidence that this does not occur, and in fact, the use of initialization schemes satisfying these conditions has a beneficial effect on the generalization error.

2.2 The exploding/vanishing gradient problem and signal propagation in recurrent networks

The exploding/vanishing gradient problem is a well-known phenomenon that hampers training on long time sequence tasks

(Bengio et al., 1994; Pascanu et al., 2013). Apart from the gating mechanism, there have been numerous proposals to alleviate the vanishing gradient problem by constraining the weight matrices to be exactly or approximately orthogonal (Pascanu et al., 2013; Wisdom et al., 2016; Vorontsov et al., 2017; Jose et al., 2017), or more recently by modifying some terms in the gradient (Arpit et al., 2018), while exploding gradients can be handled by clipping (Pascanu et al., 2013). Another recently proposed approach to ensuring signal propagation in long sequence tasks introduces auxiliary loss functions (Trinh et al., 2018). This modification of the loss can be seen as a form of regularization. Chang et al. (2019)

study the connections between recurrent networks and certain ordinary differential equations and propose the AntisymmetricRNN that can capture long term dependencies in the inputs. While many of these approaches have been quite successful, they typically require modifying the training algorithm, the loss function, or the architecture, and as such exist as complementary methods to the one we investigate here. We postpone the investigation of a combination of techniques to future work.

3 Notation

Vanilla RNN Minimal RNN (Chen et al., 2018) GRU (Cho et al., 2014) peephole LSTM (Gers et al., 2002) LSTM (Hochreiter & Schmidhuber, 1997) . . . . Id
Table 1: A number of recurrent architectures written in the form 1. The LSTM cell state is unrolled in order to emphasize that it can be written as a function of variables that are Gaussian at the large limit.

We denote matrices by bold upper case Latin characters and vectors by bold lower case Latin characters.

denotes a standard Gaussian measure. The normalized trace of a random matrix , , is denoted by . is the Hadamard product,

is a sigmoid function and both

act element-wise. We denote by a diagonal matrix with on the diagonal.

4 Mean field analysis of signal propagation and dynamical isometry

4.1 Model description and important assumptions

We present a general setup in this section and subsequently specialize to the case of the GRU, LSTM and peephole LSTM. This section follows closely the development in (Chen et al., 2018). We denote the state of a recurrent network at time by with , and a sequence of inputs to the network111We assume for convenience of exposition but this does not limit the generality of the argument. We could for instance work with inputs for some and define a feature map . by . . We also define sets of subscripts and pre-activations defined by

where . We define additional variables given by
where is an element-wise function and is defined as in eqn. (1a) 222For the specific instances we consider in detail, variables of the form 1b will be present only in the GRU (see Table 1).. In cases where there is no need to distinguish between variables of the form 1a and 1b we will refer to both as . The state evolution of the network is given by

where is an element-wise, affine function of . The output of the network at every time is given by . These dynamics describe all the architectures studied in this paper, as well as many others, as detailed in Table 1. In the case of the peephole LSTM and GRU, eqn. (1c) will be greatly simplified.

In order to understand the properties of a neural network at initialization, we take the weights of the network to be random variables. Specifically, we assume

i.i.d. and denote . As in (Chen et al., 2018), we make the untied weights assumption . Tied weights would increase the autocorrelation of states across time, but this might be dominated by the increase due to the correlations between input tokens when the latter is strong. Indeed, we provide empirical evidence that calculations performed under this assumption still have considerable predictive power in cases where it is violated.

4.2 Forward signal propagation

We now consider two sequences of normalized inputs with zero mean and covariance fed into two copies of a network with identical weights, and resulting in sequences of states . We are interested in the moments and correlations


where we define (and analogously).

At the large limit we can invoke a CLT to find


where the second moment is given by

and the correlations by and hence

The variables given by 1b are also asymptotically Gaussian, with their covariance detailed in Appendix A for brevity of exposition. We conclude that


and are distributed analogously with respect to . We will subsequently drop the vector subscript since all elements are identically distributed and act element-wise, and the input sequence index in expressions that involve only one sequence.

For any , the are independent of at the large limit. Combining this with the fact that their distribution is determined completely by and that is affine, one can rewrite eqn. (2) using eqn. (1) as the following deterministic dynamical system


where the dependence on and the data distribution has been suppressed. In the peephole LSTM and GRU, the form will be greatly simplified to . In Appendix C we compare the predicted dynamics defined by eqn. (6) to simulations, showing good agreement.

One can now study the fixed points of eqn. (6) and the rate of convergence to these fixed points by linearizing around them. The fixed points are pathological, in the sense that any information that distinguishes two input sequences is lost upon convergence. Therefore, delaying the convergence to the fixed point should allow for signals to propagate across longer time horizons. Quantitatively, the rate of convergence to the fixed point gives an effective time scale for forward signal propagation in the network.

While the dynamical system is multidimensional and analysis of convergence rates should be performed by linearizing the full system and studying the smallest eigenvalue of the resulting matrix, in practice as in

(Chen et al., 2018) this eigenvalue appears to always corresponds to the direction333This observation is explained in the case of fully-connected feed-forward networks in (Schoenholz et al., 2016) . Hence, if we assume convergence of we need only linearize


where also depends on expectations of functions of that do not depend on . While this dependence is in principle on an infinite number of Gaussian variables as the dynamics approach the fixed point, can still be reasonably approximated in the case of the LSTM as detailed in Section 4.5, while there is no such dependence for the peephole LSTM and GRU (See Table 1).

We study the dynamics approaching the fixed point by setting and writing

and since ,


We show that in the case of the peephole LSTM this map is convex in Appendix B. This can be shown for the GRU by a similar argument. It follows directly that it has a single stable fixed point in these cases.

The time scale of convergence to the fixed point is given by


which diverges as approaches 1 from below. Due to the detrimental effect of convergence to the fixed point described above, it stands to reason that a choice of such that for some small would enable signals to propagate from the initial inputs to the final hidden state when training on long sequences.

4.3 Backwards signal propagation - the state-to-state Jacobian

We now turn to controlling the gradients of the network. A useful object to consider in this case is the asymptotic state-to-state transition Jacobian

This matrix and powers of it will appear in the gradients of the output with respect to the weights as the dynamics approach the fixed point (specifically, the gradient of a network trained on a sequence of length will depend on a matrix polynomial of order in ), hence we desire to control the squared singular value distribution of this matrix. The moments of the squared singular value distribution are given by the normalized traces

Since is independent of , if we index by the variables defined by eqn. (1a) and (1b) respectively we obtain


Under the untied assumption are independent of at the large limit, and are also independent of each other and their elements have mean zero. Using this and the fact that acts element-wise, we have




and . The values of for the architectures considered in the paper are detailed in Appendix A. Forward and backward signal propagation are in fact intimately related, as the following lemma shows:

Lemma 1.

For a recurrent neural networks defined by (1), the mean squared singular value of the state-to-state Jacobian defined in (11) and that determines the time scale of forward signal propagation (given by (8)) are related by


See Appendix B. ∎

Controlling the first moment of

is not sufficient to ensure that the gradients do not explode or vanish, since the variance of the singular values may still be large. This variance is a function of the first two moments and is given by

The second moment can be calculated from (10), and is given by


where the are defined in (12). One could compute higher moments as well either explicitly or using tools from non-Hermitian random matrix theory.

4.4 Dynamical Isometry

The objective of the analysis is to ensure, at least approximately, that the following equations are satisfied simultaneously


We refer to these as dynamical isometry conditions. We demand that these equations are only satisfied approximately since for a given architecture, there may not be a value of that satisfies all the conditions. Additionally, even if such a value exists, the optimal value of for a given task may not be 1. There is some empirical evidence that if the characteristic time scale defined by is much larger than that required for a certain task, performance is degraded. In feed-forward networks as well, there is evidence that the optimal performance is achieved when the dynamical isometry conditions are only approximately satisfied (Pennington et al., 2017). Accounting for this observation is an open problem at present.

Convexity of the map (7) in the case of the peephole LSTM and GRU implies that for , conditions (15a) and (15b) can only be satisfied simultaneously if (Since (15b) implies ).

For all the architectures considered in this paper, we find that while are finite as . Combining this with (11), (13), (14), we find that if the dynamical isometry conditions are satisfied if , which can be achieved by setting and taking . This motivates the general form of the initializations used in Section 5.2 444In the case of the LSTM, we also want to prevent the output gate from taking very small values, as explained in Appendix A.2. although there are many other possible choices of such that the vanish.

4.5 The LSTM cell state distribution

   {Using 4}
   for to do
   {Using 3}
   {Using 16}
   end for {Using 6}
Algorithm 1 LSTM hidden state moment fixed point iteration using cell state sampling

In the case of the peephole LSTM, since the depend on the second moments of and is an affine function of , one can write a closed form expression for the dynamical system (6) in terms of first and second moments. In the standard LSTM however, the relevant state depends on the cell state, which has non-trivial dynamics.

Figure 2:

Training accuracy on the padded MNIST classification task described in

5.1.1 at different sequence lengths and hyperparameter values for networks with untied weights, with different values of chosen for each architecture. The green curves are where is the theoretical signal propagation time scale. As can be seen, this time scale predicts the transition between the regions of high and low accuracy across the different architectures.

The cell state differs substantially from other random variables that appear in this analysis since it cannot be expressed as a function of a finite number of variables that are Gaussian at the large and limit (see Table 1). Since at this limit the are independent, by examining the cell state update equation


we find that the asymptotic cell state distribution is that of a perpetuity, which is a random variable that obeys where are random variables and denotes equality in distribution. The stationary distributions of a perpetuity, if it exists, is known to exhibit heavy tails (Goldie, 1991). Aside from the tails, the bulk of the distribution can take a variety of different forms and can be highly multimodal, depending on the choice of which in turn determines the distributions of .

In practice, one can overcome this difficulty by sampling from the stationary cell state distribution, despite the fact that the density and even the likelihood have no closed form. For a given value of , the variables appearing in (16) can be sampled since their distribution is given by (5) at the large limit. The update equation (16) can then be iterated and the resulting samples approximate well the stationary cell state distribution for a range of different choices of , which result in a variety of stationary distribution profiles (see Appendix C3). The fixed points of (6) can then be calculated numerically as in the deterministic cases, yet care must be taken since the sampling introduces stochasticity into the process. An example of the fixed point iteration equation (6a) implemented using sampling is presented in Algorithm 1. The correlations between the hidden states can be calculated in a similar fashion. In practice, once the number of samples and sampling iterations is of order reasonably accurate values for the moment evolution and the convergence rates to the fixed point are obtained (see for instance the right panel of Figure 2). The computational cost of the sampling is linear in both (as opposed to say simulating a neural network directly in which case the cost is quadratic in ).

5 Experiments

5.1 Corroboration of calculations

5.1.1 Padded MNIST Classification

The calculations presented above predict a characteristic time scale (defined in (9)) for forward signal propagation in a recurrent network. It follows that on a task where success depends on propagation of information from the first time step to the final -th time step, the network will not be trainable for . In order to test this prediction, we consider a classification task where the inputs are sequences consisting of a single MNIST digit followed by steps of i.i.d Gaussian noise and the targets are the digit labels. By scanning across certain directions in hyperparameter space, the predicted value of changes. We plot training accuracy of a network trained with untied weights after 1000 iterations for the GRU and 2000 for the LSTM, as a function of and the hyperparameter values, and overlay this with multiples of . As seen in Figure 2, we observe good agreement between the predicted time scale of signal propagation and the success of training. As expected, there are some deviations when training without enforcing untied weights, and we present the corresponding plots in the supplementary materials.

5.1.2 Squared Jacobian spectrum histograms

To verify the results of the calculation of the moments of the squared singular value distribution of the state-to-state Jacobian presented in Section 4.3 we run an untied peephole LSTM for 100 iterations with i.i.d. Gaussian inputs. We then compute the state-to-state Jacobian and calculate its spectrum. This can be used to compare the first two moments of the spectrum to the result of the calculation, as well as to observe the difference between a standard initialization and one close to satisfying the dynamical isometry conditions. The results are shown in Figure 3. The validity of this experiment rests on making an ergodicity assumption, since the calculated spectral properties require taking averages over realizations of random matrices, while in the experiment we instead calculate the moments by averaging over the eigenvalues of a single realization. The good agreement between the prediction and the empirical average suggests that the assumption is valid.

Figure 3: Squared singular values of the state-to-state Jacobian in eqn. (10) for two choices of hyperparameter settings

. The red lines denote the empirical mean and standard deviations, while the dotted lines denote the theoretical prediction based on the calculation described in Section 

4.3. Note the dramatic difference in the spectrum caused by choosing an initialization that approximately satisfies the dynamical isometry conditions.
Figure 4: Training accuracy for unrolled, concatenated MNIST digits (top) and unrolled MNIST digits with replicated pixels (bottom) for different sequence lengths. Left: For shorter sequences the standard and critical initialization perform equivalently. Middle: As the sequence length is increased, training with a critical initialization is faster by orders of magnitude. Right: For very long sequence lengths, training with a standard initialization fails completely.

5.2 Long sequence tasks

One of the significant results of the calculation is that the results motivate critical initializations that dramatically improve the performance of recurrent networks on standard long sequence benchmarks, despite the fact that the calculation is performed using the untied assumption, at the large limit, and makes some rather unrealistic assumptions about the data distribution. The details of the initializations used are presented in Appendix C.

5.2.1 Unrolled MNIST and CIFAR-10

We unroll an MNIST digit into a sequence of length 786 and train a critically initialized peephole LSTM with 600 hidden units. We also train a critically initialized LSTM with hard sigmoid nonlinearities on unrolled CIFAR-10 images feeding in 3 pixels at every time step, resulting in sequences of length 1024. We also apply standard data augmentation for this task. We present accuracy on the test set in Table

2. Interestingly, in the case of CIFAR-10 the best performance is achieved by an initialization with a forward propagation time scale that is much smaller than the sequence length, suggesting that information sufficient for successful classification may be obtained from a subset of the sequence.

standard LSTM 98.6 555reproduced from (Arpit et al., 2018) 58.8 666reproduced from (Trinh et al., 2018)
h-detach LSTM (Arpit et al., 2018) 98.8 -
critical LSTM 98.9 61.8
Table 2: Test accuracy on unrolled MNIST and CIFAR-10.

5.2.2 Repeated pixel MNIST and multiple digit MNIST

In order to generate longer sequence tasks, we modify the unrolled MNIST task by repeating every pixel a certain number of times and set the input dimension to 7. To create a more challenging task, we also combine this pixel repetition with concatenation of multiple MNIST digits (either 0 or 1), and label such sequences by a product of the original labels. In this case, we set the input dimension to 112 and repeat each pixel 10 times. We train a peephole LSTM with both a critical initialization and a standard initialization on both of these tasks using SGD with momentum. In this former task, the dimension of the label space is constant (and not exponential in the number of digits like in the latter). In both tasks, we observe three distinct phases. If the sequence length is relatively short the critical and standard initialization perform equivalently. As the sequence length is increased, training with a critical initialization is faster by orders of magnitude compared to the standard initialization. As the sequence length is increased further, training with a standard initialization fails, while training with a critical initialization still succeeds. The results are shown in Figure 4.

6 Discussion

In this work, we calculate time scales of signal propagation and moments of the state-to-state Jacobian at initialization for a number of important recurrent architectures. The calculation then motivates initialization schemes that dramatically improve the performance of these networks on long sequence tasks by guaranteeing long time scales for forward signal propagation and stable gradients. The subspace of initialization hyperparameters that satisfy the dynamical isometry conditions is multidimensional, and there is no clear principled way to choose a preferred initialization within it. It would be of interest to study this subspace and perhaps identify preferred initializations based on additional constraints. One could also use the satisfaction of the dynamical isometry conditions as a guiding principle in simplifying these architectures. A direct consequence of the analysis is that the forget gate, for instance, is critical, while some of the other gates or weights matrices can be removed while still satisfying the conditions. A related question is the optimal choice of the forward propagation time scale for a given task. As mentioned in Section 5.2.1, this scale can be much shorter than the sequence length. It would also be valuable understand better the extent to which the untied weights assumption is violated, since it appears that the violation is non-uniform in , and to relax the constant assumption by introducing a time dependence.

A natural direction of future work is to attempt an analogous calculation for multi-layer LSTM and GRU networks, or for more complex architectures that combine LSTMs with convolutional layers.

Another compelling issue is the persistence of the dynamical isometry conditions during training and their effect on the solution, for both feed-forward and recurrent architectures. Intriguingly, it has been recently shown that in the case of certain infinitely-wide MLPs, objects that are closely related to the correlations and moments of the Jacobian studied in this work are constant during training with full-batch gradient descent in the continuous time limit, and as a result the dynamics of learning take a simple form (Jacot et al., 2018). Understanding the finite width and learning rate corrections to such calculations could help extend the analysis of signal propagation at initialization to trained networks. This has the potential to improve the understanding of neural network training dynamics, convergence and ultimately perhaps generalization as well.


Appendix A Details of Results

a.1 Covariances of

The variables given by 1b are asymptotically Gaussian at the , with


where is a Gaussian measure on corresponding to the distribution in eqn. (3).

a.2 Dynamical isometry conditions for selected architectures

We specify the form of and for the architectures considered in this paper:

a.2.1 Gru

a.2.2 peephole LSTM

a.2.3 Lstm

When evaluating 8 in this case, we write the cell state as , and assume for large . The stability of the first equation and the accuracy of the second approximation are improved if is not concentrated around 0.

Appendix B Auxiliary Lemmas and Proofs

Proof of Lemma 1.

Despite the fact that each as defined in 1 depends in principle upon the entire state vector , at the large N limit due to the isotropy of the input distribution we find that these random variables are i.i.d. and independent of the state. Combining this with the fact that is an element-wise function, it suffices to analyse a single entry of , which at the large limit gives

where and is defined similarly (i.e. we assume the first two moments have converged but the correlations between the sequences have not, and in cases where depends on a sequence of we assume the constituent variables have all converged in this way). We represent via a Cholesky decomposition as


where i.i.d. We thus have . Combining this with the fact that for any , integration by parts gives for any


Denoting and defining similarly, we have


where in the last equality we used 19. Using 19 again gives