Low-pass Recurrent Neural Networks - A memory architecture for longer-term correlation discovery

05/13/2018 ∙ by Thomas Stepleton, et al. ∙ Google 0

Reinforcement learning (RL) agents performing complex tasks must be able to remember observations and actions across sizable time intervals. This is especially true during the initial learning stages, when exploratory behaviour can increase the delay between specific actions and their effects. Many new or popular approaches for learning these distant correlations employ backpropagation through time (BPTT), but this technique requires storing observation traces long enough to span the interval between cause and effect. Besides memory demands, learning dynamics like vanishing gradients and slow convergence due to infrequent weight updates can reduce BPTT's practicality; meanwhile, although online recurrent network learning is a developing topic, most approaches are not efficient enough to use as replacements. We propose a simple, effective memory strategy that can extend the window over which BPTT can learn without requiring longer traces. We explore this approach empirically on a few tasks and discuss its implications.

READ FULL TEXT VIEW PDF

Authors

page 12

page 13

page 14

page 15

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Numerous methods have been proposed to give neural networks the ability to remember past observations. These range from LSTMs (Hochreiter & Schmidhuber, 1997; Graves, 2013) and GRUs (Cho et al., 2014a) and extensions of these (e.g. Danihelka et al., 2016) to “external memory modules” like DNC Graves et al. (2016) or Memory Networks Sukhbaatar et al. (2015)

. Nearly all practical applications of these techniques rely on backpropagation through time (BPTT) to compute weight updates. BPTT unrolls the network over time, then applies backpropagation (reverse accumulation of the chain rule), a computationally-efficient strategy with comparatively low memory requirements.

Nevertheless, BPTT has drawbacks. The memory demand still increases linearly with the size of the interval over which the network is unrolled, as all observations and hidden activations throughout must be stored. This can be prohibitive beyond a few hundred steps when observations are large. Partial mitigation strategies and trade-offs do exist: (e.g. Gruslys et al., 2016) uses careful bookkeeping to limit memory usage at the cost of additional computation.

Beyond memory demands, BPTT applied to long sequences can encounter vanishing and exploding gradients Bengio et al. (1994); Pascanu et al. (2013)

. The algorithm recursively multiplies the gradient over future steps by the Jacobian of the current timestep, and the resulting products (which are summed) can be unstable. Contractions to zero (“vanishing gradients”) or enormous values (“exploding gradients”) are common problems. LSTM and GRU architectures introduce gating mechanisms to mitigate vanishing gradients, and constraining recurrent network weights can also help. Gradient clipping is widely used to address exploding gradients. Overall, though, gradient misbehaviour can still be a difficult obstacle, particularly over long intervals.

Finally, if a long trace of observations must be collected before BPTT can compute a single weight update, infrequent gradient steps can slow convergence significantly, (Indeed, we believe one of our experiments demonstrates this.) In reinforcement learning, infrequent updates can be especially problematic, since noisy credit assignment already makes learning data-inefficient to begin with.

Given these issues, nearly all BPTT applications specifically use truncated BPTT, a variation that simply limits the network unroll to a reasonably small fixed number of timesteps. Chunking the input data in this way ensures manageable memory needs, frequent weight updates, and fewer chances (i.e. matrix multiplications) for gradients to run amok. Truncated BPTT obtains good results in many settings, including language modeling Chelba et al. (2013) and reinforcement learning Mnih et al. (2016). However, if important correlations do not occur within the same chunk, the network’s ability to learn them may be severely impaired.

In this work we explore a simple memory strategy that increases the effective intervals over which correlations can be learnt well beyond the time windowing inherent to truncated BPTT. We consider advantages and weaknesses of our approach and highlight efficiency gains in several reinforcement learning tasks.

2 Recurrent Neural Networks with concrete

Observations in naturalistic settings have rich underlying structure at multiple timescales, from high-frequency details to slowly-changing contexts. Although novel information arrives frequently, many features are recurrent at time scales matching their permanence, like seasonal cycles (fixed) or locomotive gaits (variable). Animals and humans have capabilities and habits that exploit this structure. Grazing animals revisit seasonal foraging grounds, where broad spatial recall aids efficient exploitation of known terrain whilst omitting distracting detail. Humans also show varying precision for memories of past events: we usually remember our distant past more coarsely than our immediate past, and overcoming this structure tends to require kinds of prosthesis (photographs, notes, etc.). This oversimplification omits important memory capabilities: declarative memory, fine episodic recall of events, and so on. Nevertheless, we suggest this “logarithmic” temporal resolution forms desirable contexts for conditioning future behaviour, a viewpoint echoing Howard (2018).

Figure 1: (a): Diagram of a Concrete memory system embedded within a simplification of the actor-critic network from the experiments. Parts in black are considered internal to the memory system by convention. An initial layer realises an input embedding (far left—here, the input is an observation processed by a convolutional network), which feeds into the chain of four low-pass-filtering pools forming the core of the memory. Chained pools have an exponentilly-diminishing smoothing factor with a fixed base . Next, feedforward layers interpret pool contents: these comprise a set of per-pool layers (nicknamed viewports) followed by a single unifying layer (a summariser). Subsequent network apparatus outside of the memory system yields behaviour and value judgments and . In (b), a network used as an experimental control that replaces the chain with parallel filters. All components above the pools are identical to (a) and have been elided.

Reflecting this structure, “RNNs with concrete”, or more formally Low-Pass Recurrent Neural Networks (LP-RNN) allocate varying representational precision to events of different frequency/recency. They accomplish this via chains of low-pass filtering pools, shown in their simplest form in Figure 1a, where each pool represents a different, broadening portion of the network’s past (Figure 1(a), also Appendix A). Our approach relates closely to other models that exploit low-frequency content of inputs or activations, particularly Oliva et al. (2017); Spears et al. (2017), but our focus centres on our conceptually simple model, its application to reinforcement learning, and its ability to overcome limitations in correlation discovery imposed by truncated BPTT.

Formally, let us define the hidden state of some neural network as , where corresponds to the input of the model at step , in this case some observation, and is some neural network such as an MLP or convolutional network.

Then, in its simplest form, we specify LP-RNN as a chain of memory pools (indexed by ) defined as:

where and , and where optionally for some base to obtain a regular temporal tiling as in Figure 1(a). The output of the model can then be defined as some feedforward model , applied to the concatenation of all memory pools, that is

This approach now allows a learning algorithm like BPTT to take advantage of the historical information distributed in a windowed fashion across the memory pools. In exchange for diminishing representational precision, an LP-RNN accumulates data over far longer intervals than those over which the gradients are backpropagated—even orders of magnitude longer.

For this reason, no off-the-shelf gradient-based method can alter the slowly-changing representations in pools further down the chain—indeed, the chained low-pass filters effectively extinguish gradients passing through, and in practice we reduce compute needs by blocking gradients through all but the “fastest” pools. A network must learn to read an LP-RNN—something a feedforward network can do—and because gradients need not (or cannot) pass beyond the reader, lengthy network unrolls have limited benefit. This is the essence of LP-RNN’s “workaround” for mitigating the limitations of BPTT.

(Of course, for architectures in Figure 1 and Appendix D, some gradients must pass through the LP-RNN to train the input-processing apparatus to produce useful representations. For our tasks, we find it suffices to pass gradients through the first, “fastest” pool only. Different architectures and techniques like specifying auxiliary losses Jaderberg et al. (2017a); Mirowski et al. (2017) are other avenues for ensuring that data accumulating in the memory is rich and meaningful.)

LP-RNN’s chained architecture stands in unique contrast to parallel multi-scale filtering of the input (e.g. Figure 1b, an architecture we explore as an experimental control) or direct transformations of parallel filters (see Related Work). We note that the arrangement presents appealing opportunities for more intricate inter-pool links, such as random projections, nonlinearities, or gating, which we hope to pursue in future.

For now, the “odometer-like”, gradually-decelerating flow of data through an LP-RNN, as well as its design for very long-term operation, recall to us Arthur Ganson’s kinetic sculpture Machine With Concrete Ganson (1992). As such, we have informally nicknamed our model an “RNN with concrete”, “concrete memory”, or just “Concrete” for short.

3 Related Work

LSTM and GRU, the workhorses of recurrent models, use gating to mitigate vanishing gradients and (often) gradient clipping to limit divergence. Most intuitive characterisations of these models present the gates as binary “valves” that either close to preserve stored data or open to replace it with new input. In this account, LSTMs can retain information indefinitely; but, as gates are actually computed as sigmoid functions applied to linear projections of the input and prior hidden state, it is likely that they are often neither fully open nor closed. Indeed, a binary-only behaviour would disrupt gradient-based learning, which is not usually observed in practice.

Consider the formula for the LSTM memory cell update:

where represents the state of the memory cell at time step , is the value of the input gate, the value of the forget gate, the proposed update for the memory, and element-wise multiplication. This form is reminiscent of a low pass filter; in fact, if we assume , then it is precisely a low pass filter. Tying the forget and input gate in this manner is used in practice Greff et al. (2017); indeed, this low-pass update formula is hard-coded for GRU Cho et al. (2014b). We also observe that the earliest LSTM lacked forget gates altogether.

LSTM and GRU can thus be considered to filter their activations, but in contrast to our work, the amount of filtering depends on the input and on prior hidden-unit contents. Because it is learnt, the model can lose diversity in filtering behaviour and focus on short term information early on in training. If the forget gate operates independently of other gates, these models can also reset themselves, giving them wavelet-like finite time support. Our simple model fixes its filtering behaviour a priori, and individual pools have infinite (albeit temporally concentrated) support.

Another family of approaches augments recurrent models with external memory storage. These approaches apply attention or nearest-neighbour lookups to large buffers of activations, searching old hidden states to find associations over long intervals. DNC and Memory Networks are canonical examples; sequence-to-sequence models with attention Bahdanau et al. (2014) also relate. In reinforcement learning, episodic control methods Blundell et al. (2016); Pritzel et al. (2017) do the same to search over entire histories of agent observations. MBPA Sprechmann et al. (2018)

is another relatable approach, used in supervised learning.

Some architectures address longer-term memory by storing information verbatim over time intervals. NARX Lin et al. (1996) derives new outputs from a queue (”delay line”) of prior network inputs and outputs. More recently, Clockwork RNN Koutnik et al. (2014) has units that only update at set intervals and otherwise hold their state. In contrast to these fully-connected networks with non-lossy storage, LP-RNN accepts information loss within a preset arrangement of pooled filtering units specifically built to obtain the temporally-tiled impulse behaviour in Figure 1(a).

In Schmidhuber (1992); El Hihi & Bengio (1995); Chung et al. (2017) different recurrent multi-timescale architectures are explored, which are related to our work. The architectures are hierarchical, where higher layer are slower, resulting in a different input response to ours. In Chung et al. (2017) the time-scale is learnt and is input dependent. Sordoni et al. (2015); Ling et al. (2015) exploits domain knowledge, providing input-driven time-scale boundaries for the different layers in the hierarchy. Fernández et al. (2007); Kong et al. (2015) have a similar objective to discover the segmentation structure by direct supervision.

More directly related works investigate various kinds of low-pass filtering of inputs or activations. In reservoir computing, Echo State Networks benefit considerably from “leaky integration” units: individual hidden units set up to perform randomly-sampled amounts of low-pass filtering, where a carefully-crafted distribution controls the range of available filtersJaeger et al. (2007). Holzmann & Hauser (2009) extends this idea to to use band-pass instead of low-pass filters. The serendipitous occurrence of an LP-RNN structure in a randomly-constructed network is unlikely; even learning it in any fully-connected architecture like those described above may be quite difficult.

Filtering has been applied to simpler recurrent networks with some minor improvements Bengio et al. (2013). Mikolov et al. (2014) adds low-pass filtered versions of the network’s input to the inputs to the network’s recurrent and output layers. The same filtering is used throughout, so this approach is akin to having the first pool in a Concrete memory project everywhere, provided the pool directly filters the input and not some hidden state.

Oliva et al. (2017) by contrast does present a recurrent model with multiple filterings of hidden state, but these project via fully-connected recurrent connections back into the model. The fact that they are computed in parallel implies that models impulse response would be quite different from the Concrete architecture. The motivation of their approach is also quite different from ours. More recently, the neurologically-inspired SITH model Spears et al. (2017) proposes a somewhat more restricted variant of this concept that limits projections of multiply-filtered network inputs to those yielding localised impulse responses similar to the ones in Figure 1(a), but this approach still employs full matrix multiplications to compute unit values.

We would be pleased for Concrete to be considered a “family member” to the above-mentioned works, but we consider the motivations, methods, applications, and analysis in this paper—particularly the focus on overcoming limitations of BPTT—to be unique.

Finally, there is recent work on online learning for recurrent models. Much of it adapts Real Time Recurrent Learning Williams & Zipser (1989); Pearlmutter (1995), which applies forward (instead of reverse) accumulation of the chain rule. Learning can be purely online this way, but the method requires explicit computation and storage of the Jacobian of the recurrent model’s transition function. This can be impracticable for large networks. Tallec & Ollivier (2017) explores various low-rank Jacobian approximations. Taking a different approach, Czarnecki et al. (2017) replaces BPTT with a learnt parameterised model of the gradient. These approaches show promise, but none appear ready to replace truncated BPTT in practical applications.

4 Interpretive views on RNN with concrete

4.1 Signal processing perspective

(a) Normalised impulse responses of a chain of twelve Concrete pools; note log scale. Each curve shows the response of a single pool; the -value is the degree to which an input at time in the past affects the contents of the pool, normalised by the maximum value for the pool. Sequential chaining of the pools causes the widening, temporally-tiled “windows”.
(b) Contrasting with the above, normalised impulse responses of a a collection of twelve non-chained exponential smoothing pools (that is, the input to the memory is fed to each pool in parallel). Pools in this arrangement lack temporal tiling.
Figure 2: Impulse responses for various filtering configurations. (Line styles and colours are only to make the pools more visually distinct.)

A natural first perspective views each memory pool as yielding a low-pass filtered version of the input signal. This framing allows us to employ classic signal processing methods to understand how the memory will behave in practice.

For example, Figure 1(a) shows the impulse response of the different memory pools when they are chained, versus arranged in parallel in Figure 1(b). These curves obtain simply from feeding unit values (an impulse) into the memory at the first timestep and zero values thereafter. One observation is that the chained pools exhibit a behaviour similar to a delay line, where information from the first pool moves to the second and so on. Note that movement slows and temporal precision decreases as we move through the pools (-axis is log-scaled).

The parallel pool case, where pools independently filter the input with different smoothing factors, is quite different. Not only are we losing resolution much faster over longer time spans, but we also lack the “delay line” or “temporal tiling” effect. This is an important distinction, as an MLP analysing parallel filter pools is less likely to localize information temporally—not only due to the severe loss of resolution, but also because recent events can affect the contents of all pools.

Another benefit of this perspective is insight into the kind of information the memory is likely to retain: in particular, we can consider which frequency bands are more or less likely to attenuate in the various pools. This allows us to reason about what kind of cues the model is likely to remember at various times (lower-frequency ones as time goes on).

4.2 Linear Memories

Consider the reinforcement learning setting, where the value of a state is given as the sum of future discounted rewards,

Observe that if the reward can be written as a linear function, , the value function decomposes as

Here, are the successor features at for the current policy. Notice that if instead of a discounted return we want the value as the exponentially weighted moving average of rewards the same successor features can be used,

Successor features Dayan (1993)

provide a structural form of transfer learning in reinforcement learning. The features learned for one policy

can be used to evaluate that policy under any reward function linear in , with some corresponding . This form of transfer owes to linearity. The reward function is assumed linear, and the prediction of interest here is itself a linear function of reward. Thus the same reward function can be applied to many predictions about the future (successor features), and vice versa.

Now consider predictions about the past: memories.

Let be a linear function on input, and be some memory system that maps a sequence of values to a single value, . Different choices of would yield different time horizons on the memory, or different “windows” into the past.

If is linear in its inputs, we can write , leading to the same type of structural transfer learning seen by successor features. Specifically, any memory system can be applied to many different functions , and vice versa. A simple example of how this could be useful is if we imagine a convolution over many memory systems, applying the same set of functions that ask the same questions about the past summarized in that memory. If the memory system is non-linear, such transfer is no longer possible. Each combination of must be learned in isolation, and particularly important, relearned if either changes. That is, if the receptive field in time of changes, but what computes about the state does not, both must be relearned if is non-linear, but only must change if it is linear.

Concrete is a linear memory system. Similarly, SITH is also a linear memory system as it combines several independent units using a linear operator. LSTMs by contrast are not linear memory systems.

5 Experiments

5.1 Sequence classification tasks

Figure 3:

Visualisations of sequences of one-hot symbols for the original four-class sequence classification task (top) and our modification with low-frequency symbol subsequences in place of single-timestep markers. The networks must classify the sequences based on the identity and ordering of these features. Not shown: the original task with three significant markers instead of two.

Our evaluation of the performance of our low-pass filtering-based memory systems begins with the classic “temporal order” sequence classification tasks from Hochreiter & Schmidhuber (1997). In these tasks, inputs are sequences of one-hot symbols that start and end with unique boundary markers. In between, most elements are any of four random distractor symbols, but two (or three) of the elements can be either of two additional markers. (These elements appear in specific regions only—e.g. the first within elements 10-20.) The classification target of the sequence depends on which symbols appear in which positions; so, if there are two (or three) significant markers, there are (or ) possible classes. Sequences are between 100 and 110 symbols long, and the training signal is supplied only at the end of the sequence.

In addition to the original tasks with two and three symbols, we also investigate a variation of the two symbol task where the semantically-significant binary markers are replaced by specific lower-frequency subsequences of the distractor symbols. Each of the binary markers can be replaced with any of five possible subsequences, none of which are shared between the markers (hence, ten subsequences in total). Finally, the start and end markers are also extended over multiple timesteps. These sequences can be between 98 and 133 symbols long. See Figure 3 for visualisations of both types of task.

For all tasks, we compare three families of classification network with memory: an LSTM that projects into a hidden layer and then into a layer of unscaled classification logits; a Concrete-like network without the input embedding found in the actor-critic network in Figure

1a, and a network that feeds an input embedding into a parallel (non-chained) bank of low-pass filters. Detailed schematic diagrams of these network families appear in Appendix section D. Our training system presents the network with a sequence of temporally-contiguous, non-temporally-aligned minibatches, each compiled from samples drawn from multiple independent sequence samplers running in parallel. Because we vary the size of the minibatch and the amount of BPTT truncation, we take care to ensure that the networks see the same number of input symbols in each training instance ( for the original tasks; for the modified task).

Symbol Parameter Value set
Batch size
Learning rate
Adam optimizer
Hidden layer size
LSTM size
Pool size
Number of pools
Viewport size
Pool decay base
Table 1: Free parameters and the sets of values they are drawn from in the sequence classifcation tasks; notation matches the schematic diagrams in Appendix section D

. Parameters above the dashed line are common to all network architectures; those below are sampled as appropriate for the network in use. The learning rate and epsilon parameters are sampled from a uniform distribution over the log of the values in the stated range.

Adam optimizer refers to Kingma & Ba (2015).

For each task and network type, we perform 3,000 independent experimental runs with various network and training hyperparameters drawn randomly from the sets shown in Table

1. These sets express rather loose bounds on favourable parameters (with the parameter space for the low-pass networks being considerably larger than that for LSTM) and include configurations for all memories where the hypothetical maximum capacity is well in excess of what the tasks require. Put differently, all tasks are given numerous architectural opportunities to succeed. Nevertheless, Figure 4 demonstrates that the LSTM-based network cannot solve the tasks under any configuration until BPTT truncation covers a long-enough time interval. (Low-pass filter-based memories do show a performance drop as truncation relaxes, but we attribute at least some of this to sensitivity to the volume of data and not its contents, as a similar effect occurs for minibatch size. See Appendix section C.1.)

Finally, we note better performance for the Concrete-style memory over the parallel filter arrangement, an effect that may owe to Concrete’s more temporally-localised response characteristics (Figure 1(a)).

(a) Two markers.
(b) Three markers.
(c) Two subsequence markers.
Figure 4: Performance on “sequence classification” tasks for networks with Concrete-style, parallel filter-based, and LSTM memory systems, broken down by BPTT truncation length ( axis). The axis is the fraction of correct classifications at the end of training as determined by averaging recent runs with an exponential smoother (smoothing factor: 0.98). For all tasks, the filter-based memory systems show better top performance than LSTM until the truncation length approaches the length of the input sequences. Each condition aggregates over a wide range of free hyperparameters (see Table 1), so some poor performance is not necessarily surprising; more important is that some instances of the LSTM-based network have the capacity to master the tasks but cannot do so under aggressive truncation. See Appendix section C.1 for further discussion of performance trends.

5.2 Reinforcement learning tasks

Reinforcement learning is a motivating application for our investigation into memory systems based on low-pass filtering. The training signal in RL settings can be sporadic and noisy, making it difficult to identify the information that may be relevant to future behaviour. Furthermore, some settings feature long time intervals between the moments when information is acquired and when it must be acted upon. We examine three tasks that exacerbate this latter challenge: in each, the network must retain relevant information for hundreds of timesteps.

For all three tasks, we compare an actor-critic network that uses a Concrete-style memory with a network that uses an LSTM. The networks are identical outside of their memory systems: both feature the same convolutional network and hidden layer to process the input; both have the same linear outputs for computing state values and logits over discrete actions. (Detailed schematics of the networks in use appear in Appendix section D.) We use the IMPALA distributed reinforcement learning architecture Espeholt et al. (2018) with an adaptation of Population Based Training (PBT) Jaderberg et al. (2017b). PBT is a recently proposed meta-algorithm that optimises hyperparameters and network weights through time among a population of agents. In our adaptation, agents copy the weights of better-performing agents whilst keeping their own hyperparameters fixed, effecting a parallelised exploration of weight space that exposes promising weight configurations to varied hyperparameter regimes. Further implementation details are given in the appendix E. Because IMPALA uses a multi-step value learning algorithm for the critic, rollout length affects RL efficiency in addition to BPTT truncation; to avoid conflating these effects, we fix all rollouts at 300 timesteps and artificially block gradients through the network’s hidden state at varying intervals.

All three tasks are “gridworld” environments implemented with the pycolab game engine Stepleton (2017). Although recent results in RL emphasise more elaborate settings Beattie et al. (2016); Brockman et al. (2016) our own experience suggests that the intricacies of these games can give the agent opportunities to “cheat” in memory tasks by coding information in the environment itself. An agent may keep to one side of a 3-D corridor, for example, to remember which direction to turn at an upcoming junction. Our custom environments allow us to control precisely what the agent can do in the environment, thereby letting us force it to store critical information in its memory.111Implementations of these environments may be found at https://github.com/deepmind/pycolab/tree/master/pycolab/examples/research/lp-rnn.

5.2.1 Cued Catch

    Every 3 steps.
    Every 20 steps.
    Every 100 steps.
    Every 300 steps.
Figure 5: Performance on the “Cued Catch” task with 40 no-reward trials followed by 60 trials with reward (top row) and on the “T-maze” task with a 280-timestep detention in “limbo” (see Appendix sections B.1, B.2

). Orange/solid curves depict scores for a network with a Concrete memory, blue/dashed curves for a network with an LSTM. The rightmost column shows performance for ordinary 300-step network unrolls; preceding columns are for the same unrolling but with gradients through the hidden state blocked every 100, 20, and 3 steps respectively. The best possible score for Cued Catch is 60; for T-maze, just over 0.6. All curves summarise data from five separate runs of population-based training: for each run, the performances of the top three seeds in the population are averaged per-step; the resulting curve is smoothed, then all five smoothed average curves yield the confidence interval plots above. The smoothing is cosmetic; for the same plots without smoothing, see Appendix secion

C.2.

“Cued Catch” is a bandit-like game where the agent must interpret visual signals to “catch” one of two moving blocks. In an episode, the agent undertakes 100 catching “trials”, each lasting seven timesteps. The moving blocks take six steps to reach the agent, and at all times a visual cue indicating which block to catch shows at the bottom of the screen. There are four cues, and their associations with either of the two blocks change with each episode. Episodes start with a pre-trial “teaching phase” that shows correct cue interpretations by pairing cues one-by-one with either of two additional markers whose associations with the moving blocks are constant across episodes. Each pairing appears for ten timesteps. For more details and images of the task, see Appendix section B.1.

As described, Cued Catch allows networks using conventional memory architectures like LSTM to improve gradually in retaining cue semantics: reinforcement signals from trials just after the teaching phase can “bootstrap” the network into learning to memorise the associations for longer intervals. To disrupt this bootstrapping, we simply withhold any reward for the first 40 trials. The agent must retain the associations for at least 280 timesteps before it can receive any reinforcement for doing so. As shown in Figure 5, this intervention prevents our LSTM-based network from learning the task even when backpropagating over a generous 300 timesteps; in contrast, the Concrete-based network learns fairly robustly even with agressively truncated BPTT.

5.2.2 T-maze

In this task, the agent must eventually navigate to one of two possible goal locations situated at opposite ends of a C-shaped corridor. To learn which goal to seek, the agent begins the episode in a box-shaped “teleporter room” where a relevant visual cue appears. After 50 timesteps, a portion of the room becomes a teleporter that transfers the agent upon traversal to the C-shaped corridor after a detention period in “limbo”—a tiny space where the agent is unable to move or change its observation in any way. This gives the agent no option but to store the goal information in memory for a long period. Appendix section B.2 presents more details and images of this task.

In our experiments, the detention period in limbo is set to 280 timesteps. As the agent need remember only one bit of information (compared with the required for Cued Catch), random initialisation of large conventional memories may be more likely to start out with a useful configuration. Nevertheless, as Figure 5 shows, only the Concrete-based network masters this task consistently, again even with aggressively truncated BPTT.

5.2.3 Sequence Recall

    Every 20 steps.
    Every 300 steps.
Figure 6: Performance on the “sequence recall” task (see Appendix section B.3); orange/solid curves depict scores for a network with a Concrete memory, blue/dashed curves show scores for a network with an LSTM, taupe/irregularly-dashed curves show scores for a parallel filter-based network like in Figure 1b. Right, for ordinary 300-step network unrolls; left, the same amount of unrolling with gradients through the hidden state blocked every twenty steps. The best possible score is just over 2.0. The same plotting procedure used in Figure 5 was applied here; for plots without smoothing, see Appendix section C.3.

In the “Sequence Recall” task, the agent first sits immobilised as it observes a sequence of four “light flashes”, each from one of four spatially-separated disk-shaped “lights”. Flashes last for 60 timesteps and are separated by 30-timestep gaps. When the sequence ends, the agent is free to walk amongst the lights. The agent receives a 1.0 reward if the th disk it traverses is the same as the th light in the sequence. Appendix section B.3 presents more details and images of this task.

The sequence recall task requires at least bits of information storage. Although there are 270 timesteps between the first flash and the moment when the agent can first move, the agent can apply knowledge of the last flash relatively soon after it occurs, which may support some “bootstrapping” of an effective memory strategy in traditional architectures. Whether it does nor not, Figure 6 shows a better performance from the Concrete-based network over the LSTM baseline. We also applied the “parallel filter” network (Figure 1b, Appendix section D) to the task; its diminished performance suggests that Concrete’s temporal tiling is useful for tasks where temporal ordering is significant. (We evaluated fewer BPTT truncation lengths for this task owing to its longer running time and the extra agent.)

6 Discussion

In this work we introduced the LP-RNN, or “RNN with concrete”, a simple recurrent memory system that tracks history through a chain of low-pass filter pools. Unlike popular memory architectures like LSTM, this method is less sensitive to temporal chunking imposed by truncated backpropagation through time, a popular gradient method for training recurrent models.

Viewing RNNs with concrete through a signal processing lens revealed the tiled temporal representation benefits of chaining filtering pools over connecting them in parallel. From a statistical perspective, we discovered parallels to successor representations and found that the linearity of RNNs with concrete—in contrast to the non-linear LSTM—promotes stable representations of the past.

We tested our model on supervised memory tasks and reinforcement learning tasks. The supervised tasks demonstrate that unlike LSTMs, RNNs with concrete do not require long back-propagation windows to remember distant history. The reinforcement learning tasks evaluate our model in a high signal-to-noise setting where long term credit assignment is crucial. We showed that RNNs with concrete can retain past information so it can be associated with present context even hundreds of timesteps apart. This is highly valuable in reinforcement learning and often makes the difference between learning well and not learning at all.

References

Appendix A Visualisation of data accumulating within a Concrete memory

The following images depict the evolution over time of the contents of a six-pool Concrete memory with random orthogonal inter-pool projections. The pool closest to the input is shown on the left-hand side, and following pools in the chain extend rightward.

For each image, the first fourteen timesteps (rows) show the contents of all pools as a one-hot encoding of a 14-character string is fed letter-by-letter into the memory. Subsequent rows show how the pools change over time with no further input, or, equivalently, with zero values across all of the input units. Information can be seen to migrate from a detailed, rapidly changing representation in the “upstream” pools to more slowly changing representations in the “downstream” pools as time goes on. Note how the temporal profile of these responses match the impulse response curves in Figure

1(a).

The Concrete implementation used to make these images is an augmented one: information transferred between pools is also transformed by a random orthonormal projection and then passed through a nonlinearity. We allude to extensions like these in Section 2. This augmentation has no particular bearing on the visualisation itself.

Below, the contents of the memory for the string machinelearning:

Next, the same visualisation for the string itsmycatmittens:

(Visualisations continue on the following page.)

Next, a visualisation of the absolute-valued difference between the two, where visible differences between both strings are apparent in all pools:

For distinguishing very similar inputs, it may not suffice to attend to a single pool. The image below visualises the absolute-valued difference between the Concrete memory storing the string machinelearning and the same memory storing an anagram of that string, migrainechannel. White horizontal bands indicate that there are stretches of time in individual downstream pools where the representations of the two strings are virtually indistinguishable. These periods appear not to occur simultaneously across all pools, however, in which case attention to multiple pools could resolve ambiguities between the representations.

Appendix B Descriptions of the RL tasks

Python implementations of the three RL tasks described below are available at https://github.com/deepmind/pycolab/tree/master/pycolab/examples.

b.1 Cued Catch

In this bandit-like game, the mostly-immobile player (single white block) must “catch” blocks that repeatedly approach it from the right side of the screen (yellow and cyan blocks in the right-hand image). The player can only move up or down between two positions, aligning itself to catch either the yellow or the the cyan block. The green symbol at the bottom of the screen indicates which ball to catch; catching this ball earns a reward of 1.0. No other reward is awarded at any point. A fixed number of these catching trials repeat before the episode terminates.

There are four ball symbols, two associated with the cyan block and two with the yellow block. These associations are randomly sampled at the beginning of each episode. The player is told these associations during a “teaching phase” at the episode’s beginning (left-hand image), where the symbols are presented just beneath either of two wide bars (large green blocks), each of which has a fixed association with either the cyan or the yellow block. The player must learn these associations so that it can successfully interpret the training. The teaching phase is also distinguished by two tall rectangles to the left and right of the player and by the absence of the yellow and cyan blocks.

The memory demands of the task may be increased by temporarily disabling the 1.0 reward for correct “catches” for a specified number of trials after the “teaching phase”.

b.2 T-maze

This version of the classic T-maze memory task places additional demands on the player’s memory by preventing the player from acting on the maze cue for a configurable number of environment steps.

The episode begins with the player (central cyan block) in a “teleporter room” (left image). A cue (tall green rectangle) indicates whether the player should ultimately seek the goal at the left or right end of the maze. Eventually, a teleporter appears (blue rectangle above player). Upon traversal, the player is first teleported to a “limbo” (centre image), where it is completely immobilised. The player is prevented in this way from using its position in the environment as a means of recording the cue. After a configurable delay, the player is transported to the horizontal centre of the T-maze itself, an egocentric scrolling corridor shaped like this: ] . The left and right goals are situated at the lower ends of the two vertical parts of the corridor. Traversing the goal that matched the initial cue earns a reward of 1.0 for the player; traversing the wrong goal earns -1.0; either terminates the episode.

An additional penalty of -0.001 is also levied at each timestep (even those where the player is in “limbo”) so that waiting for the episode to time out and terminate is more costly than selecting the wrong goal.

b.3 Sequence Recall

This task partly resembles the operation of an electronic memory game toy. At the beginning of the episode (left image), the player (centre cyan block) sits immobilised within a small “pen” surrounded by four disk-shaped “lights”. A random sequence of four lights (each drawn with replacement) is shown (in the left-hand image, the blue, rightmost light is “on”). After the light sequence is presented, the player is free to move about the environment (right-hand image). The player receives a 1.0 reward if the th light it traverses is the same as the th light in the original sequence, so it must traverse the lights in the same order in which they were presented in order to achieve the highest score. If the same light appears in two adjacent positions in the sequence, the player must enter, leave, and re-enter the light disk in order to receive credit for both positions.

Similar to the T-maze task, an additional penalty of -0.005 is levied at each timestep so that waiting for the episode to time out and terminate is more costly than traversing four incorrect lights.

Appendix C Supplemental plots

c.1 Sequence classification performance by minibatch size

(a) Two markers.
(b) Three markers.
(c) Two subsequence markers.
Figure 7: As a supplement to the discussion surrounding Figure 4, sequence classification performance sliced by minibatch size. The low-pass RNN-based methods achieve better results with smaller minibatches. This effect, along with the poorer performance for longer truncation intervals observed in Figure 4, may be be an instance of the infrequent update problem discussed in the introduction: for the largest batch size (128) and truncation interval (256), the entire training for the original sequence classification tasks will apply only 1,220 weight updates to the network. For the smallest batch size (4) and truncation interval (2), there will be five million updates.

c.2 Plots from Figure 5 without smoothing

    Every 3 steps.
    Every 20 steps.
    Every 100 steps.
    Every 300 steps.

c.3 Plots from Figure 6 without smoothing

    Every 20 steps.
    Every 300 steps.

Appendix D Network schematics

Figure 8: Legend for the symbols used in the neural network schematics in this Appendix.

d.1 LP-RNN-based networks used in the experiments

Figure 9: A schematic of the LP-RNN (Concrete) network family applied to the sequence classification tasks in Section 5.1. The size of the network output is determined by the number of classes in the classification task; the other parameters shown (: summariser size; : viewport size; : pool size; : filter base) as well as the number of pools are varied as described in the text. Note that gradients through all but the first pool are blocked for computational efficiency, as the low-pass architecture will naturally attenuate gradients through most of the pools on its own. At the arrow marked *, the one-hot representation of the input symbol projects into the first pool via multiplication by (note no bias), where

is initialised as a “padded identity” matrix; however, all of the values in

may be modified by the optimiser.
Figure 10: A schematic of the LP-RNN (Concrete) network applied to the reinforcement learning tasks in Section 5.2. The input supplied to the convolutional network is a rowscolumnsfeatures binary array. The network is an actor-critic network; the five-dimensional output comprises unscaled logits for the five agent actions (, , ,

, remain in place), while the one-dimensional output is a state value estimate. The only adjustable parameter for this architecture is the number of pools

, which is selected for the task as described in the text. Note that gradients through all but the first pool are blocked for computational efficiency, as the low-pass architecture will naturally attenuate gradients through most of the pools on its own. At the arrow marked *, the 128-dimensional embedding of the processed input projects into the first pool via multiplication by (note no bias), where is initialised as a “padded identity” matrix; however, all of the values in may be modified by the optimiser.

d.2 LSTM-based networks used in the experiments

(a) A schematic of the LSTM-based network family applied to the sequence classification tasks in Section 5.1. The size of the network output is determined by the number of classes in the classification task; the other parameters shown (: hidden layer size; : LSTM size) are varied as described in the text.
(b) A schematic of the LSTM-based network applied to the reinforcement learning tasks in Section 5.2. This network has the same inputs and outputs as the Concrete network in Figure 10.

d.3 Parallel filter-based networks used in the experiments

Figure 11: A schematic of the parallel-pool network family applied to the sequence classification tasks in Section 5.1. This family has the same structural parameters (, , , ) as the Concrete network family in Figure 9 and features an identical pattern of gradient blocking. The layer at bottom is a learned -dimensional embedding of the input; its projection to the first pool (the arrow marked *) is a learnable bias-free projection initialised as , where is the identity matrix.
Figure 12: A schematic of the parallel-pool network family applied to the reinforcement learning tasks in Section 5.2. This network has the same inputs and outputs as the Concrete network in Figure 10 and features an identical pattern of gradient blocking. The layer just upstream of the pools is a 198-dimensional embedding of the processed input; its projection to the first pool (the arrow marked *) is a learnable bias-free projection initialised as , where is the identity matrix.

Appendix E Population Based Training (PBT)

We use an adaptation of PBT to augment the optimisation of network weights for our model (and corresponding baselines) for all RL tasks.

For each run of the experiment, nine agents are launched in parallel. Agents train for 30,000 episodes each before the evaluation phase. Here, each agent randomly samples another and compares its performance across the last 3,000 episodes – this is done by performing a t-test over both scores. If the sampled agent is found to be statistically better (p-value

), its weights are copied over.

Unlike conventional PBT Jaderberg et al. (2017b), where agents that copy weights from better-performing agents also copy (and then perturb) various hyperparameters as well (e.g. gradient descent step size), our adaptation leaves agent hyperparameters unchanged. In a sense, this approach realises an implicit adaptive hyperparameter schedule by allowing well-performing sets of weights to “jump” amongst different hyperparameter configurations.

Appendix F When LSTM starts working on Cued Catch and T-maze

Even without blocking gradients through the hidden state at any point within the 300-step network unrolls described in Section 5.2, our LSTM networks showed poor performance on the Cued Catch and T-maze tasks. To demonstrate that an LSTM has the representational capacity to solve these tasks, we swept over more relaxed settings of their respective “difficulty parameters” to identify a point at which the LSTM network (with unblocked 300-step unrolls) can succeed.

For Cued Catch, the difficulty parameter is the number of reward-free trials before the player starts receiving a reward for “catching” the correct block. The original experiment used 40 reward-free trials; in our sweeping, we first observed some non-random performance at 10 reward-free trials and good performance at 5:

    Every 300 steps.

Note that the total number of trials remains the same in all cases, which accounts for the reason these curves appear to be vertically offset: fewer reward-free trials means more trials where even a random agent will receive some reward.

For T-maze, the difficulty parameter is the number of frames the player spends immobilised in “limbo”. The original experiment immobilised the agent for 280 frames; we observed gradual performance improvement as this delay was reduced. By 140 frames, we observed good performance:

    Every 300 steps.

Appendix G More on LP-RNN as a linear operator

In all of the following, we consider the input to be a sequence of scalar floating-point numbers; for vector inputs, we can imagine parallelising the analysis for each vector dimension. We imagine a memory with no input embedding and

pools of size 1, where is the length of the input sequence. Let be the contents of the pools at timestep . Similarly, if we repackage each scalar input value into a -vector whose entries are all 0 except for in the first entry, that is then the memory pool contents at time can be expressed as the following recurrence:

where is a matrix describing how values diffuse through the pools during a single timestep. This matrix is upper-triangular with entries in the following pattern:

or, for indices , if and 0 otherwise; is some fixed base. (Recall that the LP-RNN description in the main text allows the filter coefficients to be different for each pool; here we limit our description to common parameterisations that use powers of a common base.)

Given the linearity of the above recurrence, we can see that

and because each is 0 except for the first entry, each term is just the first row of a power of scaled by some integer. Therefore, if we place the entire input sequence in reverse into some vector , we can express

as a single linear transformation of this

:

in other words, collects in its rows the first rows of the powers of .