Unsupervised Predictive Memory in a Goal-Directed Agent

by   Greg Wayne, et al.

Animals execute goal-directed behaviours despite the limited range and scope of their sensors. To cope, they explore environments and store memories maintaining estimates of important information that is not presently available. Recently, progress has been made with artificial intelligence (AI) agents that learn to perform tasks from sensory input, even at a human level, by merging reinforcement learning (RL) algorithms with deep neural networks, and the excitement surrounding these results has led to the pursuit of related ideas as explanations of non-human animal learning. However, we demonstrate that contemporary RL algorithms struggle to solve simple tasks when enough information is concealed from the sensors of the agent, a property called "partial observability". An obvious requirement for handling partially observed tasks is access to extensive memory, but we show memory is not enough; it is critical that the right information be stored in the right format. We develop a model, the Memory, RL, and Inference Network (MERLIN), in which memory formation is guided by a process of predictive modeling. MERLIN facilitates the solution of tasks in 3D virtual reality environments for which partial observability is severe and memories must be maintained over long durations. Our model demonstrates a single learning agent architecture that can solve canonical behavioural tasks in psychology and neurobiology without strong simplifying assumptions about the dimensionality of sensory input or the duration of experiences.


page 11

page 14


Reinforcement Learning using Guided Observability

Due to recent breakthroughs, reinforcement learning (RL) has demonstrate...

Universal Memory Architectures for Autonomous Machines

We propose a self-organizing memory architecture for perceptual experien...

Zipfian environments for Reinforcement Learning

As humans and animals learn in the natural world, they encounter distrib...

Episodic Memory Deep Q-Networks

Reinforcement learning (RL) algorithms have made huge progress in recent...

Projective simulation applied to the grid-world and the mountain-car problem

We study the model of projective simulation (PS) which is a novel approa...

AIGenC: AI generalisation via creativity

This paper introduces a computational model of creative problem solving ...

Code Repositories


(Personal experiment) Unsupervised Predictive Memory in a Goal-Directed Agent https://arxiv.org/abs/1803.10760

view repo


Artificial intelligence research is undergoing a renaissance as RL techniques[sutton1998reinforcement], which address the problem of optimising sequential decisions, have been combined with deep neural networks into artificial agents that can make optimal decisions by processing complex sensory data[mnih2015human]. In tandem, new deep network structures have been developed that encode important prior knowledge for learning problems. One important innovation has been the development of neural networks with external memory systems, allowing computations to be learned that synthesise information from a large number of historical events[weston2014memory, bahdanau2014neural, graves2016hybrid].

Within RL agents, neural networks with external memory systems have been optimised “end-to-end” to maximise the amount of reward acquired during interaction in the task environment. That is, the systems learn how to select relevant information from input (sensory) data, store it in memory, and read out relevant memory items purely from trial-and-error action choices that led to higher than expected reward on tasks. While this approach to artificial memory has led to successes[todd2009learning, oh2016control, graves2016hybrid, duan2017one], we show that it fails to solve simple tasks drawn from behavioural research in psychology and neuroscience, especially ones involving long delays between relevant stimuli and later decisions: these include, but are not restricted to, problems of navigation back to previously visited goals[tolman1948cognitive, tse2007schemas, o1978hippocampus], rapid reward valuation[corbit2000role], where an agent must understand the value of different objects after few exposures, and latent learning, where an agent acquires unexpressed knowledge of the environment before being probed with a specific task[blodgett1929effect, tolman1930introduction].

We propose MERLIN, an integrated AI agent architecture that acts in partially observed virtual reality environments and stores information in memory based on different principles from existing end-to-end AI systems: it learns to process high-dimensional sensory streams, compress and store them, and recall events with less dependence on task reward. We bring together ingredients from external memory systems, reinforcement learning, and state estimation (inference) models and combine them into a unified system using inspiration from three ideas originating in psychology and neuroscience: predictive sensory coding[rao1999predictive, bastos2012canonical, hindy2016linking], the hippocampal representation theory of Gluck and Myers[gluck1993hippocampal, moustafa2013trace], and the temporal context model and successor representation[howard2002distributed, dayan1993improving, stachenfeld2017hippocampus]. To test MERLIN, we expose it to a set of canonical tasks from psychology and neuroscience, showing that it is able to find solutions to problems that pose severe challenges to existing AI agents. MERLIN points a way beyond the limitations of end-to-end RL toward future studies of memory in computational agents.

RL formalises the problem of finding a policy or a mapping from sensory observations to actions . A leading approach to RL begins by considering policies that are stochastic, so that the policy describes a distribution over actions. Memory-free RL policies that directly map instantaneous sensory data to actions fail in partially observed environments where the sensory data are incomplete. Therefore, in this work we restrict our attention to memory-dependent policies, where the action distribution depends on the entire sequence of past observations.

In Fig. 1a, we see a standard memory-dependent RL policy architecture, RL-LSTM, which is a well-tuned variant of the “Advantage Actor Critic” architecture (A3C)[mnih2016asynchronous] with a deeper convolutional network visual encoder. At each time , a sensory encoder network takes in an observation

and produces an embedding vector

. This is passed to a recurrent neural network

[hochreiter1997long], which has a memory state that is produced as a function of the input and the previous state:

. Finally, a probability distribution indicating the probability of an action is produced as a function of the memory state

. The encoder, recurrent network, and action distribution are all understood to be neural networks with optimisable parameters . An agent with relatively unstructured recurrence like RL-LSTM can perform well in partially observed environments but can fail to train when the amount of information that must be recalled is sufficiently large.

In Fig. 1b, we see an RL agent (RL-MEM) augmented with an external memory system that stores information in a matrix . In addition to the state of the recurrent network, the external memory stores a potentially larger amount of information that can be read from using a read key , which is a linear function of that is compared against the contents of memory: . The recurrent network is updated based on this read information with . The memory is written to at each time step by inserting a vector , produced as a linear function of , into an empty row of memory: . The functions “read” and “write” additionally have parameters so that .

An agent with a large external memory store can perform better on a range of tasks, but training perceptual representations for storage in memory by end-to-end RL can fail if a task demands high-fidelity perceptual memory. In RL-LSTM/MEM, the entire system, including the representations formed by the encoder, the computations performed by the recurrent network, the rules for reading information from memory and writing to it (for RL-MEM), and the action distribution are optimised to make trial and error actions more or less likely based on the amount of reward received. RL thus learns to encode and retrieve information based on trial and error decisions and resultant rewards. This is indirect and inefficient: sensory data can instead be encoded and stored without trial and error in a temporally local manner.

Denoting the reward at time as , the aim is to maximise the sum of rewards that the policy receives up to a final time, known as the return . This is achieved by a “policy gradient” update rule[sutton2000policy] that increases the log probability of the selected actions based on the return (see Methods Sec. 4.4):


In practice, Eq. 1

is implemented by the truncated backpropagation-through-time algorithm

[werbos1990backpropagation, sutskever2013training] over a fixed window defining the number of time steps over which the gradient is calculated[mnih2016asynchronous]. Intuitively, it defines the duration over which the model can assign credit or blame to network dynamics or information storage events leading to success or failure. When is smaller than the typical time scale over which information needs to be stored and retrieved, RL models can struggle to learn at all. Thus, learning the representations to put in memory by end-to-end policy gradient RL only works if the minimal time delay between encoding events and actions is not too long – for example, larger than the window .

Figure 1: Agent Models. a. At time , RL-LSTM receives the environment sensory input composed of the image , self-motion velocity , the previous reward , and the text input . These are sent through an encoder, consisting of several modality-specific encoders to make an embedding vector . This is provided as input to a recurrent LSTM network, , which outputs through a neural network with intermediate hidden layer the action probabilities. An action is sampled and acts on the environment. The whole system is optimised to maximise the sum of future rewards via the policy loss. b. RL-MEM is similar except that the recurrent reads from a memory matrix using several read heads that each produces a key vector that is compared by vector similarity (normalised dot product) to the rows of the memory. The most similar rows are averaged and returned as read vectors that are all concatenated. This read is provided as input to the recurrent network at time and influences the action probabilities at the current time. The recurrent network has an additional output write vector that is inserted to row of the memory at time . The RL-LSTM and RL-MEM architectures also learn a return prediction as a network connected to the policy as in standard A3C[mnih2016asynchronous]. These are suppressed here for simplicity but discussed in Methods Sec. 5.1. c. Sensory input in MERLIN flows first through the MBP, whose recurrent network has produced a prior distribution over the state variable at the previous time step

. The mean and log standard deviation of the Gaussian distribution

are concatenated with the embedding and passed through a network to form an intermediate variable , which is added to the prior to make a Gaussian posterior distribution , from which the state variable is sampled. This is inserted into row of the memory matrix and passed to the recurrent network of the memory-based predictor (MBP). This recurrent network has several read heads each with a key , which is used to find matching items in memory. The state variable is passed as input to the read-only policy and is passed through decoders that produce reconstructed input data (with carets) and the Gluck and Myers[gluck1993hippocampal] return prediction . The MBP is trained based on the VLB objective[kingma2013auto, rezende2014stochastic], consisting of a reconstruction loss and a KL divergence between and . To emphasise the independence of the policy from the MBP, we have blocked the gradient from the policy loss into the MBP.

MERLIN (Fig. 1c) optimises its representations and learns to store information in memory based on a different principle, that of unsupervised prediction[rao1999predictive, bastos2012canonical, hindy2016linking]. MERLIN has two basic components: a memory-based predictor (MBP) and a policy. The MBP is responsible for compressing observations into low-dimensional state representations , which we call state variables, and storing them in memory. The state variables in memory in turn are used by the MBP to make predictions guided by past observations. This is the key thesis driving our development: an agent’s perceptual system should produce compressed representations of the environment; predictive modeling is a good way to build those representations; and the agent’s memory should then store them directly. The policy can primarily be the downstream recipient of those state variables and memory contents.

Machine learning and neuroscience have both engaged with the idea of unsupervised and predictive modeling over several decades[kok2012less]. Recent discussions have proposed such predictive modeling is intertwined with hippocampal memory[hindy2016linking, finkelstein20163], allowing prediction of events using previously stored observations, for example, of previously visited landmarks during navigation or the reward value of previously consumed food. MERLIN is a particular and pragmatic instantiation of this idea that functions to solve challenging partially observable tasks grounded in raw sensory data.

We combine ideas from state estimation and inference[kalman1960new]

with the convenient modern framework for unsupervised modeling given by variational autoencoders

[kingma2013auto, rezende2014stochastic, chung2015recurrent, gemici2017generative] as the basis of the MBP (Fig. 1c). Information from multiple modalities (image , egocentric velocity , previous reward and action , and possibly a text instruction ) constitute the MBP’s observation input and are encoded to . A probability distribution, known as the prior, predicts the next state variable conditioned on a history maintained in memory of the previous state variables and actions: . Another probability distribution, the posterior, corrects this prior based on the new observations to form a better estimate of the state variable: . The posterior samples from a Gaussian distribution a realisation of the state variable , and this selected state variable is provided to the policy and stored in memory. In MERLIN, the policy, which has read-only access to the memory, is the only part of the system that is trained conventionally according to Eq. 1.

The MBP is optimised to function as a “world model”[neisser1967cognitive, barlow1987cerebral]: in particular, to produce predictions that are consistent with the probabilities of observed sensory sequences from the environment: . This objective can be intractable, so the MBP is trained instead to optimise the variational lower bound (VLB) loss, which acts as a tractable surrogate. One term of the VLB is reconstruction of observed input data. To implement this term, several decoder networks take as input, and each one transforms back into the space of a sensory modality ( reconstructs image ; the others are self-motion , text input , previous action , and previous reward

). The difference between the decoder outputs and the ground truth data is the loss term. The VLB also has a term that penalises the difference (KL divergence) between the prior and posterior probability distributions, which ensures that the predictive prior is consistent with the posterior produced after observing new stimuli.

Although it is desirable for the state variables to faithfully represent perceptual data, we still want them to emphasise, when possible, rewarding elements of the environment over and above irrelevant ones. To accomplish this, we follow the hippocampal representation theory of Gluck and Myers[gluck1993hippocampal], who proposed, as an account of diverse phenomena in animal conditioning, that hippocampal representations pass through a compressive bottleneck and then reconstruct input stimuli together with task reward. In our context, is the compressive bottleneck, and we include an additional decoder that makes a prediction of the return as a function of

. Algorithms such as A3C predict task returns and use these predictions to reduce variance in policy gradient learning

[mnih2016asynchronous]. In MERLIN, return prediction also has the essential role of shaping state representations constructed by unsupervised prediction. Including this prediction has an important effect on the performance of the system, encouraging to focus on compressing sensory information while maintaining information of significance to tasks. In the 3D virtual reality environments described subsequently, sensory input to the agent comprises order dimensions, whereas the state variable is reduced to order ; this is achieved without losing information critical to task-related computations.

The utility of the memory system can be further increased by storing each state variable together with a representation of the events that occurred after it in time, which we call retroactive memory updating. In navigation, for example, this allows perceptual information related to a way-point landmark to be stored next to information about a subsequently experienced goal. We implement this by an edit of the memory matrix in which a filtered sum of state variables produced after is concatenated in the same row: , with .

Further details about MERLIN are available in Methods.

Figure 2: Basic Tasks. a. On the Memory Game, MERLIN exhibited qualitatively superior learning performance compared to RL-LSTM and RL-MEM (yellow: average episode score as a function of the number of steps taken by the simulator; hashed line: the cost of the MBP – the negative of the variational lower bound). We include an additional comparison to a Differentiable Neural Computer (DNC)-based agent as well[graves2016hybrid]

, which we did not study further as its computational complexity scales quadratically with time steps if no approximations are made. “Number of Environment Steps” is the total number of forward integration steps computed by the simulator. (The standard error over five trained agents per architecture type was nearly invisibly small.)

b. MERLIN playing the memory game: an omniglot image observation[lake2015human] (highlighted in white) was jointly coded with its location, which was given by the previous action, on the grid. MERLIN focused its memory access on a row in memory storing information encoded when it had flipped the card indicated by the red square. On step 5, MERLIN retrieved information about the historical observation that was most similar to the currently viewed observation. On step 6, MERLIN chose the previously seen card and scored a point. c. A randomly generated layout for a large environment. The Large Environment Task was one of the four navigation tasks on which MERLIN was simultaneously trained. The agent sought a fixed goal (worth 10 points of reward) as many times as possible in 90 seconds, having been teleported to a random start location after each goal attainment. d. Learning curves for the Large Environment showed that MERLIN revisited the goal more than twice as often per episode as comparison agents. “Number of Environment Steps” logged the number of environment interactions that occurred across all four tasks in the set.

We first consider a very simple task that RL agents should easily solve, the children’s game “Memory”, which has even been comparatively studied in rhesus monkeys and humans[washburn2002species]. Here, cards are placed face down on a grid, and one card at a time is flipped over then flipped back. If two sequentially flipped cards show matching pictures, a point is scored, and the cards removed. An agent starts from a state of ignorance and must explore by flipping cards and remembering the observations. Strikingly, RL-LSTM and MEM were unable to solve this task, whereas MERLIN found an optimal strategy (Fig. 2a). A playout by MERLIN is shown in Fig. 2b, in which it is seen that MERLIN read from a memory row storing information about the previously observed matching card one time step before MERLIN’s policy flipped that card.

Figure 3: Analysis. a. MERLIN alone was able to relocate the goal in less time on subsequent visits. b. Each large environment was laid out on a unit spatial grid. From , the allocentric position of the goal was decoded to an average accuracy of units in Manhattan distance (subgrid accuracy per dimension). This accuracy improved with sub-episodes, implying MERLIN integrated information across multiple transits to goal. (Ext. Fig. 1 has an egocentric decoding.) Decoding from feedforward visual convnets (e.g. MERLIN Convnet), from the policy in RL-LSTM, or from in RL-MEM did not yield equivalent localisation accuracy. c. An example trajectory while returning to the goal (orange circle): even when the goal was not visible, MERLIN’s return prediction climbed as it anticipated the goal. d. MERLIN’s return prediction error was low on approach to the goal as the MBP used memory to estimate proximity. e. Task performance varied with different coefficients of the return prediction cost with a flat maximum that balanced sensory prediction against return prediction. f. For higher values of the coefficient, regression using a feedforward network from the return prediction to explained increasingly variance in . Thus, the state variables devoted more effective dimensions to code for future reward. g. Top row: observations at 20, 10, and 5 steps before goal attainment. Middle: the L2 norm of the gradient of the return prediction with respect to each image pixel, ( for colour channel), was used to intensity mask the observation, automatically segmenting pixels containing the goal object. Bottom: a model trained without a gradient from the return prediction through was not similarly sensitive to the goal. (The gradient pathway was unstopped for analysis.) h. The three memory read heads of the MBP specialised to focus on, respectively, memories formed at locations ahead of the agent, at waypoint doorways en route to the goal, and around the goal itself. i. Across a collection of trained agents, we assigned each MBP read head the index 1, 2, or 3 if it read from memories formed far from (blue), midway to (green), or close to the goal (red). The reading strategy in panel 3h with read heads specialised to recall information formed at different distances from the goal developed reliably.

MERLIN also excels at solving one-shot navigation problems from raw sensory input in randomly generated, partially observed 3D environments. We trained agents to be able to perform on any of four variants of a goal-finding navigation task (Ext. Video 1: https://youtu.be/YFx-D4eEs5A). These tested the ability to locate a goal in a novel environment map and quickly return to it. After reaching the goal, the agent was teleported to a random starting point. It was allowed a fixed time on each map and rewarded for each goal visit. To be successful, an agent had to rapidly build an approximate model of the map from which it could navigate back to the goal.

The base task took place in environments with 3-5 rooms. The variations included a task where doors dynamically opened and closed after reaching the goal, inspired by Tolman[tolman1948cognitive]; a task where the skyline was removed to force navigation based on proximal cues; the same task in larger environments with twice the area and a maximum of 7 rooms (Fig. 2c). MERLIN learned faster and reached higher performance than comparison agents and professional human testers (Ext. Table 1; Fig. 2d; Ext. Fig. 2). MERLIN exhibited robust memory-dependent behaviour as it returned to the goal in less time on each repeated visit, having rapidly apprehended the layout of the environment (Fig. 3

a). Within very few goal attainments in each episode in large environments, it was possible to classify the absolute position of the goal to high accuracy from the MBP state (state variable

, memory reads , and recurrent state ) (Fig. 3b), demonstrating that MERLIN quickly formed allocentric representations of the goal location.

Even when a goal was out of view, MERLIN’s return predictions rose in expectation of the oncoming goal (Fig. 3c), and its return prediction error was lower than the analogous value function predictions of RL-LSTM and RL-MEM (Fig. 3d). Agent performance was robust to a range of different weights on the return prediction cost coefficient, but for very low and high values, performance was dramatically affected, as the state variables became insensitive to reward for low values and insensitive to other sensors for high values (Fig. 3e). Decoding the MBP prior distribution’s mean over the next state variable could be used to check visual accuracy across a spectrum of weights on the return prediction cost coefficient; lower values produced cleaner visual images, retaining more information (Ext. Fig. 3). Regressing from the return predictions to showed that the return prediction explained more variance of the state variable for higher return cost coefficients. We also observed the emergent phenomenon that the region of the visual field to which the return prediction was sensitive was a segmentation around the goal (Fig. 3g). An agent trained to predict return but whose prediction errors did not send gradients during training through the state variable did not develop these receptive fields (Fig. 3g).

Remarkably, though it had not been explicitly programmed, MERLIN showed evidence of hierarchical goal-directed behaviour, which we detected from the MBP’s read operations. The three read heads of the MBP specialised to perform three functions. One would read from memories associated with previously visited locations just ahead of the agent’s movement (Fig. 3h, left); a second from memories associated with the goal location (Fig. 3h, right); intriguingly, a third alternated between memories associated with various important sub-goals – particularly doorways near the room containing the goal (Fig. 3h center). Across a group of 5 trained agents, this specialisation of heads attending to memories formed at different distances from the goal emerged robustly (Fig. 3i).

Figure 4: Task Battery. a. Arbitrary Visuomotor Mapping. The agent had to remember the target direction associated with each image. b. Rapid Reward Valuation. The agent had to explore to find good objects and thenceforth consume only those. c. Episodic Water Mazes. The agent experienced one of three water mazes identifiable by wall colour and had to relocate the platform location. d. Executing Transient Instructions. The agent was given a verbal instruction indicating the room colour and object colour to find. e-h. Learning Curves for Corresponding Tasks. The dotted yellow lines represent the negative of the VLB of the MBP. i. In comparison to a human tester, MERLIN exhibited better accuracy as a function of the number of items to recall. On a new set of synthetic images made of constellations of coloured shapes and letters (Ext. Fig. 4), MERLIN retained higher performance. j. The MBP read memories of the same object as the one about to be consumed. k. MERLIN’s return predictions became more accurate after the first experience consuming each object type. l. MERLIN developed an effective exploration strategy (red) and on subsequent visits to each of three rooms exhibited directed paths to remembered platforms (green). m. The MBP variables were sufficient to reconstruct the instruction on test data at diverse times in the episode, even when the instruction was no longer present.

To demonstrate its generality, we applied MERLIN to a battery of additional tasks. The first, “Arbitrary Visuomotor Mapping”, came from the primate visual memory literature[wise2000arbitrary] and demanded learning motor response associations to complex images[wang2016learning]. The agent needed to fixate on a screen and associate each presented image with movement to one of four directions (Fig. 4a; Ext. Video 2: https://youtu.be/IiR_NOomcpk). At first presentation, the answer was indicated with a colour cue but subsequently needed to be recalled. With correct answers, the number of images to remember was gradually increased. MERLIN solved the task essentially without error, reaching performance above human level (Fig. 4e&i). When probed on a set of multi-object synthetic images modeled on a visual working memory task[luck1997capacity] (Ext. Fig. 4), MERLIN generalised immediately, exhibiting accuracy declines at higher memory loads than a human subject (Fig. 4i). This transfer result implied that MERLIN learned the task structure largely independently of the image set. Moreover, MERLIN was able to learn, exclusively through unsupervised mechanisms, to distinguish complex images – even previously unseen ones with different statistics.

In a second task, MERLIN demonstrated the ability to perform rapid reward valuation, a facility subserved by the hippocampus[corbit2000role]. The agent was placed in a room with a selection of eight kinds of objects from a large pool with random textures that it could consume (Fig. 4b; Ext. Video 3: https://youtu.be/dQMKJtLScmk). Each object type was assigned a positive or negative reward value at the beginning of the episode; after clearing all positive reward from the environment, the agent started a second sub-episode with the same objects and had a chance to consume more. MERLIN learned to quickly probe and retain knowledge of object value. When approaching an object, it focused its reads on memories of the same object in preference to memories of others (Fig. 4j), suggesting that it had formed view and background-invariant representations. MERLIN used these reads to make predictions of the upcoming value of previously consumed objects (Fig. 4k). Here, the retroactive memory updates were very effective: the retroactive portion of a memory row written the first time an object was exploratively approached could be used to decode the reward value of the subsequent consumption with accuracy (, 5-fold cross-validation: Methods Sec. 9.7).

We next probed MERLIN’s ability to contextually load episodic memories and retain them for a long duration. In one task, over six minutes MERLIN experienced three differently coloured “water mazes”[morris1984developments] with visual cues on the walls (Fig. 4c) and each one with a randomly-positioned hidden platform as a goal location. MERLIN learned to explore the mazes and to store relevant information about them in memory and retrieve it without interference to relocate the platforms from any starting position (Fig. 4l; Ext. Fig. 6; Ext. Video 4: https://youtu.be/xrYDlTXyC6Q).

MERLIN learned to respond correctly to transiently communicated symbolic commands[pilley2011border, hermann2017grounded]. For the first five steps in an episode, an instruction to retrieve the “colour object from the colour room” was presented to the text encoding network (Fig. 4d; Ext. Video 5: https://youtu.be/04H28-qA3f8). By training feedforward network classifiers with input the recurrent state and memory readouts of the MBP, we found it was possible to decode the instruction text tokens on held-out data at any point in the episode (Fig. 4m), demonstrating persistent activity encoding task-related representations.

Figure 5: Latent Learning. a. Left: Phase 1: the agent, beginning at the purple starred location, had to engage in an apple collecting task in a T-Maze with open doors and a random object in each room, initially valueless. Middle: Phase 2: after collecting the apples in phase 1, the agent engaged in a distractor task to collect apples in a long corridor. Right: Phase 3: a glass box containing one of the same objects cued the agent to find that object in one of the rooms behind a closed door. The agents only had one chance to choose. b. MERLIN was able to learn the task while the end-to-end agents performed at random in phase 3. c. After training, at the beginning of phase 3, the policy’s read could be seen to recall maximally a memory stored in phase 1 while approaching the same object. d

. The MBP reads also contained information about the location of the cued object. A logistic regression classifier could be trained with nearly perfect accuracy to predict the goal arm from the MBP reads on 5-fold cross-validated data (


Finally, we examined a latent learning effect[seward1949experimental, thistlethwaite1951critical]. In a T-Maze environment (Fig. 5a), the agent was first given an apple collecting task (phase 1). Random objects were placed in the left and right maze arm rooms, which were occluded from outside the rooms. After collecting the apples in phase 1, the agent engaged in an unrelated distractor task of further apple collecting in a long corridor (phase 2). In phase 3, the agent was teleported to the beginning of the T-Maze where a glass box contained one of the two objects in the arm rooms. With the doors to the rooms now closed and only one chance to choose a room, MERLIN learned to find the object corresponding to the container cue (Ext. Video 6: https://youtu.be/3iA19h0Vvq0), demonstrating that it could predicate decisions on observational information acquired before the stimuli were associated to the probe task (Fig. 5b), even though the phase 2 delay period was longer than the backpropagation-through-time window . On observation of the container, the read-only policy could recall the observation of the object made during phase 1 (Fig. 5c). The MBP reads could also be used at the same time to decode the correct arm containing the cued object (Fig. 5d).

As we have intimated, the architecture in Fig. 1c is not the unique instantiation of the principle of predictive modeling with memory in an agent. For example, the gradient block between the policy and the MBP is not necessary (Ext. Fig. 7); we have only included it to demonstrate the independence of these modules from each other. Similarly, it is possible to attach the policy as an extra feedforward network that receives , , and to the MBP. This worked comparably in terms of performance (Ext. Fig. 8) but obscures how the system works conceptually since the independence of the MBP and policy is no longer explicit. On the other hand, lesions to the architecture, in particular removing the memory, the sensory or return prediction, or the retroactive memory updating, were harmful to performance or data efficiency (Ext. Fig. 5).

Across all tasks, MERLIN performed qualitatively better than comparison models, which often entirely failed to store or recall information about the environment (Figs. 2a&c; 4e-h; Ext. Figs. 2&9). On a few tasks, by increasing to ten times the duration used by MERLIN (for MERLIN : s at 15 frames per second), we were able to improve the performance of RL-MEM. However, this performance was always less than MERLIN’s and was achieved after much more training (Ext. Figs. 10-11). For the episodic Water Mazes task, extending was however unable to improve performance at all (Ext. Fig. 10), just as for the Memory Game where was the duration of the whole episode. Memory systems trained end-to-end for performance (Eq. 1) were less able to learn to write essential information into memory.

This problem with end-to-end learning will only become more pressing as AI approaches the frontier of long-lived agents with long-term memories. For example, consider an agent that needs to remember events that occurred 24 hours ago. Stored memories could be retrieved by a memory reading operation, but methods for optimising network dynamics or information storage over that interval, like backpropagation-through-time, require keeping an exact record of network states over 24 hours. This is a stipulation that is practically prohibitive, and its neural implausibility suggests, at the very least, that there are better algorithmic solutions than end-to-end gradient computation for memory and other aspects of temporal credit assignment. We note that MERLIN exclusively used a window of s to solve tasks requiring the use of memory over much longer intervals (i.e., 0.36% of the length of the longest task at 6 minutes).

While end-to-end RL with a sufficiently large network and enough experience and optimisation should theoretically learn to store relevant information in memory to call upon for later decisions, we have shown the actual requirements are prohibitive; as we have long known in domains such as object recognition and vision, architectural innovations (e.g. convolutional neural networks

[lecun1990handwritten]) are critical for practical systems. Although the implementation details will likely change, we believe that the combined use of memory and predictive modeling will prove essential to future large-scale agent models in AI and even in neuroscience[eliasmith2012large].

References and Notes

Video Links


We thank David Silver, Larry Abbott, and Charles Blundell for helpful comments on the manuscript; Daan Wierstra, Neil Rabinowitz, Ari Morcos, Nicolas Heess, Alex Graves, Dharshan Kumaran, Raia Hadsell, Brian Zhang, Oriol Vinyals, and Hubert Soyer for discussions; Amir Sadik and Sarah York for environment testing; Stephen Gaffney and Helen King for organisational help.

Author Contributions

G.W. and T.L. conceived the project. G.W., C.C.H., D.A., M.M., A.A., A.S., M.G., S.M., D.R., and T.L. developed the model. C.C.H., D.A., G.W., M.M., A.A., and J.A. performed experiments. D.A., J.R., M.R., T.H., and D.Saxton built engineering infrastructure. C.C.H., G.W., D.A., M.M., A.A., A.G.B., and P.M. performed analyses. G.W., C.C.H., A.A., J.Z.L., and T.L. designed tasks. A.C. prepared figure artwork. K.K., M.B., D.Silver, and D.H. consulted for project. C.H. helped with team coordination. G.W., C.C.H., D.A., M.M., A.A., A.G.B., P.M., and T.L. wrote the manuscript.


Correspondence should be addressed to Greg Wayne, Chia-Chun Hung, or Timothy Lillicrap (email: gregwayne, aldenhung, countzero@google.com).


1 Environment Software

All 3D environment experiments were implemented in DeepMind Lab (or DM Lab) [beattie2016deepmind]. For all experiments in this framework, we used frame rates of 60 frames per second but propagated only the first observation of each sequence of four to the networks. Rewards accumulated over each packet of four frames were summed together and associated with the first, undropped frame. Similarly, the agent chose one action at the beginning of this packet of four frames, and the action was repeated for all four steps. Accordingly, the number of “Environment Steps” reported in this manuscript is times the number of “Agent Steps”.

We used a consistent action set for all experiments except for the Arbitrary Visuomotor Mapping task. For all other tasks, we used a set of six actions: move forward, move backward, rotate left with rotation rate of 30 (mapping to an angular acceleration parameter in DM Lab), rotate right with rotation rate of 30, move forward and turn left, move forward and turn right. For the Arbitrary Visuomotor Mapping, we did not need to move relative to the screen, but we instead needed to move the viewing angle of the agent. We thus used four actions: look up, look down, look left, look right (with rotation rate parameter of 10).

2 Model

At a high level, the model consists of the memory-based predictor and the policy. The memory-based predictor contains within it several encoders and decoders and represents two distributions over the state variable: the prior, which predicts the distribution of the next state variable, and the posterior, which represents the distribution after the next observation. The memory-based predictor contains an external memory that stores a history of state variables. The policy takes the state variable as input and also reads from the external memory.

We now describe the model in detail by defining its parts list and the loss functions used to optimise it. Parameters given per task are defined in Table 


2.1 Encoder

The encoder is composed of five sub-networks: the image encoder, the velocity encoder (in all DM Lab experiments), the action encoder, the reward encoder, and the text encoder. These act independently on the different elements contained within the input set , and their outputs are concatenated into the flat vector .

2.1.1 Image Encoder

For the DM Lab tasks, we use an image encoder that takes in image tensors of size

(3 channel RGB). We then apply 6 ResNet [he2016deep]

blocks with rectified linear activation functions. All blocks have 64 output channels and bottleneck channel sizes of 32. The strides for the 6 blocks are

, resulting in 8-fold spatial down-sampling of the original image. Therefore, the ResNet module outputs tensors of size . We do not use batch-norm [ioffe2015batch], a pre-activation function on inputs, or a final activation function on the outputs. Finally, the output of the ResNet is flattened (into a element vector) and then propagated through one final linear layer that reduces the size to 500 dimensions, whereupon a nonlinearity is applied.

For the task “Memory”, we use the same architecture as for the DM Lab task save that the input image tensor size is (grey-scale). This produces ResNet module output tensors of size .

2.1.2 Velocity Encoder

The velocity encoder takes a 6-dimensional vector that comprises the translational and rotational velocities of the agent. It is calculated from measured differences between the current time step and the previous step (making it egocentric) but is zero across any agent teleportation. Though there is no physical distance metric in the DM Lab environment, the actual numerical velocities produced by the simulator are large, so we scale all numbers by dividing by 1000 before passing the data through the encoder. The encoder then applies a linear layer with an output dimensionality of 10.

2.1.3 Action Encoder

In all environments, the action from the previous time step is a one-hot binary vector (6-dimensional for most DM Lab tasks) with . We use an identity encoder that leaves the action one-hot unchanged.

2.1.4 Reward Encoder

The reward from the previous time step is represented as a scalar magnitude and is not processed further.

2.1.5 Text Encoder

To interpret text instructions, a small, single layer LSTM that processes the word tokens (vocabulary size 1000) in the input by first passing through a linear embedding of size 50 that is then input into an LSTM of width 100. The terminal hidden state is treated as the output.

2.2 Decoder

The decoder is composed of six sub-networks. Five of these sub-networks are duals (every layer has dimensions transposed from the corresponding encoder layer) of the encoder networks. The additional sub-network decodes a prediction of the return.

2.2.1 Image Decoder

The image decoder has the same architecture as the encoder except the operations are reversed. In particular, all 2D convolutional layers are replaced with transposed convolutions [dumoulin2016guide]. Additionally, the last layer produces a number of output channels that is formatted to the likelihood function used for the image reconstruction loss, described in more detail in Eq. 5.

2.2.2 Return Prediction Decoder and Value Function

The return prediction decoder is the most complicated non-visual decoder. It is composed of two networks. The first is a value function MLP that takes in the concatenation of the latent variable with the policy distribution’s multinomial logits

. This vector is propagated through a single hidden layer of hidden units and a activation function, which then projects via another linear layer to a 1-dimensional output. This function can be considered a state-value function .

A second network acts as a state-action advantage function, taking in the concatenation and propagating through an MLP with two hidden layers each of size 50 and nonlinearities: . This quantity is then added to the value function to produce a return prediction . When we calculate loss gradients, we do not allow gradients to flow back through the value function from the return prediction cost, as this changes the effective weight of the gradient on the value function term, which has its own loss function. The return prediction can also be considered a state-action value function for the current policy (a Q-function).

2.2.3 Text Decoder

We decode the multi-word instruction sequentially with a single layer LSTM of the same size as the corresponding encoder whose input is and whose output is a 1,000-way softmax over each word token in the sequence.

2.2.4 Other Decoders

The reward, velocity, and action decoders are all linear layers from to, respectively, dimension, dimensions, and the action cardinality.

2.3 Memory

Our memory system is primarily based on a simplification of the Differentiable Neural Computer (DNC) [graves2016hybrid]. The memory itself is a two-dimensional matrix of size , where is the dimensionality of the latent state vector. The memory at the beginning of each episode is initialised blank, namely .

2.3.1 Deep LSTMs

In both the policy and memory-based predictor, we use a deep LSTM [graves2013speech] of two hidden layers. Although the deep LSTM model has been described before, we describe it here for completeness. Denote the input to the network at time step as . Within a layer , there is a recurrent state and a “cell” state , which are updated based on the following recursion (with ):

In both uses of the deep LSTM (policy and MBP), to produce a complete output , we concatenate the output vectors from each layer: . These are passed out for downstream processing.

2.3.2 Mbp Lstm

At each time step , the recurrent network of the MBP receives concatenated input , where is represented by a one-hot code. The policy network receives input . For the MBP LSTM, the input is concatenated with a list of vectors read from the memory at the previous time step . The input to the MBP LSTM is therefore . The deep LSTM equations are applied, and the output is produced. A linear layer is applied to the output, and a memory interface vector is constructed of dimension , where is the dimensionality of the latent vector. is then segmented into read key vectors of length and scalars , which are passed through the function to create the scalars .

2.3.3 Reading

Memory reading is executed before memory writing. Reading is content-based, and proceeds by first computing the cosine similarity between each read key and each memory row

: . For each read key, a normalised weighting vector of length is then computed:

For that key, the readout from memory is . These readouts are concatenated together with the state of the deep LSTM and output from the module.

2.3.4 Writing

After reading, writing to memory occurs, which we also define using weighting vectors. The write weighting has length and always appends information to the -th row of the memory matrix at time , i.e., (using the Kronecker delta). A second weighting for retroactive memory updates forms a filter of the write weighting

where is the same as the discount factor for the task. Given these two weightings, the memory update can be written as an online update


where is the zero-vector of length . Each of these weightings is initialised so that .

In case the number of memory rows is less than the episode length, overwriting of rows is necessary. To implement this, each row contains a usage indicator: . This indicator is initialised to until the row is first written to. Subsequently, the row’s usage is increased if the row is read from by any of the reading heads . When allocating a new row for writing, the row with smallest usage is chosen.

2.4 Prior Distribution

The prior distribution is produced by an MLP that takes in the output from the MBP LSTM at the previous time step and passes it through two hidden layers with activation functions of width . There is a final linear layer that produces a diagonal Gaussian distribution for the current time step , where both the mean and log-standard deviation are of size .

2.5 Posterior Distribution

The posterior distribution is produced in two stages. First, the outputs of the encoded features, the outputs of the MBP LSTM, and the prior distribution parameters are concatenated into one large vector

This concatenated vector is then propagated through an MLP with two hidden layers of size and activation functions, followed by a single linear layer that produces an output of size . This MLP’s function is added to the prior distribution to determine the posterior distribution: .

2.6 State Variable Generation

After the posterior distribution is computed, the state variable is sampled as , where and ‘’ represents element-wise multiplication.

2.7 Policy

The operation of the policy is similar to that of the MBP LSTM. At time step , before the MBP LSTM operates, the policy receives . A deep LSTM that can also read from the memory in the same way as the MBP executes one cycle, but using only one read key, giving outputs . These outputs are then concatenated again with the latent variable and passed through a single hidden-layer MLP with 200 units. This then projects to the logits of a multinomial softmax with the dimensionality of the action space, which varies per environment (4-dimensional for the Arbitrary Visuomotor Mapping task, 6-dimensional for the rest of DM Lab, and 16-dimensional for Memory). The action is sampled and passed to the MBP LSTM as an additional input, as described above.

3 Derivation of the Variational Lower Bound

The log marginal likelihood of a probabilistic generative model can be lower-bounded by an approximating posterior as a consequence of Jensen’s inequality: . For convenience, we define and (the empty set). Then for a temporal model that factorises as and approximate posterior , this becomes

Suppose now that we partition the target variables into two sets and . In the stationary case, we can still form the inequality , which does not condition the approximate posterior on one of the variables. Likewise, in the temporal case, we have

This is the form we use to justify a loss function that combines prediction of incrementally observable information (image, reward, etc.) with information that is only known with some delay (the sum of future rewards) and therefore cannot be conditioned on in a filtering system. Finally, we can additionally condition the prior model on other variables such as actions , giving


4 Cost Functions for the MBP and Policy

The parameters of the policy are entirely independent of the parameters of the memory-based predictor; they are not updated based on gradients from the same loss functions. This is implemented via a gradient stop between the policy and the state variable .

The memory-based predictor has a loss function based on the variational lower bound in Eq. 3 with specific architectural choices for the output model (the decoders alongside the likelihood functions for each prediction) and prior and posterior distributions:


4.1 Conditional Log-Likelihood

The conditional log-likelihood term is factorised into independent loss terms associated with each decoder and is conditioned on a sample from the approximate posterior network, thus giving a stochastic approximation to the expectation in the variational lower bound objective. We use a multinomial softmax cross-entropy loss for the action, mean-squared error (Gaussian with fixed variance of 1) losses for the velocity (if relevant), reward, and return information, a Bernoulli cross-entropy loss for each pixel channel of the image, and a multinomial cross-entropy for each word token. Thus, we have a negative conditional log-likelihood loss contribution at each time step of



In the text loss, indexes the word in the sequence (up to words in a string), and is the vocabulary size. Constructing the target return value requires some subtlety. For environments with long episodes of length , we use “truncation windows” [mnih2016asynchronous] in which the time axis is subdivided into segments of length . If the window around time index ends at time index , the return within the window is


4.2 Kullback-Leibler Divergence

The Kullback-Leibler Divergence term in Eq. 

4 is computed as the analytical KL Divergence between the two diagonal Gaussian distributions specified by the posterior and the prior networks:

4.3 Practical Details

The contribution of each time step to the loss function in Eq. 4 is


As a measure to reduce the magnitude of the gradients, the total loss that is applied to the memory-based predictor is divided by the number of pixel-channels .

4.4 Policy Gradient

The policy gradient computation is slightly different from the one described in the main text (Eq. 1). Instead, we use discount and bootstrapping parameters and , respectively, as part of the policy advantage calculation given by the Generalised Advantage Estimation algorithm [schulman2015high]. Defining , Generalised Advantage Estimation makes an update of the form:


There is an additional loss term that increases the entropy of the policy’s action distribution. This and pseudocode for all of MERLIN’s updates are provided in Alg. 1.

5 Comparison Models

5.1 Rl-Lstm

This model shares the same encoder networks as MERLIN, acting on its input to produce the same vector . This is then passed as input to a deep LSTM that is the same as the deep policy LSTM in MERLIN. The deep policy LSTM has two output “heads”, which are linear outputs from the LSTM state , as in A3C [mnih2016asynchronous]: one for the value function baseline (return prediction) and one for the action distribution. Unlike the optimisation prescription of A3C, the policy head is trained using Eq. 8, and the value head is trained by return prediction.

5.2 Rl-Dnc

This is the same as the RL-LSTM except that the deep LSTM is replaced by a Differentiable Neural Computer [graves2016hybrid]. This model has a component with quadratic computational complexity as the memory size is scaled, so it is only applied to the memory game where the episodes are short.

5.3 Rl-Mem

In this model, the deep LSTM of the policy outputs an additional write vector with the same size as the state variable in the corresponding MERLIN experiment. This is directly stored in the memory, just as the state variable for MERLIN is. There is no retroactive memory updating. Reading from memory by the policy works the same way memory reading is implemented in MERLIN’s policy.

6 Implementation and Optimisation

For optimisation, we used truncated backpropagation through time [sutskever2013training]. We ran 192 parallel worker threads that each ran an episode on an environment and calculated gradients for learning. Each gradient was calculated after one truncation window , which for DM Lab experiments was less than the duration of an episode. For reinforcement learning, after every truncation window, we also “bootstrapped” the return value, as described in Eq. 6. The gradient computed by each worker was sent to a “parameter server” that asynchronously ran an optimisation step with each incoming gradient. The memory-based predictor and policy were optimised using two separate ADAM optimisers [kingma2014adam] with independent learning rates and .

The pseudocode for each MERLIN worker is presented in Alg. 1. For all experiments, we used the open source package Sonnet (https://github.com/deepmind/sonnet) along with its defaults for parameter initialisation.

// Assume global shared parameter vectors for the policy network and for the memory-based predictor; global shared counter
// Assume thread-specific parameter vectors
// Assume discount factor and bootstrapping parameter
Initialize thread step counter
     Synchronize thread-specific parameters
     Zero model’s memory & recurrent state if new episode begins
         Policy network update
         Policy distribution
         Update memory with by Methods Eq. 2
         Apply to environment and receive reward and observation
     until environment termination or
     If not terminated, run additional step to compute
     and set // (but don’t increment counters)
     Reset performance accumulators
     for  from down to  do
          (Eq. 7)
          (Entropy loss)
     end for
     Asynchronously update via gradient ascent using and using
Algorithm 1 MERLIN Worker Pseudocode

7 Tasks

For all learning curves, each model was independently trained 5 times with different random seed initialisation. The learning curves are reported with standard errors across training runs, visualised as the shaded areas on the curves.

7.1 Memory Game

At each episode, eight pairs of cards are chosen from the Omniglot image dataset [lake2015human]. They are randomly placed on a grid. The agent is given an initial blank observation, and at each turn the agent chooses a grid location for its action. If matching cards are selected on consecutive turns, a reward of 1 is given to the agent. In total, the agent is given 24 moves to clear the board, which is the maximum number of turns that would be needed by an optimal agent. To make the problem more challenging from the perspective of perception, we apply a random affine transformation to the image on the card each time it is viewed (a rotation uniformly between -0.2 and 0.2 radians, a translation on each axis of up to 2 pixels, and a magnification up to 1.15 times the size of the image). When the entire board is cleared, an additional bonus point is awarded at each timestep. If a previously cleared location is selected, the observation is shown as a blank.

7.2 Navigation Tasks

These constitute a set of four tasks on which single agents are simultaneously trained. Each worker thread is assigned one of the four tasks to operate on using the formula .

Goal-Finding: The Goal-Finding task tests if an agent can combine proximal and distal cues in order to find a goal in an environment with several rooms. Distal cues are provided by a fixed skyline of skyscrapers, and proximal cues are provided by wall and floor patterns and paintings on walls. At each episode, from 3 to 5 rooms are randomly placed down on cells in an arena of size units squared. (The agent moves with a much finer, nearly continuous scale.) Once the rooms are constructed, random corridors with distinct floor and wall patterning are built that connect some of the rooms. A single goal object is placed at random in one of the rooms. When the agent is first teleported in a random room, it has 90 seconds to find the goal, which is worth 10 points. On arriving at the goal, the agent will be teleported at random in one of the rooms, and must find the goal again. Returning to the goal faster on subsequent teleportations is evidence that the goal location has been memorised by the agent.

No Skyline: The second task involves eliminating the distal cues provided by the skyline, demanding that the agent navigate only by means of the proximal cues. Otherwise, the task is the same as the Goal-Finding task.

Large Environment: In this task, the environment arena is procedurally generated on a units squared grid with up to 7 rooms.

Dynamic Doors: Entrances to rooms are barricaded by doors that are opened and closed randomly at every teleporation event. This task tests if agents are able to find robust navigation strategies involving replanning trajectories around obstacles.

7.3 Arbitrary Visuomotor Mapping

The task setting emulates a human or monkey psychophysics experiment. The agent views a screen and can tilt its view continuously in any of the four cardinal directions. On screen, stimuli from a human visual memory capacity experiment [brady2008visual] are shown to the agent in an experimental procedure known as an “arbitrary visuomotor mapping” experiment [wise2000arbitrary]. At first presentation of a stimulus, one of four targets in the cardinal directions lights up green, indicating the correct direction to move the agent’s gaze. If the agent moves its gaze to the correct target, it is given a reward of points, and it receives points for the wrong choice. When the agent is shown the same image subsequently, the targets are black before answering, and the agent must remember the appropriate choice direction. When it does again answer, the black targets briefly flash green and red to indicate the correct and incorrect targets respectively. Initially, there is only images in the set of possible query images for a trial, but when the agent answers times consecutively with the correct answer, a new image is added to the pool of images that must be remembered. A block of experimental trials runs continuously in a self-paced manner for 90 seconds.

7.4 Rapid Reward Valuation

The task takes place in an open arena of units squared in area. Each trial lasts for 90 seconds and is composed of sub-episodes in which eight distinct objects (randomly assembled from a set of 16 base objects and a large number of textures) are randomly placed in the arena. These eight objects are randomly assigned values of or points, and we mandate that there be at least two positively and two negatively valued objects. Throughout the episode, the object value is fixed, but once all the positively valued objects have been collected, a new sub-episode begins with the same eight objects re-appear in permuted locations, allowing the agent to run up a high score. To solve this task, the agent must explore to identify good and bad objects, bind values with reasonably view-invariant representations of object identity, and use its memory to inform its decisions.

7.5 Episodic Water Mazes

This is a version of the Morris water maze experiment [morris1984developments] with the additional wrinkle that the agent must remember three mazes at a time, each distinguishable by wall colouring. In each water maze, an invisible platform is randomly placed within the arena, and the circular perimeter of the arena has fixed paintings to provide landmark cues. When the agent reaches the platform, the platform elevates automatically, and five consecutive rewards of 1 point are delivered for maintaining position on it. After this, the agent is teleported randomly in one of the three water mazes, continuing its activity for 360 seconds per episode.

7.6 Executing Transient Instructions

This task primarily demonstrates the MERLIN agent’s versatility in a more abstract memory task that models the kind of instruction-following demonstrated by work dogs [pilley2011border]. It is a memory-dependent variant of an instruction-following task first presented by Hermann et al. [hermann2017grounded]. The map contains two rooms of different colours, separated by two corridors. Two coloured objects are placed in each room, and the agent is provided with an instruction in the format “[object colour, room colour]” for the first five time steps of the episode. The episode ends when the agent first collects an object, providing a reward of for the correct object and for the incorrect object.

To interpret the instruction, the word tokens are sequentially processed by the LSTM encoder network. Its terminal hidden state is concatenated to the embedding vector . Likewise, we decode the instruction sequentially with a single layer LSTM from (though