Deep Reinforcement and InfoMax Learning

by   Bogdan Mazoure, et al.
McGill University

Our work is based on the hypothesis that a model-free agent whose representations are predictive of properties of future states (beyond expected rewards) will be more capable of solving and adapting to new RL problems. To test that hypothesis, we introduce an objective based on Deep InfoMax (DIM) which trains the agent to predict the future by maximizing the mutual information between its internal representation of successive timesteps. We provide an intuitive analysis of the convergence properties of our approach from the perspective of Markov chain mixing times and argue that convergence of the lower bound on mutual information is related to the inverse absolute spectral gap of the transition model. We test our approach in several synthetic settings, where it successfully learns representations that are predictive of the future. Finally, we augment C51, a strong RL baseline, with our temporal DIM objective and demonstrate improved performance on a continual learning task and on the recently introduced Procgen environment.


MIME: Mutual Information Minimisation Exploration

We show that reinforcement learning agents that learn by surprise (surpr...

Predictive Information Accelerates Learning in RL

The Predictive Information is the mutual information between the past an...

Data-Efficient Reinforcement Learning with Momentum Predictive Representations

While deep reinforcement learning excels at solving tasks where large am...

Continual Learning In Environments With Polynomial Mixing Times

The mixing time of the Markov chain induced by a policy limits performan...

Bootstrap Latent-Predictive Representations for Multitask Reinforcement Learning

Learning a good representation is an essential component for deep reinfo...

Improved Speech Representations with Multi-Target Autoregressive Predictive Coding

Training objectives based on predictive coding have recently been shown ...

Predictive Modeling in the Presence of Nuisance-Induced Spurious Correlations

Deep predictive models often make use of spurious correlations between t...

1 Introduction

In reinforcement learning (RL), model-based agents are traditionally characterized by their ability to predict future states and rewards based on previous states and actions 

[Sutton and Barto, 1998, Ha and Schmidhuber, 2018, Hafner et al., 2019a]. Model-based methods can be seen through the representation learning [Goodfellow et al., 2017] lens as endowing the agent with internal representations that are predictive of the future conditioned on its actions. This ultimately gives the agent the ability to plan – by e.g. considering a distribution of possible future trajectories and picking the best course of action.

Model-free methods, on the other hand, do not learn an explicit model of the environment, and instead focus on learning a policy that maximizes reward or a function that estimates the optimal values of states and actions 

[Mnih et al., 2013, Schulman et al., 2017, Pong et al., 2018]. They can use large amounts of training data and excel in high-dimensional state and action spaces. However, this is mostly true for fixed reward functions. Despite success on many benchmarks, model-free agents typically generalize poorly when the environment or reward function changes [Farebrother et al., 2018, Tachet des Combes et al., 2018] and can have high sample complexity.

Viewing model-based agents from a representation learning perspective, a desired outcome is an agent able to understand the underlying generative factors of the environment that determine the observed state/action sequences, which would help generalization to environments built from the same generative factors. In addition, learning a predictive model often involves much richer learning signals than those provided by reward alone, which could reduce sample complexity compared to model-free methods.

Our work is based on the hypothesis that a model-free agent whose representations are predictive of properties of future states (beyond expected rewards) will be more capable of solving and adapting to new RL problems, and in a way, incorporate aspects of model-based learning. To learn representations with model-like properties, we consider a self-supervised objective derived from variants of Deep InfoMax [DIM, Hjelm et al., 2018, Bachman et al., 2019, Anand et al., 2019]. We expect this type of contrastive estimation [Hyvarinen and Morioka, 2016] will give the agent a better understanding of the underlying factors of the environment and how they relate to its actions, eventually leading to better performance in transfer and lifelong learning problems. We examine the properties of the learnt representations in simple domains such as disjoint and glued Markov chains, and more complex environments such as a 2d Ising model, a sequential variant of Ms. PacMan from the Atari Learning Environment [ALE, Bellemare et al., 2013], and all 16 games from the Procgen suite [Cobbe et al., 2019]. Our contributions are as follows:

  • We propose a simple auxiliary objective that maximizes concordance between representations of successive states, given the action.

  • We present a series of experiments showing how our objective can be used as a measure of similarity and predictability, and how it behaves in partially deterministic systems.

  • Finally, we show that augmenting a standard C51 agent [Bellemare et al., 2017] with our contrastive objective can i) lead to faster adaptation in a continual learning setting, and ii) improve overall performance on the Procgen suite.

2 Background

Just as humans are able to retain old skills when taught new ones [Wixted, 2004], we strive for RL agents that are able to adapt quickly and reuse knowledge when dealing with a sequence of different tasks with variable reward functions. The reason for this is that real-world applications or downstream tasks can be difficult to predict before deployment, particularly with complex environments involving other intelligent agents such as humans. Unfortunately, this proves to be very challenging even for state-of-the-art systems [Atkinson et al., 2018], leading to complex deployment scenarios.

Continual Learning (CL) is a learning framework meant to benchmark an agent’s ability to adapt to new tasks by using auxiliary information about the relatedness across tasks and timescales [Kaplanis et al., 2018, Mankowitz et al., 2018]. Meta-learning [Thrun and Pratt, 1998, Finn et al., 2017] and multi-task learning [Hessel et al., 2019, D’Eramo et al., 2019] have shown good performance in CL by explicitly training the agent to transfer well between tasks.

In this study, we focus on the following inductive bias: while the reward function may change or vary, the underlying environment dynamics typically do not change as much111This is not true in all generalization settings. Generalization still has a variety of specifications within RL. In our work, we focus on the setting where the rewards change more rapidly than the environment dynamics.. To test if that inductive bias is useful, we use

auxiliary loss functions

to encourage the agent to learn about the underlying generative factors and their associated dynamics in the environment, which can result in better sample efficiency and transfer capabilities (compared to learning from rewards only). Previous work has shown this idea to be useful when training RL agents: e.g., Jaderberg et al. [2016] train the agent to predict future states given the current state-action pair, while Mohamed and Rezende [2015] uses empowerment to measure concordance between a sequence of future actions and the end state. Recent work such as DeepMDP [Gelada et al., 2019] uses a latent variable model to represent transition and reward functions in a high-dimensional abstract space. In model-based RL, various agents, such as PlaNet [Hafner et al., 2019b], Dreamer [Hafner et al., 2019a] or MuZero [Schrittwieser et al., 2019], have also shown strong asymptotic performance.

Contrastive representation learning methods are based on training an encoder to capture information that is shared across different views of the data in the features it produces for each input. The similar (i.e. positive) examples are typically either taken from different “locations” of the data [e.g., spatial patches or temporal locations, see Hjelm et al., 2018, Oord et al., 2018, Anand et al., 2019, Hénaff et al., 2019] or obtained through data augmentation [Wu et al., 2018, He et al., 2019, Bachman et al., 2019, Tian et al., 2019, Chen et al., 2020]. Contrastive models rely on a variety of objectives to encourage similarity between features. Typically, a scoring function 

[e.g., dot product or cosine similarity between pairs of features, see Wu et al.,

2018] that lower-bounds mutual information is maximized [Belghazi et al., 2018, Hjelm et al., 2018, Oord et al., 2018, Poole et al., 2019]. In an RL setting, Oord et al. [2018] augment an A2C agent with a contrastive objective at different timesteps which improves performance on 5 DeepMind lab games [Beattie et al., 2016]. CURL [Srinivas et al., 2020] uses a contrastive objective between augmented versions of the current timestep. Finally, EMI [Kim et al., 2019] uses a Jensen-Shannon divergence based lower bound on mutual information to form an exploration bonus. In comparison, our work focuses on the predictability of future states based on the current state-action pair at multiple scales [Bachman et al., 2019] in the encoder.

3 Preliminaries

3.1 Markov Chains

Given a discrete state space

with probability measure

, a discrete-time homogeneous Markov chain (MC) is a collection of random variables with the following property on its

transition matrix : . Assuming the Markov chain is ergodic, its invariant distribution222The existence and uniqueness of are direct results of the Perron-Frobenius theorem.

is the principal eigenvector of

, which verifies and summarizes the long-term behaviour of the chain. We define the marginal distribution of as , and the initial distribution of as .

Theorem 3.1 (Levin and Peres [2017])

Let and be the mixing time of the chain induced by . If

are the eigenvalues of

ordered by decreasing magnitude, then converges to arbitrarily small with rate :


3.2 Markov Decision Processes

A discrete-time, finite-horizon Markov Decision Process 

[Bellman, 1957, Puterman, 2014, MDP] comprises a state space , an action space333We consider discrete state and action spaces. , a transition kernel , a reward function and a discount factor . At every timestep , an agent interacting with this MDP observes the current state , selects an action , and observes a reward upon transitioning to a new state . The goal of an agent in a discounted MDP is to learn a policy such that taking actions maximizes the expected sum of discounted returns,

To convert a MDP into a MC, one can let

, an operation which can be easily tensorized for computational efficiency in small state spaces 

[see Mazoure et al., 2020].

For this paper, we use the C51 algorithm [Bellemare et al., 2017] for training the agent due to its simplicity and good performance on control tasks from pixels. C51 minimizes the following loss:



is the Kullback-Leibler divergence,

is the distribution of discounted returns (s.t. ), is the distributional Bellman operator [Bellemare et al., 2019] and is an operator which projects onto a fixed support of atoms.

3.3 State-action mutual information maximization

Mutual information (MI) measures the amount of information shared between a pair of random variables and can be estimated using neural networks 

[Belghazi et al., 2018]. Recent representation learning algorithms [Oord et al., 2018, Hjelm et al., 2018, Hénaff et al., 2019, Tian et al., 2019, He et al., 2019] train encoders to maximize MI between features taken from different views

of the input – e.g., different patches in an image or different versions of an image produced by applying data augmentation. These algorithms commonly optimize lower bounds on the MI based on noise-contrastive estimation 

[NCE, Gutmann and Hyvärinen, 2010, Oord et al., 2018].

Let be some fixed temporal offset. Running a policy in the MDP generates a distribution over tuples , where corresponds to the state of at some timestep , to the action selected by in state and to the state of at timestep , reached by following . We let , and

denote the corresponding random variables. We denote the joint distribution of these variables and their associated marginals using

. We consider maximizing the MI between state-action pairs and their future states , which can be written as follows:


Estimating the MI amounts to training a classifier that discriminates between a sample

drawn from the joint distribution and a sample from the product of marginals in the denominator of Eq. 3. A sample from the product of marginals can be obtained by replacing (we call this the positive sample) with a state picked at random from another trajectory (we call this a negative sample). Letting denote a set of such negative samples, the infoNCE loss function [Oord et al., 2018] that we use to maximize a lower bound on the MI in Eq. 3 takes the following form:


where are features that depend on state-action pairs and states, respectively, and is a function that outputs a scalar-valued score. Minimizing with respect to , and maximizes the MI between these features. In practice, we construct by including all states from other tuples in the same minibatch as the relevant . I.e., for a batch containing tuples , each would contain negative samples.

4 Architecture and Algorithm

We now specify forms for the functions , and . We consider a deep neural network which maps input states onto a sequence of progressively more “global” (or less “local”) feature spaces. In practice, is a deep residual and CNN composed of functions that sequentially map inputs to features (lower to upper “levels” of the network).

The features are the output of the network’s last layer and correspond to the standard C51 value heads (i.e., they span a space of 51 atoms per action) 444 can equivalently be seen as the network used in a standard C51.. For the auxiliary objective, we follow a variant of Deep InfoMax [DIM, Hjelm et al., 2018, Anand et al., 2019, Bachman et al., 2019], and train the encoder to maximize the mutual information (MI) between local and global “views” of tuples . The local and global views are realized by selecting and

respectively. In order to simultaneously estimate and maximize the MI, we embed the action (represented as a one-hot vector) using a function

. We then map the local states and the embedded action using a function , and do the same with the global states , i.e., . In addition, we have two more functions, and that map features without the actions, which will be applied to features from “future” timesteps. Note that can be thought of as a product of local spaces (corresponding to different patches in the input, or equivalently different receptive fields), each with the same dimensionality as .

We use the outputs of these functions to produce a scalar-valued score between any combination of local and global representations of state and , conditioned on action :


In practice, for the functions that take features and actions as input, we simply concatenate the values at position (local) or (global) with the embedded action , and feed the resulting tensor into the appropriate function or . All functions that process global and local features are computed using convolutions. See Figure 1 for a visual representation of our model.

We use the scores from Eq. 5 when computing the infoNCE loss [Oord et al., 2018] for our objective, using tuples sampled from trajectories stored in an experience replay buffer:


Combining Eq. 6 with the C51 update in Eq. 2 yields our full training objective, which we call DRIML 555Deep Reinforcement and InfoMax Learning. We optimize , and jointly using a single loss function:


Note that, in practice, the compute cost which Eq. 7 adds to the core C51 RL algorithm is minimal, since it only requires additional passes through the (small) state/action embedding functions followed by an outer product.

Input : Batch sampled from the replay buffer, , strictly positive integer
Update using Eq. 2;
for N in 3,4 do
       for M in 3,4 do
             if  then
                   Compute using Eq. 6 (see Appendix 8.3

for PyTorch code);

                   Update , and using gradients of ;
             end if
       end for
end for
Algorithm 1 Deep Reinforcement and InfoMax Learning
Figure 1: (a) Model architecture used for the encoder used for the RL and DIM objectives and (b) distribution of reference, positive and negative samples within training batch .

The proposed Algorithm 1 introduces an auxiliary loss which improves predictive capabilities of value-based agents by boosting similarity of representations close in time.

5 Predictability and Contrastive Learning

Information maximization has long been considered one of the standard principles for measuring correlation and performing feature selection 

[Song et al., 2012]. In the MDP context, high values of indicate that and have some form of dependence, while low values suggest independence. The fact that predictability (or more precisely determinism) in Markov systems is linked to the MI suggests a deeper connection to the spectrum of the transition kernel . For instance, the set of eigenvalues of for a Markov decision process contains important information about the connectivity of said process, such as mixing time or number of densely connected clusters [Von Luxburg, 2007, Levin and Peres, 2017].

Consider the setting in which is fixed at some iteration in the optimization process. In the rest of this section, we let denote the expected transition model (it is a Markov chain). We let be the ratio learnt when optimizing the infoNCE loss on samples drawn from the random variables and (for a fixed [Oord et al., 2018]. We also let be that ratio when the Markov chain has reached its stationary distribution (see Section 3.1), and be the scoring function learnt using InfoNCE (which converges to in the limit of infinite samples drawn from ).

Proposition 1

Let . Assume at time step , training of has close to converged on a pair , i.e. . Then the following holds:


The proof can be found in App. 8.4. Proposition 1 combined with Theorem 3.1 suggests that faster convergence of to happens when the spectral gap of is large, or equivalently when is small. It follows that, on one hand, mutual information is a natural measure of concordance of pairs and can be maximized using data-efficient, batched gradient methods. On the other hand, the rate at which the InfoNCE loss converges to its stationary value (ie maximizes the lower bound on MI) depends on the spectral gap of , which is closely linked to predictability (see Appendix 8.2).

6 Experiments

In this section, we first show how our proposed objective can be used to estimate state similarity in single Markov chains. We then show that DRIML can capture dynamics in locally deterministic systems (Ising model), which is useful in domains with partially deterministic transitions. We then provide results on a continual version of the Ms.PacMan game where the DIM loss is shown to converge faster for more deterministic tasks, and to help in a continual learning setting. Finally, we provide results on Procgen [Cobbe et al., 2019], which show that DRIML performs well when trained on 500 fixed levels. All experimental details can be found in App. 8.5.

6.1 DIM learns a transition ratio model

We first study the behaviour of contrastive losses on a simple Markov chain describing a biased random walk in . The bias is specified by a single parameter . The agent starting at state transitions to with probability and to otherwise. The agent stays in states and with probability and , respectively. We encode the current and next states (represented as one-hots) using a 1-hidden layer MLP666The action is simply ignored in this setting. (corresponding to and in equation 4), and then optimize the NCE loss  (4) (the scoring function is also 1-hidden layer MLP) to maximize the MI between representations of successive states. Results are shown in Fig. 2b, they are well aligned with the true transition matrix (Fig. 2c).

Figure 2: (a) Ratio of transition matrix over stationary vector for the random walk with , (b) the prediction matrix of being a pair of successive states learnt by , (c) the closed-form mutual information between consecutive states in time as a function of (with simplified endpoint conditions) and (d) the true inverse spectral gap as a function of .

The spectral gap of the transition matrix can be computed in closed-form as a function of . Its lowest value is reached in the neighbourhood , corresponding to the point where the system is least predictable (as shown by the mutual information, Fig 2c). Derivations are available in Appx 8.5.1.

6.2 DIM can capture complex partially deterministic dynamics

The goal of this experiment is to highlight the predictive capabilities of our DIM objective in a partially deterministic system. We consider a dynamical system composed of pixels with values in , . At the beginning of each episode, a patch corresponding to a quarter of the pixels is chosen at random in the grid. Pixels that do not belong to that patch evolve fully independently (). Pixels from the patch obey a local dependence law, in the form of a standard Ising model777 the value of a pixel at time only depends on the value of its neighbors at time . This local dependence is obtained through a function : (see Appx 8.6.1 for details). Figure 3 shows the system at during three different episodes (black pixels correspond to values of , white to ). The patches are very distinct from the noise. We then train a convolutional encoder using our DIM objective on local “views” only (see Section 4).

Figure 3: Ising model with temperature overlaid onto a lattice of uniformly random spins . The grayscale plots show each of the 3 systems at ; the color plots show the DIM similarity scores between and .

Figure 3 shows the similarity scores between the local features of states at and the same features at (a local feature corresponds to a specific location in the convolutional maps)888We chose early timesteps to make sure that the model does not simply detect large patches, but truly measures predictability.. The heatmap regions containing the Ising model (larger-scale patterns) have higher scores than the noisy portions of the lattice. Local DIM is able to correctly encode regions of high temporal predictability.

6.3 A Continual Learning experiment on Ms.PacMan

We further complicate the task of the Ising model prediction by building on top of the Ms. PacMan game and introducing non-trivial dynamics. The environment is shown in Figure 4a.

In order to assess how well our auxiliary objective captures predictability in this MDP, we define its dynamics such that . Intuitively, as , the enemies’ actions become less predictable, which in turn hinders the convergence rate of the contrastive loss. The four runs in Figure 4b correspond to various values of . We trained the agent using our objective. We can see that the convergence of the auxiliary loss becomes slower with growing , as the model struggles to predict given . After 100k frames, the NCE objective allows to separate the four MDPs according to their principal source of randomness (red and blue curves). When is close to 1, the auxiliary loss has a harder time finding features that predict the next state, and eventually ignores the random movements of enemies.

The second and more interesting setup we consider consists in making only one out of 4 enemies lethal, and changing which one every 5k episodes. Figure 4c shows that, as training progresses, C51 always reaches the same performance at the end of the 5k episodes, while DRIML’s steadily increases. C51 learns to ignore the harmless ghosts (they have no connection to the reward signal) and has to learn the task from scratch every time the lethal ghost changes. On the other hand, DRIML is incentivized to encode information about all the predictable objects on the screen (including the harmless ghosts), and as such adapts faster and faster to changes.

Figure 4: (a) The simplified Ms.PacMan environment, (b) average training NCE loss for various values of as a function of timesteps and (c) average training reward with only one harmful enemy per level (dashed line indicates average terminal C51 performance after each task).

6.4 Performance on Procgen Benchmark

Finally, we demonstrate the beneficial impact of adding a DIM-like objective to C51 (DRIML) on the 16 Procgen tasks [Cobbe et al., 2019]. All algorithms are trained for 50M environment frames with the DQN [Mnih et al., 2015]

architecture. The mean and standard deviation of the scores (over 3 seeds) are shown in Table 

1; bold values indicate best performance.

DRIML (Ours) C51 CURL No Action Random
bigfish 2.023 0.18 1.33 0.12 2.697 1.3 1.192 0.04 0.333 0.47
bossfight 0.672 0.02 0.573 0.05 0.595 0.06 0.472 0.01 0
caveflyer 10.18 0.41 9.187 0.29 6.938 0.25 8.259 0.26 0
chaser 0.286 0.02 0.222 0.04 0.353 0.04 0.229 0.02 0.04 0
climber 2.256 0.05 1.678 0.1 1.751 0.09 1.574 0.01 0
coinrun 27.238 1.92 29.701 5.44 21.166 1.94 13.146 1.21 0
dodgeball 1.28 0.02 1.198 0.08 1.093 0.04 1.221 0.04 0.667 0.94
fruitbot 5.399 1.02 3.856 0.96 4.887 0.71 5.425 1.33 0.333 0.47
heist 1.296 0.05 1.537 0.1 1.056 0.05 1.042 0.02 0
jumper 12.639 0.64 13.225 0.83 10.273 0.61 4.314 0.64 0
leaper 6.168 0.29 5.034 0.14 3.943 0.46 5.403 0.09 0
maze 1.381 0.08 2.355 0.09 0.823 0.2 1.438 0.26 0
miner 0.14 0.01 0.126 0.01 0.096 0 0.116 0.01 0
ninja 9.209 0.25 9.36 0.01 5.839 1.21 6.438 0.22 0
plunder 3.366 0.17 2.994 0.07 2.771 0.14 3.2 0.05 0.667 0.47
starpilot 4.562 0.21 2.445 0.12 2.683 0.09 3.699 0.3 0.333 0.47
Table 1: Average training returns collected after 50M of training frames, one standard deviation. No Action corresponds to DRIML without the action embedding.

Similarly to CURL, we used data augmentation on inputs to DRIML to improve the model’s predictive capabilities in fast-paced environments (see App. 8.6.3). While we used the global-global loss in DRIML’s objective, we have found that the local-local loss also had a beneficial effect on performance on a smaller set of games (e.g. starpilot, which has few moving entities on a dark background).

7 Discussion

In this paper, we introduced an auxiliary objective called Deep Reinforcement and InfoMax Learning (DRIML), which is based on maximizing concordance of state-action pairs with future states (at the representation level). Our objective has a close connection with the spectrum of the Markov transition matrix and predictability of Markovian systems, which dictate its long-term behaviour. We presented results showing that 1) DRIML implicitly learns a transition model by boosting state similarity, 2) it improves C51 in a continual learning setting and 3) it boosts training performance in complex domains such as Procgen.


We thank Harm van Seijen, Ankesh Anand, Mehdi Fatemi, Romain Laroche and Jayakumar Subramanian for useful feedback and helpful discussions.

Broader Impact

This work proposes an auxiliary objective for model-free reinforcement learning agents. The objective shows improvements in a continual learning setting, as well as on average training rewards for a suite of complex video games. While the objective is developed in a visual setting, maximizing mutual information between features is a method that can be transported to other domains, such as text. Potential applications of deep reinforcement learning are (among others) healthcare, dialog systems, crop management, robotics, etc. Developing methods that are more robust to changes in the environment, and/or perform better in a continual learning setting can lead to improvements in those various applications. At the same time, our method fundamentally relies on deep learning tools and architectures, which are hard to interpret and prone to failures yet to be perfectly understood. Additionally, deep reinforcement learning also lacks formal performance guarantees, and so do deep reinforcement learning agents. Overall, it is essential to design failsafes when deploying such agents (including ours) in the real world.


8 Appendix

8.1 Link to invariant distribution

For a discrete state ergodic Markov chain specified by and initial occupancy vector , its marginal state distribution at time is given by the Chapman-Kolmogorov form:


and its limiting distribution is the infinite-time marginal


which, if it exists, is exactly equal to the invariant distribution .

For the very restricted family of ergodic MDPs under fixed policy , we can assume that converges to a time invariant distribution .



Now, observe that is closely linked to when samples come from timesteps close to . That is, interchanging swapping and at any state would yield at most error. Moreover, existing results [Levin and Peres, 2017] from Markov chain theory provide bounds on depending on the structure of the transition matrix.

If has a limiting distribution , then using the dominated convergence theorem allows to replace matrix powers by , which is then replaced by the invariant distribution :


Of course, most real-life Markov decision processes do not actually have an invariant distribution since they have absorbing (or terminal) states. In this case, as the agent interacts with the environment, the DIM estimate of MI yields a rate of convergence which can be estimated based on the spectrum of .

Moreover, one could argue that since, in practice, we use off-policy algorithms for this sort of task, the gradient signal comes from various timesteps within the experience replay, which drives the model to learn features that are consistently predictive through time.

8.2 Predictability and spectrum of an MDP

In the simplest case of a discrete MDP, fully describes the dynamics of the environment. The row corresponds to the probability of landing in any state by taking action from state . Since a perfectly deterministic environment would transition to exactly one state with probability 1, an agent equipped with will know exactly what would look like and adjust its planning accordingly. Simply speaking, in a predictable Markov environment, the values of the random vector should be maximally correlated with those of .

As seen in the main paper, dependence of on can be explicitly encouraged by boosting the mutual information of the coupling. How effective the model is at solving this optimization problem depends mostly on the topology of the system: estimating the AMI for a time-homogeneous, ergodic MDP turns out to be quite easy, as shown in the experiments section. The rate of convergence of the InfoNCE bound to true AMI depends mostly on the inverse spectral gap of the transition kernel. More generally, if we think of as inducing a transition graph with the following rule:


then the spectral gap is related to the Cheeger constant of . For instance, if is allowed to transition from any state to any other state, the corresponding graph would be fully connected, while having a deterministic transition would make the graph a maximal directed pseudoforest. In summary, if one has access to the spectral gap of an MC, they can roughly estimate the connectedness of the system: the sparser the graph, the easier the search problem.

The spectrum of is trivial to estimate in the tabular case; for the large-dimensional case, one can use the variational power method [Wen et al., 2020, VPM] by applying Hotteling’s deflation:

  1. Using samples from

    , find principal eigenfunction

    via VPM;

  2. The corresponding eigenvalue should be close to and found as ;

  3. Use rejection sampling on the same dataset as in to form a new dataset s.t. ;

  4. Run VPM on the dataset from to find and repeat for smaller eigenvalues.

In our experiments, we managed to exactly estimate the spectral gap for tabular Markov chains, a task that is more complicated for large-dimensional inputs (e.g. pixels) since the transition model is typically parameterized by a neural network.

8.3 Code snippet for DIM objective scores

The following snippet yields pointwise (i.e. not contracted) scores given a batch of data. [mathescape, linenos, numbersep=5pt, gobble=2, frame=lines, framesep=2mm]python

def temporal_DIM_scores(reference,positive,clip_val=20): """ reference: n_batch × n_rkhs × n_locs positive: n_batch x n_rkhs x n_locs """ reference = reference.permute(2,0,1) positive = positive.permute(2,1,0) # reference: n_loc × n_batch × n_rkhs # positive: n_locs × n_rkhs × n_batch pairs = torch.matmul(reference, positive) # pairs: n_locs × n_batch × n_batch pairs = pairs / reference.shape[2]**0.5 pairs = clip_val * torch.tanh((1. / clip_val) * pairs) shape = pairs.shape scores = F.log_softmax(pairs, 2) # scores: n_locs × n_batch × n_batch mask = torch.eye(shape[2]).unsqueeze(0).repeat(shape[0],1,1) # mask: n_locs × n_batch × n_batch scores = scores * mask # scores: n_locs × n_batch × n_batch return scores

To obtain a scalar out of this batch, sum over the third dimension and then average over the first two.

8.4 Proof of Proposition 1

For a given timestep and a pair , let be the true ratio at time and let be the true ratio when the Markov chain has reached its stationary state.

Let be a function we train using the infoNCE loss on samples drawn from the random variables and (for a fixed ).

The functions , and all follow the form .

Proof 1

Let us consider fixed , and . First, since , we have

Or in other terms: . Now, we have:

By assumption on , we know that , which concludes the proof.

8.5 Experiment details

All experiments involving RGB inputs (Ising, Ms.PacMan and Procgen) were ran with the settings shown in Table 2.

Name Description Value
Exploration at 0.1
Exploration at 0.01
Exploration decay
LR Learning rate
Discount factor
Clip grad Gradient clip norm
N-step-return N-step return
Frame stack Number of stacked frames (Ising and Procgen)
Grayscale Grayscale or RGB RGB
Input size State input size (Ising and Ms.PacMan)
Warmup steps 1000
Replay size Size of replay buffer
Target soft update coeff
Clip reward Reward clipping False
Global-global DIM 1 (Ms.PacMan and Procgen)
0 (Ising)
Local-local DIM 0 (Ms.PacMan and Procgen)
1 (Ising)
Local-global DIM 0
Global-local DIM 0
DIM lookahead constant 1 (Ising and Ms.PacMan)
1 and 5 (Procgen)
Table 2: Experiments’ parameters

The global DIM heads consist of a standard single hidden fully-connected layer network of 512 with ReLU activations and a skip-connection from input to output layers. The action is transformed into one-hot and then encoded using a 64 unit layer, after which it is concatenated with the state and passed to the global DIM head.

The local DIM heads consist of a single hidden layer network made of convolution. The action is tiled to match the shape of the convolutions, encoded using a convolutions and concatenated along the feature dimension with the state, after which is is passed to the local DIM head.

In the case of the Ising model, there is no decision component and hence no concatenation is required.

8.5.1 AMI of a biased random walk

We see from the formulation of the mutual information objective that it inherently depends on the ratio of . Recall that, for a 1-d random walk on integers , the stationary distribution is a function of and can be found using the recursion . It has the form


for .

The pointwise mutual information between states and is therefore the random variable


with expectation equal to the average mutual information which we can find by maximizing, among others, the InfoNCE bound. We can then compute the AMI as a function of


which is shown in Figure 2c.

The figures were obtained by training the global DIM objective

on samples from the chain for 1,000 epochs with learning rate


8.6 Experiments with convolutional networks

8.6.1 Ising model

We start by generating an rectangular lattice which is filled with Rademacher random variables ; that is, taking or with some probability . For any , the joint distribution factors into the product of marginals .

At every timestep, we uniformly sample a random index tuple and evolve the set of nodes according to an Ising model with temperature , while the remaining nodes continue to independently take the values with equal probability. If one examines any subset of nodes outside of , then the information conserved across timesteps would be close to 0, due to observations being independent in time.

However, examining a subset of at timestep allows models based on mutual information maximization to predict the configuration of the system at , since this region has high mutual information across time due to the ratio being directly proportional to the temperature parameter .

To obtain the figure, we trained local DIM on sample snapshots of the Ising model as grayscale images for 10 epochs. The local DIM scores were obtained by feeding a snapshot of the Ising model at ; showing it snapshots from later timestep would’ve made the task much easier since there would be a clear difference in granularities of the random pattern and Ising models.

8.6.2 Ms.PacMan

In PacMan, the agent, represented by a yellow square, must collect food pellets while avoiding four harmful ghosts. When the agent collects one of the boosts, it becomes invincible for 10 steps, allowing it to destroy the enemies without dying. In their turn, ghosts alternate between three behaviours: 1) when the agent is not within line-of-sight, wander randomly, 2) when the agent is visible and does not have a boost, follow them and 3) when the agent is visible and has a boost, avoid them. The switch between these three modes happens stochastically and quasi-independently for all four ghosts. Since the food and boost pellets are fixed at the beginning of each episode, randomness in the MDP comes from the ghosts as well as the agent’s actions.

The setup for our first experiment in the domain is as follows: with a fixed probability , each of the 4 enemies take a random action instead of following one of the three movement patterns.

The setup for our second experiment in the domain consists of four levels: in each level, only one out of the four ghosts is lethal - the remaining three behave the same but do not cause damage. The model trains for 5,000 episodes on level 1, then switches to level 3, then level 3 and so forth. This specific environment tests for the ability of DIM to quickly figure out which of the four enemies is the lethal one and ignore the remaining three based on color .

For our study, the state space consisted of RGB images. The inputs to the model were states re-scaled to by stacking 4 consecutive frames, which were then concatenated with actions using an embedding layer.

8.6.3 Procgen

The data augmentation steps performed on and fed to the DIM loss consisted of a random crop ( of the original’s size) with color jitter with parameters . Although the data augmentation is helpful on some tasks (typically fast-paced, requiring a lot of camera movements), it has shown detrimental effects on others. Below is a list of games on which data augmentation was beneficial: bigfish, bossfight, chaser, coinrun, jumper, leaper and ninja.

The parameter, which specifies how far into the future the model should make its predictions, worked best when set to 5 on the games: bigfish, chaser, climber, fruitbot, jumper, miner, maze and plunder. For the remaining games, setting yielded better performance.

DRIML (no data aug) DRIML (data aug)
bigfish 1.344 0.21 1.786 0.52
bossfight 0.482 0.04 0.671 0.03
caveflyer 8.774 0.41 10.131 0.3
chaser 0.218 0.02 0.277 0.02
climber 2.232 0.1 1.872 0.15
coinrun 18.092 2.63 25.69 2.34
dodgeball 1.27 0.07 1.129 0.1
fruitbot 5.302 0.8 5.26 0.46
heist 1.295 0.03 1.298 0.07
jumper 4.239 0.59 10.903 2.94
leaper 5.089 0.12 6.089 0.31
maze 1.283 0.23 1.183 0.44
miner 0.137 0.01 0.116 0.02
ninja 7.678 0.58 9.291 0.35
plunder 3.264 0.23 2.819 0.1
starpilot 4.56 0.2 3.062 0.22
Table 3: Ablation of the impact of data augmentation on DRIML’s training performance (50M training frames).
DRIML (k=1) DRIML (k=5)
bigfish 1.768 1.43 1.786 0.52
bossfight 0.636 0.08 0.671 0.03
caveflyer 10.131 0.3 7.596 0.23
chaser 0.218 0.02 0.277 0.02
climber 1.89 0.18 2.232 0.1
coinrun 25.69 2.34 22.403 1.89
dodgeball 1.207 0.01 1.27 0.07
fruitbot 3.182 1.34 5.302 0.8
heist 1.298 0.07 1.165 0.08
jumper 9.729 0.65 10.903 2.94
leaper 6.089 0.31 4.219 0.4
maze 0.996 0.05 1.283 0.23
miner 0.118 0.01 0.137 0.01
ninja 9.291 0.35 7.339 0.37
plunder 3.034 0.17 3.264 0.23
starpilot 4.56 0.2 3.33 0.42
Table 4: Ablation of the impact of predictive timestep in NCE objective (i.e. ) on DRIML’s training performance (50M training frames).

Tables 3,4 provide ablations with respect to the two parameters of DRIML: whether or not to perform data augmentation on all input states (, and ), and which to use for positive samples (we only experimented with ).


The baselines were implemented on top of our existing architecture and, for models which use contrastive objectives, used the exactly same networks for measuring similarity (i.e. one residual block for CURL and CPC). CURL was implemented based on the authors’ code included in their paper and that of MoCo, with EMA on the target network as well as data augmentation (random crops and color jittering) on for randomly sampled .

The No Action baseline was tuned on the same budget as DRIML, over and with/without data augmentation. Best results are reported in the main paper.

No action (k=1) No action (k=5)
bigfish 1.193 0.04 1.33 0.12
bossfight 0.466 0.07 0.472 0.01
caveflyer 8.263 0.26 5.925 0.18
chaser 0.224 0.01 0.229 0.02
climber 1.359 0.13 1.574 0.01
coinrun 13.146 1.21 9.632 2.8
dodgeball 1.221 0.04 1.213 0.09
fruitbot 0.714 0.31 5.425 1.33
heist 1.042 0.02 0.861 0.07
jumper 2.966 0.1 4.314 0.64
leaper 5.403 0.09 3.521 0.3
maze 0.984 0.13 1.438 0.26
miner 0.11 0.01 0.116 0.01
ninja 6.437 0.22 5.9 0.38
plunder 2.67 0.08 3.2 0.05
starpilot 3.699 0.3 2.951 0.31
Table 5: Ablation of the impact of predictive timestep in NCE objective (i.e. ) on the no action model’s training performance (50M training frames).