Causal reasoning is an integral part of natural intelligence. The capacity to reason about cause and effect has been observed in humans and other intelligent animals as a means of survival (blaisdell2006causal; taylor2008new). Such capacity plays a crucial role for young children in their interaction with the physical world. As behavioral psychology studies have shown, young children discover the underlying causal mechanisms from their play with the world (schulz2007serious), and their knowledge of causality in turn facilitates their subsequent learning of objects, concepts, languages, and physics (rehder2003categorization; corrigan1996causal).
Nowadays, data-centric methods in artificial intelligence, such as deep networks, have achieved tremendous success in learning associations between inputs and outputs from large amounts of data, such as images to class labels(he2015deep)
. However, empirical evidence indicates that the absence of correct causal modeling in these methods has posed a major threat to generalization, causing image captioning models to generate unrealistic captions(lake2017building)
, deep reinforcement learning policies to fail in novel problem instances(edmonds2019human)
, and transfer learning models to adapt slower to new distributions(bengio2019meta).
In this work, we propose to endow a learning-based interactive agent with the capacity of causal reasoning for completing goal-directed tasks in visual environments. Imagine that a household robot enters a new home for the first time. Without prior knowledge of the wiring configuration, it has to toggle the switches and sort out the correspondences between lights and switches, before it can be commanded to turn on the kitchen light or the bathroom light. We refer to the first phase of toggling switches as causal induction, where the agent discovers the latent cause and effect relations via performing actions and observing their outcomes; and we refer to the second phase of turning on specific lights as causal inference, where the agent uses the acquired causal relations to guide its actions for the completion of a task. To build an effective computational model for causal induction and inference, we have to address generalization towards novel causal relations and new task goals at the test time, both of which can be unseen during training.
We cast this as a meta-learning problem of two phases following dasgupta2019causal. In the first stage, we use a causal induction model
to construct a causal structure, i.e., a directed acyclic graph of random variables, from observational data from an agent’s interventions. In the second stage, we use the causal structure to contextualize agoal-conditional policy to perform the task given a goal. However, in contrast to dasgupta2019causal we explicitly construct the causal structure instead of a latent feature encoding, leading to substantially better generalization towards new problem instances in long-horizon tasks as opposed to simplistic one-step querying.
To this end, we propose two technical contributions: 1) an iterative causal induction model with attention, which learns to incrementally update the predicted causal graph for each observed interaction in the environment, and 2) a goal-conditioned policy with an attention-based graph encoding, forcing it to focus on the relevant components of the causal graph at each step. We find that by factorizing the induction and inference processes through the use of causal graphs, it generalizes well to unseen causal structures given as few as 50 training causal structures. We compare our approach to using the ground-truth causal structure (which provides oracle performance), a non-iterative architecture which directly predicts the causal structure, and encoding the observation data into the LSTM (lstm) memory of the policy similar to prior work (dasgupta2019causal). We demonstrate that our method outperforms the baselines and achieves close to oracle performance in terms of both recovering the causal graph and success rate of completing goal-conditioned tasks, across several task sizes, types, and number of training causal structures.
2 Problem Statement
We formulate the agent’s interaction in a Goal-Conditioned Markov Decision Process (MDP) defined by a six tuple, where is the state space, is the action space, defines the transition dynamics, is the goal space, is a reward function where gives the one-step immediate reward conditioned on the goal , and is the discount factor. Our goal is to learn a goal-conditioned policy that maximizes the expected sum of rewards .
In this work, we not only want to generalize well across goals in , but consider a more ambitious aim of making generalize across a set of MDPs. We consider as the entire set of MDPs with the same state space and action space but different transition dynamics, where is defined by . The dynamics determines underlying causal relations between states and actions. Taking the same action at the same state could lead to a different next state under a different dynamics. We expect our agent to operate on its first-person vision and has no access to the latent causal relations. It receives high-dimensional RGB observations and has to induce a causal model from observational data. As illustrated in Figure 1, the overall procedure has two stages: 1) we execute an interaction policy to collect a sequence of transitions , which is consumed by an induction model to construct a latent causal model ; and 2) we use the causal model to contextualize a goal-conditioned policy such that it can perform tasks in the new MDP with novel causal relations. We formulate this as a meta-learning problem (dasgupta2019causal; maml). We partition the set of all MDPs into two disjoint sets and . During training, we learn our induction model and goal-conditioned policy with . During testing, we evaluate whether is able to learn from the observational data from in a novel MDP from to construct a causal model that can be used by to perform tasks in this new MDP.
Direct modeling of causal relations in raw pixel space is intractable due to the large dimensionality. Following chalupka2014visual, we assume that cause and effect in our problems can be defined on a handful of causal macro-variables. For example, the illuminance of the kitchen (the agent’s visual observation) is determined by the on and off of the kitchen light (macro-variable), which is caused by toggling the state of the switch that controls the light (another macro-variable). This assumption enables us to construct a directed acyclic causal model to represent the causal effects of actions on these macro-variables. Given the set of macro-variables, the induction model predicts directed edges between them from visual observations. A primary challenge here is the confounders raised from partial observability and spurious correlations in the agent’s visual perception. For example, illuminance changes in the kitchen might be due to turning on/off the kitchen light or the living room light. Hence, it requires the agent to disentangle the correct causal relations from visual inputs.
The goal of our method is to enable a policy to complete goal-conditioned vision-based control tasks in environments with unseen causal structures, given only a short trajectory of observational data in the environment. Prior work (dasgupta2019causal) has shown promising results on simplistic one-step querying problems using an LSTM-based policy which encodes the interaction into the policy’s memory. Our hypothesis is that to generalize in complex multi-step control problems, a more structured induction and policy scheme will be required. To address this, we propose iterative updates and attention bottlenecks in the induction model and in the policy respectively, which we demonstrate significantly improves generalization to unseen causal structures.
3.1 Iterative Causal Induction Network
Inducing the causal structure from raw sensory observations requires accurately capturing the unique effect of each action on the environment, while accounting for confounding effects of other actions. We hypothesize that the causal induction network that best generalizes will be one which disentangles individual actions and their corresponding effect, and only updates the relevant components of the causal graph.
We implement this idea in our iterative model, where we begin with an initial guess of the causal structure which has all edge weights of 0 (meaning we assume no causal relationships). We then use an Observation Encoder to map each image of the observational data to an encoding and compute the state residual between subsequent time steps. This , which captures the change in state is then concatenated with the corresponding action, and then fed into the Edge Decoder module. The output of the Edge Decoder module is an update to the edge strengths of the causal graph . This update is applied to each observed transition, that is , and at the final layer the whole graph is encoded to do a final edge update before the causal graph is predicted (see Figure 2). The Edge Decoder takes either the encoded and , or the encoded edge matrix of , and outputs a soft attention vector and a change to the edge weights , where is the number of actions in the environment. The attention vector is used to weight which nodes in the causal graph the edge update should be applied to. Thus at each iteration the update amounts to:
where is the Transition Encoder, a fully connected module (see Appendix A for details). Using this attention mechanism further encourages the module to make independent updates, which we observe enables better generalization.
3.2 Learning Goal-Conditioned Policies
The objective of the policy is given an initial image , a goal image , and the predicted causal structure , reach the goal within time steps. Additionally, the policy is a reactive one, and thus can only solve the goal-conditioned task if it can learn to use the predicted causal structure . That is — since the policy has no memory, it cannot learn to induce the graph internally during inference time, and thus must use the causal graph .
We hypothesize that like the causal induction model, the policy which best generalizes is one which learns to focus exclusively on the edges in the causal graph which are relevant to the current step of the task. To that end we propose an attention bottleneck in the graph encoding, which encourages the policy to select edges pertaining to one “effect” at each step, which enables better generalization.
Specifically, the policy encodes the current image and goal image . Based on this encoding it outputs an attention vector of size over the “effects” in the causal graph. This vector is used to perform a weighted sum over the outputs of the causal graph ( causes, effects, and edges between them), resulting in a size vector of the selected edges . The selected edges and visual encodings are used to output the final action:
where has identical architecture to the image encoder as , but which encodes both current and goal image and are all fully connected layers (see Figure 3).
3.3 Model Training
The induction network is trained using supervised learning in the limited set of training environments, specifically to minimize the reconstruction loss between the ground-truth causal graph and the predicted causal graph .
The policy is trained using the DAgger (dagger) algorithm by imitating a planner using the ground-truth causal graph in the training environments. Then is tested in unseen environments with only visual inputs and goal images. Specifically, in the training environments, the planner uses the ground-truth graph and privileged low dimensional state/goal information to compute an optimal plan to the goal. At each time step, the expert’s action is added to the memory of the policy, which is then trained using a standard cross entropy loss to imitate the expert given the current image and goal image. The policy is also injected with -greedy noise during training, with .
Network architectures and additional training details can be found in the appendix.
Through our experiments, we investigate three complementary questions: 1) Does our iterative induction network enable better causal graph induction?, 2) Does our attention bottleneck in the graph encoding in enable the policy to generalize better to unseen causal structures?, and 3) By combining our proposed and , are we able to outperform the current state-of-the-art dasgupta2019causal on visual goal-directed tasks?
4.1 Experimental Setup
Task Definition: We examine the multi-step task of light switch control. In particular an agent has control of switches (macro-variables) , which have some underlying causal structure of how they control lights (macro-variables). However, the macro-variables of the lights manifest themselves in noisy visual observations, whose partial observability and overlap result in confounding factors. The objective of the agent is starting from an initial state , to control the switches to reach a specified lighting goal , where both the state and goal are images. However as specified in the problem setup, the underlying causal structure is unknown to the agent, all that is provided is limited observational data from the environment. Thus the agent must (1) induce the causal structure between the switches and lights from the observational data, then (2) use it to reach the goal.
We explore 4 different types of causal patterns between the switches and lights (See Figure 3). The first type of causal structure is One-to-One, in which each switch maps to one light. The second type of causal structure is Many-to-One (Common Effect (keil_explanation_understanding)), where each switch controls one light, but multiple switches may control the same light. The third type is One-to-Many (Common Cause (keil_explanation_understanding)), where all lights are controlled by at most one switch, but a single switch may control more than one light. Lastly, we also explore the Masterswitch domain in which there is a Causal Chain (keil_explanation_understanding) where only once one master switch is activated can the other switches causal effects be observed, applied on top of a One-to-One causal structure. In our method we represent the causal structure as a graph with directed edges between switch states and light states in the environment, with edge strength corresponding to the likelihood that that switch controls the indicated light, with an additional edges in the Masterswitch setting.
Environment Setup: The trajectory of observational data from which we induce the causal structure is in visual space, consisting of a image as well as an size action vector for each timestep of the timestep trajectory. The visual scene consists of the lights mounted in a MuJoCo (mujoco) simulated house with 3 rooms, where the lights are mounted in the rooms and hallway, and the effect of each light is rendered onto the floor when they are turned on, which is captured by a top down camera. The illumination from the lights overlap, resulting in confounding factors, which must be disentangled by the model in order to correctly predict the causal structure. The state space (and goal space) of the policy are also images of the same environment and camera. The action space consists of size , one discrete action for each switch. During policy learning, the goals are sampled uniformly from the space of possible lighting configurations under the environments causal structure, i.e., the set of all reachable states in the environment.
Collecting Observational Data:
The observational data is collected using a heuristic interaction agent, which executes a simple policy. In the all but the Masterswitch case, the policy simply takes each action once (horizon ). In the Masterswitch setting the exploratory agent presses each switch until one has an effect on the environment, and then proceeds to press each of the other switches (horizon up to ).
4.2 Evaluation Methods
We evaluate the following methods and baselines to examine the effectiveness of our causal induction model and goal-conditioned policy.
First to examine how much our iterative causal induction network (ICIN) improves performance on inducing the causal graph we compare against a non-iterative induction model which uses temporal convolutions (TCIN), as well as an ablation of our method which uses an iterative model without the attention mechanism (ICIN (No Attn)). We compare these approaches based on F1 score of recovering the ground-truth causal graph.
Next, we compare the performance of the goal conditioned policy using all variants (ICIN, ICIN (No Attn), TCIN), compared to the previous work of dasgupta2019causal which induces the graph using the latent memory of the policy (Memory). Specifically, we provide the policy with LSTM memory, and before running DAgger, feed the interaction trajectory one step at a time through the policy. This end-to-end approach allows the causal graph to be encoded into the latent memory of the policy, to be used when doing a goal-directed task. We also compare to Memory (RL/Low Dim), a version of dasgupta2019causal which has access to privileged state information (ground-truth states), and is trained using model-free reinforcement learning with a dense reward. In this setting the same visual interaction trajectory is encoded into the policy’s LSTM memory, but the actual policy receives a binary vector for state and goal and is trained using the PPO algorithm (ppo). We also add a comparison to using the ground-truth causal graph (even at test time) as an Oracle (Oracle), which provides an upper bound on performance. All methods are compared based on success rates in unseen environments. Implementation details can be found in Appendix B.
Lastly, to understand how critical the attention bottleneck in the goal-conditioned policy is for generalization, we compare the success rates in unseen causal structures of the goal-condition policy using a graph induced by ICIN, with and without the attention bottleneck.
4.3 Causal Induction Evaluation
First we examine our approach’s ability to induce the causal model from the trajectory of observational data. We report the F1 score (threshold=0.5) between edges of the predicted causal graph and ground-truth causal graph. We compute the F1 scores across up to 100 unseen tasks given 10, 50, 100, or 500 seen structures in the 5 and 7 switch problems (see Figure 4). We observe that across almost all settings, our iterative approach with attention (ICIN) dominates. While our method without attention generally outperforms the non-iterative baseline, both fall significantly behind our final approach, suggesting that the modularity that attention provides plays a large role in enabling generalization. Furthermore, we observe that our iterative attention approach especially outperforms the others when there are less training causal structures, which would suggest that by forcing the network to make independent updates, it is able to learn a general method for induction with limited training examples. This is likely because the attention forces the model to learn to update a single edge given a single observation, while being agnostic to the total graph, which is far less likely to overfit to the training structures. Qualitative examples of ICIN can be found in Appendix C.
4.4 Goal-Conditioned Policy Evaluation
We examine the success rate of the converged policy on 500 trials in unseen causal structures with randomized goals in Figure 5. We observe in most settings that the Memory based approach of dasgupta2019causal provides a strong baseline, outperforming the TCIN and ICIN (No Attn) baselines. We suspect that it learns that to best imitate the expert, it has to encode relevant information from the interaction trajectory into its latent memory, implicitly performing induction. While this works, it generalizes to unseen structures much worse than our proposed method, likely due to the compositional structure of our approach. The memory baseline which uses low dimensional states (Memory (RL/Low Dim)), and is trained via model-free RL also performs well, in fact beating our approach on a few cases in the 7 switch, Many-to-One setting. However in general the performance of this approach is much lower than ours, and likely the cases in which it does succeed can be attributed to its use of privileged information (ground-truth states) instead of visual observations. In almost all cases ICIN significantly outperforms all baselines. In fact, in the 5 switch case our ICIN method nearly matches the Oracle, suggesting that it almost perfectly induces .
Finally, we study the importance of our proposed attention bottleneck in the graph encoding in , which forces the policy to focus on only relevant edges at each timestep. We examine the success rate of using with the attention bottleneck compared to just flattening given the current image, goal image, and predicted causal graph from ICIN.
We find that using the attention bottleneck in the graph encoder of yields a roughly increase in success rate in the One-to-One (1:1) and Masterswitch (MS) cases, and a roughly increase in success rate in the One-to-Many (1:K) and Many-to-One (K:1) cases. This is because by encouraging the policy to pick relevant edges, it has led to a modular policy which can 1) identify the changes it wants to make in the environment and 2) predict the necessary action based on the causal graph , enabling better generalization.
5 Related Work
Causal reasoning has been extensively studied by a broad range of scientific disciplines, such as social sciences (yee1996causal), medical sciences (kuipers1984causal), and econometrics (zellner1979causality). In causality literature, structural causal model (SCM) (pearl2009causality) has offered a formal framework of modeling causation from statistical data and counterfactual reasoning. SCMs are a directed graph that represents causal relationships between random variables. Both causal induction (constructing SCMs from observational data) (shimizu2006linear; hoyer2009nonlinear; peters2014causal; ortega2014generalized) and causal inference (using SCMs to estimate causal effects) (bareinboim2015bandits; bareinboim2016causal)
algorithms have been developed. Conventional methods have limited applicability in complex domains where observational data is high-dimensional and partially observable. Recent work has shown that causal learning can take advantage of the representational power of deep learning methods for inducing causal relationships from interventions(dasgupta2019causal) and for improving policy learning via counterfactual reasoning (buesing2018woulda). However, they have focused on toy-sized problems with low-dimensional states. In contrast, our model induces and makes use of the causal structure for complex interactive tasks from raw image observations.
lake2017building are among the first to discuss the limitations of state-of-the-art deep learning models in causal reasoning. There has been a growing amount of efforts in marrying the complementary strengths of deep learning and causal reasoning. Causal modeling has been explored in several contexts, including image classification (chalupka2014visual), generative models (kocaoglu2017causalgan), robot planning (kurutach2018learning), policy learning (buesing2018woulda), and transfer learning (bengio2019meta). Pioneer work on causal discovery with deep networks has applied to time series data in healthcare domains (kale2015causal; nauta2019causal). In addition, adversarial learning (kalainathan2018sam)
, graph neural networks(yu2019dag), and gradient-based DAG learning (lachapelle2019gradient) have been recently introduced to causal discovery, but largely focusing on small synthetic datasets. Most relevant to ours are dasgupta2019causal and edmonds2019human, which investigated causal reasoning in deep reinforcement learning agents. In contrast to them, our method directly learns an explicit causal structure from raw observations to solve multi-step, goal-conditioned tasks.
Generalization to new environments and new goals has been a central challenge for learning-based interactive agents. This problem has been previously studied in in the context of domain adaptation (tzeng2015towards; peng_domain_adapt), system identification (yu_sysid; zhou2019environment), meta-learning (maml; saemundsson2018meta), and multi-task learning (her; uvfa). These works have addressed variations in dynamics, visual appearances, and task rewards, while assuming fixed causal structures. Instead, we focus on changes in the latent causal relationships that determine the preconditions and effects of actions.
We have proposed novel techniques for 1) causal induction from raw visual observations and 2) causal graph encoding for goal-conditioned policies, both of which lead to better generalization to unseen causal structures. Our key insight is that by leveraging iterative predictions and attention bottlenecks, it facilitates our causal induction model and goal-conditioned policy to focus on the relevant part of the causal graph. Using this approach we show better generalization towards novel problem instances than previous works with limited training causal structures.
In this work, we induce the causal structure from observational data collected by a heuristic policy. We plan to explore more complex tasks where probing the environment to discover the causal structure requires more sophisticated strategies, and develop algorithms that jointly learn the interaction policy.
Suraj Nair is supported by an NSF GRFP award. Yuke Zhu is supported by the Tencent AI Lab PhD Fellowship.
Appendix A Architecture Details
a.1 Induction Models
a.1.1 Observation Encoder
The image encoder used in all models takes as input a
image of the scene, feeds it through 3 convolutional layers with each followed by ReLU activation andMax Pooling. The output filters of the convolutions are 8, 16, and 32 respectively, and the resulting tensor is mapped to latent vector of size using a single fully connected layer, where is the number of switches/lights.
a.1.2 ICIN Transition Encoder
The transition encoder used in our iterative model takes as input a state residual of dimension and an action of size concatenated together, and feeds it through fully connected layers of size 1024 and 512 each with ReLU activation, followed by another layer which outputs an attention vector of size (or in the Masterswitch case) with SoftMax activation and an edge update of size with Sigmoid activation. The first two layers are trained with dropout of .
a.1.3 Icin Encoder
The causal graph encoder used in the last step our iterative model takes as input the (or in the Masterswitch Case) flattened edge weights of the current graph and feeds them though a single fully connected layer, which outputs an attention vector of size (or in the Masterswitch Case) with SoftMax activation and an edge update of size with Sigmoid activation.
a.1.4 ICIN (No Attn)
The non-iterative ablation of our method has an identical architecture, but instead of the third fully connected layer outputting an attention vector of size with SoftMax activation and an edge update of size with Sigmoid activation, it instead updates a full set of edge weights with Sigmoid activation.
The temporal convolution induction network uses the same image encoder as our approach. However the size state encodings are then concatenated with the size
action labels, and then are passed through three layers of temporal convolutions, with filter size [256, 128, 128] and a size 3 kernel with stride 1. The horizonby 128 dimensional output is then flattened and fed through fully connected layers of size 1024 and 512, each with dropout. Finally, the size causal graph is outputted with Sigmoid activation.
a.2 Policy Architecture
a.2.1 Attention Based
The same image encoder as the induction models is used for the policy, except as input it takes a image which contains the current image and goal image, concatenated channel wise. The encoded image is then flattened and fed through a fully connected layer of size 128, which then outputs an attention vector of size , which is used to do a weighted sum over the edges of the causal graph, producing a vector of size . This is then encoded to size 128, concatenated with the 128 dim image encoding, and def through 2 more fully connected layers of size 64 and ultimately outputting the final action prediction of size .
In the non attention version the architecture is identical except the full graph is flattened, then encoded and concatenated directly with the image encoding.
a.3 Memory Baseline
In this baseline we use an image encoder as above, except there is an additional input for action. There is also an LSTM Cell of hidden dimension 256 which the image encoding and action encoding are fed into, which is then fed through fully connected layers of size 256, 64 which output the action.
a.4 Memory (RL/Low Dim) Baseline
The policy is a MLP-LSTM policy as implemented in stable-baselines, with two fully connected layers of size 64, and an LSTM layer with 256 hidden units. It is augmented with additional input heads for each step of the observational data, namely images and size actions.
Appendix B Training Details
b.1 Causal Induction Model () Training
Each causal induction model is trained for each split of seen/unseen causal structures as described in the experiments. The is trained offline on all and corresponding , for 60000 training iterations using Adam optimizer (adam)
with learning rate 0.0001 and batch size 512. They are implemented using PyTorch(pytorch) and trained on an NVIDIA Titan X GPU.
b.2 DAgger Policy Training
The policies trained with DAgger (dagger) are trained in the training environments with episode horizon . The policy takes actions in the environment, and at each step an expert action is appended to the policy’s memory buffer. The policy is then trained to imitate the experience in the memory. The expert uses the ground-truth causal graph and ground-truth low dimensional states to compute the difference between the goal and current state, and based on the graph what action needs to be taken. Each policy is trained for 100000 episodes, with learning rate 0.0001 and batch size from the memory of 32.
b.3 RL Policy Training
The Memory (RL/Low Dim) baseline is trained using the Proximal Policy Optimization algorithm (ppo) as implemented in Stable-Baselines (stable-baselines). They use hyper-parameters , 128 steps per update, entropy coefficient 0.01, learning rate , value function coefficient and . The policy itself consists of two fully connected layers of size 64, as well as an LSTM layer consisting of a 256 size hidden state. The policy is trained until the policy performance converges on unseen causal structures, and capped at a max of 9 million episodes. In this experiment, we set the horizon of each episode equal to the number of switches/lights .
Appendix C Causal Induction Qualitative Examples
Here we demonstrate an example trajectory and how the causal induction model iteratively builds the causal graph.