predrnn-pytorch
Official implementation for NIPS'17 paper: PredRNN: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal LSTMs.
view repo
The predictive learning of spatiotemporal sequences aims to generate future images by learning from the historical context, where the visual dynamics are believed to have modular structures that can be learned with compositional subsystems. This paper models these structures by presenting PredRNN, a new recurrent network, in which a pair of memory cells are explicitly decoupled, operate in nearly independent transition manners, and finally form unified representations of the complex environment. Concretely, besides the original memory cell of LSTM, this network is featured by a zigzag memory flow that propagates in both bottom-up and top-down directions across all layers, enabling the learned visual dynamics at different levels of RNNs to communicate. It also leverages a memory decoupling loss to keep the memory cells from learning redundant features. We further improve PredRNN with a new curriculum learning strategy, which can be generalized to most sequence-to-sequence RNNs in predictive learning scenarios. We provide detailed ablation studies, gradient analyses, and visualizations to verify the effectiveness of each component. We show that our approach obtains highly competitive results on three standard datasets: the synthetic Moving MNIST dataset, the KTH human action dataset, and a radar echo dataset for precipitation forecasting.
READ FULL TEXT VIEW PDFOfficial implementation for NIPS'17 paper: PredRNN: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal LSTMs.
None
As a key application of predictive learning, generating future frames from historical consecutive frames has received growing interest in machine learning and computer vision communities. It benefits many practical applications and downstream tasks, such as the precipitation forecasting
[shi2015convolutional, wang2017predrnn], traffic flow prediction [xu2018predcnn, wang2019memory], physical scene understanding
[wu2017learning, van2018relational, kipf2018neural, xu2019unsupervised], early activity recognition [wang2019eidetic], deep reinforcement learning
[ha2018recurrent, hafner2018learning], and the vision-based model predictive control [finn2017deep, ebert2017self]. Many of these existing approaches suggested leveraging Recurrent Neural Networks (RNNs) [rumelhart1988learning, werbos1990backpropagation]with stacked Long Short-Term Memory (LSTM) units
[hochreiter1997long]to capture the temporal dependencies of spatiotemporal data. Such an architecture is mainly inspired by similar ideas from other well-explored tasks of sequential data, such as neural machine translation
[Sutskever2011Generating, Cho2014On], speech recognition [Graves2014Towards], video action recognition [Ng15, donahue2015long], and video captioning [donahue2015long].For stacked LSTMs, a network structure named memory cell
plays an important role in alleviating the vanishing gradient problem of RNNs. Strong theoretical and empirical evidence has shown that it can latch the gradients of hidden states inside each LSTM unit in the training process and thereby preserve valuable information of underlying temporal dynamics
[hochreiter1997long]. However, the state transition pathway of LSTM memory cells may not be optimal for the spatiotemporal predictive learning, as this task requires different focuses on the learned representations in many aspects from other tasks of sequential data. First, most predictive networks for language or speech modeling [Sutskever2011Generating, Cho2014On, Graves2014Towards] focus on capturing the long-term, non-Markovian properties of sequential data, rather than spatial deformations of visual appearance. But for future frames prediction, both data structures in space-time are crucial and need to be carefully considered. Second, in other supervised tasks of video data, such as action recognition, high-level semantical features can be informative enough, and the low-level features are less important to final outputs. Due to the absence of complex structures of supervision signals, the stacked LSTMs don’t need to preserve fine-grained representations from the bottom up. Although the existing recurrent architecture based on inner-layer memory transitions can be sufficient to capture temporal variations at each level of the network, it may not be the best choice for predictive learning, where low-level details and high-level semantics of spatiotemporal data are both significant to generating future frames.To jointly model the spatial correlations and temporal dynamics at different levels of RNNs, we propose a new memory-prediction framework named Predictive Recurrent Neural Network (PredRNN), which extends the inner-layer transition function of memory states in LSTMs to spatiotemporal memory flow. The spatiotemporal memory flow traverses all nodes of PredRNN in a zigzag path of bi-directional hierarchies: At each timestamp, the low-level information is delivered vertically from the input to the output via a newly-designed memory cell, while at the top layer, the spatiotemporal memory flow brings the high-level memory state to the bottom layer at the next timestamp. By this means, the top-down expectations interact with the bottom-up signals to both analyse those inputs and generate predictions of subsequent expected inputs, which is different from stacked LSTMs, where the memory state is latched inside each individual recurrent unit.
Accordingly, we define the central building block of PredRNN as the Spatiotemporal LSTM (ST-LSTM), in which the proposed spatiotemporal memory flow interacts with the original, unidirectional memory state of LSTMs. The intuition is that if we expect a vivid imagination of multiple future images, we need a unified memory mechanism to cope with both short-term deformations of spatial details and long-term dynamics: On one hand, the new design of the spatiotemporal memory cell enables the network to learn complex transition functions within short neighborhoods of consecutive frames by increasing the depth of non-linear neurons between time-adjacent RNN states. It thus significantly improves the modeling capability of ST-LSTM for short-term dynamics. On the other hand, ST-LSTM still retains the temporal memory cell of LSTMs and closely combines it with the newly proposed spatiotemporal memory cell, in pursuit of both long-term coherence of hidden states and their quick response to short-term dynamics.
This journal paper extends our previous work [wang2017predrnn]
in a number of aspects: First, we demonstrate that decoupling the interlayer spatiotemporal memory and inner-layer temporal memory in latent space can greatly improve the predictive ability of PredRNN, which is largely inspired by the theoretical argument that distributed representations can bring a potentially exponential advantage
[pascanu2013number], if this matches properties of the underlying data distribution. In other words, we explicitly train the two memory cells to focus on different aspects of spatiotemporal variations. Second, we introduce a new curriculum learning strategy inspired by scheduled sampling [bengio2015scheduled] to improve the consistency of sequence-to-sequence modeling. The motivation is that, for sequence prediction, the current approach to training RNNs consists of maximizing the likelihood of each frame in the sequence given the previous ground truth and the previous generation. According to scheduled sampling, the forecasting part of RNNs gently changes the training process from a fully guided scheme using the true previous frame, towards a less guided scheme which mostly uses the generated one instead. However, it only considers the discrepancy between training and inference for the forecasting part of RNNs, while the encoding part always takes true frames in the input sequence as the prediction context. Such a training approach hampers the encoder to learn more complex non-Markovian dependencies in the long term. To solve this problem, we propose a “reverse” curriculum learning strategy for sequence prediction. As opposed to scheduled sampling, it gradually changes the training process of the PredRNN encoder from using the previously generated frame to using the previous ground truth.Our approach achieves state-of-the-art performance on three datasets for spatiotemporal prediction: a synthetic dataset with moving digits, the KTH dataset with human motions, and a precipitation forecasting dataset with natural radar echo sequences. In addition, we perform more ablation studies to understand the effectiveness of each component of PredRNN. We release the code^{1}^{1}1Code available at https://github.com/thuml/predrnn-pytorch of our approach to facilitate future research.
For spatiotemporal predictive learning, different inductive biases are encoded into deep networks by using different network architectures. In general, the mainstream models can be divided into three groups: convolutional models [Oh2015Action, Mathieu2015Deep, tulyakov2018mocogan], recurrent models [srivastava2015unsupervised, shi2015convolutional, babaeizadeh2017stochastic]
, and others including the Transformer-like, autoregressive models and flow-based methods
[weissenborn2019scaling, kumar2019videoflow].The use of convolutional layers has introduced the inductive bias of group invariance over space to spatiotemporal predictive learning. Oh et al. [Oh2015Action]
defined an action-conditioned autoencoder with convolutions for next-frame prediction in Atari games. Later on, in particular video prediction scenarios, the Generative Adversarial Networks (GANs)
[Goodfellow2014Generative, Denton2015Deep] are jointly used with the convolutional architectures [Mathieu2015Deep, bhattacharjee2017temporal, liang2017dual, tulyakov2018mocogan, wu2020future], which managed to reduce the uncertainty of the learning process and thus improve the sharpness of the generated frames. By stacking convolutional layers, these models learn complicated state transition functions as compositions of simpler ones, thus being able to capture very short-term variations between adjacent frame pairs.Recent advances in RNNs provide some useful insights on how to predict future visual sequences based on historical observations. Ranzato et al. [Ranzato2014Video] defined an RNN architecture inspired by language modeling, predicting future frames in a discrete space of patch clusters. Srivastava et al. [srivastava2015unsupervised] adopted a sequence-to-sequence LSTM model from neural machine translation [Sutskever2011Generating] as a solution to video prediction. Denton et al. [denton2017unsupervised] introduced a new adversarial loss in the recurrent predictive learning framework to learn representations that can factorize each frame into a stationary part and a temporally varying component. Babaeizadeh et al. [babaeizadeh2017stochastic] combined RNNs and Variational AutoEncoders (VAEs) to cope with the uncertainty of the spatiotemporal sequences, i.e., the multi-modal mappings from the historical observations to future frames. Many other approaches based on RNNs have been proposed to learn disentangled representations from videos in an unsupervised predictive manner [Villegas2017Decomposing, hsieh2018learning, guen2020disentangling], or to provide probabilistic solutions to modeling the past-future uncertainties [denton2018stochastic, lee2018stochastic, villegas2019high, castrejon2019improved, franceschi2020stochastic]. The main contribution of this paper is to improve the basic architecture and the training approach of RNN-based predictive models.
To combine the advantages of convolutions and recurrent architectures, Shi et al. [shi2015convolutional] replaced matrix multiplication by convolutions in both input-to-state and state-to-state transitions of the original LSTM unit [hochreiter1997long]. The Convolutional LSTM (ConvLSTM) layer unifies the learning processes of visual appearances and temporal dynamics, and provides a foundation for the follow-up research work [patraucean2015spatio, Finn2016Unsupervised, Lotter2016Deep, Kalchbrenner2016Video, shi2017deep, wang2018predrnn++, byeon2018contextvp, oliu2018folded, xu2018structure, wang2019memory, wang2019eidetic, su2020convolutional]. Finn et al. [Finn2016Unsupervised] designed a network based on ConvLSTMs for visual planning and control. Lotter et al. [Lotter2016Deep] presented a deep network based on ConvLSTMs for next-frame prediction (instead of sequence prediction). In addition to the generated image, it also learns to produce an error term that is then fed back into the network for future prediction. Kalchbrenner et al. [Kalchbrenner2016Video] proposed the Video Pixel Network (VPN) that uses ConvLSTMs and PixelCNNs [van2016conditional] to unfold an image into a sequence of pixels and capture the correlation between the pixel values. This method achieves good performance on synthetic datasets such as Moving MNIST, but it also suffers from high computational complexity. Shi et al. [shi2017deep] improved their previous work with TrajGRU, which allows non-local receptive fields in the convolutional state-to-state transitions. This approach was shown particularly effective for precipitation nowcasting. More recently, Su et al. [su2020convolutional] further extended ConvLSTM to long-term prediction with an efficient mechanism in terms of computation and memory requirements to combine convolutional features across time.
To sum up, due to different inductive biases in the architecture design, the above deep networks have different preferences for specific properties of spatiotemporal variations. Deep convolutional networks implicitly assume complex changes in spatial appearance [Oh2015Action, Mathieu2015Deep, tulyakov2018mocogan], as well as the Markov property over time (also characterized as “memorylessness” [gagniuc2017markov]), and may therefore fall short in learning long-term dependencies in dynamic systems. The recurrent networks learn to model the spatiotemporal non-Markovian properties with LSTMs [srivastava2015unsupervised, villegas2017learning, denton2018stochastic, wichers2018hierarchical], ConvLSTMs [shi2015convolutional, babaeizadeh2017stochastic], or temporally skip connections [ebert2017self]. However, the quality of the generated images is largely limited by a suboptimal combination of convolutions and recurrent state transitions. In this paper, we improve upon the prior work and mainly discuss how to leverage a new set of recurrent memory states (i.e., spatiotemporal memory flow) to unify the inductive biases of convolutions, deep networks, and recurrent networks in short- and long-term dynamics modeling.
Suppose we are monitoring a dynamical system (e.g., a video clip) of measurements over time, where each measurement (e.g., an RGB channel) is recorded at all locations in a spatial region represented by an grid (e.g., video frames). From the spatial view, the observation of these
measurements at any time can be represented by a tensor of
. To get a better picture of the 3D tensors, we may imagine them as vectors standing on a spatial grid. From the temporal view, the observations over
timestamps form a sequence of , which can be denoted as. The spatiotemporal predictive learning problem is to predict the most probable length-
sequence in the future, , given the observation sequence . In this paper, we train neural networks parametrized byto solve such tasks. Concretely, we use stochastic gradient descent to find a set of parameters
that maximizes the log likelihood of producing the true target sequence given the input data for all training pairs :(1) |
In this paper, we take video prediction as a typical predictive learning domain, where the observed data at each timestamp is an RGB video frame, in the sense that the number of measured channels is . Another experimental domain is precipitation nowcasting, where the observed data at each timestamp is a gray-scale radar echo map in a certain area^{2}^{2}2 We usually use a color transform method based on radar echo intensity to visualize the observed/forecasted data as a color image..
Compared with the standard LSTM that consists of fully-connected layers, the Convolutional LSTM (ConvLSTM) [shi2015convolutional] can model the spatial and temporal data structures simultaneously by explicitly encoding the spatial information into tensors, and applying convolution operators to both the state-to-state and input-to-state recurrent transitions. It overcomes the limitation of vector-variate representations in standard LSTM, where the spatial correlation is not explicitly modeled. In a ConvLSTM layer, all of the input state , memory state , hidden state , and gated signals , , , at each timestamp are 3D tensors in . The first dimension is either the number of measurements (for input states) or the number of feature maps (otherwise), and the last two dimensions are the numbers of spatial rows and columns . ConvLSTM determines the future state of a certain cell in the grid based on the input frame and past states of its local neighbors:
(2) |
where
is the Sigmoid activation function,
and denote the convolution operator and the Hadamard product respectively. Like standard LSTM, the input gate , forget gate , output gate , and input-modulation gate control the information flow across the memory state . In this way, the gradient will be kept from quickly vanishing by being trapped in the memory cell. There are three points in the design of the ConvLSTM network that could be further improved.Challenge 1: For a stacked ConvLSTM network, the input sequence is fed into the bottom layer, and the output sequence is generated at the top one. With hidden states being delivered from the bottom up, spatial representations are encoded layer by layer. However, the memory states are merely updated along the arrow of time within each ConvLSTM layer, being less dependent on the hierarchical visual features at other layers. Thus, the first layer at the current timestamp may largely ignore what had been memorized by the top layer at the previous timestamp. In this paper, we solve this problem by introducing an interlayer memory cell and show that it can significantly benefit the predictive learning of spatiotemporal data.
Challenge 2: In each ConvLSTM layer, the output hidden state is dependent on the memory state and the output gate , which means that the memory cell is forced to cope with long-term and short-term dynamics simultaneously. Therefore, the modeling capability of the memory state transition function to the complex spatiotemporal variations may greatly limit the overall performance of the predictive model.
Challenge 3: The ConvLSTM adopts the sequence-to-sequence RNN architecture [Sutskever2014Sequence], which might be harmed by the training discrepancy between the encoder and the forecaster (w.r.t. input/output sequences) in the predictive learning context.
To solve the above problems, we propose the Predictive Recurrent Neural Network (PredRNN), which consists of a new model and a new training scheme for generic sequence-to-sequence architectures. In this section, we first tackle Challenge 1 by presenting the spatiotemporal memory flow, which improves the state-to-state transition functions of the original ConvLSTM network and facilitates the modeling of short-term dynamics. Next, we give detailed descriptions about the Spatiotemporal LSTM (ST-LSTM), which serves as the building block of PredRNN. It provides a solution to Challenge 2 with a pair of memory cells that are jointly learned and explicitly decoupled to cover different aspects (e.g., long- and short-term dynamics) of spatiotemporal variations. Last but not least, we present the improved training method of PredRNN, which partly solves Challenge 3 by reducing the training discrepancy of the sequence-to-sequence model and greatly encourages the predictive network to learn non-Markovian dynamics from longer input contexts.
PredRNN employs a stack of convolutional recurrent units to learn unified representations of the spatial correlations and temporal dynamics from the input sequence, and then transforms these features back to the data space to predict future spatiotemporal frames. We initially adopt the original ConvLSTM layer as the basic building block of PredRNN and apply a novel memory state transition method between the recurrent nodes. In the original ConvLSTM network [shi2015convolutional] shown in Fig. 1 (right), the memory states are constrained inside individual recurrent layers and updated along the arrow of time. Only the hidden states
can be delivered from the observation to the final prediction. This temporal memory flow has been widely used in supervised learning tasks such as video classification, where the hidden representations are more and more abstract and class-specific, starting at the bottom. However, in predictive learning, the input frames and the expected outputs share the same data space, i.e., they may have very close data distributions in the spatial domain and very related underlying dynamics in the temporal domain. Therefore, it becomes important to make the predictions more effectively dependent on the memory representations learned at different levels of recurrent layers. If we want to frame the future vividly, we need both high-level understandings of global motions and more detailed memories of the subtle variations in the input sequence.
Considering that the memory cell of ConvLSTM can latch the gradients^{3}^{3}3Like the standard LSTM, the memory cell of ConvLSTM was originally designed to alleviate the gradient vanishing problem during training. and thereby store valuable information across recurrent nodes, we improve the above horizontal memory flow by updating the memory state in the zigzag direction, so that it can better deliver knowledge from input to output. We name this new memory state transition method the spatiotemporal memory flow, and show its transition direction by the orange arrows in Fig. 1 (left). In this way, the memory states in different ConvLSTM layers are on longer independent, and all nodes in the entire recurrent network jointly maintain a memory bank denoted by . We show here the key equations of the spatiotemporal memory flow using ConvLSTM as the building block:
(3) |
The input gate, input modulation gate, forget gate, and output gate are no longer dependent on the hidden state and the temporal memory state from the previous timestamp at the same layer. Instead, they rely on the hidden state and the spatiotemporal memory state supplied by the previous layer at current timestamp (see Fig. 1 (left)). In particular, the bottom recurrent unit () receives state values from the top layer at the previous timestamp: , . The four layers in this figure have different sets of convolutional parameters regarding the input-to-state and state-to-state transitions. They thereby read and update the values of the memory state based on their individual understandings of the spatiotemporal dynamics as the information flows through the current node. Note that we replace the notation for memory state from to to emphasize that it flows in the zigzag direction in PredRNN, instead of the horizontal direction in standard recurrent networks. Different from ConvLSTM that uses Hadamard product for the computation of the gates, we adopt convolution operators for finer-grained memory transitions. An additional benefit of this change is that the learned PredRNN can be deployed directly on the input sequence of different spatial resolutions.
The spatiotemporal memory flow provides a recurrent highway for the hierarchical visual representations that can reduce the loss of information from the bottom layers to the top of the network. Besides, by introducing more nonlinear gated neurons within temporal neighborhoods, it expands the deep transition path of hidden states and enhances the modeling capability of the network for short-term dynamics. In contrast, the original ConvLSTM network requires larger convolution kernels for input-to-state and state-to-state transitions in order to capture faster motion, resulting in unnecessary increase in model parameters.
Alternatively, we can understand the spatiotemporal memory flow from the perspective of memory networks [graves2014neural, sukhbaatar2015end, graves2016hybrid]. Here, the proposed spatiotemporal memory state can be viewed as a simple version of the so-called external memory with continuous memory representations, and the stacked recurrent layers can be viewed as multiple computational steps. The layer-wise forget gates, input gates, and input modulation gates respectively determine the read and write functions, as well as the content to be written. One advantage of the classic memory networks is to capture long-term structure (even with multiple temporal hops) within sequences. Our spatiotemporal memory flow is analogous to their mechanism, as it enables our model to consider different levels of video representations before making a prediction.
As described previously, the spatiotemporal memory state is updated first upwards across layers then forwards to the next timestamp. It stretches the state transition path across time and adds extra neurons between horizontally adjacent nodes at the same level of the network. It thus enables the network to learn complex non-linear transition functions of short-term motions. However, this deep-in-time architecture may also bring in the gradient vanishing problem. The roundabout memory transition path may make it difficult to capture long-term temporal dependencies. In pursuit of both short-term recurrence depth and long-term coherence, we introduce a double-flow memory transition mechanism that combines the original memory flow of and the new spatiotemporal memory flow of . It derives a novel recurrent unit named Spatiotemporal LSTM (ST-LSTM).
In Fig. 2, we present the final PredRNN model by taking ST-LSTM as the building block in place of ConvLSTM. There are two memory states: is the temporal memory that transits within each ST-LSTM layer from the previous node at to the current timestamp. is the spatiotemporal memory, which transits vertically to the current node from the lower ST-LSTM layer at the same timestamp. In particular, we assign to for the bottom ST-LSTM where . We adopt the original gated structures for from the standard LSTM, and construct another set of input gate , forget gate , and input modulation gate for , because the memory transition functions in distinct directions are supposed to be controlled by different signals. Specifically, in each ST-LSTM, we compute:
(4) |
The final hidden states of each node are dependent on a combination of the horizontal and zigzag memory states: we concatenate the and , and then apply the convolutional layer for dimensionality reduction, which makes the hidden state have the same dimensions as the memory states. In addition to simple concatenation, pairs of memory states are finally twisted and unified by output gates with bidirectional control signals (horizontal and vertical), resulting in a comprehensive modeling of long-term and short-term dynamics. This dual-memory mechanism benefits from the compositional advantage with distributed representations [hinton1984distributed, bengio2000taking]. Intuitively, due to different mechanisms of recurrent state transitions, the pair of memory cells in ST-LSTM are expected to deal with different aspects of spatiotemporal variations:
introduces a deeper transition path that zigzags across ST-LSTM layers. With the forget gate and the input-related modules , it improves the ability to model complex short-term dynamics from one timestamp to the next, and allows to transit adaptively at different rates.
operates on a slower timescale. It provides as shorter gradient path between distant hidden states, thus facilitating the learning process of long-term dependencies.
However, as shown in Fig. 3, we visualized the increments of memory states at each timestamp (i.e., and ) using t-SNE [maaten2008visualizing], and found that they were not automatically separated as expected. In fact, the two memory states are often so intertwined that they are difficult to decouple spontaneously through their respective network architectures. To some extent, it results in the inefficient utilization of network parameters.
Building upon the first version of ST-LSTM [wang2017predrnn], we present a new decoupling loss, as shown in Fig. 2 (right), that keeps and from learning redundant features. In each ST-LSTM unit, we first add the convolutional layer upon the increments of and at every timestamp, and leverage a new decoupling loss to explicitly extend the distance between them in latent space. By this means, different memory states are trained to focus on different aspects of spatiotemporal variations. The overall memory decoupling method can be formulated as follows:
(5) |
where denotes convolutions shared by all ST-LSTM layers; and are respectively dot product and norm of flattened feature maps, which are calculated channel by channel. At training time, the increments of memory states, and , are derived from Eq. (4). Notably, the new parameters are only used at training time and are removed from the entire model at inference. That is, there is no increase in model size compared to the previous version of ST-LSTM [wang2017predrnn]
. By defining the decoupling loss with the cosine similarity, our approach encourages the increments of the two memory states to be orthogonal at any timestamp. It unleashes the respective strengths of
and for long-term and short-term dynamic modeling. As shown in Fig. 3, at test time, and can be easily separated by t-SNE techniques.The decoupling loss is largely inspired by the theoretical evidence that using reasonable inductive bias to construct distributed representations can bring great performance improvement if it matches properties of the underlying data distribution [pascanu2013number]. There is a similar idea in ensemble learning that generally, to get a good ensemble, the base learners should be as more accurate as possible, and as more diverse as possible [krogh1995neural]. The diversity of base learners can be enhanced in different ways, such as sub-sampling the training examples, manipulating the input attributes, and employing randomness into training procedures [zhou2009ensemble]. Similarly, in adversarial learning, the experimental results showed that diversifying features of multiple discriminators can effectively improve the discriminability of the ensemble model, making it easy to discover the discrepancy between generated data and real data from a diverse perspective [jin2020multi]. It is worth noting that the proposed memory decoupling method has not been used by existing ensemble learning algorithms, though it is inspired by the general idea of enhancing the diversity of base learners. We use it to diversify the distance of pairs of memory states of recurrent networks, intuitively in pursuit of a more disentangled representation of long-term and short-term dynamics in predictive learning.
The final PredRNN model is trained end-to-end in an unsupervised manner. The overall training objective is:
(6) |
where the first term is the frame reconstruction loss that works on the network output at each timestamp, the second term is the memory decoupling regularization from Eq. (5), and is a hyper-parameter between them.
We propose an improved scheduled sampling approach to facilitate the training process of PredRNN, which consists of two components, respectively used in the sequence encoding part and the forecasting part of PredRNN. For the encoding phase, we propose a new curriculum learning strategy named Reverse Scheduled Sampling, which has the following benefits to the training procedure:
It enforces the encoder to learn temporal dynamics from longer periods in ;
It reduces the discrepancy between the training processes of the encoder and the forecaster.
As for the forecasting phase, we follow the common practice of scheduled sampling [bengio2015scheduled] to bridge the gap between training and inference of the forecaster.
Encoder-forecaster training discrepancy. As shown in Fig. 4, we follow the common practice [Finn2016Unsupervised, denton2018stochastic] to train a unified recurrent network, which can be denoted by , for both sequence encoding and forecasting. The intuition is that in pursuit of a perfect predictive model, the input frames and the output frames
are supposed to be modeled in the same space. During inference, the model generates the full sequence by generating one frame at a time, and advancing time by one step. The temporal dynamics is estimated with a combination of previous states, which can be denoted by
. Here, we leave out the layer index for simplicity, and these tensors are initialized with zero values. At each timestamp , is computed by as follows:(7) |
where the main difference between encoder () and forecaster () is whether to use the previous true frame or an estimate coming from the model itself. We argue that, although widely employed by previous literature, using the same set of parameters may result in an encoder-forecaster training discrepancy. In the encoding phase, the model is mainly learned to make the one-step prediction, because tends to be a more informative signal than . The availability of the true current frame may largely spare the efforts for PredRNN in digging deeper into the complex, non-Markovian properties of previous observations. But in the forecasting phase, the model has to depend more on , because new observations are now inaccessible and only the memory states can store trustworthy long-term dynamics of spatiotemporal variations. Obviously, the model is trained under a more challenging task for the forecasting part than the encoding part. If the model is well pre-trained, in the sense that the distribution of generated data is close to that of the true data, it may be fine to train by implementing forward propagation as Eq. (7). But if it is not, especially in the early stage of training, the encoder-forecaster gap would lead to an ineffective optimization of , hampering the model from learning long-term dynamics from .
Scheduled sampling. The original scheduled sampling approach [bengio2015scheduled] can partly address the dependency between the training of encoder and forecaster and improve the recurrent models. A typical training procedure with scheduled sampling in the spatiotemporal prediction context [Finn2016Unsupervised] computes:
(8) |
where denotes a gradual change over the training time from taking the true frame as the input of the forecaster to taking the previous estimation instead. The use of or is randomly decided with a sampling mechanism. We flip a coin to use with probability , or with probability , where is the index of training iterations. is a function of and decreases linearly or exponentially in a similar manner used to gradually reduce the learning rate in modern stochastic gradient descent algorithms. When , at the very beginning of training procedure, the forecaster is trained exactly as the encoder, while when it is down to , the entire model is trained in the same setting as the inference stage. We may conclude that this approach has two benefits. First, compared with the above-mentioned training method, it enables the model to perceive the distribution of during training. Second, it can partly bridge the encoder-forecaster gap from the forecaster side: intuitively, at the beginning of training, sampling from true frames enables the parameters to be optimized consistently over the two parts of the sequence. But still, an alternative way to bridge the encoder-forecaster gap is to modify the training approach regarding the encoding side of the model. Moreover, the original scheduled sampling method does not improve the ineffective training of the encoder in learning long-term dynamics.
Reverse scheduled sampling. To enforce the recurrent model to learn more temporal dynamics from consecutive input frames, we propose an alternative scheduled sampling method, which mainly improves the training strategy in the encoding phase:
(9) |
where indicates a curriculum learning schedule that goes from one to the other, and is the combination of previous states . In the encoding stage, there is a probability () of sampling the true frame or a corresponding probability () of sampling the previous estimation . For the entire encoder, a sequence of sampling outcomes can be seen as a Bernoulli process with independent trials. Different from above, is an increasing function starting from and increasing to of the training iteration , which has the following forms:
Linear increase: ;
Exponential increase: ;
Sigmoid increase: ,
where denote the increasing factors and
denotes the starting point of the sigmoid function. These hyper-parameters jointly decide the increasing curve of
. Examples of such schemes are shown in Fig. 5 in the red curves. We name this approach the Reverse Scheduled Sampling, as the encoder is trained with a progressively simplified curriculum. It gradually changes from generating multi-step future frames, which is challenging due to the absence of some historical observations, to making one-step predictions, just as the encoder does at test time. Intuitively, this method encourages the model to extract long-term, non-Markovian dynamics from the input sequence. Besides, it provides an alternative solution to the encoder-forecaster training discrepancy at the encoder end. At the early stage of training, both parts of the sequence-to-sequence model are optimized consistently under similar settings (i.e., with a higher probability of using the previous estimation). Also shown in Fig. 5 are two feasible strategies to integrate the reverse scheduled sampling with the original scheduled sampling. For simplicity, we make (the sampling probability of the forecasting part) decay linearly, although other scheduled sampling schemes could be employed (such as an exponential decay). The biggest difference between the two strategies lies in whether the variation range of the sampling probabilities of encoder and forecaster are close at the early stage of training. Empirically, the second strategy performs slightly better (Table (IV) in Section 5.3). It fits well with our expectation: using similar sampling probabilities for the encoding and forecasting phases at the beginning of training epochs does indeed reduce the encoder-forecaster training discrepancy as shown in Eq. (
7).In this section, we evaluate the ability of PredRNN for multi-step future prediction on three spatiotemporal datasets, including a synthetic dataset of flying digits, a human motion dataset, and a radar echo dataset for precipitation forecasting. The code is available at https://github.com/thuml/predrnn-pytorch.
We use the ConvLSTM network [shi2015convolutional] as the primary baseline model of PredRNN, and include more advanced video prediction models for further comparison, including Conv-TT-LSTM [su2020convolutional], VPN [Kalchbrenner2016Video], MCnet [Villegas2017Decomposing], TrajGRU [shi2017deep], and MIM [wang2019memory]. For simplicity, we refer to different versions of PredRNN as follows:
PredRNN--Only. This model improves the ConvLSTM network with the spatiotemporal memory flow (), whose architecture is shown in Fig. 1 (left).
PredRNN [wang2017predrnn]. This model uses ST-LSTMs as building blocks, but compared with PredRNN-V2, it is not trained with memory decoupling or reverse scheduled sampling.
PredRNN-V2. This is the final proposed model that improves the training process of the original PredRNN with memory decoupling and reverse scheduled sampling.
We use the ADAM optimizer [Kingma2014Adam] to train the models and use a mini-batch of sequences at each training iteration. Unless otherwise specified, we set the learning rate to and stop the training process after iterations. We observe that the number of channels of hidden states has a strong impact on the final performance. We typically use four ST-LSTM layers in PredRNN with channels for each hidden state and memory state to strike a balance between the prediction quality and training efficiency. In each ST-LSTM layer, we set the size of convolutional kernels to . Notably, we use a similar number of model parameters for the compared ConvLSTM network. The entire training time of PredRNN is around hours for Moving MNIST with a TITANX GPU. For other datasets, the learning processes are similar to that of Moving MNIST, except that the maximum numbers of iterations are adjusted according to the size of the training sets.
Model | MSE () | SSIM () | LPIPS () |
---|---|---|---|
FC-LSTM [srivastava2015unsupervised] | 118.3 | 0.690 | - |
ConvLSTM [shi2015convolutional] | 103.3 | 0.707 | 0.156 |
CDNA [Finn2016Unsupervised] | 97.4 | 0.721 | - |
Conv-TT-LSTM [su2020convolutional] | 64.3 | 0.846 | 0.133 |
VPN Baseline [Kalchbrenner2016Video] | 64.1 | 0.870 | - |
MIM [wang2019memory] | 52.0 | 0.874 | 0.079 |
PredRNN--Only | 74.0 | 0.851 | 0.109 |
PredRNN | 56.8 | 0.867 | 0.107 |
PredRNN-V2 | 48.4 | 0.891 | 0.071 |
Dataset. We generate the Moving MNIST sequences in the same way as the previous work [srivastava2015unsupervised, shi2015convolutional]. Each sequence consists of consecutive frames, including historical observations and future predictions. At the beginning of each sequence, we sample handwritten digits randomly from the static MNIST training set and place initially at random locations of a grid of an image. We randomly initialize the speed of the digits and make them move at a constant speed throughout the sequence. The digits bounce off the edges of the image at a fixed, predictable angle. We construct a training set with sequences and a test set with sequences, in which the digits are respectively sampled from the training/test set of the static MNIST dataset, such that the spatial information of the test data would not be overly exposed during the training process. In this task, although future trajectories are certain and only dependent on historical observations, it is still challenging to infer the underlying dynamics and the initial velocities. Moreover, the frequent occlusions of moving digits lead to complex and short-term changes of local spatial information, which brings more difficulties to spatiotemporal prediction.
Main results.
We adopt evaluation metrics that were widely used by previous methods: the Mean Squared Error (MSE), the Structural Similarity Index Measure (SSIM)
[Wang2004Image], and the Learned Perceptual Image Patch Similarity (LPIPS) [zhang2018unreasonable]. The difference between these metrics is that MSE estimates the absolute pixel-wise errors, SSIM measures the similarity of structural information within the spatial neighborhoods, while LPIPS is based on deep features and aligns better to human perceptions. Table (
I) shows the results of all compared models averaged per frame. Fig. 6 provides the corresponding frame-wise comparisons. The final PredRNN model significantly outperforms all previous approaches. With the proposed spatiotemporal memory flow, the PredRNN--only model reduces the per-frame MSE of the ConvLSTM baseline from down to . By using the ST-LSTM in place of the ConvLSTM unit, our model further reduces the MSE down to . Finally, the employment of the memory decoupling and the reverse scheduled sampling techniques brings another improvement in MSE (from to ).Model | MSE () |
---|---|
PredRNN | 56.8 |
w/o the bottom-up memory flow | 57.6 |
w/o the top-down memory flow | 59.7 |
Qualitative comparisons. Fig. 7 shows two examples randomly sampled from the test set, where most of the frames produced by the compared models are severely blurred, especially for long-term predictions. In contrast, PredRNN produces clearer images, which means it can be more certain of future variations due to its stronger long-term modeling capabilities. When we look closely, we can see that the most challenging issues on this dataset are to make accurate predictions of the future trajectories of moving digits and to maintain the correct shape of each digit after occlusions (in the second example, the compared VPN model incorrectly predicts the digit as after the occlusion of and ). In both cases, the original PredRNN and the newly proposed PredRNN-V2 progressively improve the quality of the prediction results.
Ablation studies on the spatiotemporal memory flow. For a detailed analysis of the spatiotemporal memory flow, we compare the performance of (a) removing the bottom-up transition path of PredRNN from to with (b) removing the top-down transition path from to . The results in Table (II) show that the top-down transition path contributes more to the final performance. In Fig. 8 (left), we compare the normalized gradient of the hidden states
at the first ST-LSTM layer with respect to the loss function at the last timestamp:
, . It shows that the full spatiotemporal memory flow helps to alleviate the vanishing gradient problem, being especially good at learning the long-term information from the very beginning of the sequence (Note that the first frame usually contains clearer appearance information, while there are often occlusions of digits in the middle of the input sequence).Contributions of memory decoupling. To show that memory decoupling facilitates both long-term and short-term dependencies, as shown in Fig. 9, we make the pixel intensity of the images change regularly or irregularly over time. Thanks to the decoupled memory cells of ST-LSTMs, our approach can respond to sudden changes more rapidly and adapt to video dynamics at different timescales. In Table (III), we use the MIM model [wang2019memory], which is also based on the ST-LSTM unit, to show the generality of the memory decoupling loss to different network backbones.
Base model | W/o memory decoupling | W/ memory decoupling |
---|---|---|
MIM [wang2019memory] | 52.0 | 47.9 |
PredRNN | 56.8 | 51.1 |
Contributions of the reverse scheduled sampling. We compare different combinations of the original and the reverse scheduled sampling techniques. As shown in Table (IV), the second strategy in Fig. 5 with an exponentially increased performs best, in the sense that the encoder-forecaster discrepancy can be effectively reduced by keeping their probabilities of sampling the true context frames close to each other in the early stage of training. To further demonstrate that the reverse scheduled sampling contributes to learning long-term dynamics, we perform another gradient analysis in Fig. 8 (right). We evaluate the gradients of the encoder’s hidden states with respect to the loss functions at different output timestamps, and average the results over the entire input sequence: , . The normalized gradient curves show that the context information can be encoded more effectively by using the reverse scheduled sampling.
Method | Mode | MSE () | ||
---|---|---|---|---|
PredRNN | - | - | - | 57.3 |
+ Scheduled sampling [bengio2015scheduled] | - | - | - | 56.8 |
+ 1st strategy in Fig. 5 | 0.0 | 1.0 | Linear | 51.8 |
Sigmoid | 53.9 | |||
Exponential | 51.8 | |||
+ 2nd strategy in Fig. 5 | 0.5 | 1.0 | Linear | 51.9 |
Sigmoid | 50.9 | |||
Exponential | 50.6 |
Dataset. The KTH action dataset [Sch2004Recognizing] contains types of human actions, i.e., walking, jogging, running, boxing, hand-waving, and hand-clapping, performed by persons in different scenarios. The videos last seconds on average and were taken against fairly uniform backgrounds with a static camera in a frame rate of FPS. To make the results comparable, we adopt the training/testing protocol (persons - for training, and persons - for testing) from [Villegas2017Decomposing] and resize the video frames into a resolution of . All compared models are trained across the action categories by predicting future frames from context frames. At test time, we expand the prediction horizon to timestamps into the future. We manually select reasonable video sequences to ensure that there is always someone in the first frame. We obtain a training set of sequences, and a test set of sequences.
Model | PSNR () | SSIM () | LPIPS () |
---|---|---|---|
ConvLSTM [shi2015convolutional] | 23.58 | 0.712 | 0.231 |
MCnet + Residual [Villegas2017Decomposing] | 26.29 | 0.806 | - |
TrajGRU [shi2017deep] | 26.97 | 0.790 | - |
DFN [de2016dynamic] | 27.26 | 0.794 | - |
Conv-TT-LSTM [su2020convolutional] | 27.62 | 0.815 | 0.196 |
PredRNN | 27.55 | 0.839 | 0.204 |
PredRNN-V2 | 28.37 | 0.838 | 0.139 |
Results.
We adopt the Peak Signal to Noise Ratio (PSNR) from the previous literature as the third evaluation metric, in addition to SSIM and LPIPS. Like MSE, PSNR also estimates the pixel-level similarity of two images (higher is better). The evaluation results of different methods are shown in Table (
V), and the corresponding frame-wise comparisons are shown in Fig. 10, from which we have two observations: First, our models show significant improvements in both short-term and long-term predictions over the ConvLSTM network. Second, with the newly proposed memory decoupling and reverse scheduled sampling, PredRNN-V2 improves the conference version by a large margin in LPIPS (from 0.204 to 0.139). As mentioned above, LPIPS is more sensitive to perceptual human judgments, indicating that PredRNN-V2 has a stronger ability to generate high-fidelity images. In accordance with these results, we can see from the visual examples in Fig. 11 that our approaches (especially PredRNN-V2) obtain more accurate predictions of future movement and body details. The increase in image sharpness is an evidence that PredRNN-V2 is more certain about the future. By decoupling memory states, it learns to capture the complex spatiotemporal variations from different timescales.The accurate prediction of the movement of radar echoes in the next - hours is the foundation of precipitation forecasting. It is a challenging task because the echoes tend to have non-rigid shapes and may move, accumulate or dissipate rapidly due to complex atmospheric physics, which makes it important to learn the dynamics in a unified spatiotemporal feature space.
Model | MSE () | CSI-30 () | CSI-40 () | CSI-50 () |
---|---|---|---|---|
TrajGRU [shi2017deep] | 68.3 | 0.309 | 0.266 | 0.211 |
ConvLSTM [shi2015convolutional] | 63.7 | 0.381 | 0.340 | 0.286 |
MIM [wang2019memory] | 39.3 | 0.451 | 0.418 | 0.372 |
PredRNN | 39.1 | 0.455 | 0.417 | 0.358 |
PredRNN-V2 | 36.4 | 0.462 | 0.425 | 0.378 |
Dataset. We collect the radar echo dataset by following the data handling method described in the work from Shi et al. [shi2015convolutional]. Our dataset consists of consecutive radar observations recorded every minutes at Guangzhou, China. For data preprocessing, we first map the radar observations to pixel values and represent them as gray-scale images (slightly different from the conference version), and then slice the sequential radar maps with a sliding window and obtain a total number of sequences. Each sequence contains input frames and output frames, covering the historical data for the past hour and that for the future hour. We use sequences for training and leave the rest of them for model evaluation.
Results. In Table (VI), we compare PredRNN with three existing methods that have been shown effective for precipitation forecasting. In addition to MSE, following a common practice, we also evaluate the predicted radar maps with the Critical Success Index (CSI). Concretely, we first transform the pixel values back to echo intensities in dBZ, and then take , and dBZ as thresholds to compute: , where “hits” indicates the true positive, “misses” indicates the false positive, and “false_alarms” is the false negative. From Table (VI), PredRNN consistently achieves the best performance over all CSI thresholds. Further, we visualize the predicted radar frames by mapping them into RGB space. The results are shown in Fig. 12. Note that areas with echo intensities over dBZ tend to have severe weather phenomena and should be considered carefully.
In this paper, we propose a recurrent network named PredRNN for spatiotemporal predictive learning. With a new Spatiotemporal LSTM unit, PredRNN models the short-term deformations in spatial appearance and the long-term dynamics over multiple frames simultaneously. The core of the Spatiotemporal LSTM is a zigzag memory flow that propagates across stacked recurrent layers vertically and through all temporal states horizontally, which enables the interactions of the hierarchical memory representations at different levels of PredRNN. Building upon the conference version of this paper, we introduce a new method to decouple the twisted memory states along the horizontal and the zigzag pathway of recurrent state transitions. It enables the model to benefit from learning distributed representations that could cover different aspects of spatiotemporal variations. Furthermore, we also propose a new curriculum learning strategy named reverse scheduled sampling, which enforces the encoding part of PredRNN to learn temporal dynamics from longer periods of the input sequence. Another benefit of reverse scheduled sampling is to reduce the training discrepancy between the encoding part and the forecasting part. Our approach achieves state-of-the-art performance on multiple datasets, including both synthetic and natural spatiotemporal sequences.
This work was supported by the National Key R&D Program of China (2020AAA0109201), NSFC grants 62022050, 61772299, 62021002, 71690231, Beijing Nova Program (Z201100006820041), and CAAI-Huawei MindSpore Open Fund. The work was in part done when Y. Wang was a student at Tsinghua University. Y. Wang and H. Wu contributed equally to this work.