1 Introduction
In reinforcement learning (RL), model-based agents are traditionally characterized by their ability to predict future states and rewards based on previous states and actions
[Sutton and Barto, 1998, Ha and Schmidhuber, 2018, Hafner et al., 2019a]. Model-based methods can be seen through the representation learning [Goodfellow et al., 2017] lens as endowing the agent with internal representations that are predictive of the future conditioned on its actions. This ultimately gives the agent the ability to plan – by e.g. considering a distribution of possible future trajectories and picking the best course of action.Model-free methods, on the other hand, do not learn an explicit model of the environment, and instead focus on learning a policy that maximizes reward or a function that estimates the optimal values of states and actions
[Mnih et al., 2013, Schulman et al., 2017, Pong et al., 2018]. They can use large amounts of training data and excel in high-dimensional state and action spaces. However, this is mostly true for fixed reward functions. Despite success on many benchmarks, model-free agents typically generalize poorly when the environment or reward function changes [Farebrother et al., 2018, Tachet des Combes et al., 2018] and can have high sample complexity.Viewing model-based agents from a representation learning perspective, a desired outcome is an agent able to understand the underlying generative factors of the environment that determine the observed state/action sequences, which would help generalization to environments built from the same generative factors. In addition, learning a predictive model often involves much richer learning signals than those provided by reward alone, which could reduce sample complexity compared to model-free methods.
Our work is based on the hypothesis that a model-free agent whose representations are predictive of properties of future states (beyond expected rewards) will be more capable of solving and adapting to new RL problems, and in a way, incorporate aspects of model-based learning. To learn representations with model-like properties, we consider a self-supervised objective derived from variants of Deep InfoMax [DIM, Hjelm et al., 2018, Bachman et al., 2019, Anand et al., 2019]. We expect this type of contrastive estimation [Hyvarinen and Morioka, 2016] will give the agent a better understanding of the underlying factors of the environment and how they relate to its actions, eventually leading to better performance in transfer and lifelong learning problems. We examine the properties of the learnt representations in simple domains such as disjoint and glued Markov chains, and more complex environments such as a 2d Ising model, a sequential variant of Ms. PacMan from the Atari Learning Environment [ALE, Bellemare et al., 2013], and all 16 games from the Procgen suite [Cobbe et al., 2019]. Our contributions are as follows:
-
We propose a simple auxiliary objective that maximizes concordance between representations of successive states, given the action.
-
We present a series of experiments showing how our objective can be used as a measure of similarity and predictability, and how it behaves in partially deterministic systems.
-
Finally, we show that augmenting a standard C51 agent [Bellemare et al., 2017] with our contrastive objective can i) lead to faster adaptation in a continual learning setting, and ii) improve overall performance on the Procgen suite.
2 Background
Just as humans are able to retain old skills when taught new ones [Wixted, 2004], we strive for RL agents that are able to adapt quickly and reuse knowledge when dealing with a sequence of different tasks with variable reward functions. The reason for this is that real-world applications or downstream tasks can be difficult to predict before deployment, particularly with complex environments involving other intelligent agents such as humans. Unfortunately, this proves to be very challenging even for state-of-the-art systems [Atkinson et al., 2018], leading to complex deployment scenarios.
Continual Learning (CL) is a learning framework meant to benchmark an agent’s ability to adapt to new tasks by using auxiliary information about the relatedness across tasks and timescales [Kaplanis et al., 2018, Mankowitz et al., 2018]. Meta-learning [Thrun and Pratt, 1998, Finn et al., 2017] and multi-task learning [Hessel et al., 2019, D’Eramo et al., 2019] have shown good performance in CL by explicitly training the agent to transfer well between tasks.
In this study, we focus on the following inductive bias: while the reward function may change or vary, the underlying environment dynamics typically do not change as much111This is not true in all generalization settings. Generalization still has a variety of specifications within RL. In our work, we focus on the setting where the rewards change more rapidly than the environment dynamics..
To test if that inductive bias is useful, we use auxiliary loss functions
Contrastive representation learning methods are based on training an encoder to capture information that is shared across different views of the data in the features it produces for each input. The similar (i.e. positive) examples are typically either taken from different “locations” of the data [e.g., spatial patches or temporal locations, see Hjelm et al., 2018, Oord et al., 2018, Anand et al., 2019, Hénaff et al., 2019] or obtained through data augmentation [Wu et al., 2018, He et al., 2019, Bachman et al., 2019, Tian et al., 2019, Chen et al., 2020].
Contrastive models rely on a variety of objectives to encourage similarity between features. Typically, a scoring function [e.g., dot product or cosine similarity between pairs of features, see Wu et al.,
3 Preliminaries
3.1 Markov Chains
Given a discrete state space
with probability measure
, a discrete-time homogeneous Markov chain (MC) is a collection of random variables with the following property on its
transition matrix : . Assuming the Markov chain is ergodic, its invariant distribution222The existence and uniqueness of are direct results of the Perron-Frobenius theorem.is the principal eigenvector of
, which verifies and summarizes the long-term behaviour of the chain. We define the marginal distribution of as , and the initial distribution of as .Theorem 3.1 (Levin and Peres [2017])
Let and be the mixing time of the chain induced by . If are the eigenvalues of
(1) |
3.2 Markov Decision Processes
A discrete-time, finite-horizon Markov Decision Process
[Bellman, 1957, Puterman, 2014, MDP] comprises a state space , an action space333We consider discrete state and action spaces. , a transition kernel , a reward function and a discount factor . At every timestep , an agent interacting with this MDP observes the current state , selects an action , and observes a reward upon transitioning to a new state . The goal of an agent in a discounted MDP is to learn a policy such that taking actions maximizes the expected sum of discounted returns,To convert a MDP into a MC, one can let
, an operation which can be easily tensorized for computational efficiency in small state spaces
[see Mazoure et al., 2020].For this paper, we use the C51 algorithm [Bellemare et al., 2017] for training the agent due to its simplicity and good performance on control tasks from pixels. C51 minimizes the following loss:
(2) |
where
is the Kullback-Leibler divergence,
is the distribution of discounted returns (s.t. ), is the distributional Bellman operator [Bellemare et al., 2019] and is an operator which projects onto a fixed support of atoms.3.3 State-action mutual information maximization
Mutual information (MI) measures the amount of information shared between a pair of random variables and can be estimated using neural networks
[Belghazi et al., 2018]. Recent representation learning algorithms [Oord et al., 2018, Hjelm et al., 2018, Hénaff et al., 2019, Tian et al., 2019, He et al., 2019] train encoders to maximize MI between features taken from different viewsof the input – e.g., different patches in an image or different versions of an image produced by applying data augmentation. These algorithms commonly optimize lower bounds on the MI based on noise-contrastive estimation
[NCE, Gutmann and Hyvärinen, 2010, Oord et al., 2018].Let be some fixed temporal offset. Running a policy in the MDP generates a distribution over tuples , where corresponds to the state of at some timestep , to the action selected by in state and to the state of at timestep , reached by following . We let , and
denote the corresponding random variables. We denote the joint distribution of these variables and their associated marginals using
. We consider maximizing the MI between state-action pairs and their future states , which can be written as follows:(3) |
Estimating the MI amounts to training a classifier that discriminates between a sample
drawn from the joint distribution and a sample from the product of marginals in the denominator of Eq. 3. A sample from the product of marginals can be obtained by replacing (we call this the positive sample) with a state picked at random from another trajectory (we call this a negative sample). Letting denote a set of such negative samples, the infoNCE loss function [Oord et al., 2018] that we use to maximize a lower bound on the MI in Eq. 3 takes the following form:(4) |
where are features that depend on state-action pairs and states, respectively, and is a function that outputs a scalar-valued score. Minimizing with respect to , and maximizes the MI between these features. In practice, we construct by including all states from other tuples in the same minibatch as the relevant . I.e., for a batch containing tuples , each would contain negative samples.
4 Architecture and Algorithm
We now specify forms for the functions , and . We consider a deep neural network which maps input states onto a sequence of progressively more “global” (or less “local”) feature spaces. In practice, is a deep residual and CNN composed of functions that sequentially map inputs to features (lower to upper “levels” of the network).
The features are the output of the network’s last layer and correspond to the standard C51 value heads (i.e., they span a space of 51 atoms per action) 444 can equivalently be seen as the network used in a standard C51.. For the auxiliary objective, we follow a variant of Deep InfoMax [DIM, Hjelm et al., 2018, Anand et al., 2019, Bachman et al., 2019], and train the encoder to maximize the mutual information (MI) between local and global “views” of tuples . The local and global views are realized by selecting and
respectively. In order to simultaneously estimate and maximize the MI, we embed the action (represented as a one-hot vector) using a function
. We then map the local states and the embedded action using a function , and do the same with the global states , i.e., . In addition, we have two more functions, and that map features without the actions, which will be applied to features from “future” timesteps. Note that can be thought of as a product of local spaces (corresponding to different patches in the input, or equivalently different receptive fields), each with the same dimensionality as .We use the outputs of these functions to produce a scalar-valued score between any combination of local and global representations of state and , conditioned on action :
(5) |
In practice, for the functions that take features and actions as input, we simply concatenate the values at position (local) or (global) with the embedded action , and feed the resulting tensor into the appropriate function or . All functions that process global and local features are computed using convolutions. See Figure 1 for a visual representation of our model.
We use the scores from Eq. 5 when computing the infoNCE loss [Oord et al., 2018] for our objective, using tuples sampled from trajectories stored in an experience replay buffer:
(6) |
Combining Eq. 6 with the C51 update in Eq. 2 yields our full training objective, which we call DRIML 555Deep Reinforcement and InfoMax Learning. We optimize , and jointly using a single loss function:
(7) |
Note that, in practice, the compute cost which Eq. 7 adds to the core C51 RL algorithm is minimal, since it only requires additional passes through the (small) state/action embedding functions followed by an outer product.

The proposed Algorithm 1 introduces an auxiliary loss which improves predictive capabilities of value-based agents by boosting similarity of representations close in time.
5 Predictability and Contrastive Learning
Information maximization has long been considered one of the standard principles for measuring correlation and performing feature selection
[Song et al., 2012]. In the MDP context, high values of indicate that and have some form of dependence, while low values suggest independence. The fact that predictability (or more precisely determinism) in Markov systems is linked to the MI suggests a deeper connection to the spectrum of the transition kernel . For instance, the set of eigenvalues of for a Markov decision process contains important information about the connectivity of said process, such as mixing time or number of densely connected clusters [Von Luxburg, 2007, Levin and Peres, 2017].Consider the setting in which is fixed at some iteration in the optimization process. In the rest of this section, we let denote the expected transition model (it is a Markov chain). We let be the ratio learnt when optimizing the infoNCE loss on samples drawn from the random variables and (for a fixed ) [Oord et al., 2018]. We also let be that ratio when the Markov chain has reached its stationary distribution (see Section 3.1), and be the scoring function learnt using InfoNCE (which converges to in the limit of infinite samples drawn from ).
Proposition 1
Let . Assume at time step , training of has close to converged on a pair , i.e. . Then the following holds:
(8) |
The proof can be found in App. 8.4. Proposition 1 combined with Theorem 3.1 suggests that faster convergence of to happens when the spectral gap of is large, or equivalently when is small. It follows that, on one hand, mutual information is a natural measure of concordance of pairs and can be maximized using data-efficient, batched gradient methods. On the other hand, the rate at which the InfoNCE loss converges to its stationary value (ie maximizes the lower bound on MI) depends on the spectral gap of , which is closely linked to predictability (see Appendix 8.2).
6 Experiments
In this section, we first show how our proposed objective can be used to estimate state similarity in single Markov chains. We then show that DRIML can capture dynamics in locally deterministic systems (Ising model), which is useful in domains with partially deterministic transitions. We then provide results on a continual version of the Ms.PacMan game where the DIM loss is shown to converge faster for more deterministic tasks, and to help in a continual learning setting. Finally, we provide results on Procgen [Cobbe et al., 2019], which show that DRIML performs well when trained on 500 fixed levels. All experimental details can be found in App. 8.5.
6.1 DIM learns a transition ratio model
We first study the behaviour of contrastive losses on a simple Markov chain describing a biased random walk in . The bias is specified by a single parameter . The agent starting at state transitions to with probability and to otherwise. The agent stays in states and with probability and , respectively. We encode the current and next states (represented as one-hots) using a 1-hidden layer MLP666The action is simply ignored in this setting. (corresponding to and in equation 4), and then optimize the NCE loss (4) (the scoring function is also 1-hidden layer MLP) to maximize the MI between representations of successive states. Results are shown in Fig. 2b, they are well aligned with the true transition matrix (Fig. 2c).

6.2 DIM can capture complex partially deterministic dynamics
The goal of this experiment is to highlight the predictive capabilities of our DIM objective in a partially deterministic system. We consider a dynamical system composed of pixels with values in , . At the beginning of each episode, a patch corresponding to a quarter of the pixels is chosen at random in the grid. Pixels that do not belong to that patch evolve fully independently (). Pixels from the patch obey a local dependence law, in the form of a standard Ising model777https://en.wikipedia.org/wiki/Ising_model: the value of a pixel at time only depends on the value of its neighbors at time . This local dependence is obtained through a function : (see Appx 8.6.1 for details). Figure 3 shows the system at during three different episodes (black pixels correspond to values of , white to ). The patches are very distinct from the noise. We then train a convolutional encoder using our DIM objective on local “views” only (see Section 4).

Figure 3 shows the similarity scores between the local features of states at and the same features at (a local feature corresponds to a specific location in the convolutional maps)888We chose early timesteps to make sure that the model does not simply detect large patches, but truly measures predictability.. The heatmap regions containing the Ising model (larger-scale patterns) have higher scores than the noisy portions of the lattice. Local DIM is able to correctly encode regions of high temporal predictability.
6.3 A Continual Learning experiment on Ms.PacMan
We further complicate the task of the Ising model prediction by building on top of the Ms. PacMan game and introducing non-trivial dynamics. The environment is shown in Figure 4a.
In order to assess how well our auxiliary objective captures predictability in this MDP, we define its dynamics such that . Intuitively, as , the enemies’ actions become less predictable, which in turn hinders the convergence rate of the contrastive loss. The four runs in Figure 4b correspond to various values of . We trained the agent using our objective. We can see that the convergence of the auxiliary loss becomes slower with growing , as the model struggles to predict given . After 100k frames, the NCE objective allows to separate the four MDPs according to their principal source of randomness (red and blue curves). When is close to 1, the auxiliary loss has a harder time finding features that predict the next state, and eventually ignores the random movements of enemies.
The second and more interesting setup we consider consists in making only one out of 4 enemies lethal, and changing which one every 5k episodes. Figure 4c shows that, as training progresses, C51 always reaches the same performance at the end of the 5k episodes, while DRIML’s steadily increases. C51 learns to ignore the harmless ghosts (they have no connection to the reward signal) and has to learn the task from scratch every time the lethal ghost changes. On the other hand, DRIML is incentivized to encode information about all the predictable objects on the screen (including the harmless ghosts), and as such adapts faster and faster to changes.

6.4 Performance on Procgen Benchmark
Finally, we demonstrate the beneficial impact of adding a DIM-like objective to C51 (DRIML) on the 16 Procgen tasks [Cobbe et al., 2019]. All algorithms are trained for 50M environment frames with the DQN [Mnih et al., 2015]
architecture. The mean and standard deviation of the scores (over 3 seeds) are shown in Table
1; bold values indicate best performance.DRIML (Ours) | C51 | CURL | No Action | Random | |
---|---|---|---|---|---|
bigfish | 2.023 0.18 | 1.33 0.12 | 2.697 1.3 | 1.192 0.04 | 0.333 0.47 |
bossfight | 0.672 0.02 | 0.573 0.05 | 0.595 0.06 | 0.472 0.01 | 0 |
caveflyer | 10.18 0.41 | 9.187 0.29 | 6.938 0.25 | 8.259 0.26 | 0 |
chaser | 0.286 0.02 | 0.222 0.04 | 0.353 0.04 | 0.229 0.02 | 0.04 0 |
climber | 2.256 0.05 | 1.678 0.1 | 1.751 0.09 | 1.574 0.01 | 0 |
coinrun | 27.238 1.92 | 29.701 5.44 | 21.166 1.94 | 13.146 1.21 | 0 |
dodgeball | 1.28 0.02 | 1.198 0.08 | 1.093 0.04 | 1.221 0.04 | 0.667 0.94 |
fruitbot | 5.399 1.02 | 3.856 0.96 | 4.887 0.71 | 5.425 1.33 | 0.333 0.47 |
heist | 1.296 0.05 | 1.537 0.1 | 1.056 0.05 | 1.042 0.02 | 0 |
jumper | 12.639 0.64 | 13.225 0.83 | 10.273 0.61 | 4.314 0.64 | 0 |
leaper | 6.168 0.29 | 5.034 0.14 | 3.943 0.46 | 5.403 0.09 | 0 |
maze | 1.381 0.08 | 2.355 0.09 | 0.823 0.2 | 1.438 0.26 | 0 |
miner | 0.14 0.01 | 0.126 0.01 | 0.096 0 | 0.116 0.01 | 0 |
ninja | 9.209 0.25 | 9.36 0.01 | 5.839 1.21 | 6.438 0.22 | 0 |
plunder | 3.366 0.17 | 2.994 0.07 | 2.771 0.14 | 3.2 0.05 | 0.667 0.47 |
starpilot | 4.562 0.21 | 2.445 0.12 | 2.683 0.09 | 3.699 0.3 | 0.333 0.47 |
Similarly to CURL, we used data augmentation on inputs to DRIML to improve the model’s predictive capabilities in fast-paced environments (see App. 8.6.3). While we used the global-global loss in DRIML’s objective, we have found that the local-local loss also had a beneficial effect on performance on a smaller set of games (e.g. starpilot, which has few moving entities on a dark background).
7 Discussion
In this paper, we introduced an auxiliary objective called Deep Reinforcement and InfoMax Learning (DRIML), which is based on maximizing concordance of state-action pairs with future states (at the representation level). Our objective has a close connection with the spectrum of the Markov transition matrix and predictability of Markovian systems, which dictate its long-term behaviour. We presented results showing that 1) DRIML implicitly learns a transition model by boosting state similarity, 2) it improves C51 in a continual learning setting and 3) it boosts training performance in complex domains such as Procgen.
Acknowledgements
We thank Harm van Seijen, Ankesh Anand, Mehdi Fatemi, Romain Laroche and Jayakumar Subramanian for useful feedback and helpful discussions.
Broader Impact
This work proposes an auxiliary objective for model-free reinforcement learning agents. The objective shows improvements in a continual learning setting, as well as on average training rewards for a suite of complex video games. While the objective is developed in a visual setting, maximizing mutual information between features is a method that can be transported to other domains, such as text. Potential applications of deep reinforcement learning are (among others) healthcare, dialog systems, crop management, robotics, etc. Developing methods that are more robust to changes in the environment, and/or perform better in a continual learning setting can lead to improvements in those various applications. At the same time, our method fundamentally relies on deep learning tools and architectures, which are hard to interpret and prone to failures yet to be perfectly understood. Additionally, deep reinforcement learning also lacks formal performance guarantees, and so do deep reinforcement learning agents. Overall, it is essential to design failsafes when deploying such agents (including ours) in the real world.
References
- Sutton and Barto [1998] R. S. Sutton and A. G. Barto. Reinforcement Learning: An Introduction. MIT Press, Cambridge, MA, 1998.
- Ha and Schmidhuber [2018] David Ha and Jürgen Schmidhuber. World models. 2018. doi: 10.5281/zenodo.1207631. URL http://arxiv.org/abs/1803.10122. cite arxiv:1803.10122.
- Hafner et al. [2019a] Danijar Hafner, Timothy Lillicrap, Jimmy Ba, and Mohammad Norouzi. Dream to control: Learning behaviors by latent imagination. arXiv preprint arXiv:1912.01603, 2019a.
- Goodfellow et al. [2017] Ian Goodfellow, Yoshua Bengio, and Aaron Courville. Deep learning. 2017. ISBN 9780262035613 0262035618. URL https://www.worldcat.org/title/deep-learning/oclc/985397543&referer=brief_results.
- Mnih et al. [2013] Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, Ioannis Antonoglou, Daan Wierstra, and Martin Riedmiller. Playing atari with deep reinforcement learning. arXiv preprint arXiv:1312.5602, 2013.
- Schulman et al. [2017] John Schulman, Filip Wolski, Prafulla Dhariwal, Alec Radford, and Oleg Klimov. Proximal policy optimization algorithms. CoRR, abs/1707.06347, 2017. URL http://dblp.uni-trier.de/db/journals/corr/corr1707.html#SchulmanWDRK17.
- Pong et al. [2018] Vitchyr Pong, Shixiang Gu, Murtaza Dalal, and Sergey Levine. Temporal difference models: Model-free deep rl for model-based control. arXiv preprint arXiv:1802.09081, 2018.
- Farebrother et al. [2018] Jesse Farebrother, Marlos C. Machado, and Michael Bowling. Generalization and regularization in dqn, 2018.
- Tachet des Combes et al. [2018] Remi Tachet des Combes, Philip Bachman, and Harm van Seijen. Learning invariances for policy generalization. CoRR, abs/1809.02591, 2018. URL http://dblp.uni-trier.de/db/journals/corr/corr1809.html#abs-1809-02591.
- Hjelm et al. [2018] R Devon Hjelm, Alex Fedorov, Samuel Lavoie-Marchildon, Karan Grewal, Phil Bachman, Adam Trischler, and Yoshua Bengio. Learning deep representations by mutual information estimation and maximization. arXiv preprint arXiv:1808.06670, 2018.
- Bachman et al. [2019] Philip Bachman, R Devon Hjelm, and William Buchwalter. Learning representations by maximizing mutual information across views. In Advances in Neural Information Processing Systems, pages 15509–15519, 2019.
- Anand et al. [2019] Ankesh Anand, Evan Racah, Sherjil Ozair, Yoshua Bengio, Marc-Alexandre Côté, and R Devon Hjelm. Unsupervised state representation learning in atari. In Advances in Neural Information Processing Systems, pages 8766–8779, 2019.
-
Hyvarinen and Morioka [2016]
Aapo Hyvarinen and Hiroshi Morioka.
Unsupervised feature extraction by time-contrastive learning and nonlinear ica.
In Advances in Neural Information Processing Systems, pages 3765–3773, 2016. -
Bellemare et al. [2013]
Marc G Bellemare, Yavar Naddaf, Joel Veness, and Michael Bowling.
The arcade learning environment: An evaluation platform for general
agents.
Journal of Artificial Intelligence Research
, 47:253–279, 2013. - Cobbe et al. [2019] Karl Cobbe, Christopher Hesse, Jacob Hilton, and John Schulman. Leveraging procedural generation to benchmark reinforcement learning, 2019.
-
Bellemare et al. [2017]
Marc G Bellemare, Will Dabney, and Rémi Munos.
A distributional perspective on reinforcement learning.
In
Proceedings of the 34th International Conference on Machine Learning-Volume 70
, pages 449–458. JMLR. org, 2017. - Wixted [2004] John T Wixted. The psychology and neuroscience of forgetting. Annu. Rev. Psychol., 55:235–269, 2004.
- Atkinson et al. [2018] Craig Atkinson, Brendan McCane, Lech Szymanski, and Anthony Robins. Pseudo-rehearsal: Achieving deep reinforcement learning without catastrophic forgetting. arXiv preprint arXiv:1812.02464, 2018.
- Kaplanis et al. [2018] Christos Kaplanis, Murray Shanahan, and Claudia Clopath. Continual reinforcement learning with complex synapses. arXiv preprint arXiv:1802.07239, 2018.
- Mankowitz et al. [2018] Daniel J Mankowitz, Augustin Žídek, André Barreto, Dan Horgan, Matteo Hessel, John Quan, Junhyuk Oh, Hado van Hasselt, David Silver, and Tom Schaul. Unicorn: Continual learning with a universal, off-policy agent. arXiv preprint arXiv:1802.08294, 2018.
- Thrun and Pratt [1998] Sebastian Thrun and Lorien Pratt. Learning to learn: Introduction and overview. In Learning to learn, pages 3–17. Springer, 1998.
- Finn et al. [2017] Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of deep networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pages 1126–1135. JMLR. org, 2017.
- Hessel et al. [2019] Matteo Hessel, Hubert Soyer, Lasse Espeholt, Wojciech Czarnecki, Simon Schmitt, and Hado van Hasselt. Multi-task deep reinforcement learning with popart. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 33, pages 3796–3803, 2019.
- D’Eramo et al. [2019] Carlo D’Eramo, Davide Tateo, Andrea Bonarini, Marcello Restelli, and Jan Peters. Sharing knowledge in multi-task deep reinforcement learning. In International Conference on Learning Representations, 2019.
- Jaderberg et al. [2016] Max Jaderberg, Volodymyr Mnih, Wojciech Marian Czarnecki, Tom Schaul, Joel Z Leibo, David Silver, and Koray Kavukcuoglu. Reinforcement learning with unsupervised auxiliary tasks. arXiv preprint arXiv:1611.05397, 2016.
- Mohamed and Rezende [2015] Shakir Mohamed and Danilo Jimenez Rezende. Variational information maximisation for intrinsically motivated reinforcement learning. In Advances in neural information processing systems, pages 2125–2133, 2015.
- Gelada et al. [2019] Carles Gelada, Saurabh Kumar, Jacob Buckman, Ofir Nachum, and Marc G Bellemare. Deepmdp: Learning continuous latent space models for representation learning. In International Conference on Machine Learning, pages 2170–2179, 2019.
- Hafner et al. [2019b] Danijar Hafner, Timothy Lillicrap, Ian Fischer, Ruben Villegas, David Ha, Honglak Lee, and James Davidson. Learning latent dynamics for planning from pixels. In International Conference on Machine Learning, pages 2555–2565, 2019b.
- Schrittwieser et al. [2019] Julian Schrittwieser, Ioannis Antonoglou, Thomas Hubert, Karen Simonyan, Laurent Sifre, Simon Schmitt, Arthur Guez, Edward Lockhart, Demis Hassabis, Thore Graepel, Timothy Lillicrap, and David Silver. Mastering atari, go, chess and shogi by planning with a learned model, 2019.
- Oord et al. [2018] Aaron van den Oord, Yazhe Li, and Oriol Vinyals. Representation learning with contrastive predictive coding. arXiv preprint arXiv:1807.03748, 2018.
- Hénaff et al. [2019] Olivier J Hénaff, Ali Razavi, Carl Doersch, SM Eslami, and Aaron van den Oord. Data-efficient image recognition with contrastive predictive coding. arXiv preprint arXiv:1905.09272, 2019.
-
Wu et al. [2018]
Zhirong Wu, Yuanjun Xiong, X Yu Stella, and Dahua Lin.
Unsupervised feature learning via non-parametric instance
discrimination.
In
2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)
, pages 3733–3742. IEEE, 2018. - He et al. [2019] Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, and Ross Girshick. Momentum contrast for unsupervised visual representation learning. arXiv preprint arXiv:1911.05722, 2019.
- Tian et al. [2019] Yonglong Tian, Dilip Krishnan, and Phillip Isola. Contrastive multiview coding. arXiv preprint arXiv:1906.05849, 2019.
- Chen et al. [2020] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey Hinton. A simple framework for contrastive learning of visual representations. arXiv preprint arXiv:2002.05709, 2020.
- Belghazi et al. [2018] Mohamed Ishmael Belghazi, Aristide Baratin, Sai Rajeswar, Sherjil Ozair, Yoshua Bengio, Aaron Courville, and R Devon Hjelm. Mine: mutual information neural estimation. arXiv preprint arXiv:1801.04062, 2018.
- Poole et al. [2019] Ben Poole, Sherjil Ozair, Aaron van den Oord, Alexander A Alemi, and George Tucker. On variational bounds of mutual information. arXiv preprint arXiv:1905.06922, 2019.
- Beattie et al. [2016] Charles Beattie, Joel Z. Leibo, Denis Teplyashin, Tom Ward, Marcus Wainwright, Heinrich Küttler, Andrew Lefrancq, Simon Green, Víctor Valdés, Amir Sadik, Julian Schrittwieser, Keith Anderson, Sarah York, Max Cant, Adam Cain, Adrian Bolton, Stephen Gaffney, Helen King, Demis Hassabis, Shane Legg, and Stig Petersen. Deepmind lab, 2016.
- Srinivas et al. [2020] Aravind Srinivas, Michael Laskin, and Pieter Abbeel. Curl: Contrastive unsupervised representations for reinforcement learning. arXiv preprint arXiv:2004.04136, 2020.
- Kim et al. [2019] Hyoungseok Kim, Jaekyeom Kim, Yeonwoo Jeong, Sergey Levine, and Hyun Oh Song. Emi: Exploration with mutual information. In International Conference on Machine Learning, pages 3360–3369, 2019.
- Levin and Peres [2017] David A Levin and Yuval Peres. Markov chains and mixing times, volume 107. American Mathematical Soc., 2017.
- Bellman [1957] Richard Bellman. A markovian decision process. Journal of Mathematics and Mechanics, pages 679–684, 1957.
- Puterman [2014] Martin L Puterman. Markov decision processes: discrete stochastic dynamic programming. John Wiley & Sons, 2014.
- Mazoure et al. [2020] Bogdan Mazoure, Thang Doan, Tianyu Li, Vladimir Makarenkov, Joelle Pineau, Doina Precup, and Guillaume Rabusseau. Provably efficient reconstruction of policy networks. arXiv preprint arXiv:2002.02863, 2020.
- Bellemare et al. [2019] Marc G Bellemare, Nicolas Le Roux, Pablo Samuel Castro, and Subhodeep Moitra. Distributional reinforcement learning with linear function approximation. arXiv preprint arXiv:1902.03149, 2019.
- Gutmann and Hyvärinen [2010] Michael Gutmann and Aapo Hyvärinen. Noise-contrastive estimation: A new estimation principle for unnormalized statistical models. In Proceedings of the Thirteenth International Conference on Artificial Intelligence and Statistics, pages 297–304, 2010.
- Song et al. [2012] Lin Song, Peter Langfelder, and Steve Horvath. Comparison of co-expression measures: mutual information, correlation, and model based indices. BMC bioinformatics, 13(1):328, 2012.
-
Von Luxburg [2007]
Ulrike Von Luxburg.
A tutorial on spectral clustering.
Statistics and computing, 17(4):395–416, 2007. - Mnih et al. [2015] Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Andrei A Rusu, Joel Veness, Marc G Bellemare, Alex Graves, Martin Riedmiller, Andreas K Fidjeland, Georg Ostrovski, et al. Human-level control through deep reinforcement learning. Nature, 518(7540):529–533, 2015.
- Wen et al. [2020] Junfeng Wen, Bo Dai, Lihong Li, and Dale Schuurmans. Batch stationary distribution estimation. arXiv preprint arXiv:2003.00722, 2020.
8 Appendix
8.1 Link to invariant distribution
For a discrete state ergodic Markov chain specified by and initial occupancy vector , its marginal state distribution at time is given by the Chapman-Kolmogorov form:
(9) |
and its limiting distribution is the infinite-time marginal
(10) |
which, if it exists, is exactly equal to the invariant distribution .
For the very restricted family of ergodic MDPs under fixed policy , we can assume that converges to a time invariant distribution .
Therefore,
(11) |
Now, observe that is closely linked to when samples come from timesteps close to . That is, interchanging swapping and at any state would yield at most error. Moreover, existing results [Levin and Peres, 2017] from Markov chain theory provide bounds on depending on the structure of the transition matrix.
If has a limiting distribution , then using the dominated convergence theorem allows to replace matrix powers by , which is then replaced by the invariant distribution :
(12) |
Of course, most real-life Markov decision processes do not actually have an invariant distribution since they have absorbing (or terminal) states. In this case, as the agent interacts with the environment, the DIM estimate of MI yields a rate of convergence which can be estimated based on the spectrum of .
Moreover, one could argue that since, in practice, we use off-policy algorithms for this sort of task, the gradient signal comes from various timesteps within the experience replay, which drives the model to learn features that are consistently predictive through time.
8.2 Predictability and spectrum of an MDP
In the simplest case of a discrete MDP, fully describes the dynamics of the environment. The row corresponds to the probability of landing in any state by taking action from state . Since a perfectly deterministic environment would transition to exactly one state with probability 1, an agent equipped with will know exactly what would look like and adjust its planning accordingly. Simply speaking, in a predictable Markov environment, the values of the random vector should be maximally correlated with those of .
As seen in the main paper, dependence of on can be explicitly encouraged by boosting the mutual information of the coupling. How effective the model is at solving this optimization problem depends mostly on the topology of the system: estimating the AMI for a time-homogeneous, ergodic MDP turns out to be quite easy, as shown in the experiments section. The rate of convergence of the InfoNCE bound to true AMI depends mostly on the inverse spectral gap of the transition kernel. More generally, if we think of as inducing a transition graph with the following rule:
(13) |
then the spectral gap is related to the Cheeger constant of . For instance, if is allowed to transition from any state to any other state, the corresponding graph would be fully connected, while having a deterministic transition would make the graph a maximal directed pseudoforest. In summary, if one has access to the spectral gap of an MC, they can roughly estimate the connectedness of the system: the sparser the graph, the easier the search problem.
The spectrum of is trivial to estimate in the tabular case; for the large-dimensional case, one can use the variational power method [Wen et al., 2020, VPM] by applying Hotteling’s deflation:
-
The corresponding eigenvalue should be close to and found as ;
-
Use rejection sampling on the same dataset as in to form a new dataset s.t. ;
-
Run VPM on the dataset from to find and repeat for smaller eigenvalues.
In our experiments, we managed to exactly estimate the spectral gap for tabular Markov chains, a task that is more complicated for large-dimensional inputs (e.g. pixels) since the transition model is typically parameterized by a neural network.
8.3 Code snippet for DIM objective scores
The following snippet yields pointwise (i.e. not contracted) scores given a batch of data. [mathescape, linenos, numbersep=5pt, gobble=2, frame=lines, framesep=2mm]python
def temporal_DIM_scores(reference,positive,clip_val=20): """ reference: n_batch × n_rkhs × n_locs positive: n_batch x n_rkhs x n_locs """ reference = reference.permute(2,0,1) positive = positive.permute(2,1,0) # reference: n_loc × n_batch × n_rkhs # positive: n_locs × n_rkhs × n_batch pairs = torch.matmul(reference, positive) # pairs: n_locs × n_batch × n_batch pairs = pairs / reference.shape[2]**0.5 pairs = clip_val * torch.tanh((1. / clip_val) * pairs) shape = pairs.shape scores = F.log_softmax(pairs, 2) # scores: n_locs × n_batch × n_batch mask = torch.eye(shape[2]).unsqueeze(0).repeat(shape[0],1,1) # mask: n_locs × n_batch × n_batch scores = scores * mask # scores: n_locs × n_batch × n_batch return scores
To obtain a scalar out of this batch, sum over the third dimension and then average over the first two.
8.4 Proof of Proposition 1
For a given timestep and a pair , let be the true ratio at time and let be the true ratio when the Markov chain has reached its stationary state.
Let be a function we train using the infoNCE loss on samples drawn from the random variables and (for a fixed ).
The functions , and all follow the form .
Proof 1
Let us consider fixed , and . First, since , we have
Or in other terms: . Now, we have:
By assumption on , we know that , which concludes the proof.
8.5 Experiment details
All experiments involving RGB inputs (Ising, Ms.PacMan and Procgen) were ran with the settings shown in Table 2.
Name | Description | Value |
Exploration at | 0.1 | |
Exploration at | 0.01 | |
Exploration decay | ||
LR | Learning rate | |
Discount factor | ||
Clip grad | Gradient clip norm | |
N-step-return | N-step return | |
Frame stack | Number of stacked frames | (Ising and Procgen) |
(Ms.PacMan) | ||
Grayscale | Grayscale or RGB | RGB |
Input size | State input size | (Ising and Ms.PacMan) |
(Procgen) | ||
Warmup steps | 1000 | |
Replay size | Size of replay buffer | |
Target soft update coeff | ||
Clip reward | Reward clipping | False |
Global-global DIM | 1 (Ms.PacMan and Procgen) | |
0 (Ising) | ||
Local-local DIM | 0 (Ms.PacMan and Procgen) | |
1 (Ising) | ||
Local-global DIM | 0 | |
Global-local DIM | 0 | |
DIM lookahead constant | 1 (Ising and Ms.PacMan) | |
1 and 5 (Procgen) |
The global DIM heads consist of a standard single hidden fully-connected layer network of 512 with ReLU activations and a skip-connection from input to output layers. The action is transformed into one-hot and then encoded using a 64 unit layer, after which it is concatenated with the state and passed to the global DIM head.
The local DIM heads consist of a single hidden layer network made of convolution. The action is tiled to match the shape of the convolutions, encoded using a convolutions and concatenated along the feature dimension with the state, after which is is passed to the local DIM head.
In the case of the Ising model, there is no decision component and hence no concatenation is required.
8.5.1 AMI of a biased random walk
We see from the formulation of the mutual information objective that it inherently depends on the ratio of . Recall that, for a 1-d random walk on integers , the stationary distribution is a function of and can be found using the recursion . It has the form
(14) |
for .
The pointwise mutual information between states and is therefore the random variable
(15) |
with expectation equal to the average mutual information which we can find by maximizing, among others, the InfoNCE bound. We can then compute the AMI as a function of
(16) |
which is shown in Figure 2c.
The figures were obtained by training the global DIM objective
on samples from the chain for 1,000 epochs with learning rate
.8.6 Experiments with convolutional networks
8.6.1 Ising model
We start by generating an rectangular lattice which is filled with Rademacher random variables ; that is, taking or with some probability . For any , the joint distribution factors into the product of marginals .
At every timestep, we uniformly sample a random index tuple and evolve the set of nodes according to an Ising model with temperature , while the remaining nodes continue to independently take the values with equal probability. If one examines any subset of nodes outside of , then the information conserved across timesteps would be close to 0, due to observations being independent in time.
However, examining a subset of at timestep allows models based on mutual information maximization to predict the configuration of the system at , since this region has high mutual information across time due to the ratio being directly proportional to the temperature parameter .
To obtain the figure, we trained local DIM on sample snapshots of the Ising model as grayscale images for 10 epochs. The local DIM scores were obtained by feeding a snapshot of the Ising model at ; showing it snapshots from later timestep would’ve made the task much easier since there would be a clear difference in granularities of the random pattern and Ising models.
8.6.2 Ms.PacMan
In PacMan, the agent, represented by a yellow square, must collect food pellets while avoiding four harmful ghosts. When the agent collects one of the boosts, it becomes invincible for 10 steps, allowing it to destroy the enemies without dying. In their turn, ghosts alternate between three behaviours: 1) when the agent is not within line-of-sight, wander randomly, 2) when the agent is visible and does not have a boost, follow them and 3) when the agent is visible and has a boost, avoid them. The switch between these three modes happens stochastically and quasi-independently for all four ghosts. Since the food and boost pellets are fixed at the beginning of each episode, randomness in the MDP comes from the ghosts as well as the agent’s actions.
The setup for our first experiment in the domain is as follows: with a fixed probability , each of the 4 enemies take a random action instead of following one of the three movement patterns.
The setup for our second experiment in the domain consists of four levels: in each level, only one out of the four ghosts is lethal - the remaining three behave the same but do not cause damage. The model trains for 5,000 episodes on level 1, then switches to level 3, then level 3 and so forth. This specific environment tests for the ability of DIM to quickly figure out which of the four enemies is the lethal one and ignore the remaining three based on color .
For our study, the state space consisted of RGB images. The inputs to the model were states re-scaled to by stacking 4 consecutive frames, which were then concatenated with actions using an embedding layer.
8.6.3 Procgen
The data augmentation steps performed on and fed to the DIM loss consisted of a random crop ( of the original’s size) with color jitter with parameters . Although the data augmentation is helpful on some tasks (typically fast-paced, requiring a lot of camera movements), it has shown detrimental effects on others. Below is a list of games on which data augmentation was beneficial: bigfish, bossfight, chaser, coinrun, jumper, leaper and ninja.
The parameter, which specifies how far into the future the model should make its predictions, worked best when set to 5 on the games: bigfish, chaser, climber, fruitbot, jumper, miner, maze and plunder. For the remaining games, setting yielded better performance.
DRIML (no data aug) | DRIML (data aug) | |
---|---|---|
bigfish | 1.344 0.21 | 1.786 0.52 |
bossfight | 0.482 0.04 | 0.671 0.03 |
caveflyer | 8.774 0.41 | 10.131 0.3 |
chaser | 0.218 0.02 | 0.277 0.02 |
climber | 2.232 0.1 | 1.872 0.15 |
coinrun | 18.092 2.63 | 25.69 2.34 |
dodgeball | 1.27 0.07 | 1.129 0.1 |
fruitbot | 5.302 0.8 | 5.26 0.46 |
heist | 1.295 0.03 | 1.298 0.07 |
jumper | 4.239 0.59 | 10.903 2.94 |
leaper | 5.089 0.12 | 6.089 0.31 |
maze | 1.283 0.23 | 1.183 0.44 |
miner | 0.137 0.01 | 0.116 0.02 |
ninja | 7.678 0.58 | 9.291 0.35 |
plunder | 3.264 0.23 | 2.819 0.1 |
starpilot | 4.56 0.2 | 3.062 0.22 |
DRIML (k=1) | DRIML (k=5) | |
---|---|---|
bigfish | 1.768 1.43 | 1.786 0.52 |
bossfight | 0.636 0.08 | 0.671 0.03 |
caveflyer | 10.131 0.3 | 7.596 0.23 |
chaser | 0.218 0.02 | 0.277 0.02 |
climber | 1.89 0.18 | 2.232 0.1 |
coinrun | 25.69 2.34 | 22.403 1.89 |
dodgeball | 1.207 0.01 | 1.27 0.07 |
fruitbot | 3.182 1.34 | 5.302 0.8 |
heist | 1.298 0.07 | 1.165 0.08 |
jumper | 9.729 0.65 | 10.903 2.94 |
leaper | 6.089 0.31 | 4.219 0.4 |
maze | 0.996 0.05 | 1.283 0.23 |
miner | 0.118 0.01 | 0.137 0.01 |
ninja | 9.291 0.35 | 7.339 0.37 |
plunder | 3.034 0.17 | 3.264 0.23 |
starpilot | 4.56 0.2 | 3.33 0.42 |
Tables 3,4 provide ablations with respect to the two parameters of DRIML: whether or not to perform data augmentation on all input states (, and ), and which to use for positive samples (we only experimented with ).
Baselines
The baselines were implemented on top of our existing architecture and, for models which use contrastive objectives, used the exactly same networks for measuring similarity (i.e. one residual block for CURL and CPC). CURL was implemented based on the authors’ code included in their paper and that of MoCo, with EMA on the target network as well as data augmentation (random crops and color jittering) on for randomly sampled .
The No Action baseline was tuned on the same budget as DRIML, over and with/without data augmentation. Best results are reported in the main paper.
No action (k=1) | No action (k=5) | |
---|---|---|
bigfish | 1.193 0.04 | 1.33 0.12 |
bossfight | 0.466 0.07 | 0.472 0.01 |
caveflyer | 8.263 0.26 | 5.925 0.18 |
chaser | 0.224 0.01 | 0.229 0.02 |
climber | 1.359 0.13 | 1.574 0.01 |
coinrun | 13.146 1.21 | 9.632 2.8 |
dodgeball | 1.221 0.04 | 1.213 0.09 |
fruitbot | 0.714 0.31 | 5.425 1.33 |
heist | 1.042 0.02 | 0.861 0.07 |
jumper | 2.966 0.1 | 4.314 0.64 |
leaper | 5.403 0.09 | 3.521 0.3 |
maze | 0.984 0.13 | 1.438 0.26 |
miner | 0.11 0.01 | 0.116 0.01 |
ninja | 6.437 0.22 | 5.9 0.38 |
plunder | 2.67 0.08 | 3.2 0.05 |
starpilot | 3.699 0.3 | 2.951 0.31 |