Schema Networks: Zero-shot Transfer with a Generative Causal Model of Intuitive Physics

by   Ken Kansky, et al.

The recent adaptation of deep neural network-based methods to reinforcement learning and planning domains has yielded remarkable progress on individual tasks. Nonetheless, progress on task-to-task transfer remains limited. In pursuit of efficient and robust generalization, we introduce the Schema Network, an object-oriented generative physics simulator capable of disentangling multiple causes of events and reasoning backward through causes to achieve goals. The richly structured architecture of the Schema Network can learn the dynamics of an environment directly from data. We compare Schema Networks with Asynchronous Advantage Actor-Critic and Progressive Networks on a suite of Breakout variations, reporting results on training efficiency and zero-shot generalization, consistently demonstrating faster, more robust learning and better transfer. We argue that generalizing from limited data and learning causal relationships are essential abilities on the path toward generally intelligent systems.


page 1

page 10


Schema-Guided Paradigm for Zero-Shot Dialog

Developing mechanisms that flexibly adapt dialog systems to unseen tasks...

On the Generalization and Adaption Performance of Causal Models

Learning models that offer robust out-of-distribution generalization and...

Delta Schema Network in Model-based Reinforcement Learning

This work is devoted to unresolved problems of Artificial General Intell...

STAR: A Schema-Guided Dialog Dataset for Transfer Learning

We present STAR, a schema-guided task-oriented dialog dataset consisting...

Schema-Guided Dialogue State Tracking Task at DSTC8

This paper gives an overview of the Schema-Guided Dialogue State Trackin...

Cross-TOP: Zero-Shot Cross-Schema Task-Oriented Parsing

Deep learning methods have enabled task-oriented semantic parsing of inc...

Zero-Shot Adaptation for mmWave Beam-Tracking on Overhead Messenger Wires through Robust Adversarial Reinforcement Learning

This paper discusses the opportunity of bringing the concept of zero-sho...

Code Repositories

1 Introduction

Figure 1: Variations of Breakout. From top left: standard version, middle wall, half negative bricks, offset paddle, random target, and juggling. After training on the standard version, Schema Networks are able to generalize to the other variations without any additional training.

A longstanding ambition of research in artificial intelligence is to efficiently generalize experience in one scenario to other similar scenarios. Such generalization is essential for an embodied agent working to accomplish a variety of goals in a changing world. Despite remarkable progress on individual tasks like Atari 2600 games

(Mnih et al., 2015; Van Hasselt et al., 2016; Mnih et al., 2016) and Go (Silver et al., 2016a), the ability of state-of-the-art models to transfer learning from one environment to the next remains limited. For instance, consider the variations of Breakout illustrated in Fig. 1. In these environments the positions of objects are perturbed, but the object movements and sources of reward remain the same. While humans have no trouble generalizing experience from the basic Breakout to its variations, deep neural network-based models are easily fooled (Taylor & Stone, 2009; Rusu et al., 2016).

The model-free approach of deep reinforcement learning (Deep RL) such as the Deep-Q Network and its descendants is inherently hindered by the same feature that makes it desirable for single-scenario tasks: it makes no assumptions about the structure of the domain. Recent work has suggested how to overcome this deficiency by utilizing object-based representations (Diuk et al., 2008; Usunier et al., 2016). Such a representation is motivated by the well-acknowledged Gestalt principle, which states that the ability to perceive objects as a bounded figure in front of an unbounded background is fundamental to all perception (Weiten, 2012). Battaglia et al. (2016) and Chang et al. (2016) go further, learning relations to predict object interactions.

While object-based and relational representations have shown great promise alone, they stop short of modeling causality – the ability to reason about previous observations and explain away alternative causes. A causal model is essential for regression planning, in which an agent works backward from a desired future state to produce a plan (Anderson, 1990). Reasoning backward and allowing for multiple causation requires a framework like Probabilistic Graphical Models (PGMs), which can natively support explaining away (Koller & Friedman, 2009).

Here we introduce Schema Networks – a generative model for object-oriented reinforcement learning and planning111We borrow the term “schema” from Drescher (1991), whose schema mechanism inspired the early development of our model.. Schema Networks incorporate key desiderata for the flexible and compositional transfer of learned prior knowledge to new settings. 1) Knowledge is represented with “schemas” – local cause-effect relationships involving one or more object entities; 2) In a new setting, these cause-effect relationships are traversed to guide action selection; and 3) The representation deals with uncertainty, multiple-causation, and explaining away in a principled way. We first describe the representational framework and learning algorithms and then demonstrate how action policies can be generated by treating planning as inference in a factor graph. We evaluate the end-to-end system on Breakout variations and compare against Asynchronous Advantage Actor-Critic (A3C) (Mnih et al., 2016) and Progressive Networks (PNs) (Rusu et al., 2016), the latter of which extends A3C explicitly to handle transfer. We show that the structure of the Schema Network enables efficient and robust generalization beyond these Deep RL models.

2 Related Work

The field of reinforcement learning has witnessed significant progress with the recent adaptation of deep learning methods to traditional frameworks like Q-learning. Since the introduction of the Deep Q-network (DQN)

(Mnih et al., 2015), which uses experience replay to achieve human-level performance on a set of Atari 2600 games, several innovations have enabled faster convergence and better performance with less memory. The asynchronous methods introduced by Mnih et al. (2016) exploit multiple agents acting in copies of the same environment, combining their experiences into one model. As the Asynchronous Advantage Actor-Critic (A3C) is the best among these methods, we use it as our primary comparison.

Model-free Deep RL models like A3C are unable to substantially generalize beyond their training experience (Jaderberg et al., 2016; Rusu et al., 2016). To address this limitation, recent work has attempted to introduce more structure into neural network-based models. The Interaction Network (Battaglia et al., 2016) (IN) and the Neural Physics Engine (NPE) (Chang et al., 2016) use object-level and pairwise relational representations to learn models of intuitive physics. The primary advantage of these models is their amenability to gradient-based methods, though such techniques might be applied to Schema Networks as well. Schema Networks offer two key advantages: latent physical properties and relations need not be hardcoded, and planning can make use of backward search, since the model can distinguish different causes. Furthermore, neither INs nor NPEs have been applied in RL domains. Progress in model-based Deep RL has thus far been limited, though methods like Embed to Control (Watter et al., 2015), Value Iteration Networks (Tamar et al., 2016), and the Predictron (Silver et al., 2016b) demonstrate the promise of this direction. However, these approaches do not exploit the object-relational representation of INs or NPEs, nor do they incorporate a backward model for regression planning.

Schema Networks build upon the ideas of the Object-Oriented Markov Decision Process (OO-MDP) introduced by Diuk et al. (2008) (see also Scholz et al. (2014)). Related frameworks include relational and first-order logical MDPs (Guestrin et al., 2003a). These formalisms, which harken back to classical AI’s roots in symbolic reasoning, are designed to enable robust generalization. Recent work by Garnelo et al. (2016) on “deep symbolic reinforcement learning” makes this connection explicit, marrying first-order logic with deep RL. This effort is similar in spirit to our work with Schema Networks, but like INs and NPEs, it lacks a mechanism to learn disentangled causes of the same effect and cannot perform regression planning.

Schema Networks transfer experience from one scenario to other similar scenarios that exhibit repeatable structure and sub-structure (Taylor & Stone, 2009). Rusu et al. (2016) show how A3C can be augmented to similarly exploit common structure between tasks via Progressive Networks (PNs). A PN is constructed by successively training copies of A3C on each task of interest. With each new task, the existing network is frozen, another copy of A3C is added, and lateral connections between the frozen network and the new copy are established to facilitate transfer of features learned during previous tasks. One obvious limitation of PNs is that the number of network parameters must grow quadratically with the number of tasks. However, even if this growth rate was improved, the PN would still be unable to generalize from biased training data without continuing to learn on the test environment. In contrast, Schema Networks exhibit zero-shot transfer.

Schema Networks are implemented as probabilistic graphical models (PGMs), which provide practical inference and structure learning techniques. Additionally, inference with uncertainty and explaining away are naturally supported by PGMs. We direct the readers to (Koller & Friedman, 2009) and (Jordan, 1998) for a thorough overview of PGMs. In particular, early work on factored MDPs has demonstrated how PGMs can be applied in RL and planning settings (Guestrin et al., 2003b).

3 Schema Networks

3.1 MDPs and Notation

The traditional formalism for the Reinforcement Learning problem is the Markov Decision Process (MDP). An MDP is a five-tuple , where is a set of states, is a set of actions,

is the probability of transitioning from state

to after action , is the probability of receiving reward after executing action while in state , and is the rate at which future rewards are exponentially discounted.

Figure 2: Architecture of a Schema Network.

An ungrounded schema is a template for a factor that predicts either the value of an entity-attribute (A) or a future reward (B) based on entity states and actions taken in the present. Self-transitions (C) predict that entity-attributes remain in the same state when no schema is active to predict a change. Self-transitions allow continuous or categorical variables to be represented by a set of binary variables (depicted as smaller nodes). The grounded schema factors, instantiated from ungrounded schemas at all positions, times, and entity bindings, are combined with self-transitions to create a Schema Network (D).

3.2 Model Definition

A Schema Network is a structured generative model of an MDP. We first describe the architecture of the model informally. An image input is parsed into a list of entities, which may be thought of as instances of objects in the sense of OO-MDPs (Diuk et al., 2008). All entities share the same collection of attributes. We refer to a specific attribute of a specific entity as an entity-attribute, which is represented as a binary variable to indicate the presence of that attribute for an entity. An entity state is an assignment of states to all attributes of the entity, and the complete model state is the set of all entity states.

A grounded schema is a binary variable associated with a particular entity-attribute in the next timestep, whose value depends on the present values of a set of binary entity-attributes. The event that one of these present entity-attributes assumes the value 1 is called a precondition of the grounded schema. When all preconditions of a grounded schema are satisfied, we say that the schema is active, and it predicts the activation of its associated entity-attribute. Grounded schemas may also predict rewards and may be conditioned on actions, both of which are represented as binary variables. For instance, a grounded schema might define a distribution over Entity 1’s “position” attribute at time 5, conditioned on Entity 2’s “position” attribute at time 4 and the action “UP” at time 4. Grounded schemas are instantiated from ungrounded schemas, which behave like templates for grounded schemas to be instantiated at different times and in different combinations of entities. For example, an ungrounded schema could predict the “position” attribute of Entity at time conditioned on the “position” of Entity at time and the action “UP” at time ; this ungrounded schema could be instantiated at time with and to create the grounded schema described above. In the case of attributes like “position” that are inherently continuous or categorical, several binary variables may be used to discretely approximate the distribution (see the smaller nodes in Figure 2). A Schema Network is a factor graph that contains all grounded instantiations of a set of ungrounded schemas over some window of time, illustrated in Figure 2.

We now formalize the Schema Network factor graph. For simplicity, suppose the number of entities and the number of attributes are fixed at and respectively. Let refer to the entity and let refer to the attribute value of the entity at time . We use the notation to refer to the state of the entity at time . The complete state of the MDP modeled by the network at time is then . Actions and rewards are also represented with sets of binary variables, denoted and respectively. A Schema Network for time will contain the variables in , , , and .

Let denote the variable for grounded schema . is bound to a specific entity-attribute and activates it when the schema is active. Multiple grounded schemas can predict the same attribute, and these predictions are combined through an OR gate. For binary variables , let , and . A grounded schema is connected to its precondition entity-attributes with an AND factor, written as for entity-attribute preconditions and an optional action . There is no restriction on how many entities or attributes from a single entity can be preconditions of a grounded schema.

An ungrounded schema (or template) is represented as , where determines the relative entity index of the -th precondition and determines which attribute variable is the precondition. The ungrounded schema is a template that can be bound to multiple specific entities and locations to generate grounded schemas.

A subset of attributes corresponds to discrete positions. These attributes are treated differently from all others, whose semantic meanings are unknown to the model. When a schema predicts a movement to a new position, we must inform the previously active position attribute to be inactive unless there is another schema that predicts it to remain active. We introduce a self-transition variable to represent the probability that a position attribute will remain active in the next time step when no schema predicts a change from that position. We compute the self-transition variable as for entity and position attribute , where the set includes all schemas that predict the future position of the same entity and include as a precondition.

With these terms defined, we may now compute the transition function, which can be factorized as . An entity-attribute is active at the next time step if either a schema predicts it to be active or if its self-transition variable is active: , where are the indices of all grounded schemas that predict .

3.3 Construction of Entities and Attributes

In practice we assume that a vision system is responsible for detecting and tracking entities in an image. It is therefore largely up to the vision system to determine what constitutes an entity. Essentially any trackable image feature could be an entity, which most typically includes objects, their boundaries, and their surfaces. Recent work has demonstrated one possible method for unsupervised entity construction using autoencoders

(Garnelo et al., 2016). Depending on the task, Schema Networks could learn to reason flexibly at different levels of representation. For example, using entities from surfaces might be most relevant for predicting collisions, while using one entity per object might be most relevant for predicting whether it can be controlled by an action. The experiments in this paper utilize surface entities, described further in Section 5.

Similarly, entity attributes can be provided by the vision system, and these attributes typically include: color/appearance, surface/edge orientation, object category, or part-of an object category (e.g. front-left tire). For simplicity we here restrict the entities to have fully observable attributes, but in general they could have latent attributes such as “bounciness” or “magnetism.”

3.4 Connections to Existing Models

Schema Networks are closely related to Object-Oriented MDPs (OO-MDPs) (Diuk et al., 2008) and Relational MDPs (R-MDPs) (Guestrin et al., 2003a). However, neither OO-MDPs nor R-MDPs define a transition function with an explicit OR of possible causes, and traditionally transition functions have not been learned in these models. In contrast, Schema Networks provide an explicit OR to reason about multiple causation, which enables regression planning. Additionally, the structure of Schema Networks is amenable to efficient learning.

Schema Networks are also related to the recently proposed Interaction Network (IN) (Battaglia et al., 2016) and Neural Physics Engine (NPE) (Chang et al., 2016). At a high level, INs, NPEs, and Schema Networks are much alike – objects are to entities as relations are to schemas. However, neither INs nor NPEs are generative and hence do not support regression planning from a goal through causal chains. Because Schema Networks are generative models, they support more flexible inference and search strategies for planning. Additionally, the learned structures in Schema Networks are amenable to human interpretation, explicitly factorizing different causes, making prediction errors easier to relate to the learned model parameters.

4 Learning and Planning in Schema Networks

In this section we describe how to train Schema Networks (i.e., learn its structure) from interactions with an environment, as well as how they can be used to perform planning. Planning is not only necessary at test time to maximize reward, but also can be used to improve exploration during the training procedure.

4.1 Training Procedure

Given a series of actions, rewards and images, we represent each possible action and reward with a binary variable, and we convert each image into a set of entity states. The number of entities is allowed to vary between adjacent frames, accounting for objects appearing or moving out of view.

Given a dataset of entity states over time, we preprocess the entity states into a representation that is more convenient for learning. For entities observed over timesteps, we wish to predict on the basis of the attribute values of the entity and its spatial neighbors at time (for and ). The attribute values of

and its neighbors can be represented as a row vector of length

, where is the number of attributes and is the number of neighbor positions of each entity, determined by a fixed radius. Let be the arrangement of all such vectors into a binary matrix, with and . Let be a binary vector such that if row in refers to , then . Schemas are then learned to predict from using the method described in Section 4.2.

While gathering data, actions are chosen by planning using the schemas that have been learned so far. This planning algorithm is described in Section 4.3. We use an -greedy approach to encourage exploration, taking a random action at each timestep with small probability. We found no need to perform any additional policy learning, and after convergence predictions were accurate enough to allow for successful planning. As shown in Section 5

, since learning only involves understanding the dynamics of the game, transfer learning is simplified and there is no need for policy adaptation.

4.2 Schema Learning

Structure learning in graphical models is a well studied topic in machine learning

(Koller & Friedman, 2009; Jordan, 1998)

. To learn the structure of the Schema Network, we cast the problem as a supervised learning problem over a discrete space of parameterizations (the schemas), and then apply a greedy algorithm that solves a sequence of LP relaxations. See

Jaakkola et al. (2010) for further work on applying LP relaxations to structure learning.

With and defined above, the learning problem is to find a mapping

where all the involved variables are binary and operations follow Boolean logic: addition corresponds to ORing, and overlining to negation. is a binary matrix, with each column representing one (ungrounded) schema for at most schemas. The elements set to 1 in each schema represent an existing connection between that schema and an input condition (see Fig. 2). The outputs of each individual schema are ORed to produce the final prediction.

We would like to minimize the prediction error of Schema Networks while keeping them as simple as possible. A suitable objective function is


where the first term computes the prediction error, the second term estimates the complexity and parameter

controls the trade-off between both. This is an NP-hard problem for which we cannot hope to find an exact solution, except for very small environments.

We consider a greedy solution in which linear programming (LP) relaxations are used to find each new schema. Starting from the empty set, we greedily add schemas (columns to

) that have perfect precision and increase recall for the prediction of (See Algorithm 1 in the Supplementary). In each successive iteration, only the input-output pairs for which the current schema network is predicting an output of zero are passed. This procedure monotonically decreases the prediction error of the overall schema network, while increasing its complexity. The process stops when we hit some predefined complexity limit. In our implementation, the greedy schema selection produces very sparse schemas, and we simply set a limit to the number of schemas to add. For this algorithm to work, no contradictions can exist in the input data (such as the same input appearing twice with different labels). Such contradictions might appear in stochastic environments and would not be artifacts in real environments, so we preprocess the input data to remove them.

4.3 Planning as Probabilistic Inference

The full Schema Network graph (Fig. 2) provides a probabilistic model for the set of rewards that will be achieved by a sequence of actions. Finding the sequence of actions that will result in a given set of rewards becomes then a MAP inference problem. This problem can be addressed approximately using max-product belief propagation (MPBP) (Attias, 2003). Another option is variational inference. Cheng et al. (2013) use variational inference for planning but resort to MPBP to optimize the variational free energy functional. We will follow the first approach.

Without loss of generality, we will consider the present time step to be . The state, action and reward variables for are observed, and we will consider inference over the unobserved variables in a look-ahead window of size222In contrast with MDPs, the reward is discounted with a rolling square window instead of an exponentially weighted one. , . Since the Schema Network is built exclusively of compatibility factors

that can take values 0 or 1, any variable assignment is either impossible or equally probable under the joint distribution of the graph. Thus, if we want to know if there exists any global assignment that activates a binary variable (say, variable

signaling positive reward at some future time ), we should look at the max-marginal . It will be 0 if no global assignment compatible with both the SN and existing observations can lead to activate the reward, or 1 if it is feasible. Similarly, we will be interested in the max-marginal , i.e., whether it is feasible to find a configuration that avoids a negative reward.

At a high-level, planning proceeds as follows: Identify feasible desirable states (activating positive rewards and deactivating negative rewards), clamp their value to our desires by adding a unary potential to the factor graph, and then find the MAP configuration of the resulting graph. The MAP configuration contains the values of the action variables that are required to reach our goal of activating/deactivating a variable. We can also look at to see how the model “imagines” the evolution of the entities until they reach their goal state. Then we perform the actions found by the planner and repeat. We now explain each of these stages in more detail.

Potential feasibility analysis

First we run a feasibility analysis. To this end, a forward pass MPBP from time to time is performed. This provides a (coarse) approximation to the desired max-marginals for every variable. Because the SN graph is loopy, MPBP is not exact and the forward pass can be too optimistic, announcing the feasibility of states that are unfeasible333To illustrate the problem, consider the case in which it is feasible for an entity to move at time to position or position (but obviously not both) and then some reward is conditioned on that type of entity being in both positions: A single forward pass will not handle the entanglement properly and will incorrectly report that such reward is also feasible.. Actual feasibility will be verified later, at the backtracking stage.

Choosing a positive reward goal state

We will choose the potentially feasible positive reward that happens soonest within our explored window, clamp its state to 1, and backtrack (see below) to find the set of actions that lead to it. If backtracking fails, we will repeat for the remaining potentially feasible positive rewards.

Avoiding negative rewards

Keeping the selected positive reward variable clamped to 1 (if it was found in the previous step), we now repeat the same procedure on the negative rewards. Among the negative rewards that have been found as potentially feasible to turn off, we clamp to zero as many negative rewards as we can find a jointly satisfying backtrack. If no positive reward was feasible, we backtrack from the earliest predicted negative reward.


This step is akin to Viterbi backtracking, a message passing backward pass that finds a satisfying configuration. Unlike the HMM for which the Viterbi algorithm was designed, our model is loopy, so a standard backward pass is not enough to find a satisfying configuration (although can help to find good candidates). We combine the standard backward pass with a depth-first search algorithm to find a satisfying configuration.

5 Experiments

Figure 5: Comparison of learning rates. (a) Schema Networks and A3C were trained for 100k frames in Mini Breakout. Plot shows the average of 5 training attempts for Schema Networks and the best of 5 training attempts for A3C, which did not converge as reliably. (b) PNs and Schema Networks were pretrained on 100K frames of Standard Breakout, and then training continued on 45K additional frames of the Middle Wall variation. We show performance as a function of training frames for both models. Note that Schema Networks are ignoring all the additional training data, since all the required schemas were learned during pretraining. For Schema Networks, zero-shot transfer learning is happening.

We compared the performance of Schema Networks, A3C, and PNs (Progressive Networks) on several variations of the game Breakout. The chosen variations all share similar dynamics, but the layouts change, requiring different policies to achieve high scores. A diverse set of concepts must be learned to correctly predict object movements and rewards. For example, when predicting why rewards occur, the model must disentangle possible causes to discover that reward depends on the color of a brick but is independent of the ball’s velocity and position where it was hit. While these causal relationships are straightforward for humans to recover, we have yet to see any existing approach for learning a generative model that can recover all of these dynamics without supervision and transfer them effectively.

Schema Networks rely on an input of entity states instead of raw images, and we provided the same information to A3C and PNs by augmenting the three color channels of the image with 34 additional channels. Four of these channels indicated the shape to which each pixel belongs, including shapes for bricks, balls, and walls. Another 30 channels indicated the positions of parts of the paddle, where each part consisted of a single pixel. To reduce training time, we did not provide A3C and PN with part channels for objects other than the paddle, since these are not required to learn the dynamics or predict scores. Removing irrelevant inputs could only give A3C and PN an advantage, since the input to Schema Networks did not treat any object differently. Schema Networks were provided separate entities for each part (pixel) of each object, and each entity contained 53 attributes corresponding to the available part labels (21 for bricks, 30 for the paddle, 1 for walls, and 1 for the ball). Only one of these part attributes was active per entity. Schema Networks had to learn that some attributes, like parts of bricks, were irrelevant for prediction.

5.1 Transfer Learning

This experiment examines how effectively Schema Networks and PNs are able to learn a new Breakout variation after pretraining, which examines how well the two models can transfer existing knowledge to a new task. Fig. 5a shows the learning rates during 100k frames of training on Mini Breakout. In a second experiment, we pretrained on Large Breakout for 100k frames and continued training on the Middle Wall variation, shown in Fig. 1b. Fig. 5b shows that PNs require significant time to learn in this new environment, while Schema Networks do not learn anything new because the dynamics are the same.

5.2 Zero-Shot Generalization

Standard Breakout Offset Paddle Middle Wall Random Target Juggling
A3C Image Only N/A
A3C Image + Entities N/A
Schema Networks
Table 1: Zero-Shot Average Score per Episode

Average of the 2 best out of 5 training attempts for A3C, and average of 5 training attempts for Schema Networks. A3C was trained on 200k frames of Standard Breakout (hence its zero-shot scores for Standard Breakout are unknown) while Schema Networks were trained on 100k frames of Mini Breakout. Episodes were limited to 2500 frames for all variations. In every case the average Schema Network scores are better than the best A3C scores by more than one standard deviation.

Many Breakout variations can be constructed that all involve the same dynamics. If a model correctly learns the dynamics from one variation, in theory the others could be played perfectly by planning using the learned model. Rather than comparing transfer with additional training using PNs, in these variations we can compare zero-shot generalization by training A3C only on Standard Breakout. Fig. 1b-e shows some of these variations with the following modifications from the training environment:

  • Offset Paddle (Fig. 1d): The paddle is shifted upward by a few pixels.

  • Middle Wall (Fig. 1b): A wall is placed in the middle of the screen, requiring the agent to aim around it to hit the bricks.

  • Random Target (Fig. 1e): A group of bricks is destoyed when the ball hits any of them and then reappears in a new random position, requiring the agent to delibarately aim at the group.

  • Juggling (Fig. 1f, enlarged from actual environment to see the balls): Without any bricks, three balls are launched in such a way that a perfect policy could juggle them without dropping any.

Table 1 shows the average scores per episode in each Breakout variation. These results show that A3C has failed to recognize the common dynamics and adapt its policy accordingly. This comes as no surprise, as the policy it has learned for Standard Breakout is no longer applicable in these variations. Simply adding an offset to the paddle is sufficient to confuse A3C, which has not learned the causal nature of controlling the paddle with actions and controlling the ball with the paddle. The Middle Wall and Random Target variations illustrate that Schema Networks are aiming to deliberately cause positive rewards from ball-brick collisions, while A3C struggles to adapt its policy accordingly. The Juggling variation is particularly challenging, since it is not clear which ball to respond to unless the model understands that the lowest downward-moving ball is the most imminent cause of a negative reward. By learning and transferring the correct causal dynamics, Schema Networks outperform A3C in all variations.

Half Negative Bricks
A3C Image Only
A3C Image + Entities
Schema Networks
Table 2: Average Score per Episode on Half Negative Bricks

A3C and Schema Networks were trained on 200k frames of Random Negative Bricks, both to convergence. Testing episodes were limited to 1000 frames. Negative rewards are sometimes unavoidable, resulting in higher variance for all methods.

5.3 Testing for Learned Causes

To better evaluate whether these models are truly learning the causes of rewards, we designed one more zero-shot generalization experiment. We trained both Schema Networks and A3C on a Mini Breakout variation in which the color of a brick determines whether a positive or negative reward is received when it is destroyed. Six colors of bricks provide +1 reward, and two colors provide -1 reward. Negative bricks occurred in random positions 33% of the time during training. Then during testing, the bricks were arranged into two halves, with all positive colored bricks on one half and negative colored bricks on the other (see Fig. 1c). If the causes of rewards have been correctly learned, the agent should prefer to aim for the positive half whenever possible. As Table 1 shows, Schema Networks have correctly learned from random arrangements which brick colors cause which rewards, preferring to aim for the positive half during testing, while A3C demonstrates no preference for one half or the other, achieving an average score near zero.

6 Discussion and Conclusion

In this work we have demonstrated the promise of Schema Networks with strong performance on a suite of Breakout variations. Instead of learning policies to maximize rewards, the learning objective for Schema Networks is designed to understand causality within these environments. The fact that Schema Networks are able to achieve rewards more efficiently than state-of-the-art model-free methods like A3C is all the more notable, since high scores are a byproduct of learning an accurate model of the game.

The success of Schema Networks is derived in part from the entity-based representation of state. Our results suggest that providing Deep RL models like A3C with such a representation as input can improve both training efficiency and generalization. This finding corroborates recent attempts (Usunier et al., 2016; Garnelo et al., 2016; Chang et al., 2016; Battaglia et al., 2016) to incorporate object and relational structure into neural network-based models.

The environments considered in this work are conceptually diverse but also simplified in a number of ways with respect to the real world: states, actions, and rewards are all discretized as binary random variables; the dynamics of the environments are deterministic; and there is no uncertainty in the observed entity states. In future work we plan to address each of these limitations, adapting Schema Networks to continuous, stochastic domains.

Schema Networks have shown promise toward multi-task transfer where Deep RL struggles. This transfer is enabled by explicit causal structures, which in turn allow for planning in novel tasks. As progress in RL and planning continues, robust generalization from limited experience will be vital for future intelligent systems.


Special thanks to Eric Purdy and Ramki Gummadi for useful insights and discussions during the preparation of this work.


Appendix A Breakout playing visualizations

See for visualizations of Schema Networks playing different variations of Breakout after training only on basic Breakout.

Figure 8 shows typical gameplay for one variation.

(a) A3C
(b) Schema Networks
Figure 8: Screen average during the course of gameplay in the Midwall variation when using A3C and Schema Networks. SNs are able to purposefully avoid the middle wall most of the time, whereas A3C struggles to score any points.

Appendix B LP-based Greedy Schema Learning

The details of the LP-based learning of Section 4.2 are provided here.

0:  Input vectors for which (the current schema network predicts 0), and the corresponding output scalars
1:  Find a cluster of input samples that can be solved with a single (relaxed) schema while keeping perfect precision (no false alarms). Select an input sample and put it in the set “solved”, then solve the LP
2:  Simplify the resulting schema. Put all the input samples for which in the set “solved”. Simplify the just found schema

by making it as sparse as possible while keeping the same precision and recall:

3:  Binarize the schema. In practice, the found is binary most of the time. If it is not, repeat the previous minimization using binary programming, but optimize only over the elements of that were found to be non-zero. Keep the rest clamped to zero.
3:  New schema to add to the network
Algorithm 1 LP-based greedy schema learning