Temporal Difference Variational Auto-Encoder

06/08/2018 ∙ by Karol Gregor, et al. ∙ Google 0

One motivation for learning generative models of environments is to use them as simulators for model-based reinforcement learning. Yet, it is intuitively clear that when time horizons are long, rolling out single step transitions is inefficient and often prohibitive. In this paper, we propose a generative model that learns state representations containing explicit beliefs about states several time steps in the future and that can be rolled out directly in these states without executing single step transitions. The model is trained on pairs of temporally separated time points, using an analogue of temporal difference learning used in reinforcement learning, taking the belief about possible futures at one time point as a bootstrap for training the belief at an earlier time. While we focus purely on the study of the model rather than its use in reinforcement learning, the model architecture we design respects agents' constraints as it builds the representation online.



There are no comments yet.


page 5

page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Generative models of sequential data have received a lot of attention, due to their wide applicability in domains such as speech synthesis (van den Oord et al., 2016a, 2017), neural translation (Bahdanau et al., 2014), image captioning (Xu et al., 2015), and many others. Different application domains will often have different requirements (e.g. long term coherence, sample quality, abstraction learning, etc.), which in turn will drive the choice of the architecture and training algorithm.

Of particular interest to this paper is the problem of reinforcement learning in partially observed environments, where, in order to act and explore optimally, agents need to build a representation of the uncertainty about the world, computed from the information they have gathered so far. While an agent endowed with memory could in principle learn such a representation implicitly through model-free reinforcement learning, in many situations the reinforcement signal may be too weak to quickly learn such a representation in a way which would generalize to a collection of tasks.

Furthermore, in order to plan in a model-based fashion, an agent needs to be able to imagine distant futures which are consistent with the agent’s past. In many situations however, planning step-by-step is not a cognitively or computationally realistic approach.

To successfully address an application such as the above, we argue that a model of the agent’s experience should exhibit the following properties:

  • [leftmargin=0.5cm]

  • The model should learn an abstract state representation of the data and be capable of making predictions at the state level, not just the observation level.

  • The model should learn a belief state, i.e. a deterministic, coded representation of the filtering posterior of the state given all the observations up to a given time. A belief state contains all the information an agent has about the state of the world and thus about how to act optimally.

  • The model should exhibit temporal abstraction

    , both by making ‘jumpy’ predictions (predictions several time steps into the future), and by being able to learn from temporally separated time points without backpropagating through the entire time interval.

To our knowledge, no model in the literature meets these requirements. In this paper, we develop a new model and associated training algorithm, called Temporal Difference Variational Auto-Encoder (TD-VAE), which meets all of the above requirements. We first develop TD-VAE in the sequential, non-jumpy case, by using a modified evidence lower bound (ELBO) for stochastic state space models (Krishnan et al., 2015; Fraccaro et al., 2016; Buesing et al., 2018) which relies on jointly training a filtering posterior and a local smoothing posterior. We demonstrate that on a simple task, this new inference network and associated lower bound lead to improved likelihood compared to methods classically used to train deep state-space models.

Following the intuition given by the sequential TD-VAE, we develop the full TD-VAE model, which learns from temporally extended data by making jumpy predictions into the future. We show it can be used to train consistent jumpy simulators of complex 3D environments. Finally, we illustrate how training a filtering a posterior leads to the computation of a neural belief state with good representation of the uncertainty on the state of the environment.

2 Model desiderata

2.1 Construction of a latent state-space

Autoregressive models. One of the simplest way to model sequential data

is to use the chain rule to decompose the joint sequence likelihood as a product of conditional probabilities, i.e. 

. This formula can be used to train an autoregressive model of data, by combining an RNN which aggregates information from the past (recursively computing an internal state

) with a conditional generative model which can score the data given the context . This idea is used in handwriting synthesis (Graves, 2013)

, density estimation

(Uria et al., 2016), image synthesis (van den Oord et al., 2016b), audio synthesis (van den Oord et al., 2017), video synthesis (Kalchbrenner et al., 2016), generative recall tasks (Gemici et al., 2017), and environment modeling (Oh et al., 2015; Chiappa et al., 2017).

While these models are conceptually simple and easy to train, one potential weakness is that they only make predictions in the original observation space, and don’t learn a compressed representation of data. As a result, these models tend to be computationally heavy (for video prediction, they constantly decode and re-encode single video frames). Furthermore, the model can be computationally unstable at test time since it is trained as a next step model (the RNN encoding real data), but at test time it feeds back its prediction into the RNN. Various methods have been used to alleviate this issue (Bengio et al., 2015; Lamb et al., 2016; Goyal et al., 2017; Amos et al., 2018).

State-space models. An alternative to autoregressive models are models which operate on a higher level of abstraction, and use latent variables to model stochastic transitions between states (grounded by observation-level predictions). This enables to sample state-to-state transitions only, without needing to render the observations, which can be faster and more conceptually appealing. They generally consist of decoder or prior networks, which detail the generative process of states and observations, and encoder or posterior networks, which estimate the distribution of latents given the observed data. There is a large amount of recent work on these type of models, which differ in the precise wiring of model components (Bayer & Osendorfer, 2014; Chung et al., 2015; Krishnan et al., 2015; Archer et al., 2015; Fraccaro et al., 2016; Liu et al., 2017; Serban et al., 2017; Buesing et al., 2018; Lee et al., 2018; Ha & Schmidhuber, 2018).

Let be a state sequence and an observation sequence. We assume a general form of state-space model, where the joint state and observation likelihood can be written as .111For notational simplicity, . Also note the conditional distributions could be very complex, using additional latent variables, flow models, or implicit models (for instance, if a deterministic RNN with stochastic inputs is used in the decoder). These models are commonly trained with a VAE-inspired bound, by computing a posterior over the states given the observations. Often, the posterior is decomposed autoregressively: , where is a function of for filtering posteriors or the entire sequence for smoothing posteriors. This leads to the following lower bound:


2.2 Online creation of belief state.

A key feature of sequential models of data is that they allow to reason about the conditional distribution of the future given the past: . For reinforcement learning in partially observed environments, this distribution governs the distribution of returns given past observations, and as such, it is sufficient to derive the optimal policy. For generative sequence modeling, it enables conditional generation of data given a context sequence. For this reason, it is desirable to compute sufficient statistics of the future given the past, which allow to rewrite the conditional distribution as . For an autoregressive model as described in section 2.1, the internal RNN state can immediately be identified as the desired sufficient statistics . However, for the reasons mentioned in the previous section, we would like to identify an equivalent quantity for a state-space model.

For a state-space model, the filtering distribution , also known as the belief state in reinforcement learning, is sufficient to compute the conditional future distribution, due to the Markov assumption underlying the state-space model and the following derivation:


Thus, if we train a network that extracts a code from so that , would contain all the information about the state of the world the agent has, and would effectively form a neural belief state, i.e. a code fully characterizing the filtering distribution.

Classical training of state-space model does not compute a belief state: by computing a joint, autoregressive posterior , some of the uncertainty about the marginal posterior of may be ‘leaked’ in the sample . Since that sample is stochastic, to obtain all information from about , we would need to re-sample , which would in turn require re-sampling all the way to .

While the notion of a belief state itself and its connection to optimal policies in POMDPs is well known (Astrom, 1965; Kaelbling et al., 1998; Hauskrecht, 2000)

, it has often been restricted to the tabular case (Markov chain), and little work investigates computing belief states for learned deep models. A notable exception is

(Igl et al., 2018), which uses a neural form of particle filtering, and represents the belief state more explicitly as a weighted collection of particles. Related to our definition of belief states as sufficient statistics is the notion of predictive state representations (PSRs) (Littman & Sutton, 2002); see also (Venkatraman et al., 2017) for a model that learns PSRs which, combined with a decoder, can predict future observations.

Our last requirement for the model is that of temporal abstraction. We postpone the discussion of this aspect until section 4.

3 Belief-state-based ELBO for sequential TD-VAE

In this section, we develop a sequential model that satisfies the requirements given in the previous section, namely (a) it constructs a latent state-space, and (b) it creates a online belief state. We consider an arbitrary state space model with joint latent and observable likelihood given by , and we aim to optimize the data likelihood . We begin by autoregressively decomposing the data likelihood as: . For a given , we evaluate the conditional likelihood by inferring over two latent states only: and , as they will naturally make belief states appear for times and :


Because of the Markov assumptions underlying the state-space model, we can simplify and decompose . Next, we choose to decompose as a belief over and a one-step smoothing distribution over : . We obtain the following belief-based ELBO for state-space models:


Both quantities and represent the belief state of the model at different times, so at this stage we approximate them with the same distribution , with representing the belief state code for . Similarly, we represent the smoothing posterior over as . We obtain the following loss:


We provide an intuition on the different terms of the ELBO in the next section.

4 TD-VAE and jumpy state modeling

The model derived in the previous section expresses a state model that describes how the state of the world evolves from one time step to the next. However, in many applications, the relevant timescale for planning may not be the one at which we receive observations and execute simple actions. Imagine for example planning for a trip abroad; the different steps involved (discussing travel options, choosing a destination, buying a ticket, packing a suitcase, going to the airport, and so on), all occur at vastly different time scales (potentially months in the future at the beginning of the trip, and days during the trip). Certainly, making a plan for this situation does not involve making second-by-second decisions. This suggests that we should look for models that can imagine future states directly, without going through all intermediate states.

Beyond planning, there are several other reasons that motivate modeling the future directly. First, training signal coming from the future can be stronger than small changes happening between time steps. Second, the behavior of the model should ideally be independent from the underlying temporal sub-sampling of the data, if the latter is an arbitrary choice. Third, jumpy predictions can be computationally efficient; when predicting several steps into the future, there may be some intervals where the prediction is either easy (e.g. a ball moving straight), or the prediction is complex but does not affect later time steps — which Neitz et al. (2018) call inconsequential chaos.

There is a number of research directions that consider temporal jumps. Koutnik et al. (2014) and Chung et al. (2016)

consider recurrent neural network with skip connections, making it easier to bridge distant timesteps.

Buesing et al. (2018) temporally sub-sample the data and build a jumpy model (for fixed jump size) of this data; but by doing so they also drop the information contained in the skipped observations. Neitz et al. (2018) and Jayaraman et al. (2018) predict sequences with variable time-skips, by choosing as target the most predictable future frames. They predict the observations directly without learning appropriate states, and only focus on nearly fully observed problems (and therefore do not need to learn a notion of belief state). For more general problems, this is a fundamental limitation, as even if one could in principle learn a jumpy observation model , it cannot be used recursively (feeding back to the RNN and predicting ). This is because does not capture the full state of the system and so we would be missing information from to to fully characterize what happens after time . In addition, might not be appropriate even as target, because some important information can only be extracted from a number of frames (potentially arbitrarily separated), such as a behavior of an agent.

4.1 The TD-VAE model

Motivated by the model derived in section 3, we extend sequential TD-VAE to exhibit time abstraction. We start from the same assumptions and architectural form: there exists a sequence of states from which we can predict the observations . A forward RNN encodes a belief state from past observations . The main difference is that, instead of relating information known at times and through the states and , we relate two distant time steps and through their respective states and , and we learn a jumpy, state-to-state model between and . Following equation 3, the negative loss for the TD-VAE model is:


To train this model, one should choose the distribution of times ; for instance, can be chosen uniformly from the sequence, and uniformly over some finite range ; other approaches could be investigated. Figure 1 describes in detail the computation flow of the model.

Figure 1: Diagram of TD-VAE. Follow the red panels for an explanation of the architecture. For succinctness, we use the notation to denote the decoder , to denote the transition distribution , for the smoothing distribution and for the belief distribution.

Finally, it would be desirable to model the world with different hierarchies of state, the higher-level states predicting the same-level or lower-level states, and ideally representing more invariant or abstract information. For this reason, we also develop stacked (hierarchical) version of TD-VAE, which uses several layers of latent states. Hierarchical TD-VAE is detailed in the appendix.

4.2 Intuition behind TD-VAE

In this section, we provide a more intuitive explanation behind the computation and loss of the model. Assume we want to predict a future time step from all the information we have up until time . All relevant information up until time (respectively ) has been compressed into a code (respectively ). We make an observation of the world222In RL, this observation may include the reward and previous action. at every time step , but posit the existence of a state which fully captures the full condition of the world at time .

Consider an agent at the current time . At that time, the agent can make a guess of what the state of the world is by sampling from its belief model . Because the state should entail the corresponding observation , the agent aims to maximize (first term of the loss), with a variational bottleneck penalty (second term of the loss) to prevent too much information from the current observation from being encoded into . Then follows the question ‘could the state of the world at time have been predicted from the state of the world at time ?’. In order to ascertain this, the agent must estimate the state of the world at time . By time , the agent has aggregated observations between and that are informative about the state of the world at time , which, together with the current guess of the state of the world , can be used to form an ex post guess of the state of the world. This is done by computing a smoothing distribution and drawing a corresponding sample . Having guessed states of the world and , the agent optimizes its predictive jumpy model of the world state (third term of the loss). Finally, it should attempt to see how predictable the revealed information was, or in other words, to assess whether the smoothing distribution could have been predicted from information only available at time (this is indirectly predicting from the state of knowledge at time - the problem we started with). The agent can do so by minimizing the KL between the smoothing distribution and the belief distribution at time : (fourth term of the loss). Summing all the losses described so far, we obtain the TD-VAE loss.

4.3 Connection with temporal-difference learning

In reinforcement learning, the state of an agent represents a belief about the sum of discounted rewards . In the classic setting, the agent only models the mean of this distribution represented by the value function or action dependent Q-function (Sutton & Barto, 1998). Recently in (Bellemare et al., 2017), a full distribution over has been considered. To estimate or at time , one does not usually wait to get all the rewards to compute . Instead, one uses an estimate at some future time as a bootstrap to estimate or (temporal difference).

In our case, the model expresses a belief about possible future states instead of the sum of discounted rewards. The model trains the belief at time using belief at some time in the future. It accomplishes this by (variationally) auto-encoding a sample of the future state into a sample , using the approximate posterior distribution and the decoding distribution . This auto-encoding mapping translates between states at and , forcing beliefs at the two time steps to be consistent. Sample forms the target for training the belief , which appears as a prior distribution over .

5 Experiments.

The first experiment uses sequential TD-VAE, which enables a direct comparison to related algorithms for training state-space models. Subsequent experiments use the full TD-VAE model.

5.1 Partially observed MiniPacman

We use a partially observed version of the MiniPacman environment (Racanière et al., 2017), shown in Figure 2. The agent (Pacman) navigates a maze, and tries to eat all the food while avoiding being eaten by a ghost. Pacman sees only a window around itself. To achieve a high score, the agent needs to form a belief state that captures memory of past experience (e.g. which parts of the maze have been visited) and uncertainty on the environment (e.g. where the ghost might be).

We evaluate the performance of sequential (non-jumpy) TD-VAE on the task of modeling a sequence of the agent’s observations. We compare it with two state-space models trained using the standard ELBO of equation 1:

  • [leftmargin=0.5cm]

  • A filtering model with encoder , where .

  • A mean-field model with encoder , where .

Figure 2 shows the ELBO and estimated negative log probability on a test set of MiniPacman sequences for each model. TD-VAE outperforms both baselines, whereas the mean-field model is the least well-performing. We note that is a belief state for the mean-field model, but not for the filtering model; the encoder of the latter explicitly depends on the previous latent state , hence is not its sufficient statistics. This comparison shows that naively restricting the encoder in order to obtain a belief state hurts the performance significantly; TD-VAE overcomes this difficulty.

ELBO (est.) Filtering model 0.0962 Mean-field model 0.1678 TD-VAE
Figure 2: MiniPacman. Left: A full frame from the game (size ). Pacman (green) is navigating the maze trying to eat all the food (blue) while being chased by a ghost (red). Top right: A sequence of observations, consisting of consecutive windows around Pacman. Bottom right: ELBO and estimated negative log probability on a test set of MiniPacman sequences. Lower is better. Log probability is estimated using importance sampling with the encoder as proposal.

5.2 Moving MNIST

In this experiment, we show that the model is able to learn the state and roll forward in jumps. We consider sequences of length 20 of images of MNIST digits. For each sequence, a random digit from the dataset is chosen, as well as the direction of movement (left or right). At each time step, the digit moves by one pixel in the chosen direction, as shown in Figure 3. We train the model with and separated by a random amount from the interval . We would like to see whether the model at a given time can roll out a simulated experience in time steps , , with , without considering the inputs in between these time points. Note that it is not sufficient to predict the future inputs , as they do not contain information about whether the digit moves left or right. We need to sample a state that contains this information.

Figure 3: Moving MNIST. Left: Rows are example input sequences. Right: Jumpy rollouts from the model. We see that the model is able to roll forward by skipping frames, keeping the correct digit and the direction of motion.

We roll out a sequence from the model as follows: (a) is computed by the aggregation recurrent network from observations up to time ; (b) a state is sampled from ; (c) a sequence of states is rolled out by repeatedly sampling starting with ; (d) each is decoded by , producing a sequence of frames. The resulting sequences are shown in Figure 3. We see that indeed the model can roll forward the samples in steps of more than one elementary time step (the sampled digits move by more than one pixel) and that it preserves the direction of motion, demonstrating that it rolls forward a state.

5.3 Noisy harmonic oscillator

We would like to demonstrate that the model can build a state even when little information is present in each observation, and that it can sample states far into the future. For this we consider a 1D sequence obtained from a noisy harmonic oscillator, as shown in Figure 4 (first and fourth rows). The frequencies, initial positions and initial velocities are chosen at random from some range. At every update, noise is added to the position and the velocity of the oscillator, but the energy is approximately preserved. The model observes a noisy version of the current position. Attempting to predict the input, which consists of one value, 100 time steps in the future would be uninformative; such a prediction wouldn’t reveal what the frequency or the magnitude of the signal is, and because the oscillator updates are noisy, the phase information would be nearly lost. Instead, we should try to predict as much as possible about the state, which consists of frequency, magnitude and position, and it is only the position that cannot be accurately predicted.

Figure 4: Skip-state prediction for 1D signal. The input is generated by a noisy harmonic oscillator. Rollouts consist of (a) a jumpy state transition with either or , followed by state transitions with . The model is able to create a state and predict it into the future, correctly predicting frequency and magnitude of the signal.

The aggregation RNN is an LSTM; we use a hierarchical TD-VAE with two layers, where the latent variables in the higher layer are sampled first, and their results are passed to the lower layer. The belief, smoothing and state-transition distributions are feed-forward networks, and the decoder simply extracts the first component from the of the first layer. We also feed the time interval into the smoothing and state-transition distributions. We train on sequences of length , with taking values chosen at random from with probability 0.8 and from with probability 0.2.

We analyze what the model has learned as follows. We pick time and sample . Then, we choose a time interval to skip, sample from the forward model to obtain at . To see the content of this state, we roll forward times with time step and plot the result, shown in Figure 4. We see that indeed the state is predicted correctly, containing the correct frequency and magnitude of the signal. We also see that the position (phase) is predicted well for and less accurately for (at which point the noisiness of the system makes it unpredictable).

Finally, we show that TD-VAE training can improve the quality of the belief state. For this experiment, the harmonic oscillator has a different frequency in each interval . The first three frequencies are chosen at random. The final frequency is chosen to be one fixed value if and another fixed value otherwise ( and are constants). In order to correctly model the signal in the final time interval, the model needs to learn the relation between and , store it over length of

steps, and apply it over a number of time steps (due to the noise) in the final interval. To test whether the belief state contains the information about this relationship, we train a binary classifier from the belief state to the final frequency

at points just before the final interval. We compare two models with the same recurrent architecture (an LSTM), but trained with different objective: next-step prediction vs TD-VAE loss. The figure on the right shows the classification accuracy for the two methods, averaged over runs. We found that the longer the separating time interval (containing frequency ) and the smaller the size of the LSTM, the better TD-VAE is compared to next-step predictor.

5.4 DeepMind Lab environment

In the final experiment, we analyze the model on a more visually complex domain. We use sequences of frames seen by an agent solving tasks in the DeepMind Lab environment (Beattie et al., 2016). We aim to demonstrate that the model holds explicit beliefs about various possible futures, and that it can roll out in jumps. We suggest functional forms inspired by convolutional DRAW: we use convolutional LSTMs for all the circles in Figure 8 and make the model layers deep (except for the forward updating LSTMs which are fully connected with depth ).

We use time skips sampled uniformly from and analyze the content of the belief state . We take three samples from , which should represent three instances of possible futures. Figure 5 (left) shows that they decode to roughly the same frame. To see what they represent about the future, we draw samples , and decode them, as shown in Figure 5 (right). We see that for a given , the predicted samples decode to similar frames (images in the same row). However ’s for different ’s decode to different frames. This means represented a belief about several different possible futures, while different each represent a single possible future.

Figure 5: Beliefs of the model. Left: Independent samples from current belief; all decode to roughly the same frame. Right: Multiple predicted futures for each sample. The frames are similar for each , but different across ’s.

Finally, we show what rollouts look like. We train on time separations chosen uniformly from on a task where the agent tends to move forward and rotate. Figure 6 shows rollouts from the model. We see that the motion appears to go forward and into corridors and that it skips several time steps (real single step motion is slower).

Figure 6: Rollout from the model

. The model was trained on steps uniformly distributed in

. The model is able to create forward motion that skips several time steps.

6 Conclusions

In this paper, we argued that an agent needs a model that is different from an accurate step-by-step environment simulator. We discussed the requirements for such a model, and presented TD-VAE, a sequence model that satisfies all requirements. TD-VAE builds states from observations by bridging time points separated by random intervals. This allows the states to relate to each other directly over longer time stretches and explicitly encode the future. Further, it allows rolling out in state-space and in time steps larger than, and potentially independent of, the underlying temporal environment/data step size. In the future, we aim to apply TD-VAE to more complex settings, and investigate a number of possible uses in reinforcement learning such are representation learning and planning.


Appendix A TD-VAE as a model of jumpy observations

In section 3, we derive an approximate ELBO which forms the basis of the training loss of the one-step TD-VAE. One may wonder whether a similar idea may underpin the training loss of the jumpy TD-VAE. Here we show how to modify the derivation to provide an approximate ELBO for a slightly different training regime.

Assume a sequence , and an arbitrary distribution over subsequences of . For each time index , we suppose a state , and model the subsequence with a jumpy state-space model ; denote the state subsequence. We use the exact same machinery as the next-step ELBO, except that we enrich the posterior distribution over by making it depend not only on observation subsequence , but on the entire sequence . This is possible because posterior distributions can have arbitrary contexts; the observations which are part of but not effectively serve as auxiliary variable for a stronger posterior. We use the full sequence to form a sequence of belief states at all time steps. We use in particular the ones computed at the subsampled times . By following the same derivation as the one-step TD-VAE, we obtain:

which, using the same belief approximations as the next step TD-VAE, becomes:

which is the same loss as the TD-VAE for a particular choice of the sampling scheme (only sampling pairs).

Appendix B Derivation of the TD-VAE model from its desired properties

In this section we start with a general recurrent variational auto-encoder and consider how the desired properties detailed in sections 1 and 2 constrain the architecture. We will find that these constraints in fact naturally lead to the TD-VAE model.

Let us first consider a relatively general form of temporal variational auto-encoder. We consider recurrent models where the same module is applied at every step, and where outputs are sampled one at a time (so that arbitrarily long sequences can be generated). A very general form of such an architecture consist of forward-backward encoder RNNs and a forward decoder RNN (Figure 7) but otherwise allowing for all the connections. Several works (Chung et al., 2015; Lee et al., 2018; Archer et al., 2015; Fraccaro et al., 2016; Liu et al., 2017; Goyal et al., 2017; Buesing et al., 2018; Serban et al., 2017) fall into this framework.

Figure 7: Recurrent variational auto-encoder. General recurrent variational auto-encoder, obtained by imposing recurrent structure, forward sampling and allowing all potential connections. Note that the encoder can have several alternating layers of forward and backward RNNs. Also note that the connection 1 has to be absent if the backwards encoder is used. Possible skip connections are not shown as they can directly be implemented in the RNN weights. If connections 2 are absent, the model is capable of forward sampling in latent space without going back to observations.

Now let us consider our desired properties.

In order to sample forward in latent space, the encoder must not feed into the decoder or the prior of the latent variables, since observations are required to compute the encoded state, and we would therefore require the sampled observations to compute the distribution over future states and observations.

We next consider the constraint of computing a belief state . The belief state represents the state of knowledge up to time , and therefore cannot receive an input from the backwards decoder. Furthermore, should have an unrestricted access to information; it should ideally not be disturbed by sampling (two identical agents with the same information should compute the same information; this will not be the case if the computation involves sampling), nor go through information bottlenecks. This suggests using the forward encoder for computing the belief state.

Given the use of a decoder RNN, the information needed to predict the future could be stored in the decoder state, which may prevent the encoder from storing the full state information (in other words, the information contained in about the state could be partially stored in the decoder state and previous sample ). This presents two options: the first is to make the prior and the reconstruction depend only on , i.e. to only consider distributions and . The second is to include the decoder state in the belief state (together with the encoder state). We will choose the former option, as we our next constraint will invalidate the latter option.

Next, we argue that smoothing, or the dependence of posterior on the future, is an important property that should be part of our model. As an example, imagine a box that can contain two items and and two time points: before opening the box, when we don’t know the content of the box, and after opening it. We would want our latent variable to represent the content of the box. The perfect model of the content of the box is that the content doesn’t change (the same object is in the box before and after opening it). Now imagine is in the box. Our belief at is high for but our belief at is uncertain. If we sample this belief at without considering we would sample half of the time. However, then we would be learning a wrong model of the world: that goes to . To solve this problem, we should sample first and then, given this value, sample .

Smoothing requires the use of the backward encoder; this prevents the use of the decoder state as part of our belief state, since the decoder has access to the encoder, and the encoder depends on the future. We therefore require a latent-to-latent model .

We are therefore left with a forward encoder which ideally computes the belief state, a backwards encoder which - with the forward encoder - compute posteriors over states, and a state-to-state forward model. The training of the backwards encoder will be induced by its use as a posterior in the state-space model. How do then make sure the forward encoder is in fact trained to contain the belief state? To do so, we will force to be close to the posterior by using a KL term between prior belief and posterior belief.

Before detailing the KL term, we need to consider how to practically run the backwards decoder. Ideally, we would like to train the model in a nearly forward fashion, for arbitrary long sequences. This prevents running the backwards inference from the end of the sequence. However if we assume that represents our best belief about the future, we can take a sample from it as an instance of the future: . It forms a type of bootstrap information. Then we can go backwards and infer what would the world have looked like given this future (e.g. the object was still in the box even if we don’t see it). Using VAE training, we sample from its posterior (the conditioning variables are the ones we have available locally), using as prior. Conversely, for , we sample from as posterior, but with as prior. We therefore obtain the VAE losses at and at . In addition we have the reconstruction term that grounds the latent in the input. The whole algorithm is presented in the Figure 1.

Appendix C Hierarchical Model

In the main paper we detailed a framework for learning models by bridging two temporally separated time points. It would be desirable to model the world with different hierarchies of state, the higher-level states predicting the same-level or lower-level states, and ideally representing more invariant or abstract information. In this section we describe a stacked (hierarchical) version of the model.

The first part to extend to layers is the RNN that aggregates observations to produce the belief state . Here we simply use a deep LSTM, but with layer receiving inputs also from layer from the previous time step. This is so that the higher layers can influence the lower ones (and vice versa). For :


and setting and .

Figure 8: Deep version of the model from Figure 1. A deep version of the model is formed by creating a layer similar to the shallow model of Figure 1 and replicating it. Both sampling and inference proceed downwards through the layers. Circles have the same meaning as in Figure 1 and are implemented using neural networks, such as LSTMs.

We create a deep version of the belief part of the model by stacking the shallow one, as shown in Figure 8. In the usual spirit of deep directed models, the model samples downwards, generating higher level representations before the lower level ones (closer to pixels). The model implements deep inference, that is, the posterior distribution of one layer depends on the samples from the posterior distribution in previously sampled layers. The order of inference is a design choice, and we use the same direction as that of generation, from higher to lower layers, as done for example by Gregor et al. (2016); Kingma et al. (2016); Rasmus et al. (2015). We implement the dependence of various distributions on latent variables sampled so far using a recurrent neural network that summarizes all such variables (in a given group of distributions). We don’t share the weights between different layers. Given these choices, we can allow all connections consistent with the model. Next we describe the functional forms used in our model.

Appendix D Functional forms and parameter choices

Here we describe the functional forms used in more detail. We start with those used for the harmonic oscillator experiments. Let , be the input sequence. The belief state network (both is a standard LSTM network: . For any arbitrary context , we denote the map from

to a normal distribution with mean

and log-standard deviation

, where , with as weight matrices and as biases. We use the letter for all such maps (even when they don’t share weights); weights are shared if the contexts are identical except for the time index. Consider the update for a given pair of time points . We use a two-layer hierarchical TD-VAE. A variable at layer and time is denoted . Beliefs are time and are denoted . The set of equations describing the system are as follows.

The hidden layer of the maps is ; the size of each is . Belief states have size . We use the Adam optimizer with learning rate .

The same network works for the MNIST experiment with the following modifications. Observations are pre-processed by a two hidden layer MLP with ReLU nonlinearity. The decoder

also have a two layer MLP, which outputs the logits of a Bernoulli distribution.

was not passed as input to any network.

For the DeepMind Lab experiments, all the circles in Figure 8 are LSTMs. Blue circles are fully connected LSTM, the others are all convolutional LSTM. We use a fully connected LSTM of size and convolutional layers of size . All kernel sizes are . The decoder layer has an extra canvas layer, similar to DRAW.