Invariant Causal Prediction for Block MDPs

by   Amy Zhang, et al.

Generalization across environments is critical to the successful application of reinforcement learning algorithms to real-world challenges. In this paper, we consider the problem of learning abstractions that generalize in block MDPs, families of environments with a shared latent state space and dynamics structure over that latent space, but varying observations. We leverage tools from causal inference to propose a method of invariant prediction to learn model-irrelevance state abstractions (MISA) that generalize to novel observations in the multi-environment setting. We prove that for certain classes of environments, this approach outputs with high probability a state abstraction corresponding to the causal feature set with respect to the return. We further provide more general bounds on model error and generalization error in the multi-environment setting, in the process showing a connection between causal variable selection and the state abstraction framework for MDPs. We give empirical evidence that our methods work in both linear and nonlinear settings, attaining improved generalization over single- and multi-task baselines.



There are no comments yet.


page 1

page 2

page 3

page 4


Model-Invariant State Abstractions for Model-Based Reinforcement Learning

Accuracy and generalization of dynamics models is key to the success of ...

Multi-Task Reinforcement Learning as a Hidden-Parameter Block MDP

Multi-task reinforcement learning is a rich paradigm where information f...

Randomized Value Functions via Posterior State-Abstraction Sampling

State abstraction has been an essential tool for dramatically improving ...

Learning Invariant Representations for Reinforcement Learning without Reconstruction

We study how representation learning can accelerate reinforcement learni...

Counterexample-guided Planning

Planning in adversarial and uncertain environments can be modeled as the...

Learning Shared Representations in Multi-task Reinforcement Learning

We investigate a paradigm in multi-task reinforcement learning (MT-RL) i...

LEADS: Learning Dynamical Systems that Generalize Across Environments

When modeling dynamical systems from real-world data samples, the distri...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The canonical reinforcement learning (RL) problem assumes an agent interacting with a single MDP with a fixed observation space and dynamics structure. This assumption is difficult to ensure in practice, where state spaces are often large and infeasible to explore entirely during training. However, there is often a latent structure to be leveraged to allow for good generalization. As an example, a robot’s sensors may be moved, or the lighting conditions in a room may change, but the physical dynamics of the environment are still the same. These are examples of environment-specific characteristics that current RL algorithms often overfit to. In the worst case, some training environments may contain spurious correlations that will not be present at test time, causing catastrophic failures in generalization (Zhang et al., 2018a; Song et al., 2020). To develop algorithms that will be robust to these sorts of changes, we must consider problem settings that allow for multiple environments with a shared dynamics structure.

Recent prior works (Amit and Meir, 2018; Yin et al., 2019) have developed generalization bounds for the multi-task problem, but they depend on the number of tasks seen at training time, which can be prohibitively expensive given how sample inefficient RL is even in the single task regime. To obtain stronger generalization results, we propose to consider a problem which we refer to as ‘multi-environment’ RL: like multi-task RL, the agent seeks to maximize return on a set of environments, but only some of which can be trained on. We make the assumption that there exists some latent causal structure that is shared among all of the environments, and that the sources of variability between environments do not affect reward. This family of environments is called a Block MDP (Du et al., 2019), in which the observations may change, but the latent states, dynamics, and reward function are the same. A formal definition of this type of MDP will be presented in Section 3.

Though the setting we consider is a subset of the multi-task RL problem, we show in this work that the added assumption of shared structure allows for much stronger generalization results than have been obtained by multi-task approaches. Naive application of generalization bounds to the multi-task reinforcement learning setting is very loose because the learner is typically given access to only a few tasks relative to the number of samples from each task. Indeed, Cobbe et al. (2018); Zhang et al. (2018b) find that agents trained using standard methods require many thousands of environments before succeeding at ‘generalizing’ to new environments.

The main contribution of this paper is to use tools from causal inference to address generalization in the Block MDP setting, proposing a new method based on the invariant causal prediction literature. In certain linear function approximation settings, we demonstrate that this method will, with high probability, learn an optimal state abstraction that generalizes across all environments using many fewer training environments than would be necessary for standard PAC bounds. We replace this PAC requirement with requirements from causal inference on the types of environments seen at training time. We then draw a connection between bisimulation and the minimal causal set of variables found by our algorithm, providing bounds on the model error and sample complexity of the method. We further show that using analogous invariant prediction methods for the nonlinear function approximation setting can yield improved generalization performance over multi-task and single-task baselines. We relate this method to previous work on learning representations of MDPs (Gelada et al., 2019; Luo et al., 2019) and develop multi-task generalization bounds for such representations.

2 Background

2.1 State Abstractions and Bisimulation

State abstractions have been studied as a way to distinguish relevant from irrelevant information (Li et al., 2006) in order to create a more compact representation for easier decision making and planning. Bertsekas and Castanon (1989); Roy (2006) provide bounds for approximation errors for various aggregation methods, and Li et al. (2006) discuss the merits of abstraction discovery as a way to solve related MDPs.

Bisimulation relations are a type of state abstraction that offers a mathematically precise definition of what it means for two environments to ‘share the same structure’ (Larsen and Skou, 1989; Givan et al., 2003). We say that two states are bisimilar if they share the same expected reward and equivalent distributions over the next bisimilar states. For example, if a robot is given the task of washing the dishes in a kitchen, changing the wallpaper in the kitchen doesn’t change anything relevant to the task. One then could define a bisimulation relation that equates observations based on the locations and soil levels of dishes in the room and ignores the wallpaper. These relations can be used to simplify the state space for tasks like policy transfer (Castro and Precup, 2010), and are intimately tied to state abstraction. For example, the model-irrelevance abstraction described by Li et al. (2006) is precisely characterized as a bisimulation relation.

Definition 1 (Bisimulation Relations (Givan et al., 2003)).

Given an MDP , an equivalence relation between states is a bisimulation relation if for all states that are equivalent under (i.e. ), the following conditions hold for all actions :

Where denotes the partition of under the relation , the set of all groups of equivalent states, and where

Whereas this definition was originally designed for the single MDP setting to find bisimilar states within an MDP, we are now trying to find bisimilar states across different MDPs, or different experimental conditions. One can intuitively think of this carrying over by imagining all experimental conditions mapped to a single super-MDP with state space where we give up the irreducibility assumption, i.e. we can no longer reach every state from any other state . Specifically, we say that two MDPs and are bisimilar if there exist bisimulation relations and such that is isomorphic to . Bisimilar MDPs are therefore MDPs which are behaviourally the same.

2.2 Causal Inference Using Invariant Prediction

Peters et al. (2016) first introduced an algorithm, Invariant Causal Prediction (ICP), to find the causal feature set, the minimal set of features which are causal predictors of a target variable, by exploiting the fact that causal models have an invariance property (Pearl, 2009; Schölkopf et al., 2012). Arjovsky et al. (2019) extend this work by proposing invariant risk minimization (IRM, see Equation 1), augmenting empirical risk minimization to learn a data representation free of spurious correlations. They assume there exists some partition of the training data into experiments , and that the model’s predictions take the form . IRM aims to learn a representation

for which the optimal linear classifier,

, is invariant across , where optimality is defined as minimizing the empirical risk . We can then expect this representation and classifier to have low risk in new experiments , which have the same causal structure as the training set.


The IRM objective in Equation 1 can be thought of as a constrained optimization problem, where the objective is to learn a set of features for which the optimal classifier in each environment is the same. Conditioned on the environments corresponding to different interventions on the data-generating process, this is hypothesized to yield features that only depend on variables that bear a causal relationship to the predicted value. Because the constrained optimization problem is not generally feasible to optimize, Arjovsky et al. (2019) propose a penalized optimization problem with a schedule on the penalty term as a tractable alternative.

3 Problem Setup

We consider a family of environments , where is some index set. For simplicity of notation, we drop the subscript when referring to the union over all environments . Our goal is to use a subset of these environments to learn a representation which enables generalization of a learned policy to every environment. We denote the number of training environments as . We assume that the environments share some structure, and consider different degrees to which this structure may be shared.

3.1 The Block MDP

Block MDPs (Du et al., 2019) are described by a tuple with a finite, unobservable state space , finite action space , and possibly infinite, but observable space . denotes the latent transition distribution for , is the (possibly stochastic) emission function that gives the observations from the latent state for , and the reward function. A graphical model of the interactions between the various variables can be found in Figure 1.

Assumption 1 (Block structure (Du et al., 2019)).

Each observation uniquely determines its generating state . That is, the observation space can be partitioned into disjoint blocks , each containing the support of the conditional distribution .

This assumption gives us the Markov property in . We translate the block MDP to our multi-environment setting as follows. If a family of environments satisfies the block MDP assumption, then each corresponds to an emission function , with and shared for all . We will move the potential randomness from into an auxiliary variable , with some probability space, and write . Further, we require that if , then . The objective is to learn a useful state abstraction to promote generalization across the different emission functions , given that only a subset is provided for training. Song et al. (2020) also describes a similar POMDP setting where there is an additional observation function, but assume information can be lost. We note that this problem can be made arbitrarily difficult if each has a disjoint range, but will focus on settings where the overlap in structured ways – for example, where is the concatenation of the noise and state variables: .

3.2 Relaxations

Figure 1: Graphical model of a block MDP with stochastic, correlated observations, with an IRM goal to extract from the sequence of observations, and discard the spurious noise . Red dashed ovals indicate the entire tangled latent state at each timestep. Black dashed lines and grey lines indicate two additional tiers of difficulty to consider.

Spurious correlations. Our initial presentation of the block MDP assumes that the noise variable is sampled randomly at every time step, which prevents multi-timestep correlations (Figure 1 in black, solid lines). We therefore also consider a more realistic relaxed block MDP, where spurious variables may have different transition dynamics across the different environments so long as these correlations do not affect the expected reward (Figure 1, now including black dashed lines). This is equivalent to augmenting each Block MDP in our family with a noise variable , such that the observation , and

We note that this section still satisfies Assumption 1.

Realizability. Though our analysis will require Assumption 1, we claim that this is a reasonable requirement as it makes the learning problem realizable. Relaxing Assumption 1 means that the value function learning problem may become ill-posed, as the same observation can map to entirely different states in the latent MDP with different values, making our environment partially observable (a POMDP, Figure 1 with grey lines). We provide a lower bound on the value approximation error attainable in this setting in the appendix (Proposition 2).

3.3 Assumptions on causal structure

State abstraction and causal inference both aim to eliminate spurious features in a learning algorithm’s input. However, these two approaches are applied to drastically different types of problems. Though we demonstrate that causal inference methods can be applied to reinforcement learning, this will require some assumption on how causal mechanisms are observed. Definitions of the notation used in this section are deferred to the appendix, though they are standard in the causal inference literature.

The key assumption we make is that the variables in the environment state at time can only affect the values of the state at time , and can only affect the reward at time . This assumption allows us to consider the state and action at time as the only candidate for causal parents of the state at time and of the reward at time

. This assumption is crucial to the Markov behaviour of the Markov decision process. We refer the reader to

Figure 2 to demonstrate how causal graphical models can be translated to this setting.

Assumption 2 (Temporal Causal Mechanisms).

Let and be components of the observation . Then when no intervention is performed on the environment, we have the following independence.

Assumption 3 (Environment Interventions).

Let , and . Each environment corresponds to a do- (Pearl, 2009) or soft (Eberhardt and Scheines, 2007) intervention on a single variable in the observation space.

This assumption allows us to use tools from causal inference to identify candidate model-irrelevance state abstractions that may hold across an entire family of MDPs, rather than only the ones observed, based on using the state at one timestep to predict values at the next timestep. In the setting of Assumption 3, we can reconstruct the block MDP emission function by concatenating the spurious variables from to . We discuss some constraints on interventions necessary to satisfy the block MDP assumption in the appendix.

4 Connecting State Abstractions to Causal Feature Sets

Invariant causal prediction aims to identify a set of causal variables such that a linear predictor with support on will attain consistent performance over all environments. In other words, ICP removes irrelevant variables from the input, just as state abstractions remove irrelevant information from the environment’s observations. An attractive property of the block MDP setting is that it is easy to show that there does exist a model-irrelevance state abstraction for all MDPs in – namely, the function mapping each observation to its generating latent state . The formalization and proof of this statement are deferred to the appendix (see Theorem 4).

We consider whether, under Assumptions 1-3, such a state abstraction can be obtained by ICP. Intuitively, one would then expect that the causal variables should have nice properties as a state abstraction. The following result confirms this to be the case; a state abstraction that selects the set of causal variables from the observation space of a block MDP will be a model-irrelevance abstraction for every environment .

Theorem 1.

Consider a family of MDPs , with . Let satisfy Assumptions 1-3. Let be the set of variables such that the reward is a function only of ( restricted to the indices in ). Then let denote the ancestors of in the (fully observable) causal graph corresponding to the transition dynamics of . Then the state abstraction is a model-irrelevance abstraction for every .

Figure 2: Graphical causal models with temporal dependence – note that while (circled in blue) is the only causal parent of the reward, because its next-timestep distribution depends on , a model-irrelevance state abstraction must include both variables. Shaded in blue: the graphical causal model of an MDP with states when ignoring timesteps.

An important detail in the previous result is the model irrelevance state abstraction incorporates not just the parents of the reward, but also its ancestors. This is because in RL, we seek to model return rather than solely rewards, which requires a state abstraction that can capture multi-timestep interactions. We provide an illustration of this in Figure 2. As a concrete example, we note that in the popular benchmark CartPole, only position and angle are necessary to predict the reward. However, predicting the return requires and , their respective velocities.

Learning a minimal in the setting of Theorem 1 using a single training environment may not always be possible. However, applying invariant causal prediction methods in the multi-environment setting will yield the minimal causal set of variables when the training environment interventions satisfy certain conditions necessary for the identifiability of the causal variables (Peters et al., 2016).

5 Block MDP Generalization Bounds

We continue to relax the assumptions needed to learn a causal representation and look to the nonlinear setting. As a reminder, the goal of this work is to produce representations that will generalize from the training environments to a novel test environment. However, normal PAC generalization bounds require a much larger number of environments than one could expect to obtain in the reinforcement learning setting. The appeal of an invariant representation is that it may allow for theoretical guarantees on learning the right state abstraction with many fewer training environments, as discussed by Peters et al. (2016). If the learned state abstraction is close to capturing the true base MDP, then the model error in the test environment can be bounded by a function of the distance of the test environment’s abstract state distribution to the training environments’. Though the requirements given in the following Theorem 2 are difficult to guarantee in practice, the result will hold for any arbitrary learned state abstraction.

Theorem 2 (Model error bound).

Consider an MDP , with denoting a coarser bisimulation of . Let denote the mapping from states of to states of . Suppose that the dynamics of are -Lipschitz w.r.t. and that is some approximate transition model satisfying , for some . Let denote the 1-Wasserstein distance. Then


Proof found in Appendix B.

Instead of assuming access to a bisimilar MDP , we can provide discrepancy bounds for an MDP produced by a learned state representation , dynamics function , and reward function using the distance in dynamics and reward of to the underlying MDP . We first define these distances,

Theorem 3.

Let be a block MDP and the learned invariant MDP with a mapping . For any -Lipschitz valued policy the value difference of that policy is bounded by


where is the value function for in and is the value function for in .

Proof found in Appendix B. This gives us a bound on generalization performance that depends on the supremum of the dynamics and reward errors, which correspondingly is a regression problem that will depend on , the number of samples we have in aggregate over all training environments rather than the number of training environments,

. Recent generalization bounds for deep neural networks using Rademacher complexity 

(Bartlett et al., 2017a; Arora et al., 2018) scale with a factor of where is the number of samples. We can use for our setting, getting generalization bounds for the block MDP setting that scale with the number of samples in aggregate over all environments, an improvement over previous multi-task bounds that depend on .

6 Methods

Given these theoretical results, we propose two methods to learn invariant representations in the block MDP setting. Both methods take inspiration from invariant causal prediction, with the first being the direct application of linear ICP to select the causal variables in the state in the setting where variables are given. This corresponds to direct feature selection, which with high probability returns the minimal causal feature set. The second method is a gradient-based approach akin to the IRM objective, with no assumption of a linear causal relationship and a learned causal invariant representation. Like the IRM goal (

Equation 1), we aim to learn an invariant state abstraction from stochastic observations across different interventions , and impose an additional invariance constraint.

6.1 Variable Selection for Linear Predictors

The following algorithm (Algorithm 1) returns a model-irrelevance state abstraction. We require the presence of a replay buffer , in which transitions are stored and tagged with the environment from which they came. The algorithm then applies ICP to find all causal ancestors of the reward iteratively. This approach has the benefit of inheriting many nice properties from ICP – under suitable identifiability conditions, it will return the exact causal variable set to a specified degree of confidence.

It also inherits inconvenient properties: the ICP algorithm is exponential in the number of variables, and so this method is not efficient for high-dimensional observation spaces. We are also restricted to considering linear relationships of the observation to the reward and next state. Further, because we take the union over iterative applications of ICP, the confidence parameter used in each call must be adjusted accordingly. Given observation variables, we give a conservative bound of .

Result: , the causal state variables

Input: , a confidence parameter, , an replay buffer with observations .  

stack r
while stack is not empty do

       = stack.pop()
if  then
             ICP(v, , )

       end if
end while
Algorithm 1 Linear MISA: Model-irrelevance State Abstractions

6.2 Learning a Model-irrelevance State Abstraction

We design an objective to learn a dynamics preserving state abstraction , or model-irrelevance abstraction (Li et al., 2006), where the similarity of the model is bounded by the model error in the environment setting shown in Figure 1. This requires disentangling the state space into a minimal representation that causes reward and everything else . Our algorithm proceeds as follows.

We assume the existence of an invariant state embedding, whose mapping function we denote by . We also assume an invariant dynamics model , a task-specific dynamics model , and an invariant reward model in the embedding space. To incorporate a meaningful objective and ground the learned representation, we need a decoder . We assume training environments are given. The overall dynamics and reward objectives become

under data collected from behavioral policies for each experimental setting.

Of course, this does not guarantee that the representation learned by will be minimal, so we incorporate additional regularization as an incentive. We train a task classifier on the shared latent representation with cross-entropy loss and employ an adversarial loss (Tzeng et al., 2017) on to maximize the entropy of the classifier output to ensure task specific information is not passing through to .

This gives us a final objective


where and

are hyperparameters and

denotes entropy (Algorithm 2).

Result: , an invariant state encoder

while forever do

       for  do
       end for
      for  do
             Sample batch from

       end for
end while
Algorithm 2 Nonlinear Model-irrelevance State Abstraction (MISA) Learning

7 Results

We evaluate both linear and non-linear versions of MISA, in corresponding Block MDP settings with both linear and non-linear dynamics. First, we examine model error in environments with low-dimensional (Section 7.1.1) and high-dimensional (Section 7.1.2

) observations and demonstrate the ability for MISA to zero-shot generalize to unseen environments. We next look to imitation learning in a rich observation setting (

Section 7.2) and show non-linear MISA generalize to new camera angles. Finally, we explore end-to-end reinforcement learning in the low-dimensional observation setting with correlated noise (Section 7.3) and again show generalization capabilities where single task and multi-task methods fail.

7.1 Model Learning

7.1.1 Linear Setting

We first evaluate the linear MISA algorithm in Algorithm 1. To empirically evaluate whether eliminating spurious variables from a representation is necessary to guarantee generalization, we consider a simple family of MDPs with state space , with a transition dynamics structure such that , , and . We train on 3 environments with soft interventions on each noise variable. We then run the linear MISA algorithm on batch data from these 3 environments to get a state abstraction , and then train 2 linear predictors on and . We then evaluate the generalization performance for novel environments that correspond to different hard interventions on the value of the variable. We observe that the predictor trained on attains zero generalization error because it zeros out automatically. However, any nonzero weight on in the least-squares predictor will lead to arbitrarily large generalization error, which is precisely what we observe in Figure 3.

Figure 3: The presence of spurious uncorrelated variables in the state can still lead to poor generalization of linear function approximation methods. Invariant Causal Prediction methods can eliminate these spurious variables altogether.

7.1.2 Rich Observation Setting

We next test the gradient-based MISA method (Algorithm 2) in a setting with nonlinear dynamics and rich observations. Instead of having access to observation variables and selecting the minimal causal feature set, we are tasked with learning the invariant causal representation. We randomly initialize the background color of two train environments from Deepmind Control (Tassa et al., 2018) from range . We also randomly initialize another two backgrounds for evaluation. The orange line in Figure 4 shows performance on the evaluation environments in comparison to three baselines. In the first, we only train on a single environment and test on another with our method, (MISA - 1 env). Without more than a single experiment to observe at training time, there is no way to disentangle what is causal in terms of dynamics, and what is not. In the second baseline, we combine data from the two environments and train a model over all data (Baseline - 1 decoder). The third is another invariance-based method which uses a gradient penalty, IRM (Arjovsky et al., 2019)

. In the second case the error is tempered by seeing variance in the two environments at training time, but it is not as effective as MISA with two environments at disentangling what is invariant, and therefore causal with respect to dynamics, and what is not. With IRM, the loss starts much higher but very slowly decreases, and we find it is very brittle to tune in practice. Implementation details found in

Section C.1.

Figure 4:

Model error on evaluation environments on Cheetah Run from Deepmind Control. 10 seeds, with one standard error shaded.

7.2 Imitation Learning

In this setup, we first train an expert policy using the proprioceptive state of Cheetah Run from (Tassa et al., 2018). We then use this policy to collect a dataset for imitation learning in each of two training environments. When rendering these low dimensional images, we alter the camera angles in the different environments (Figure 5). We report the generalization performance as the test error when predicting actions in Figure 6. While we see test error does increase with our method, MISA, the error growth is significantly slower compared to single task and multi-task baselines.

Figure 5: The Cheetah Run environment from Deepmind Control with different camera angles. The first two images are from the training environments and the last image is from evaluation environment.
Figure 6: Actor error on evaluation environments on Cheetah Run from Deepmind Control. 10 seeds, with one standard error shaded.

7.3 Reinforcement Learning

We go back to the proprioceptive state in the cartpole_swingup environment in Deepmind Control (Tassa et al., 2018) to show that we can learn MISA while training a policy. We use Soft Actor Critic (Haarnoja et al., 2018) with an additional linear encoder, and add spurious correlated dimensions which are a multiplicative factor of the original state space. We also add an additional environment identifier to the observation. This multiplicative factor varies across environments, and we train on two environments with and , and test on . Like Arjovsky et al. (2019), we also incorporate noise on the causal state to make the task harder, specifically Gaussian noise to the true state dimension. This incentivizes the agent to attend to the spuriously correlated dimension instead, which has no noise. In Figure 7 we see the generalization gap drastically improve with our method in comparison to training SAC with data over all environments in aggregate and with IRM (Arjovsky et al., 2019) implemented on the critic loss. Implementation details and more information about Soft Actor Critic can be found in Section C.2.

Figure 7: Generalization gap in SAC performance with 2 training environments on cartpole_swingup from DMC. Evaluated with 10 seeds, standard error shaded.

8 Related Work

8.1 Prior Work on Generalization Bounds

Generalization bounds provide guarantees on the test set error attained by an algorithm. Most of these bounds are probabilistic and targeted at the supervised setting, falling into the PAC (Probably Approximately Correct) framework. PAC bounds give probabilistic guarantees on a model’s true error as a function of its train set error and the complexity of the function class encoded by the model. Many measures of hypothesis class complexity exist: the Vapnik-Chernovenkis (VC) dimension (Vapnik and Chervonenkis, 1971), the Lipschitz constant, and classification margin of a neural network (Bartlett et al., 2017b), and second-order properties of the loss landscape (Neyshabur et al., 2019) are just a few of many.

Analogous techniques can be applied to Bayesian methods, giving rise to PAC-Bayes bounds (McAllester, 1999). This family of bounds can be optimized to yield non-vacuous bounds on the test error of over-parametrized neural networks (Dziugaite and Roy, 2017), and have demonstrated strong empirical correlation with model generalization (Jiang* et al., 2020). More recently, Amit and Meir (2018); Yin et al. (2019) introduce a PAC-Bayes bound for the multi-task setting dependent on the number of tasks seen at training time.

Strehl et al. (2006) extend PAC framework to reinforcement learning, defining a new class of bounds called PAC-MDP. An algorithm is PAC-MDP if for any and , the sample complexity of the algorithm is less than some polynomial in with probability at least . The authors provide a PAC-MDP algorithm for model-free Q-learning. Lattimore and Hutter (2012) offers lower and upper bounds on the sample complexity of learning near-optimal behavior in MDPs by modifying the Upper Confidence RL (UCRL) algorithm (Jaksch et al., 2010).

8.2 Multi-Task Reinforcement Learning

Teh et al. (2017); Borsa et al. (2016) handle multi-task reinforcement learning with a shared “distilled” policy (Teh et al., 2017) and shared state-action representation (Borsa et al., 2016) to capture common or invariant behavior across all tasks. No assumptions are made about how these tasks relate to each other other than a shared state and action space.

D’Eramo et al. (2020) show the benefits of learning a shared representation in multi-task settings with an approximate value iteration bound and Brunskill and Li (2013) also demonstrate a PAC-MDP algorithm with improved sample efficiency bounds through transfer across similar tasks. Again, none of these works look to the multi-environment setting to explicitly leverage environment structure. Barreto et al. (2017) exploit successor features for transfer, making the assumption that the dynamics across tasks are the same, but the reward changes. However, they do not handle the setting where states are latent, and observations change.

9 Discussion

This work has demonstrated that given certain assumptions, we can use causal inference methods in reinforcement learning to learn an invariant causal representation that generalizes across environments with a shared causal structure. We have provided a framework for defining families of environments, and methods, for both the low dimensional linear value function approximation setting and the deep RL setting, which leverage invariant prediction to extract a causal representation of the state. We have further provided error bounds and identifiability results for these representations. We see this paper as a first step towards the more significant problem of learning useful representations for generalization across a broader class of environments. Some examples of potential applications include third-person imitation learning, sim2real transfer, and, related to sim2real transfer, the use of privileged information in one task (the simulation) as grounding and generalization to new observation spaces  (Salter et al., 2019).

10 Acknowledgements

MK has received funding from the European Research Council (ERC) under the European Union’s Horizon 2020 research and innovation programme (grant agreement No. 834115). The authors would also like to thank Marlos Machado for helpful feedback in the writing process.


  • R. Amit and R. Meir (2018) Meta-learning by adjusting priors based on extended PAC-Bayes theory. In Proceedings of the 35th International Conference on Machine Learning, J. Dy and A. Krause (Eds.), Proceedings of Machine Learning Research, Vol. 80, Stockholmsmässan, Stockholm Sweden, pp. 205–214. External Links: Link Cited by: §1, §8.1.
  • M. Arjovsky, L. Bottou, I. Gulrajani, and D. Lopez-Paz (2019) Invariant Risk Minimization. arXiv e-prints. External Links: 1907.02893 Cited by: §C.2, §2.2, §2.2, §7.1.2, §7.3.
  • S. Arora, R. Ge, B. Neyshabur, and Y. Zhang (2018) Stronger generalization bounds for deep nets via a compression approach. In 35th International Conference on Machine Learning, ICML 2018, A. Krause and J. Dy (Eds.), 35th International Conference on Machine Learning, ICML 2018, pp. 390–418 (English (US)). Cited by: §5.
  • J. L. Ba, J. R. Kiros, and G. E. Hinton (2016) Layer normalization. arXiv e-prints. Cited by: §C.1.
  • A. Barreto, W. Dabney, R. Munos, J. J. Hunt, T. Schaul, H. P. van Hasselt, and D. Silver (2017) Successor features for transfer in reinforcement learning. In Advances in Neural Information Processing Systems 30, I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett (Eds.), pp. 4055–4065. External Links: Link Cited by: §8.2.
  • P. L. Bartlett, D. J. Foster, and M. J. Telgarsky (2017a) Spectrally-normalized margin bounds for neural networks. In Advances in Neural Information Processing Systems 30, I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett (Eds.), pp. 6240–6249. External Links: Link Cited by: §5.
  • P. L. Bartlett, D. J. Foster, and M. J. Telgarsky (2017b) Spectrally-normalized margin bounds for neural networks. In Advances in Neural Information Processing Systems, pp. 6240–6249. Cited by: §8.1.
  • D. Bertsekas and D. Castanon (1989) Adaptive aggregation for infinite horizon dynamic programming. Automatic Control, IEEE Transactions on 34, pp. 589 – 598. External Links: Document Cited by: §2.1.
  • D. Borsa, T. Graepel, and J. Shawe-Taylor (2016) Learning shared representations in multi-task reinforcement learning. CoRR abs/1603.02041. External Links: Link, 1603.02041 Cited by: §8.2.
  • E. Brunskill and L. Li (2013) Sample complexity of multi-task reinforcement learning.

    Uncertainty in Artificial Intelligence - Proceedings of the 29th Conference, UAI 2013

    , pp. .
    Cited by: §8.2.
  • P. S. Castro and D. Precup (2010) Using bisimulation for policy transfer in mdps. In Twenty-Fourth AAAI Conference on Artificial Intelligence, Cited by: §2.1.
  • K. Cobbe, O. Klimov, C. Hesse, T. Kim, and J. Schulman (2018) Quantifying generalization in reinforcement learning. CoRR abs/1812.02341. External Links: Link, 1812.02341 Cited by: §1.
  • C. D’Eramo, D. Tateo, A. Bonarini, M. Restelli, and J. Peters (2020) Sharing knowledge in multi-task deep reinforcement learning. In International Conference on Learning Representations, External Links: Link Cited by: §8.2.
  • S. S. Du, A. Krishnamurthy, N. Jiang, A. Agarwal, M. Dudík, and J. Langford (2019) Provably efficient RL with rich observations via latent state decoding. CoRR abs/1901.09018. External Links: Link, 1901.09018 Cited by: §1, §3.1, Assumption 1.
  • G. K. Dziugaite and D. M. Roy (2017) Computing nonvacuous generalization bounds for deep (stochastic) neural networks with many more parameters than training data. Cited by: §8.1.
  • F. Eberhardt and R. Scheines (2007) Interventions and causal inference. Philosophy of Science 74 (5), pp. 981–995. External Links: Document, Link, Cited by: Assumption 3.
  • C. Gelada, S. Kumar, J. Buckman, O. Nachum, and M. G. Bellemare (2019) DeepMDP: learning continuous latent space models for representation learning. In Proceedings of the 36th International Conference on Machine Learning, K. Chaudhuri and R. Salakhutdinov (Eds.), Proceedings of Machine Learning Research, Vol. 97, Long Beach, California, USA, pp. 2170–2179. Cited by: §1.
  • R. Givan, T. L. Dean, and M. Greig (2003) Equivalence notions and model minimization in markov decision processes. Artif. Intell. 147, pp. 163–223. Cited by: §2.1, Definition 1.
  • T. Haarnoja, A. Zhou, P. Abbeel, and S. Levine (2018) Soft actor-critic: off-policy maximum entropy deep reinforcement learning with a stochastic actor. In Proceedings of the 35th International Conference on Machine Learning, J. Dy and A. Krause (Eds.), Proceedings of Machine Learning Research, Vol. 80, Stockholmsmässan, Stockholm Sweden, pp. 1861–1870. External Links: Link Cited by: §C.2, §7.3.
  • T. Jaksch, R. Ortner, and P. Auer (2010) Near-optimal regret bounds for reinforcement learning. J. Mach. Learn. Res. 11, pp. 1563–1600. External Links: ISSN 1532-4435 Cited by: §8.1.
  • Y. Jiang*, B. Neyshabur*, D. Krishnan, H. Mobahi, and S. Bengio (2020) Fantastic generalization measures and where to find them. In International Conference on Learning Representations, External Links: Link Cited by: §8.1.
  • K. G. Larsen and A. Skou (1989) Bisimulation through probabilistic testing (preliminary report). In Proceedings of the 16th ACM SIGPLAN-SIGACT Symposium on Principles of Programming Languages, POPL ’89, New York, NY, USA, pp. 344–352. External Links: ISBN 0897912942, Link, Document Cited by: §2.1.
  • T. Lattimore and M. Hutter (2012) PAC bounds for discounted mdps. In International Conference on Algorithmic Learning Theory, pp. 320–334. Cited by: §8.1.
  • L. Li, T. Walsh, and M. Littman (2006) Towards a unified theory of state abstraction for mdps.. pp. . Cited by: §2.1, §2.1, §6.2.
  • Y. Luo, H. Xu, Y. Li, Y. Tian, T. Darrell, and T. Ma (2019) Algorithmic framework for model-based deep reinforcement learning with theoretical guarantees. In International Conference on Learning Representations, External Links: Link Cited by: §1.
  • D. A. McAllester (1999) PAC-bayesian model averaging. In COLT, Vol. 99, pp. 164–170. Cited by: §8.1.
  • B. Neyshabur, Z. Li, S. Bhojanapalli, Y. LeCun, and N. Srebro (2019) The role of over-parametrization in generalization of neural networks. In International Conference on Learning Representations, External Links: Link Cited by: §8.1.
  • J. Pearl (2009) Causality: models, reasoning and inference. 2nd edition, Cambridge University Press, New York, NY, USA. External Links: ISBN 052189560X, 9780521895606 Cited by: §2.2, Assumption 3.
  • J. Peters, P. Bühlmann, and N. Meinshausen (2016)

    Causal inference using invariant prediction: identification and confidence intervals

    Journal of the Royal Statistical Society, Series B (with discussion) 78 (5), pp. 947–1012. Cited by: Appendix B, §2.2, §4, §5, Proposition 1.
  • B. V. Roy (2006) Performance loss bounds for approximate value iteration with state aggregation. Math. Oper. Res. 31 (2), pp. 234–244. External Links: Link, Document Cited by: §2.1.
  • S. Salter, D. Rao, M. Wulfmeier, R. Hadsell, and I. Posner (2019) Attention privileged reinforcement learning for domain transfer. External Links: 1911.08363 Cited by: §9.
  • B. Schölkopf, D. Janzing, J. Peters, E. Sgouritsa, K. Zhang, and J. Mooij (2012) On causal and anticausal learning. In Proceedings of the 29th International Coference on International Conference on Machine Learning, ICML’12, Madison, WI, USA, pp. 459–466. External Links: ISBN 9781450312851 Cited by: §2.2.
  • X. Song, Y. Jiang, S. Tu, Y. Du, and B. Neyshabur (2020) Observational overfitting in reinforcement learning. In International Conference on Learning Representations, External Links: Link Cited by: §1, §3.1.
  • A. L. Strehl, L. Li, E. Wiewiora, J. Langford, and M. L. Littman (2006) PAC model-free reinforcement learning. In Proceedings of the 23rd international conference on Machine learning, pp. 881–888. Cited by: §8.1.
  • Y. Tassa, Y. Doron, A. Muldal, T. Erez, Y. Li, D. de Las Casas, D. Budden, A. Abdolmaleki, J. Merel, A. Lefrancq, T. Lillicrap, and M. Riedmiller (2018) DeepMind control suite. Technical report Vol. abs/1504.04804, DeepMind. Note: External Links: Link Cited by: §C.1, §7.1.2, §7.2, §7.3.
  • Y. Teh, V. Bapst, W. M. Czarnecki, J. Quan, J. Kirkpatrick, R. Hadsell, N. Heess, and R. Pascanu (2017) Distral: robust multitask reinforcement learning. In Advances in Neural Information Processing Systems 30, I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett (Eds.), pp. 4496–4506. External Links: Link Cited by: §8.2.
  • E. Tzeng, J. Hoffman, K. Saenko, and T. Darrell (2017) Adversarial discriminative domain adaptation. In

    2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR)

    Vol. , Los Alamitos, CA, USA, pp. 2962–2971. External Links: ISSN 1063-6919, Document, Link Cited by: §6.2.
  • V. N. Vapnik and A. Ya. Chervonenkis (1971) On the uniform convergence of relative frequencies of events to their probabilities. Theory of Probability and its Applications 16 (2), pp. 264–280. Cited by: §8.1.
  • D. Yarats and I. Kostrikov (2020)

    Soft actor-critic (sac) implementation in pytorch

    GitHub. Note: Cited by: §C.2.
  • M. Yin, G. Tucker, M. Zhou, S. Levine, and C. Finn (2019) Meta-learning without memorization. arXiv preprint arXiv:1912.03820. Cited by: §1, §8.1.
  • A. Zhang, Y. Wu, and J. Pineau (2018a) Natural environment benchmarks for reinforcement learning. CoRR abs/1811.06032. External Links: Link, 1811.06032 Cited by: §1.
  • C. Zhang, O. Vinyals, R. Munos, and S. Bengio (2018b) A study on overfitting in deep reinforcement learning. CoRR abs/1804.06893. External Links: Link, 1804.06893 Cited by: §1.

Appendix A Notation

We provide a summary of key notation used throughout the paper here.

Appendix B Proofs

Technical notes and assumptions. In order for the block MDP assumption to be satisfied, we will require that the interventions defining each environment only occur outside of the causal ancestors of the reward. Otherwise, the different environments will have different latent state dynamics, which violates our assumption that the environments are obtained by an noisy emission function from the latent state space . Although ICP will still find the correct causal variables in this setting, this state abstraction will no longer be a model irrelevance state abstraction over the union of the environments.

Theorem 1.

Consider a family of MDPs , with . Let satisfy Assumptions 1-3. Let be the set of variables such that the reward is a function only of ( restricted to the indices in ). Then let denote the ancestors of in the (fully observable) causal graph corresponding to the transition dynamics of . Then the state abstraction is a model-irrelevance abstraction for every .


To prove that is a model-irrelevance abstraction, we must first show that for any . For this, we note that and, because by definition , we have that . Therefore,


To show that is a MISA, we must also show that for any such that , and for any , the distribution over next state equivalence classes will be equal for and .

For this, it suffices to observe that is closed under taking parents in the causal graph, and that by construction environments only contain interventions on variables outside of the causal set. Specifically, we observe that the probability of seeing any particular equivalence class after state is only a function of .

This allows us to define a natural decomposition of the transition function as follows.

We further observe that since the components of are independent, . We now return to the property we want to show:

and because , we have
for which we can apply the previous chain of equalities backward to obtain

Proposition 1 (Identifiability and Uniqueness of Causal State Abstraction).

In the setting of the previous theorem, assume the transition dynamics and reward are linear functions of the current state. If the training environment set satisfies any of the conditions of Theorem 2 (Peters et al., 2016) with respect to each variable in AN(R), then the causal feature set is identifiable. Conversely, if the training environments don’t contain sufficient interventions, then it may be that there exists a such that is a model irrelevance abstraction over , but not over globally.


The proof of the first statement follows immediately from the iterative application of the identifiability result of Peters et al. (2016) to each variable in the causal variables set.

For the converse, we consider a simple counterexample in which one variable is constant in every training environment, with value . Then letting , we observe that is also a model-irrelevance state abstraction. First, we show for any .

Finally, we must show that

Again starting from the result of Theorem 1 we have:

and because , we have
for which we can apply the previous chain of equalities backward to obtain

However, if one of the test environments contains the intervention , then the distribution over next-states in the new environment will violate the conditions for a model-irrelevance abstraction. ∎

Theorem 2.

Consider an MDP , with denoting a coarser bisimulation of . Let denote the mapping from states of to states of . Suppose that the dynamics of are -Lipschitz w.r.t. and that is some approximate transition model satisfying , for some . Let denote the 1-Wasserstein distance. Then


We will use the shorthand for , the distribution of state embeddings corresponding to the behaviour policy, and for for the distribution of state embeddings given by the behaviour policy.

Let be a coupling over the distributions of and such that

Theorem 4 (Existence of model-irrelevance state abstractions).

Let denote some family of bisimilar MDPs with joint state space . Let the mapping from states in to the underlying abstract MDP be denoted by . Then if the states in satisfy , then is a model-irrelevance state abstraction for .


First, note that is well-defined (because each agrees with the rest on the value of all states appearing in multiple tasks). Then will be a model-irrelevance abstraction for every MDP because it agrees with (a model-irrelevance abstraction). ∎

Theorem 3.

Let be our block MDP and the learned invariant MDP with a mapping . For any -Lipschitz valued policy the value difference is bounded by