Working Memory Graphs

by   Ricky Loynd, et al.

Transformers have increasingly outperformed gated RNNs in obtaining new state-of-the-art results on supervised tasks involving text sequences. Inspired by this trend, we study the question of how Transformer-based models can improve the performance of sequential decision-making agents. We present the Working Memory Graph (WMG), an agent that employs multi-head self-attention to reason over a dynamic set of vectors representing observed and recurrent state. We evaluate WMG in two partially observable environments, one that requires complex reasoning over past observations, and another that features factored observations. We find that WMG significantly outperforms gated RNNs on these tasks, supporting the hypothesis that WMG's inductive bias in favor of learning and leveraging factored representations can dramatically boost sample efficiency in environments featuring such structure.


page 2

page 4

page 7


Stabilizing Transformers for Reinforcement Learning

Owing to their ability to both effectively integrate information over lo...

Deep Transformer Q-Networks for Partially Observable Reinforcement Learning

Real-world reinforcement learning tasks often involve some form of parti...

GaAN: Gated Attention Networks for Learning on Large and Spatiotemporal Graphs

We propose a new network architecture, Gated Attention Networks (GaAN), ...

Universal Transformers

Self-attentive feed-forward sequence models have been shown to achieve i...

Transformers are Meta-Reinforcement Learners

The transformer architecture and variants presented remarkable success a...

Sparsely Changing Latent States for Prediction and Planning in Partially Observable Domains

A common approach to prediction and planning in partially observable dom...

Hierarchical RNNs-Based Transformers MADDPG for Mixed Cooperative-Competitive Environments

At present, attention mechanism has been widely applied to the fields of...

1 Introduction

Because of their ability to process sequences of data, gated Recurrent Neural Networks (RNNs) have been widely applied to natural language processing (NLP) tasks such as machine translation. In the RNN-based approach of

Sutskever et al. (2014), an encoder RNN maps an input sentence to a series of internal hidden state vectors. The encoder’s final hidden state is copied into a decoder RNN, which then generates another sequence of hidden states that determine the selection of output tokens in the target language. This model can be trained to translate sentences, but translation quality deteriorates on long sentences where long-term dependencies become critical. Reasoning that this drop in performance is due to the limited representational capacity of an RNN’s hidden state vector, Bahdanau et al. (2014) boosted translation quality by applying an attention mechanism to create paths serving as shortcuts from the input to the output sequences, routing information outside the linear chain of the RNN’s hidden states. Similar attention mechanisms have since gained wide usage, culminating in the Transformer model (Vaswani et al., 2017) which replaces the RNN with many short paths of self-attention. Since then, Transformers have outperformed RNNs on many NLP tasks (Devlin et al., 2018).

We seek to leverage these intuitions to improve the ability of Reinforcement Learning (RL) agents to reason over long time horizons in Partially Observable Markov Decision Processes (POMDPs) 

(Kaelbling et al., 1998). In a POMDP, a single observation is not sufficient to identify the latent environment state . Thus the agent must reason over the history of past observations in order to select the best action for the current step. A simple strategy employed by DQN (Mnih et al., 2015) is to condition the policy on the most recent observations . But in complex environments, the sufficient number may be large, highly variable, and unknown. To address this issue, gated RNNs such as LSTMs (Hochreiter and Schmidhuber, 1997) and GRUs (Chung et al., 2015) use internal, recurrent state vectors which in theory can maintain information from past observations (Hausknecht and Stone, 2015; Oh et al., 2016). However, in practice, these methods are limited by the single path of information flow defined by the linear chain of RNN hidden states. As in NLP, we hypothesize that providing alternative paths for information will be advantageous to RL agents. Building on this intuition, we introduce the Working Memory Graph (WMG), a Transformer-based agent that uses self-attention to provide a multitude of shortcut paths for information to flow between the past observations and the current action.

In addition to providing many paths for information, Transformers are also well suited for handling variably-sized inputs such as words in a sentence. Although most Reinforcement Learning environments provide fixed-sized feature spaces, certain environments have observations spaces amenable to factorization. As a motivating example, consider the BabyAI environment  (Chevalier-Boisvert et al., 2018) depicted in Figure 1 (left). The native observation space is the agent’s field of view, a 7x7 grid, shown in lighter grey. This observation can be efficiently represented by a set of factors describing the types, colors, and relative and coordinates of all visible objects:

([green, key, 1, 3], [grey, box, 2, 1], [green, ball, 2, 2], [red, key, 3, 0])

This factored observation is more compact than the native observation, but will vary in size depending on the number of objects in view. Motivated by prior work on factored representations (Russell and Norvig, 2009) and factored MDPs (Boutilier et al., 2000, 2001), we explore the idea of encoding factored observations as input to Transformer-based agents. In particular, we compare how factored observations affect the learning speed of Transformer and RNN-based agents.

Instruction: Pick up the green key

Figure 1: Left: BabyAI rewards the agent (red triangle) for performing the task given by the instruction. Right: Our Transformer-based model (WMG) leverages factored observations and shortcut recurrence to solve this partially observable BabyAI task (PickupLoc) in far fewer environment interactions than a GRU-based agent using the same factored observations.

Our contributions are twofold: First we introduce the Working Memory Graph (WMG), a Transformer-based agent implementing a novel form of shortcut recurrence which we demonstrate to be effective at complex reasoning over long-term dependencies. Second, we identify the synergy between Transformer-based RL architectures and factored observations, demonstrating that by virtue of its Transformer-style self-attention, WMG is able to effectively leverage factored observations to learn high-performing policies given an order of magnitude fewer environment interactions than alternative architectures. To preview our findings, Figure 1 (right) shows an example of the dramatic boosts in sample efficiency obtained through the combination of shortcut recurrence and Transformer-based processing of factored observations.

2 Working Memory Graph

Broadly, WMG incorporates an inductive bias in favor of learning and leveraging factored representations, including both observed and unobserved (latent) factors. Observed factors are represented by multiple input vectors called percepts. Latent factors are represented by multiple recurrent vectors called concepts. Instead of handling long-range dependencies over time by applying self-attention to a long history of observations, for which the quadratic computational cost could be prohibitively expensive, WMG relies on its much more limited set of concepts to represent long-range dependencies. The term Working Memory Graph is motivated by the relatively limited size of WMG’s self-attention computation graph, in loose analogy with the cognitive science term working memory, which refers to a cognitive system that holds a limited amount of information for use in mental processing. (Miller, 1956)

WMG introduces shortcut recurrence, which replaces a gated RNN’s single path of information flow with a network of shorter self-attention paths. As illustrated in Figure 2 (right), WMG’s shortcut recurrence applies multi-head self-attention to a dynamic set of hidden state vectors, the aforementioned concepts, to simultaneously represent multiple latent factors or aspects of partially observable environments. Formally, each concept vector defines one row in a concept matrix , where is the number of concepts maintained by WMG and is the dimension of each concept vector. On each time step, the oldest concept is replaced by a new one.

Figure 2: Schematic diagrams of two RL agent architectures taking input from a factored observation. Left: A gated RNN constrains information to flow through a linear path of fixed-length vectors. (G denotes the internal vector operations of a GRU or LSTM.) Right: WMG allows information to flow through many self-attention paths among a network of vectors. (W denotes WMG’s embedding and multi-layer Transformer operations.) WMG replaces the RNN’s recurrent hidden state with a recurrent network of concept vectors, and replaces the RNN’s single observation input vector with a network of vectors including percepts mapped to observation factors.

WMG applies self-attention to observations by introducing multiple observation input vectors, percepts, as depicted in Figure 2 (right). In our experiments, a single percept encodes either an entire observation from a window of recent observations, or one factor (such as a green key in BabyAI) of a factored observation. On each time step, WMG receives a formatted observation consisting of a variable number of () percept vectors forming a percept matrix , and a core vector that contains any other observation information (such as any non-factored portions of the current observation). The core, percept and concept vectors are stacked into one matrix for input to WMG’s Transformer operation:

where , and

are embedding matrices with corresponding bias vectors

broadcast over rows, and each concept is concatenated with a one-hot age vector. Closely following the encoder architecture of Vaswani et al. (2017), WMG’s Transformer operation takes the input matrix and returns an output matrix , where is the number of input (or output) nodes, and is the size of each node vector. The oldest concept is replaced by a new concept vector generated as a non-linear function of the core node’s output vector :

(a) Gated RNN unrolled in time.
(b) WMG unrolled in time.
Figure 3: Shortcut Recurrence: In a Gated RNN, for the first observation to affect the agent’s output , information must pass through 8 gating operations and 7 intervening hidden states a-g. In contrast, in a WMG, many possible paths lead from the first observation to the output . The highlighted path requires only three passes through W, and is stored unchanged for several time steps in concept vectors a and d. This example illustrates how WMG’s concept nodes provide shorter paths for information to flow forward and gradients to flow backward.

The trainable parameters

of WMG and its Transformer layers are trained end-to-end through backpropagation of a policy-gradient loss maximizing the cumulative expected return


where denotes WMG’s policy head operating on hidden state , is the entropy of the policy’s action distribution, and

controls the strength of the entropy regularization term. To reduce the variance of gradient estimates, we use the advantage actor-critic algorithm described by

Mnih et al. (2016), which estimates the advantage using a -discounted -step return as follows:

where denotes WMG’s state-value head, which is trained to minimize the squared difference between the -step return and the current value estimate: , and k is upper-bounded by the number of time steps () in the actor’s current update window.

To summarize WMG’s operation, Figure 3 compares the flow of information through a gated RNN and through WMG, illustrating how WMG’s concept vectors latch information unchanged for multiple time steps to create shorter paths for information flow in both the forward and backward passes.

3 Related Approaches

Having explained how WMG operates, we distinguish it from related work: Prior approaches have used attention for memory access (Graves et al., 2016; Oh et al., 2016) or self-attention to process individual observations (Zambaldi et al., 2019; Vinyals et al., 2019). These approaches all used LSTM-based recurrence over time. In contrast, WMG obviates the need for gated recurrence by applying self-attention to a network of concept vectors which are persisted through time.

Other Transformer-based models handle partial observability using state vectors analogous to WMG’s concepts, but with different state-update schedules: RMC (Santoro et al., 2018) updates all state vectors on every time step, while RIMs (Goyal et al., 2019) enforces sparsity by updating exactly half of the state vectors (called RIMs) on each step. WMG replaces only one concept on each time step in order to maximize the persistence of latched concept vectors and thereby extend the reach of the shortcut paths that they create from inputs to outputs. And unlike WMG, RMC and RIMs use gated RNNs to update their state vectors.

Unlike the other models discussed here, the Gated Transformer-XL (Anonymous, 2019) addresses partial observability by feeding hundreds of past observations at once into the Transformer. By contrast, in order to mitigate the computational cost of self-attention, WMG computes self-attention over a comparatively small number of concepts which capture and maintain the relevant aspects of past observations.

4 Experiments

In our experiments, we aim to (1) evaluate WMG’s ability to reason over long time spans in a setting of high partial observability, and (2) understand how factored representations may be effectively utilized by WMG. To address these questions we present results on two environments: a novel Pathfinding task which requires complex reasoning over past observations, and the BabyAI domain  (Chevalier-Boisvert et al., 2018) which involves changing goals, partial observability, and observations that can be readily factored. To foreshadow our results, the Pathfinding task demonstrates the effectiveness of WMG’s shortcut recurrence, and BabyAI demonstrates that WMG leverages factored observations to deliver very large gains in sample efficiency.

4.1 Pathfinding Task

The Pathfinding task is designed to evaluate WMG’s ability to perform complex reasoning over past observations. Figure 4 depicts the incremental construction of a directed graph over nodes identified by unique pattern vectors which are randomly generated on every episode. (See Appendix A

for the graph-construction algorithm and other details.) On odd time steps the agent observes two pattern nodes to be linked, and on even steps the agent must indicate whether or not a directed path exists from one given pattern to another. As this cycle repeats, the graph grows larger and the agent must perform an increasing number of reasoning steps to confirm or deny the existence of a path between arbitrary nodes. Because the observation only contains incremental information and the entirety of the graph is never directly observed, the agent must leverage information from previous observations to infer connectivity between nodes.

Figure 4: Left: One episode of the Pathfinding task, in this work, consists of 12 time steps with a maximum graph size patterns. The boxes with rounded corners illustrate the observations for the given time steps, where a question mark identifies the step as a quiz step rather than a construction step. The box colors represent distinct pattern vectors. Right: Results on Pathfinding

: Each plotted point is the percentage of reward on quiz steps received by the agent over the previous 10k time steps, averaged over 100 independent training runs. Bands display one standard deviation. (See Table 

9 for more details.)

For example, consider step 4 of Figure 4: To determine whether a path exists from green to yellow, the agent must recall and combine information from steps 1 and 3. Similarly, on step 12, if the agent were asked about the existence of a path from cyan to yellow, answering correctly without guessing would require piecing together information from three non-contiguous time steps. Since the actual quiz on step 12 asks whether a path exists from green to blue, the agent must reason over many past observations to determine that no such path exists.

Each pattern is a vector of D real numbers drawn randomly from the interval -1 to 1. A binary value is added to the observation vector to indicate whether the current step is a quiz step, bringing the size of the observation space to , where for our experiments. The action space consists of two actions, defined as yes or no. If the agent answers correctly on a quiz step, it receives a reward of 1; otherwise, it receives a reward of 0. The quiz questions are constructed to guarantee that each answer (yes or no) is correct half the time, so agents that act randomly or have no memory will obtain 50% of possible reward in expectation.

WMG is configured with concept nodes to handle the partial observability but no percept nodes, since we are not using this task to explore factored observations. The number of concept nodes is a tuned hyperparameter, equal to 16 in this experiment. (See Table 

B for all settings.) Each observation is passed directly to WMG’s core node, and WMG generates a new concept on each time step. We compare WMG’s performance to several baselines. Each Depth-n baseline is a hand-coded algorithm demonstrating the performance obtained using perfect memory of past observations and reasoning over paths up to steps long. For example, Depth-2 remembers all previous construction steps, and reasons over all paths of depth 2. Finally, in order to understand the effectiveness of concept nodes at capturing past information, we evaluate a full-history, non-recurrent version of WMG by removing the concept nodes and giving it all past observations on each time step, each one passed to a separate percept node.

As shown in Figure 4, the GRU-based agent exceeds Depth-1 performance, but remains well short of Depth-2 performance after 20 million steps of training (environment interactions). In contrast, both versions of the WMG agent nearly reach Depth-3 performance, demonstrating a greater ability to perform complex reasoning over past observations. The best performance is achieved by the nr-WMG with full-history, which has no need for recurrence. But the full WMG (with concepts) is nearly as sample efficient as this perfect-memory baseline. These results indicate that shortcut recurrence enables WMG to learn to store and utilize essential information from past Pathfinding observations in a more effective manner than a GRU.

4.2 BabyAI Environment

In order to understand how factored representations may be effectively utilized by WMG, we study BabyAI, a domain whose observation space is amenable to factorization. BabyAI (Chevalier-Boisvert et al., 2018) is a partially observable, 2D grid-world containing objects that can be viewed and moved by the agent. Unlike most RL environments, BabyAI features text instructions that specify the goal the agent needs to achieve, such as “pick up the green box.”

We focus on five BabyAI levels, for which the environment consists of a single 6x6 room, as shown in Figure 5 (left). Despite the apparent simplicity of a single-room domain, learning to solve it can often take model-free RL agents hundreds of thousands of environment interaction steps. The agent’s action space consists of 7 discrete actions: Move forward, Turn left, Turn Right, Pick up, Drop, Toggle, and Done. An episode ends after 64 time steps, or when the agent achieves the goal, for which it receives a reward of 1. In Level 1 (GoToObj), the room contains only one object. The agent completes the mission by moving to an adjacent square and pointing toward the object. In Level 2, the target object is always a red ball, and seven grey boxes are present as distractors. In Level 3, the distractors may be any of the 3 object types and 6 colors. If one of the distractors happens to be a red ball, the agent is rewarded for reaching it. In Level 4, the instruction specifies the color and type of the target object. This is the first level in which the text instruction contains valuable information. (See Table 12 for instruction templates.) Level 5 increases the difficulty of Level 4 in two ways. First, the agent must not only reach the target object, but must also pick it up. Second, if multiple qualifying target objects are present, the agent is given the initial relative location of the true target, such as “behind you”.


Go to the yellow box

Part of observation Variable assignments Node
Factored image color=green, type=key, X=3, Y=1 percept
Factored image color=grey, type=box, X=1, Y=2 percept
Factored image color=green, type=ball, X=2, Y=2 percept
Factored image color=red, type=key, X=0, Y=3 percept
Factored image vertical wall X=-2 core
Factored image horizontal wall Y=4 core
Factored instruction command=go to, article=the, core
color=yellow, type=box, loc=None
Additional info orientation=west, core
last action=move forward
Figure 5: One completely factored observation, where each variable assignment corresponds to a one-hot vector in the full observation vector. Since the number of objects in an observation can vary, each object’s vectors are concatenated then passed to a percept node. All other one-hot vectors from the observation are concatenated then passed to the core node. X & Y coordinates refer to a frame of reference with the agent at the origin, pointed in the positive Y direction. The agent always observes one vertical wall and one horizontal wall.

Each agent observation in BabyAI consists of a text instruction, an image, and the agent’s orientation. The image’s native format is a 7x7 array of cell descriptors (not pixels) identifying three attributes of each cell: type, color, and open/closed/locked (referring to doors, which are not found in these 5 levels). To study factored observations in BabyAI, we define a factored representation, depicted in Figure 5. In our experiments the text instruction is always factored, but the image is formatted in multiple ways: (1) 7x7x3, the native BabyAI image array; (2) flat, the native 7x7x3 array flattened to one vector; (3) factored image, as described in Figure 5. (Note

that when a factored image is passed to a GRU, it must first be flattened and padded to form a fixed-length vector.)

To determine whether WMG can leverage factored observations more effectively than gated RNNs in BabyAI, we evaluate the following agents: (1) WMG is the full, recurrent WMG model, with percepts mapped to observation factors, (2) nr-WMG is an ablated, non-recurrent version of WMG with no concepts, (3) GRU is a GRU model, and (4) CNN+GRU

uses a CNN to process the native 7x7x3 image, followed by a GRU. This CNN is one of the two CNN models provided in the BabyAI open source code

(Chevalier-Boisvert et al., 2018).

4.2.1 Results

Image format factored factored factored flat flat native 7x7x3
1 - GoToObj 1.6 1.4 1.7 15.0 19.0 10.6
2 - GoToRedBallGrey 6.7 5.2 24.6 29.0 31.0 22.3
3 - GoToRedBall 16.0 23.6 174.4 92.0 124.6 204.9
4 - GoToLocal 59.7 71.3 2,241.6 1,379.9 1,799.4 —–
5 - PickupLoc 222.3 253.0 —– —– —– —–
Table 1: BabyAI sample efficiency: the amount of training (shown here in thousands of environment interactions) required for a model to solve 99% of 10,000 episodes. Hyperparameters were first tuned on each model/format/level combination separately, then each reported result was computed as the median sample efficiency over 100 additional training runs. Dashes indicate that no model reliably reached a solution rate of 99% within 6 million training steps (environment interactions). Note that Chevalier-Boisvert et al. (2018) report sample efficiencies in terms of episodes rather than environment interactions. (See Table 12.)

Factored Observations: The largest performance differences in Table 1 stem from the choice of factored versus flat or native image formats. Notably, WMG with factored images can achieve sample efficiencies 10x greater (on Level 3) than CNN+GRU using the native 7x7 image format. However, factored observations alone are not sufficient for sample efficiency: WMG utilizes factored images much more effectively than a GRU on Levels 2-5. This result supports our hypothesis that Transformer-based models are particularly well suited for operating on set-based inputs like factored observations, and large gains in sample efficiency are observed as a result.

Concept Nodes: Without factored observations, WMG-flat slightly outperforms GRU-flat, suggesting that shortcut recurrence implemented by the WMG’s concept nodes compares favorably to the GRU’s gated recurrence. With the benefit of factored observations, the non-recurrent ablation of WMG (nr-WMG) performs slightly better than the full WMG on the simplest two levels. But for the more challenging levels 3-5, WMG’s concept vectors prove to be of benefit for WMG with factored observations.

Early vs Late instruction fusion: Interestingly, within our training limit of 6 million environment interactions, CNN+GRU is unable to learn to solve the levels (4 & 5) where instructions carry important information. We suspect this is because the CNN processes just the image while the factored instruction is passed directly to the GRU, skipping the CNN. By contrast, the baseline BabyAI agent uses FiLM layers to integrate the processing of the image with the text instruction. Both WMG and GRU models can process the image and instruction together in all levels of processing. This early fusion appears to allow all WMG and GRU models to solve Level 4.

In summary, the two WMG models with factored images were the only agents able to solve Level 5, and they learned to do so in approximately the same number of interactions that CNN-GRU required to solve Level 3. These drastic differences in sample efficiency serve to highlight the potential gains that can be achieved by RL agents equipped to utilize factored observations.

While WMG’s sample efficiencies dramatically exceed the RL benchmarks published with the BabyAI domain (Chevalier-Boisvert et al., 2018), often by two orders of magnitude (Table 12), it’s important to note that these sets of results are not directly comparable. Our experiments all used factored text instructions, and each model’s hyperparameters were tuned for each level separately, while the BabyAI benchmark agent was trained on all levels using the single hyperparameter configuration provided in the BabyAI release. Because of these differences, our experiments should not be interpreted as a new state-of-the-art on the standard BabyAI tasks.

4.2.2 Hyperparameter Sensitivity

Level 1 2 3
WMG-factored 5.0 13.5 34.7
nr-WMG-factored 3.2 9.9 39.3
GRU-factored 8.0 42.6 313.9
WMG-flat 40.6 74.9 231.4
GRU-flat 36.9 55.3 188.9
Figure 6: Hyperparameter Sensitivity: Sample efficiency (in thousands of environment interactions) of various model-format combinations using hyperparameters optimized for Level 4 then subsequently applied to Levels 1, 2, and 3 (left), as well as 5 (right). All model performances degrade, but WMG with factors still outperforms GRU-based models. Although none of the models reach the 99% threshold for Level 5, WMG reaches a high level of performance before the others. (See Table 1 for more details.)

To evaluate WMG’s sensitivity to hyperparameter selection, we applied the tuned hyperparameter settings from Level 4 to new training runs on all other levels. Figure 6 shows moderate degradations in performance for all models. In particular, when the hyperparameter values tuned on Level 4 are used in Level 5 training runs, none of the models reach a 99% solution rate within 1 million training steps, but WMG with factored observations reaches higher levels of performance than the other models. Broadly, these results indicate that WMG is no more sensitive to hyperparameter settings than the baseline agents.

5 Conclusion and future work

We designed the Working Memory Graph to investigate how Transformer-based models can improve the performance of RL agents. In order to effectively leverage factored observations, WMG applies Transformer-style self-attention to arbitrary numbers of percept vectors mapped directly to observed factors. And in order to represent multiple latent aspects of partially observable environments, without incurring large quadratic computational costs of self-attention over long histories, WMG incorporates a form of recurrence that creates shortcut paths of self-attention over a dynamic set of hidden states, called concepts.

We compared WMG’s performance to that of gated RNNs in two partially observable environments, one focused on complex reasoning over long-term dependencies, and one focused on reasoning over factored observations. In these experiments, WMG outperforms gated RNNs by wide margins. In particular, our results demonstrate that when factored observations are available, sample efficiency can be dramatically boosted by passing the factors separately to WMG percepts, instead of entangling the factors through concatenation into fixed-length vectors for processing by a gated RNN.

To clarify certain limitations of this version of WMG, we outline three potential enhancements:

Flexible concept lifetimes: In the work reported here, each new concept automatically replaces the oldest. A more flexible and adaptive concept-deletion scheme may improve WMG’s ability to model latent aspects in the environment. For instance, concept vectors that receive more attention than others may be the ones most worth keeping around for longer. Deleting a concept only when its recently-received attention falls below a certain threshold would allow the number of concept vectors to fluctuate somewhat over time, depending on the needs of the situation.

Graph edge content: As in the original Transformer, WMG applies input vectors to the nodes in its computation graph, but not to the edges between them. To better represent graph-structured data, Veličković et al. (2017) contemplated incorporating edge-specific data into Graph Attention Networks as future work. By harnessing the richer representational abilities of graph structures over set structures, a similar extension of WMG may allow it to better model complex relations among observed and latent factors in the environment.

Memory vectors: Various forms of external memory have been proposed in recent years. (Graves et al., 2016; Munkhdalai et al., 2019) Memory vectors retrieved from such stores could be fed to dedicated WMG memory nodes, in addition to the current concept and percept nodes, to further extend the range and flexibility of an agent’s effective time horizon.


The authors wish to thank Alekh Agarwal and Xiaodong Liu for many valuable discussions.


  • Anonymous (2019) Stabilizing transformers for reinforcement learning. Note: Under review, International Conference on Learning Representations, 2020 External Links: Link Cited by: §3.
  • D. Bahdanau, K. Cho, and Y. Bengio (2014) Neural machine translation by jointly learning to align and translate. Note: Accepted at ICLR 2015 as oral presentation External Links: Link Cited by: §1.
  • C. Boutilier, R. Reiter, and B. Price (2001) Symbolic dynamic programming for first-order mdps. In

    Proceedings of the 17th International Joint Conference on Artificial Intelligence - Volume 1

    IJCAI’01, San Francisco, CA, USA, pp. 690–697. External Links: ISBN 1-55860-812-5, 978-1-558-60812-2, Link Cited by: §1.
  • C. Boutilier, R. Reiter, M. Soutchanski, and S. Thrun (2000) Decision-theoretic, high-level agent programming in the situation calculus. In Proceedings of the Seventeenth National Conference on Artificial Intelligence and Twelfth Conference on Innovative Applications of Artificial Intelligence, pp. 355–362. External Links: ISBN 0-262-51112-6, Link Cited by: §1.
  • M. Chevalier-Boisvert, D. Bahdanau, S. Lahlou, L. Willems, C. Saharia, T. H. Nguyen, and Y. Bengio (2018) BabyAI: first steps towards grounded language learning with a human in the loop. CoRR abs/1810.08272. External Links: Link, 1810.08272 Cited by: Table 12, §1, §4.2.1, §4.2, §4.2, Table 1, §4.
  • J. Chung, K. Kastner, L. Dinh, K. Goel, A. Courville, and Y. Bengio (2015) A recurrent latent variable model for sequential data. In Proceedings of the 28th International Conference on Neural Information Processing Systems - Volume 2, NIPS’15, Cambridge, MA, USA, pp. 2980–2988. External Links: Link Cited by: §1.
  • J. Devlin, M. Chang, K. Lee, and K. Toutanova (2018) BERT: pre-training of deep bidirectional transformers for language understanding. CoRR abs/1810.04805. External Links: Link, 1810.04805 Cited by: §1.
  • A. Goyal, A. Lamb, J. Hoffmann, S. Sodhani, S. Levine, Y. Bengio, and B. Schölkopf (2019) Recurrent independent mechanisms. CoRR abs/1909.10893, pp. . External Links: Link Cited by: §3.
  • A. Graves, G. Wayne, M. Reynolds, T. Harley, I. Danihelka, A. Grabska-Barwińska, S. G. Colmenarejo, E. Grefenstette, T. Ramalho, J. Agapiou, A. P. Badia, K. M. Hermann, Y. Zwols, G. Ostrovski, A. Cain, H. King, C. Summerfield, P. Blunsom, K. Kavukcuoglu, and D. Hassabis (2016) Hybrid computing using a neural network with dynamic external memory. Nature 538 (7626), pp. 471–476. External Links: ISSN 00280836, Link Cited by: §3, §5.
  • M. J. Hausknecht and P. Stone (2015) Deep recurrent q-learning for partially observable mdps. CoRR abs/1507.06527. External Links: Link, 1507.06527 Cited by: §1.
  • K. He, X. Zhang, S. Ren, and J. Sun (2015)

    Delving deep into rectifiers: surpassing human-level performance on imagenet classification

    CoRR abs/1502.01852. External Links: Link, 1502.01852 Cited by: Table 2.
  • S. Hochreiter and J. Schmidhuber (1997) Long short-term memory. Neural Comput. 9 (8), pp. 1735–1780. External Links: ISSN 0899-7667, Link, Document Cited by: §1.
  • L. P. Kaelbling, M. L. Littman, and A. R. Cassandra (1998) Planning and acting in partially observable stochastic domains. Artif. Intell. 101 (1-2), pp. 99–134. External Links: ISSN 0004-3702, Link, Document Cited by: §1.
  • D. P. Kingma and J. Ba (2014) Adam: a method for stochastic optimization. Note: 3rd International Conference for Learning Representations, San Diego, 2015 External Links: Link Cited by: Table 2.
  • G. A. Miller (1956) The magical number seven, plus or minus two: some limits on our capacity for processing information. The Psychological Review 63 (2), pp. 81–97. External Links: Link Cited by: §2.
  • V. Mnih, A. P. Badia, M. Mirza, A. Graves, T. P. Lillicrap, T. Harley, D. Silver, and K. Kavukcuoglu (2016) Asynchronous methods for deep reinforcement learning. In

    Proceedings of the 33rd International Conference on Machine Learning (ICML)

    pp. 1928–1937. Cited by: Table 2, §2.
  • V. Mnih, K. Kavukcuoglu, D. Silver, A. A. Rusu, J. Veness, M. G. Bellemare, A. Graves, M. Riedmiller, A. K. Fidjeland, G. Ostrovski, S. Petersen, C. Beattie, A. Sadik, I. Antonoglou, H. King, D. Kumaran, D. Wierstra, S. Legg, and D. Hassabis (2015) Human-level control through deep reinforcement learning. Nature 518 (7540), pp. 529–533. External Links: ISSN 00280836, Link Cited by: §1.
  • T. Munkhdalai, A. Sordoni, T. Wang, and A. Trischler (2019) Metalearned neural memory. CoRR abs/1907.09720. External Links: Link, 1907.09720 Cited by: §5.
  • J. Oh, V. Chockalingam, S. Singh, and H. Lee (2016) Control of memory, active perception, and action in minecraft. In Proceedings of the 33rd International Conference on International Conference on Machine Learning - Volume 48, ICML’16, pp. 2790–2799. External Links: Link Cited by: §1, §3.
  • S. Russell and P. Norvig (2009) Artificial intelligence: a modern approach. 3rd edition, Prentice Hall Press, Upper Saddle River, NJ, USA. External Links: ISBN 0136042597, 9780136042594 Cited by: §1.
  • A. Santoro, R. Faulkner, D. Raposo, J. W. Rae, M. Chrzanowski, T. Weber, D. Wierstra, O. Vinyals, R. Pascanu, and T. P. Lillicrap (2018) Relational recurrent neural networks. CoRR abs/1806.01822. External Links: Link, 1806.01822 Cited by: §3.
  • I. Sutskever, O. Vinyals, and Q. V. Le (2014) Sequence to sequence learning with neural networks. In Proceedings of the 27th International Conference on Neural Information Processing Systems - Volume 2, NIPS’14, Cambridge, MA, USA, pp. 3104–3112. External Links: Link Cited by: §1.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin (2017) Attention is all you need. CoRR abs/1706.03762. External Links: Link, 1706.03762 Cited by: §1, §2.
  • P. Veličković, G. Cucurull, A. Casanova, A. Romero, P. Liò, and Y. Bengio (2017) Graph attention networks. CoRR abs/1710.10903. External Links: Link, 1710.10903 Cited by: §5.
  • O. Vinyals, I. Babuschkin, W. M. Czarnecki, M. Mathieu, A. Dudzik, J. Chung, D. H. Choi, R. Powell, T. Ewalds, P. Georgiev, et al. (2019) Grandmaster level in starcraft ii using multi-agent reinforcement learning. Nature, pp. 1–5. Cited by: §3.
  • V. Zambaldi, D. Raposo, A. Santoro, V. Bapst, Y. Li, I. Babuschkin, K. Tuyls, D. Reichert, T. Lillicrap, E. Lockhart, M. Shanahan, V. Langston, R. Pascanu, M. Botvinick, O. Vinyals, and P. Battaglia (2019) Deep reinforcement learning with relational inductive biases. Cited by: §3.

Appendix A Pathfinding environment details

a.1 Graph construction algorithm

The Pathfinding graph is constrained to be a polytree (singly-connected, directed acyclic graph) at each step of the episode, up to a maximum size of patterns, each of size , connected by

links. At the start of each episode, the graph contains a single random pattern. On each construction step, the environment links one new pattern to the graph through the following procedure (drawing all random numbers from uniform distributions):

  1. Create a new random pattern .

  2. Randomly choose one existing pattern in the graph.

  3. Create a new link, choosing a random direction, between the two patterns.

a.2 Quiz-generation algorithm

On each quiz step, the environment first decides whether the correct answer should be 0 or 1 by sampling a binary value from a discrete uniform distribution. Then the environment draws uniform-random ordered pairs of nodes from the current graph until finding a pair that satisfies the desired answer. The observation is then constructed by concatenating the two pattern vectors and a value of

to mark this observation as a quiz.

a.3 Depth-n baseline algorithm

The hand-coded baseline agent is configured with a depth parameter . As new pattern pairs are revealed on graph-construction time steps, the agent maintains a growing vector of all patterns seen, and a growing matrix of directed path lengths from every observed pattern to every other. A path length of zero indicates that no path exists from the first pattern to the second. On each quiz step, the agent looks up from the matrix the path length for the ordered pair of patterns in the observation. If , the agent chooses the yes action. Otherwise, the agent chooses the no action.

Appendix B Training details and hyperparameters

Settings and options Values
Dropout None
Learning rate schedule Constant learning rate
Non-linearities ReLU, tanh
Parallel training workers 1
Optimizer Adam (Kingma and Ba, 2014)
Parameter initialization, biases 0
Parameter initialization, non-bias weights Kaiming uniform (He et al., 2015)
Reward shaping None
Training algorithm Advantage actor-critic (Mnih et al., 2016)
Weight decay regularization None
Table 2: Fixed settings and options used for all experiments, except for the replicated baselines in Table 12.
Actor-critic hidden layer size 128 128 512
Actor-critic 16 16 16
Adam eps 1e-06 1e-08 1e-08
Discount factor 0.5 0.6 0.5
Entropy term strength 0.01 0.005 0.02
Gradient clipping threshold 16.0 16.0 4.
GRU observation embeding size 256
GRU size 384
Learning rate 0.00016 0.00016 0.0001
Reward scale factor 2.0 1.0 0.5
WMG attention head size 12 16
WMG attention heads 6 6
WMG concept nodes 16 0
WMG concept size 128
WMG hidden layer size 12 32
WMG layers 4 4
Table 3: Tuned hyperparameter settings for Pathfinding experiments.
factored factored factored flat flat native 7x7x3
Actor-critic hidden layer size 2048 4096 4096 4096 2048 512
Actor-critic 1 1 6 16 4 6
Adam eps 0.0001 1e-08 1e-08 1e-10 0.0001 1e-10
CNN hidden channel size 1 16
CNN hidden channel size 2 40
CNN hidden channel size 3 192
Discount factor 0.98 0.9 0.7 0.6 0.9 0.8
Entropy term strength 0.002 0.05 0.01 0.005 0.02 0.02
Gradient clipping threshold 256.0 1024.0 512.0 512.0 128.0 128.0
GRU observation embed size 1024 512 512
GRU size 96 512 96
Learning rate 0.0001 4e-05 0.0004 0.0001 0.0001 0.0004
Reward scale factor 4.0 32.0 32.0 8.0 32.0 8.0
WMG attention head size 24 16 16
WMG attention heads 4 10 12
WMG concept nodes 1 0 1
WMG concept size 64 256
WMG hidden layer size 64 64 32
WMG layers 4 4 1
Table 4: Tuned hyperparameter settings for BabyAI Level 1 - GoToObj.
factored factored factored flat flat native 7x7x3
Actor-critic hidden layer size 4096 2048 4096 4096 4096 64
Actor-critic 8 6 16 1 1 1
Adam eps 1e-06 1e-08 1e-10 1e-10 1e-06 0.0001
CNN hidden channel size 1 12
CNN hidden channel size 2 24
CNN hidden channel size 3 192
Discount factor 0.8 0.9 0.8 0.9 0.9 0.95
Entropy term strength 0.01 0.02 0.01 0.005 0.005 0.02
Gradient clipping threshold 1024.0 512.0 1024.0 128.0 64.0 64.0
GRU observation embed size 4096 2048 256
GRU size 96 512 64
Learning rate 0.0001 0.00025 0.0001 2.5e-05 2.5e-05 0.0004
Reward scale factor 8.0 4.0 4.0 4.0 4.0 2.0
WMG attention head size 64 48 64
WMG attention heads 4 1 3
WMG concept nodes 1 0 8
WMG concept size 32 64
WMG hidden layer size 16 24 384
WMG layers 3 3 1
Table 5: Tuned hyperparameter settings for BabyAI Level 2 - GoToRedBallGrey.
factored factored factored flat flat native 7x7x3
Actor-critic hidden layer size 4096 2048 4096 4096 4096 4096
Actor-critic 1 2 3 1 2 3
Adam eps 1e-12 0.0001 1e-06 0.0001 1e-06 0.01
CNN hidden channel size 1 12
CNN hidden channel size 2 40
CNN hidden channel size 3 192
Discount factor 0.95 0.9 0.9 0.9 0.9 0.9
Entropy term strength 0.1 0.05 0.1 0.05 0.02 0.05
Gradient clipping threshold 128.0 128.0 128.0 128.0 32.0 32.0
GRU observation embed size 2048 4096 256
GRU size 192 512 64
Learning rate 2.5e-05 6.3e-05 6.3e-05 2.5e-05 2.5e-05 0.0004
Reward scale factor 8.0 4.0 8.0 8.0 4.0 4.0
WMG attention head size 128 32 24
WMG attention heads 2 8 12
WMG concept nodes 2 0 16
WMG concept size 128 256
WMG hidden layer size 64 32 128
WMG layers 4 4 1
Table 6: Tuned hyperparameter settings for BabyAI Level 3 - GoToRedBall.
factored factored factored flat flat
Actor-critic hidden layer size 2048 2048 1024 512 4096
Actor-critic 6 3 3 6 4
Adam eps 1e-12 0.01 1e-06 1e-08 1e-12
Discount factor 0.5 0.6 0.95 0.5 0.9
Entropy term strength 0.1 0.1 0.1 0.02 0.02
Gradient clipping threshold 512.0 512.0 256.0 256.0 512.0
GRU observation embed size 1024 512
GRU size 128 96
Learning rate 6.3e-05 0.0001 4e-05 2.5e-05 4e-05
Reward scale factor 32.0 16.0 8.0 16.0 2.0
WMG attention head size 128 64 24
WMG attention heads 2 4 16
WMG concept nodes 8 0 16
WMG concept size 32 64
WMG hidden layer size 32 48 16
WMG layers 4 3 2
Table 7: Tuned hyperparameter settings for BabyAI Level 4 - GoToLocal.
factored factored
Actor-critic hidden layer size 512 2048
Actor-critic 12 12
Adam eps 1e-10 1e-10
Discount factor 0.7 0.8
Entropy term strength 0.02 0.05
Gradient clipping threshold 512.0 512.0
Learning rate 0.0001 6.3e-05
Reward scale factor 8.0 8.0
WMG attention head size 24 48
WMG attention heads 10 6
WMG concept nodes 8 0
WMG concept size 32
WMG hidden layer size 128 96
WMG layers 2 2
Table 8: Tuned hyperparameter settings for BabyAI Level 5 - PickupLoc.

Appendix C Additional experimental results

Models & algorithms Final performance Trainable parameters Training speed
Depth-(n-1) baseline 100.0% of reward
Depth-3 baseline 99.7% of reward
Depth-2 baseline 97.6% of reward
Depth-1 baseline 86.9% of reward
nr-WMG, full-history 99.6% of reward 204,963 96 steps/sec
WMG 99.6% of reward 132,507 91 steps/sec
GRU 94.7% of reward 1,139,459 291 steps/sec
Table 9: Additional details for the Pathfinding experimental results in Figure 4 (right).
BabyAI level factored factored factored flat flat native 7x7x3
1 - GoToObj 636 1,864 1,572 2,053 4,170 393
2 - GoToRedBallGrey 2,997 258 3,723 2,116 10,075 140
3 - GoToRedBall 3,418 2,217 3,749 3,229 15,126 709
4 - GoToLocal 2,235 1,960 1,137 2,022 1,479 —–
5 - PickupLoc 879 2,007 —– —– —– —–
Table 10: Number of trainable parameters, in thousands, for the BabyAI models in Table 1.
BabyAI level factored factored factored flat flat native 7x7x3
1 - GoToObj 38 28 146 111 86 149
2 - GoToRedBallGrey 58 113 147 35 18 88
3 - GoToRedBall 18 32 78 25 20 87
4 - GoToLocal 44 48 132 54 134 —–
5 - PickupLoc 81 84 —– —– —– —–
Table 11: Training steps per second on a fixed machine, for the BabyAI models in Table 1.
Published Replicated Replicated
BabyAI level Instruction template (episodes) (episodes) (interactions)
1 - GoToObj GO TO color object —– 19 333
2 - GoToRedBallGrey GO TO RED BALL 16 16 282
3 - GoToRedBall GO TO RED BALL 272 283 3,674
4 - GoToLocal GO TO color object 971 1,064 16,422
5 - PickupLoc PICK UP color object loc 2,977 1,557 25,574
Table 12: BabyAI baseline agent sample efficiencies, defined as the amount of training (in either episodes or environment interaction steps) required for the agent to solve 99% of random episodes within 64 steps. The published results are the means of the min & max RL sample efficiencies reported in Table 3 of Chevalier-Boisvert et al. (2018). The replicated results are the medians over 10 training runs, using the code and default hyperparameter settings from the open source release of the BabyAI baseline agent. All numbers are in thousands.