h-detach
Code for h-detach: Modifying the LSTM Gradient Towards Better Optimization - https://arxiv.org/abs/1810.03023
view repo
Recurrent neural networks are known for their notorious exploding and vanishing gradient problem (EVGP). This problem becomes more evident in tasks where the information needed to correctly solve them exist over long time scales, because EVGP prevents important gradient components from being back-propagated adequately over a large number of steps. We introduce a simple stochastic algorithm (h-detach) that is specific to LSTM optimization and targeted towards addressing this problem. Specifically, we show that when the LSTM weights are large, the gradient components through the linear path (cell state) in the LSTM computational graph get suppressed. Based on the hypothesis that these components carry information about long term dependencies (which we show empirically), their suppression can prevent LSTMs from capturing them. Our algorithm prevents gradients flowing through this path from getting suppressed, thus allowing the LSTM to capture such dependencies better. We show significant convergence and generalization improvements using our algorithm on various benchmark datasets.
READ FULL TEXT VIEW PDF
Long Short-Term Memory (LSTM) units have the ability to memorise and use...
read it
Recurrent neural networks can be difficult to train on long sequence dat...
read it
It is a known fact that training recurrent neural networks for tasks tha...
read it
It is well known that it is challenging to train deep neural networks an...
read it
Vanishing (and exploding) gradients effect is a common problem for recur...
read it
Language Models (LMs) are important components in several Natural Langua...
read it
Gradient control plays an important role in feed-forward networks applie...
read it
Code for h-detach: Modifying the LSTM Gradient Towards Better Optimization - https://arxiv.org/abs/1810.03023
Recurrent Neural Networks (RNNs) (Rumelhart et al. (1986); Elman (1990)) are a class of neural network architectures used for modeling sequential data. Compared to feed-forward networks, the loss landscape of recurrent neural networks are much harder to optimize. Among others, this difficulty may be attributed to the exploding and vanishing gradient problem (Bengio et al., 1994; Pascanu et al., 2013) which is more severe for recurrent networks and arises due to the highly ill-conditioned nature of their loss surface. This problem becomes more evident in tasks where training data has dependencies that exist over long time scales.
Due to the aforementioned optimization difficulty, variants of RNN architectures have been proposed that aim at addressing these problems. The most popular among such architectures that are used in a wide number of applications include long short term memory (LSTM,
Hochreiter & Schmidhuber (1997)) and gated recurrent unit (GRU,
Chung et al. (2014)) networks. These architectures mitigate such difficulties by introducing a linear temporal path that allows gradients to flow more freely across time steps. Arjovsky et al. (2016)on the other hand try to address this problem by parameterizing a recurrent neural network to have unitary transition matrices based on the idea that unitary matrices have unit singular values which prevents gradients from exploding/vanishing.
Among the aforementioned RNN architectures, LSTMs are arguably most widely used and it remains a hard problem to optimize them on tasks that involve long term dependencies. Examples of such tasks are copying problem (Bengio et al., 1994; Pascanu et al., 2013), and sequential MNIST (Le et al., 2015), which are designed in such a way that the only way to produce the correct output is for the model to retain information over long time scales.
The goal of this paper is to introduce a simple trick that is specific to LSTM optimization and improves its training on tasks that involve long term dependencies. To achieve this goal, we write out the full back-propagation gradient equation for LSTM parameters and split the composition of this gradient into its components resulting from different paths in the unrolled network. We then show that when LSTM weights are large in magnitude, the gradients through the linear temporal path (cell state) get suppressed (recall that this path was designed to allow smooth gradient flow over many time steps). We show empirical evidence that this path carries information about long term dependencies (see section 3.5) and hence gradients from this path getting suppressed is problematic for such tasks. To fix this problem, we introduce a simple stochastic algorithm that in expectation scales the individual gradient components, which prevents the gradients through the linear temporal path from being suppressed. In essence, the algorithm stochastically prevents gradient from flowing through the -state of the LSTM (see figure 1), hence we call it -detach. Using this method, we show improvements in convergence/generalization over vanilla LSTM optimization on the copying task, transfer copying task, sequential and permuted MNIST, and image captioning.
We begin by reviewing the LSTM roll-out equations. We then derive the LSTM back-propagation equations and by studying its decomposition, identify the aforementioned problem. Based on this analysis we propose a simple stochastic algorithm to fix this problem.
LSTM is a variant of traditional RNNs that was designed with the goal of improving the flow of gradients over many time steps. The roll-out equations of an LSTM are as follows,
(1) | ||||
(2) |
where denotes point-wise product and the gates , , and are defined as,
(3) | ||||
(4) | ||||
(5) | ||||
(6) |
Here and are the cell state and hidden state respectively. Usually a transformation is used as the output at time step (Eg. next word prediction in language model) based on which we can compute the loss for that time step.
An important feature of the LSTM architecture is the linear recursive relation between the cell states as shown in Eq. 1. This linear path allows gradients to flow easily over long time scales. This however is one of the components in the full composition of the LSTM gradient. As we will show next, the remaining components that are a result of the other paths in the LSTM computational graph are polynomial in the weight matrices whose order grows with the number of time steps. These terms cause an imbalance in the order of magnitude of gradients from different paths, thereby suppressing gradients from linear paths of LSTM computational graph in cases where the weight matrices are large.
In this section we derive the back-propagation equations for LSTM network and by studying its composition, we identify a problem in this composition. The back-propagation equation of an LSTM can be written in the following form.
Fix to be an element of the matrix , , or . Define,
(7) |
Then . In other words,
(8) |
where all the symbols used to define and are defined in notation 1 in appendix.
To avoid unnecessary details, we use a compressed definitions of and in the above statement and write the detailed definitions of the symbols that constitute them in notation 1 in appendix. Nonetheless, we now provide some intuitive properties of the matrices and .
The matrix contains components of parameter’s full gradient that arise due to the cell state (linear temporal path) described in Eq. (1) (top most horizontal path in figure 1). Thus the terms in are a function of the LSTM gates and hidden and cell states. Note that all the gates and hidden states
are bounded by definition because they are a result of sigmoid or tanh activation functions. The cell state
on the other hand evolves through a linear recursive equation shown in Eq. (1). Thus it can grow at each time step by at most (element-wise) and its value is bounded by the number of time steps . Thus given a finite number of time steps and finite initialization of , the values in matrix are bounded.The matrix on the other hand contains components of parameter’s full gradient that arise due to the remaining paths. The elements of are a linear function of the weights . Thus the magnitude of elements in can become very large irrespective of the number of time steps if the weights are very large. This problem becomes worse when we multiply s in Eq. (8) because the product becomes polynomial in the weights which can become unbounded for large weights very quickly as the number of time steps grow.
Thus based on the above analysis, we identify the following problem with the LSTM gradient: when the LSTM weights are large, the gradient component through the cell state paths () get suppressed compared to the gradient components through the other paths () due to an imbalance in gradient component magnitudes. We recall that the linear recursion in the cell state path was introduced in the LSTM architecture (Hochreiter & Schmidhuber, 1997) as an important feature to allow gradients to flow smoothly through time. As we show in our ablation studies (section 3.5), this path carries information about long term dependencies in the data. Hence it is problematic if the gradient components from this path get suppressed.
We now propose a simple fix to the above problem. Our goal is to manipulate the gradient components such that the components through the cell state path () do not get suppressed when the components through the remaining paths () are very large (described in the section 2.2). Thus it would be helpful to multiply by a positive number less than 1 to dampen its magnitude. In Algorithm 1 we propose a simple trick that achieves this goal. A diagrammatic form of algorithm 1 is shown in Figure 1. In simple words, our algorithm essentially blocks gradients from flowing through each of the
states independently with a probability
, where is a tunable hyper-parameter. Note the subtle detail in Algorithm 1 (line 9) that the loss at any time step is a function of which is not detached.We now show that the gradient of the loss function resulting from the LSTM forward pass shown in algorithm
1 has the property that the gradient components arising from get dampened.Let and be the analogue of when applying -detach with probability during back-propagation. Then,
where
are i.i.d. Bernoulli random variables with probability
of being 1, and , and and are same as defined in theorem 1.The above theorem shows that by stochastically blocking gradients from flowing through the states of an LSTM with probability , we stochastically drop the term in the gradient components. The corollary below shows that in expectation, this results in dampening the term compared to the original LSTM gradient.
Finally, we note that when training LSTMs with -detach, we reduce the amount of computation needed. This is simply because by stochastically blocking the gradient from flowing through the hidden states of LSTM, less computation needs to be done during back-propagation through time (BPTT).
This task requires the recurrent network to memorize the network inputs provided at the first few time steps and output them in the same order after a large time delay. Thus the only way to solve this task is for the network to capture the long term dependency between inputs and targets which requires gradient components carrying this information to flow through many time steps.
We follow the copying task setup identical to Arjovsky et al. (2016)
(described in appendix). Using their data generation process, we sample 100,000 training input-target sequence pairs and 5,000 validation pairs. We use cross-entropy as our loss to train an LSTM with hidden state size 128 for a maximum of 500-600 epochs. We use the ADAM optimizer with batch-size 100, learning rate 0.001 and clip the gradient norms to 1.
Figure 2 shows the validation accuracy plots for copying task training for (top row) and (bottom row) without -detach (left), and with -detach (middle and right). Each plot contains runs from the same algorithm with multiple seeds to show a healthy sample of variations using these algorithms. For time delay, we see both vanilla LSTM and LSTM with -detach converge to accuracy. For time delay 100 and the training setting used, vanilla LSTM is known to converge to optimal validation performance (for instance, see Arjovsky et al. (2016)). Nonetheless, we note that -detach converges faster in this setting. A more interesting case is when time decay is set to 300 because it requires capturing longer term dependencies. In this case, we find that LSTM training without -detach achieves a validation accuracy of at best while a number of other seeds converge to much worse performance. On the other hand, we find that using -detach with detach probabilities 0.25 and 0.5 achieves the best performance of and converging quickly while being reasonably robust to the choice of seed.
Having shown the benefit of -detach in terms of training dynamics, we now extend the challenge of the copying task by evaluating how well an LSTM trained on data with a certain time delay generalizes when a larger time delay is used during inference. This task is referred as the transfer copying task (Hochreiter & Schmidhuber, 1997). Specifically, we train the LSTM architecture on copying task with delay without -detach and with -detach with probability 0.25 and 0.5. We then evaluate the accuracy of the trained model for each setting for various values of . The results are shown in table 1. We find that the function learned by LSTM when trained with -detach generalize significantly better on longer time delays during inference compared with the LSTM trained without -detach.
T | Vanilla LSTM | h-detach 0.5 | h-detach 0.25 |
---|---|---|---|
200 | 90.72 | ||
400 | 77.76 | ||
500 | 74.68 | ||
1000 | 63.19 | ||
2000 | 51.83 | ||
5000 | 42.35 |
This task is a sequential version of the MNIST classification task (LeCun & Cortes, 2010). In this task, an image is fed into the LSTM one pixel per time step and the goal is to predict the label after the last pixel is fed. We consider two versions of the task: one is which the pixels are read in order (from left to right and top to bottom), and one where all the pixels are permuted in a random but fixed order. We call the second version the permuted MNIST task or pMNIST in short. The setup used for this experiment is as follows. We use 50000 images for training, 10000 for validation and 10000 for testing. We use the ADAM optimizer with different learning rates– 0.001,0.0005 and 0.0001, and a fixed batch size of 100. We train for 200 epochs and pick our final model based on the best validation score. We use an LSTM with 100 hidden units. For -detach, we do a hyper-parameter search on the detach probability in . For both pixel by pixel MNIST and pMNIST, we found the detach hyper-parameter of to perform best on the validation set for both MNIST and pMNIST.
On the sequential MNIST task, vanilla LSTM and training with -detach give an accuracy of and respectively. Although these values are similar, we note that the convergence of our method is much faster and is more robust to the different learning rates of the ADAM optimizer as seen in Figure 3. Refer to appendix (figure 6) for experiments with multiple seeds that shows the robustness of our method to initialization.
In the pMNIST task, we find that training LSTM with -detach gives a test accuracy of which is an improvement over the regular LSTM training which reaches an accuracy of . A detailed comparison of test performance with existing algorithms is shown in table 2.
Method | MNIST | pMNIST |
---|---|---|
Vanilla LSTM | 98.6 | 90.9 |
SAB (Ke et al., 2017) | - | 91.1 |
iRNN Le et al. (2015) | 97.0 | 82.0 |
uRNN (Arjovsky et al., 2016) | 95.1 | 91.4 |
Zoneout (Krueger et al., 2016) | - | 93.1 |
h-detach (ours) | 98.8 | 92.9 |
We now evaluate -detach on an image captioning task which involves using an RNN for generating captions for images. We use the Microsoft COCO dataset (Lin et al., 2014) which contains 82,783 training images and 40,504 validation images. Since this dataset does not have a standard split for training, validation and test, we follow the setting in Karpathy & Fei-Fei (2015) which suggests a split of 80,000 training images and 5,000 images each for validation and test set.
We use two models to test our approach– the Show&Tell encoder-decoder model (Vinyals et al., 2015) which does not employ any attention mechanism, and the ‘Show, Attend and Tell’ model (Xu et al., 2015)
, which uses soft attention. For feature extraction, we use the 2048-dimensional last layer feature vector of a residual network (Resnet
He et al. (2015)) with 152 layers which was pre-trained on ImageNet for image classification. We use an LSTM with 512 hidden units for caption generation. We train both the Resnet and LSTM models using the ADAM optimizer
(Kingma & Ba, 2014) with a learning rate ofand leave the rest of the hyper-parameters as suggested in their paper. We also perform a small hyperparameter search where we find the optimial value of the
-detach parameter. We considered values in the set and pick the optimal value based on the best validation score. Similar to Serdyuk et al. (2018), we early stop based on the validation CIDEr scores and report BLEU-1 to BLEU-4, CIDEr, and Meteor scores.The results are presented in table 3. Training the LSTM with -detach outperforms the baseline LSTM by a good margin for all the metrics and produces the best BLEU-1 to BLEU-3 scores among all the compared methods. Even for the other metrics, except for the results reported by Lu et al. (2017), we beat all the other methods reported. We emphasize that compared to all the other reported methods, -detach is extremely simple to implement and does not add any computational overhead (in fact reduces computation).
Models | B-1 | B-2 | B-3 | B-4 | METEOR | CIDEr |
DeepVS (Karpathy & Fei-Fei, 2015) | ||||||
ATT-FCN (You et al., 2016) | — | |||||
Show & Tell (Vinyals et al., 2015) | — | — | — | |||
Soft Attention (Xu et al., 2015) | — | |||||
Hard Attention (Xu et al., 2015) | — | |||||
MSM (Yao et al., 2017) | ||||||
Adaptive Attention (Lu et al., 2017) | ||||||
TwinNet (Serdyuk et al., 2018) | ||||||
No attention, Resnet152 | ||||||
Soft Attention, Resnet152 | ||||||
No attention, Resnet152 | ||||||
Show&Tell (Our impl.) | ||||||
+ -detach (0.25) | ||||||
Attention, Resnet152 | ||||||
Soft Attention (Our impl.) | ||||||
+ -detach (0.4) |
In this section, we first study the effect of removing gradient clipping in the LSTM training and compare how the training of vanilla LSTM and our method get affected. Getting rid of gradient clipping would be insightful because it would confirm our claim that stochastically blocking gradients through the hidden states
of the LSTM prevent the growth of gradient magnitude. We train both models on pixel by pixel MNIST using ADAM without any gradient clipping. The validation accuracy curves are reported in figure 4 for two different learning rates. We notice that removing gradient clipping causes the Vanilla LSTM training to become extremely unstable. -detach on the other hand seems robust to removing gradient clipping for both the learning rates used. Additional experiments with multiple seeds and learning rates can be found in figure 8 in appendix.Second, we conduct experiments where we stochastically block gradients from flowing through the cell state instead of the hidden state and observe how the LSTM behaves in such a scenario. We refer detaching the cell state as -detach. The goal of this experiment is to corroborate our hypothesis that the gradients through the cell state path carry information about long term dependencies. Figure 5 shows the effect of -detach (with probabilities shown) on copying task and pixel by pixel MNIST task. We notice in the copying task for , learning becomes very slow (figure 5 (a)) and does not converge even after 500 epochs, whereas when not detaching the cell state, even the Vanilla LSTM converges in around 150 epochs for most cases for T=100 as shown in the experiments in section 3.1. For pixel by pixel MNIST (which involves 784 time steps), there is a much larger detrimental effect on learning as we find that none of the seeds cross accuracy at the end of training (Figure 5 (b)). This experiment corroborates our hypothesis that gradients through the cell state contain important components of the gradient signal as blocking them worsens the performance of these models when compared to Vanilla LSTM.
Capturing long term dependencies in data using recurrent neural networks has been long known to be a hard problem (Bengio et al., 1993). Therefore, there has been a considerable amount of work on addressing this issue. Prior to the invention of the LSTM architecture (Hochreiter & Schmidhuber, 1997)
, another class of architectures called NARX (nonlinear autoregressive models with exogenous) recurrent networks
(Lin et al., 1996) was popular for tasks involving long term dependencies. More recently gated recurrent unit (GRU) networks (Chung et al., 2014) was proposed that adapts some favorable properties of LSTM while requiring fewer parameters. Work has also been done towards better optimization for such tasks (Martens & Sutskever, 2011; Kingma & Ba, 2014). Since vanishing and exploding gradient problems
(Bengio et al., 1994) also hinder this goal, gradient clipping methods have been proposed to alleviate this problem (Tomas, 2012; Pascanu et al., 2013). Yet another line of work focuses on making use of unitary transition matrices in order to avoid loss of information as hidden states evolve over time. Le et al. (2015) propose to initialize recurrent networks with unitary weights while Arjovsky et al. (2016) propose a new network parameterization that ensures that the state transition matrix remains unitary. Very recently, Ke et al. (2017) propose to learn an attention mechanism over past hidden states and sparsely back-propagate through paths with high attention weights in order to capture long term dependencies. Trinh et al. (2018) propose to add an unsupervised auxiliary loss to the original objective that is designed to encourage the network to capture such dependencies. We point out that our proposal in this paper is orthogonal to a number of the aforementioned papers and may even be applied in conjunction to some of them. Further, our method is specific to LSTM optimization and reduces computation relative to the vanilla LSTM optimization which is in stark contrast to most of the aforementioned approaches which increase the amount of computation needed for training.In section 3.5 we showed that LSTMs trained with -detach are stable even without gradient clipping. We caution that while this is true, in general the gradient magnitude depends on the value of detaching probability used in -detach. Hence for the general case, we do not recommend removing gradient clipping.
When training stacked LSTMs, there are two ways in which -detach can be used: 1) detaching the hidden state of all LSTMs simultaneously for a given time step depending on the stochastic variable ) stochastically detaching the hidden state of each LSTM separately. We leave this for future work.
-detach stochastically blocks the gradient from flowing through the hidden states of LSTM. In corollary 1, we showed that in expectation, this is equivalent to dampening the gradient components from paths other than the cell state path. We especially chose this strategy because of its ease of implementation in current auto-differentiation libraries. Another approach to dampen these gradient components would be to directly multiply these components with a dampening factor. This feature is currently unavailable in these libraries but may be an interesting direction to look into. A downside of using this strategy though is that it will not reduce the amount of computation similar to -detach (although it will not increase the amount of computation compared with vanilla LSTM either).
We proposed a simple stochastic algorithm called -detach aimed at improving LSTM performance on tasks that involve long term dependencies. We provided a theoretical understanding of the method using a novel analysis of the back-propagation equations of the LSTM architecture. We note that our method reduces the amount of computation needed during training compared to vanilla LSTM training. Finally, we empirically showed that -detach is robust to initialization, makes the convergence of LSTM faster, and/or improves generalization compared to vanilla LSTM (and other existing methods) on various benchmark datasets.
We thank Stanisław Jastrzębski, David Kruger and Isabela Albuquerque for helpful discussions. DA was supported by IVADO.
International Conference on Machine Learning
, pp. 1120–1128, 2016.Proceedings of the IEEE conference on computer vision and pattern recognition
, pp. 3128–3137, 2015.Copying Experiment setup - We define 10 tokens, . The input to the LSTM is a sequence of length formed using one of the ten tokens at each time step. Input for the first time steps are sampled i.i.d. (uniformly) from . The next entries are set to , which constitutes a delay. The next single entry is , which represents a delimiter, which should indicate to the algorithm that it is now required to reproduce the initial input tokens as output. The remaining input entries are set to . The target sequence consists of repeated entries of , followed by the first entries of the input sequence in exactly the same order.
Let us recall the equations from an LSTM
Here denotes the element-wise product, also called the Hadamard product. denotes the sigmoid activation function. . . . .
For any , define to be a matrix of size dim() dim(). We set all the elements of this matrix to s if if w is not an element of . Further, if , then and for all .
Let us assume is an entry of the matrix or , then
Proof
By chain rule of total differentiation,
We note that,
and,
which proves the claim for . The derivation for are similar.
Now let us establish recursive formulas for and , using the above formulas
Considering the above notations, we have
Proof Recall that , and thus
Using the previous Lemma as well as the above notation, we get
Considering the above notations, we have
Proof Recall that , and thus
Using the previous Lemma as well as the above notation, we get
Let us now combine corollary 1 and 2 to get a recursive expression of in terms of and
Considering the above notations, we have
Proof From Corollary 1, we know that
Using Corollary 2, we get
Fix to be an element of the matrix or . Define,
(9) |
Then,
In other words,
where all the symbols used to define and are defined in notation 1.
Proof By Corollary 2, we get
Similarly by Corollary 3, we get
Thus we have
(10) |
Applying this formula recursively proves the claim.
Note: Since has ’s in the second column of the block matrix representation, it ignores the contribution of coming from , whereas (having non-zero block matrices only in the second column of the block matrix representation) only takes into account the contribution coming from . Hence captures the contribution of the gradient coming from the cell state .
Let and be the analogue of when applying -detach with probability during back-propagation. Then,
where are i.i.d. Bernoulli random variables with probability of being 1, and and are same as defined in theorem 1.
Proof Replacing by in lemma 1 and therefore in Corollaries 2 and 3, we get the following analogous equations
and
Similarly as in the proof of previous theorem, we can rewrite
and
Thus
Iterating this formula gives,
Comments
There are no comments yet.