Data-Efficient Reinforcement Learning with Momentum Predictive Representations

While deep reinforcement learning excels at solving tasks where large amounts of data can be collected through virtually unlimited interaction with the environment, learning from limited interaction remains a key challenge. We posit that an agent can learn more efficiently if we augment reward maximization with self-supervised objectives based on structure in its visual input and sequential interaction with the environment. Our method, Momentum Predictive Representations (MPR), trains an agent to predict its own latent state representations multiple steps into the future. We compute target representations for future states using an encoder which is an exponential moving average of the agent's parameters, and we make predictions using a learned transition model. On its own, this future prediction objective outperforms prior methods for sample-efficient deep RL from pixels. We further improve performance by adding data augmentation to the future prediction loss, which forces the agent's representations to be consistent across multiple views of an observation. Our full self-supervised objective, which combines future prediction and data augmentation, achieves a median human-normalized score of 0.444 on Atari in a setting limited to 100K steps of environment interaction, which is a 66 Moreover, even in this limited data regime, MPR exceeds expert human scores on 6 out of 26 games.


page 1

page 2

page 3

page 4


Environment Predictive Coding for Embodied Agents

We introduce environment predictive coding, a self-supervised approach t...

Bootstrap Latent-Predictive Representations for Multitask Reinforcement Learning

Learning a good representation is an essential component for deep reinfo...

Value-Consistent Representation Learning for Data-Efficient Reinforcement Learning

Deep reinforcement learning (RL) algorithms suffer severe performance de...

Deep Reinforcement and InfoMax Learning

Our work is based on the hypothesis that a model-free agent whose repres...

Improving Model-Based Reinforcement Learning with Internal State Representations through Self-Supervision

Using a model of the environment, reinforcement learning agents can plan...

CCLF: A Contrastive-Curiosity-Driven Learning Framework for Sample-Efficient Reinforcement Learning

In reinforcement learning (RL), it is challenging to learn directly from...

Learning Instance Segmentation by Interaction

We present an approach for building an active agent that learns to segme...

1 Introduction

Deep Reinforcement Learning (deep RL, François-Lavet et al., 2018) has proven to be an indispensable tool for training successful agents on difficult sequential decision-making problems (Bellemare et al., 2013; Tassa et al., 2018). The success of deep RL is particularly noteworthy in highly complex strategic games such as StarCraft (Vinyals et al., 2019) and DoTA2 (OpenAI et al., 2019), where deep RL agents now surpass expert human performance in some scenarios.

Deep RL involves training agents based on large neural networks using large amounts of data 

(Sutton, 2019), a trend evident across both model-based (Schrittwieser et al., 2019) and model-free (Schrittwieser et al., 2019) learning. The sample complexity of such state-of-the-art agents is often incredibly high: MuZero (Schrittwieser et al., 2019) and Agent-57 (Badia et al., 2020) use 10-50 years of experience per Atari game, and OpenAI Five (OpenAI et al., 2019) uses 45,000 years of experience to accomplish its remarkable performance. This is clearly impractical: unlike easily-simulated environments such as video games, collecting interaction data for many real-world tasks is costly, making improved data efficiency a prerequisite for successful use of deep RL in these settings (Dulac-Arnold et al., 2019).

Figure 1: Median and Mean Human-Normalized scores of different methods across 26 games in the Atari 100k benchmark (Kaiser et al., 2019), averaged over 5 random seeds. Each each method is allowed access to only 100k environment steps or 400k frames per game. (*) indicates that the method uses data augmentation. MPR achieves a state-of-art-result on both mean and median human-normalized scores. Note that without data augmentation MPR still outperforms prior methods that use data augmentation.

Meanwhile, new self-supervised representation learning methods have significantly improved data efficiency when learning new vision and language tasks, particularly in low data regimes or semi-supervised learning

(Xie et al., 2019; Hénaff et al., 2019; Chen et al., 2020b). Self-supervised methods improve data efficiency by leveraging a nearly limitless supply of training signal from tasks generated on-the-fly, based on “views" drawn from the natural structure of the data (e.g., image patches, data augmentation or temporal proximity, see Doersch et al., 2015; Oord et al., 2018; Hjelm et al., 2019; Tian et al., 2019; Bachman et al., 2019; He et al., 2019; Chen et al., 2020a).

Motivated by successes in semi-supervised and self-supervised learning (Tarvainen and Valpola, 2017; Xie et al., 2019; Grill et al., 2020), we train better state representations for RL by forcing representations to be temporally predictive and consistent when subject to data augmentation. Specifically, we extend a strong model-free agent by adding a dynamics model which predicts future latent representations provided by a parameter-wise exponential moving average of the agent itself. We also add data augmentation to the future prediction task, which enforces consistency across different views of each observation. Contrary to some methods (Kaiser et al., 2019; Hafner et al., 2019), our dynamics model operates entirely in the latent space and doesn’t rely on reconstructing raw states.

We evaluate our method, which we call Momentum Predictive Representations (MPR), on 26 games in the Atari 100k benchmark (Kaiser et al., 2019), where agents are allowed only 100k steps of environment interaction (producing 400k frames of input) per game, which roughly corresponds to two hours of real-time experience. Notably, the human experts in Mnih et al. (2015) and Van Hasselt et al. (2016) were given the same amount of time to learn these games, so a budget of 100k steps permits a reasonable comparison in terms of data efficiency.

In our experiments, we augment a modified version of Data-Efficient Rainbow (DER) (van Hasselt et al., 2019) with the MPR loss, and evaluate versions of MPR with and without data augmentation. We find that each version is superior to controlled baselines. When coupled with data augmentation, MPR achieves a median score of 0.444, which is a state-of-the-art result on this benchmark, outperforming prior methods by a significant margin. Notably, MPR also outperforms human expert scores on out of games while using roughly the same amount of in-game experience.

2 Method

We consider reinforcement learning (RL) in the standard Markov Decision Process (MDP) setting where an agent interacts with its environment in

episodes, each consisting of sequences of observations, actions and rewards. We use , and to denote the observation, the action taken by the agent and the reward received at timestep . We seek to train an agent whose expected cumulative reward in each episode is maximized. To do this, we combine a strong model-free RL algorithm, Rainbow (Hessel et al., 2018), with Momentum Predictive Representations as an auxiliary loss to improve sample efficiency. We now describe our overall approach in detail.

2.1 Deep Q-Learning

We focus on the Atari Learning Environment (Bellemare et al., 2013), a challenging setting where the agent takes discrete actions while receiving purely visual, pixel-based observations. A prominent method for solving Atari, Deep Q Networks (Mnih et al., 2015), trains a neural network to approximate the agent’s current Q-function (policy evaluation) while updating the agent’s policy greedily with respect to this Q-function (policy improvement). This involves minimizing the error between predictions from

and a target value estimated by

, an earlier version of the network:


Various improvements have been made over the original DQN: Distributional RL (Bellemare et al., 2017) models the full distribution of future reward rather than just the mean, Dueling DQN (Wang et al., 2016) decouples the value of a state from the advantage of taking a given action in that state, Double DQN (Van Hasselt et al., 2016) modifies the Q-learning update to avoid overestimation due to the operation, among many others. Rainbow (Hessel et al., 2018) consolidates these improvements into a single combined algorithm and has been adapted to work well in data-limited regimes (van Hasselt et al., 2019).

2.2 Momentum Predictive Representations

Figure 2: An illustration of the full MPR method. Representations from the online encoder are used in the reinforcement learning task and for prediction of future representations from the momentum encoder via the transition model. The momentum encoder and projection head are defined as an exponential moving average of their online counterparts and are not updated via gradient descent. For brevity, we illustrate only the step of future prediction, but in practice we compute the loss over all steps from to . Note: our implementation for this paper includes in the Q-learning head.

For our auxiliary loss, we start with the intuition that encouraging state representations to be predictive of future observations given future actions should improve the data efficiency of RL algorithms. Let denote a sequence of previously experienced states and actions sampled from a replay buffer, where is the maximum number of steps into the future which we want to predict. Our method has four main components which we describe below:

  • [leftmargin=*]

  • Online and Target networks: We use an online encoder to transform observed states into representations . We use these representations in an objective that encourages them to be predictive of future observations up to some fixed temporal offset , given a sequence of actions to perform. We augment each observation independently when using data augmentation. Rather than predicting representations produced by the online encoder, we follow prior work (Tarvainen and Valpola, 2017; Grill et al., 2020) by computing target representations for future states using a momentum encoder , whose parameters are an exponential moving average (EMA) of the online encoder parameters. Denoting the parameters of as , those of as , and the momentum coefficient as , the update rule for is:


    Note that this means the momentum encoder is not updated via gradient descent.

  • Transition Model: For the prediction objective, we generate a sequence of predictions of future state representations using an action-conditioned transition model . We compute iteratively: , starting from . We compute by applying the momentum encoder to the observed future states : . The transition model and prediction loss operate in the latent space, thus avoiding pixel-based reconstruction objectives. We describe the architecture of in section 2.3.

  • Projection Heads: We use online and momentum projection heads and (Chen et al., 2020a) to project online and target representations to a smaller latent space, and apply an additional prediction head (Grill et al., 2020) to the online projections to predict the target projections:


    The momentum projection head parameters are given by an EMA of the online projection head parameters, using the same update as the online and momentum encoders.

  • Prediction Loss

    : We compute the future prediction loss for MPR by summing over cosine similarities

    111Cosine similarity is proportional to the “normalized L2" loss used in BYOL (Grill et al., 2020). between the predicted and observed representations at timesteps for :


    where and are computed from as we just described.

We call our method Momentum Predictive Representations (MPR), following the predictive nature of the objective and the use of an exponential moving average target network similar to (Tarvainen and Valpola, 2017; He et al., 2019). During training, we combine the MPR loss with the Q-learning loss for Rainbow. The MPR loss affects , , and . The Q-learning loss affects and the Q-learning head, which contains additional layers specific to Rainbow. Denoting the Q-learning loss from Rainbow as , our full optimization objective is: .

Compared to prior work (Kostrikov et al., 2020; Laskin et al., 2020), our method can leverage data augmentations more effectively by encouraging consistency between representations of different augmented views. We empirically verify this via a controlled comparison to DrQ (see section 5). It should be noted that MPR can still be used in contexts where data augmentation is unavailable or counterproductive. Compared to related work on contrastive representation learning, MPR does not use negative samples, which may require careful design of contrastive tasks, large batch sizes (Chen et al., 2020a), or the use of a buffer to emulate large batch sizes (He et al., 2019).

Denote parameters of online encoder and projection as Denote parameters of target encoder and projection as Denote parameters of transition model , predictor and Q-learning head as Denote the maximum prediction depth as initialize replay buffer while Training do
       collect experience with and add to buffer sample minibatch if augmentation then
       end if
         // online representations
       for k in (1, …, K) do
               // latent states via transition model
               // target representations
             ,   // projections
               // MPR loss at step
       end for
         // Add RL loss for batch with
         // update online parameters
         // update target parameters
end while
Algorithm 1 Momentum Predictive Representations

2.3 Transition Model Architecture

For the transition model , we apply a convolutional network directly to the spatial output of the convolutional encoder . The network comprises two 64-channel convolutional layers with

filters, with batch normalization

(Ioffe and Szegedy, 2015)

after the first convolution and ReLU nonlinearities after each convolution. We append a one-hot vector representing the action taken to each location in the input to the first convolutional layer, similar to

Schrittwieser et al. (2019). We use a maximum prediction depth of , and we truncate calculation of the MPR loss at episode boundaries to avoid encoding environment reset dynamics into the model.

2.4 Data Augmentation

When using augmentation, we use the same set of image augmentations as in DrQ from Kostrikov et al. (2020), consisting of small random shifts and color jitter. We found it important to normalize activations to lie in at the output of the convolutional encoder and transition model when using augmentation, as in Schrittwieser et al. (2019). We use Kornia (Riba et al., 2020) for efficient GPU-based data augmentations.

When not using augmentation, we find that MPR performs better when dropout with probability

is applied at each layer in the online and momentum encoders. This is consistent with Laine and Aila (2016); Tarvainen and Valpola (2017), who find that adding noise inside the network is important when not using image-specific augmentation, as proposed by Bachman et al. (2014). We found that applying dropout in this way was not helpful when using image-specific augmentation.

2.5 Implementation Details

For our Atari experiments, we largely follow van Hasselt et al. (2019)

for DQN hyperparameters, with four exceptions. We follow DrQ

(Kostrikov et al., 2020) by: using the 3-layer convolutional encoder from Mnih et al. (2015), using 10-step returns instead of 20-step returns for Q-learning, and not using a separate DQN target network when using augmentation222Note that this makes Double DQN updates (Van Hasselt et al., 2016) identical to standard DQN updates.. We also perform two gradient steps per environment step instead of one. We show results for this configuration with and without augmentation in Table 4, and confirm that these changes are not themselves responsible for our performance. We reuse the first layer of the DQN MLP head as the MPR projection head . When using dueling DQN (Wang et al., 2016), concatenates the outputs of the first layers of the value and advantage heads. When these layers are noisy (Fortunato et al., 2018), does not use the noisy parameters. Finally, we parameterize the predictor as a linear layer. We list the complete hyperparameters in Table 3. For , we use based on early experiments.

Our implementation is based on rlpyt (Stooke and Abbeel, 2019) and PyTorch (Paszke et al., 2019).

3 Results

We test MPR on the sample-efficient Atari setting introduced by Kaiser et al. (2019) and van Hasselt et al. (2019). In this task, only 100,000 environment steps of training data are available – equivalent to 400,000 frames, or just under two hours – compared to the typical standard of 50,000,000 environment steps, or roughly 39 days of experience. When used without image augmentation, MPR demonstrates scores comparable to the previous best result from Kostrikov et al. (2020). When combined with image augmentation, MPR achieves a median human-normalized score of 0.444, which is a new state-of-the-art result on this task. MPR achieves super-human performance on six games in this data-limited setting: Boxing, Krull, Kangaroo, Road Runner, James Bond and Crazy Climber, compared to a maximum of two for any previous methods, and achieves scores higher than DrQ (the previous state-of-the-art method) on 21 out of 26 games. See Table 1 for a full list of scores. For consistency with previous works, we report human and random scores from Wang et al. (2016).

Game Random Human SimPLe DER OTRainbow CURL DrQ MPR (no Aug) MPR
Alien 227.8 7127.7 616.9 739.9 824.7 558.2 771.2 801.9 919.6
Amidar 5.8 1719.5 88.0 188.6 82.8 142.1 102.8 177.3 159.6
Assault 222.4 742.0 527.2 431.2 351.9 600.6 452.4 661.1 699.5
Asterix 210.0 8503.3 1128.3 470.8 628.5 734.5 603.5 619.2 983.5
Bank Heist 14.2 753.1 34.2 51.0 182.1 131.6 168.9 313.1 370.1
Battle Zone 2360.0 37187.5 5184.4 10124.6 4060.6 14870.0 12954.0 11510.0 14472.0
Boxing 0.1 12.1 9.1 0.2 2.5 1.2 6.0 16.1 30.5
Breakout 1.7 30.5 16.4 1.9 9.8 4.9 16.1 14.2 15.6
Chopper Command 811.0 7387.8 1246.9 861.8 1033.3 1058.5 780.3 739.8 1130.0
Crazy Climber 10780.5 35829.4 62583.6 16185.3 21327.8 12146.5 20516.5 73294.8 36659.8
Demon Attack 152.1 1971.0 208.1 508.0 711.8 817.6 1113.4 464.9 636.4
Freeway 0.0 29.6 20.3 27.9 25.0 26.7 9.8 23.1 24.6
Frostbite 65.2 4334.7 254.7 866.8 231.6 1181.3 331.1 752.8 1811.0
Gopher 257.6 2412.5 771.0 349.5 778.0 669.3 636.3 341.5 593.4
Hero 1027.0 30826.4 2656.6 6857.0 6458.8 6279.3 3736.3 4235.4 5602.8
Jamesbond 29.0 302.8 125.3 301.6 112.3 471.0 236.0 371.5 378.7
Kangaroo 52.0 3035.0 323.1 779.3 605.4 872.5 940.6 1452.0 3876.0
Krull 1598.0 2665.5 4539.9 2851.5 3277.9 4229.6 4018.1 3067.6 3810.3
Kung Fu Master 258.5 22736.3 17257.2 14346.1 5722.2 14307.8 9111.0 14535.2 14135.8
Ms Pacman 307.3 6951.6 1480.0 1204.1 941.9 1465.5 960.5 1339.8 1205.3
Pong -20.7 14.6 12.8 -19.3 1.3 -16.5 -8.5 -10.5 -3.8
Private Eye 24.9 69571.3 58.3 97.8 100.0 218.4 -13.6 40.0 20.2
Qbert 163.9 13455.0 1288.8 1152.9 509.3 1042.4 854.4 594.2 791.8
Road Runner 11.5 7845.0 5640.6 9600.0 2696.7 5661.0 8895.1 10805.2 13062.4
Seaquest 68.4 42054.7 683.3 354.1 286.9 384.5 301.2 361.8 603.8
Up N Down 533.4 11693.2 3350.3 2877.4 2847.6 2955.2 3180.8 5167.9 7307.8
Median Human-Norm’d 0.000 1.000 0.144 0.161 0.204 0.175 0.268 0.276 0.444
Mean Human-Norm’d 0.000 1.000 0.443 0.285 0.264 0.381 0.357 0.510 0.621
Median DQN-Norm’d 0.000 0.994 0.118 0.142 0.103 0.142 0.131 0.205 0.324
Mean DQN-Norm’d 0.000 23.382 0.232 0.239 0.197 0.325 0.171 0.290 0.391
# Superhuman 0 N/A 2 2 1 2 2 5 6
Table 1: Mean episodic returns on the 26 Atari games considered by Kaiser et al. (2019) after 100k environment steps. The results are recorded at the end of training and averaged over 5 random seeds. MPR outperforms prior methods on all aggregate metrics, and exceeds expert human performance on 6 out of 26 games while using a similar amount of experience.

3.1 Evaluation

We evaluate the performance of different methods by computing the average episodic return at the end of training. It is common to normalize scores with respect to expert human scores to account for different scales of scores in each game. The human-normalized performance of an agent on a game is calculated as

and then aggregated across the 26 games by taking their mean or median. It is common to report the median human-normalized performance, as the median is less susceptible to outliers. However, we find that in some games human scores are so high that differences between methods are washed out when normalizing scores. This makes it difficult for scores in these games, such as Alien, Asterix, and Seaquest, to influence aggregate metrics. To account for this, we also also report scores normalized by a DQN

(Mnih et al., 2015) agent trained over 50 million steps, using DQN scores reported in Wang et al. (2016).

Additionally, we note that the standard evaluation protocol of evaluating over only 500,000 frames per game is problematic, as the quantity we are trying to measure is expected return over episodes. Due to the very long lengths of some episodes (up to 108,000 frames), this method may collect as few as four complete episodes. Moreover, better policies tend to have longer episodes on many games, leading stronger algorithms to experience even greater variance in this estimate of expected episodic returns. As variance is already a concern in deep RL

(see Henderson et al., 2018), we propose evaluating over 100 episodes irrespective of their length, but we report results using the standard metric.

4 Related Work

4.1 Data-Efficient RL:

A number of works have sought to improve sample efficiency in deep RL. SiMPLe (Kaiser et al., 2019) learns an explicit pixel-level transition model for Atari to generate simulated training data, achieving strong results on several games in the 100k frame setting. However, both van Hasselt et al. (2019) and Kielak (2020) demonstrate that variants of Rainbow (Hessel et al., 2018) tuned for sample efficiency can achieve comparable or superior performance.

In the context of continuous control, several works propose to leverage a latent-space model trained on reconstruction loss to improve sample efficiency (Hafner et al., 2019; Lee et al., 2019; Hafner et al., 2020). Most recently, DrQ (Kostrikov et al., 2020) and RAD (Laskin et al., 2020) have found that applying modest image augmentation can substantially improve sample efficiency in reinforcement learning, yielding better results than prior model-based methods. Data augmentation has also been found to improve generalization of reinforcement learning methods (Combes et al., 2018; Laskin et al., 2020) in multi-task and transfer settings. We show that data augmentation can be more effectively leveraged in reinforcement learning by forcing representations to be consistent between different augmented views of an observation while also predicting future latent states.

4.2 Representation Learning in RL:

Representation learning has a long history of use in RL – see Lesort et al. (2018). For example, CURL (Srinivas et al., 2020) recently proposed a combination of image augmentation and a contrastive loss to perform representation learning for RL. However, follow-up results from RAD (Laskin et al., 2020) suggest that most of the benefits of CURL come from its use of image augmentation rather than its contrastive loss.

CPC (Oord et al., 2018), CPC|Action (Guo et al., 2018), ST-DIM (Anand et al., 2019) and DRIML (Mazoure et al., 2020) propose to optimize various temporal contrastive losses in reinforcement learning environments. We perform an ablation comparing such temporal contrastive losses to our method in section 5. Kipf et al. (2019) propose to learn object-oriented contrastive representations by training a structured transition model based on a graph neural network.

MPR bears some resemblance to Deep MDP (Gelada et al., 2019), which trains a transition model with an unnormalized L2 loss to predict representations of future states, along with a reward prediction objective. However, DeepMDP uses its online encoder for prediction targets as well rather than employing a target encoder, and is thus prone to representational collapse (sec. C.5 in Gelada et al. (2019)). To mitigate this issue, DeepMDP relies on an additional observation reconstruction objective. In contrast, our model is self-supervised, trained entirely in the latent space, and uses a normalized loss. Our ablations (see section 5) demonstrate that using a momentum target encoder has a large impact on our method, making it another key difference between MPR and DeepMDP.

MPR is also similar to PBL (Guo et al., 2020), which also directly predicts representations of future states. However, PBL uses a separate target network trained via gradient descent, whereas MPR uses a momentum target encoder, which we find to be vital. Moreover, PBL studies multi-task generalization in the asymptotic limits of data, whereas MPR is concerned with single-task performance in low data regimes, using as much data as PBL. Unlike PBL, MPR additionally enforces consistency across augmentations, which empirically provides a large boost in performance.

5 Discussion

We now present several ablation studies to measure the contribution of components in our method.

The importance of the momentum encoder

To illustrate the importance of using a momentum target encoder, we test two variants of our method that omit it. In each experiment, we use the online network to encode both the inputs and targets. In the first ablation, we allow gradients to flow from the target representations into the online encoder, and in the second we place a stop-gradient operation on the target representations to stop the encoder from learning to make the representations more predictable. We find that both ablations degrade performance, the former achieving a median human-normalized score of and the latter , compared to our full method’s score of . These results are consistent with the findings of Tarvainen and Valpola (2017), who observe that a momentum target encoder produces a more effective learning signal than the online encoder.

Making better use of data augmentation

We examine an ablated version of MPR in which all temporal elements are removed. Similarly to Mean Teachers (Tarvainen and Valpola, 2017) and BYOL (Grill et al., 2020), we provide the online and momentum encoders with different augmented views of each observation, and calculate the MPR loss between their representations of these views. We find that this variant alone outperforms DrQ or CURL (see Table 2, line 3), suggesting MPR makes better use of data augmentation than prior methods in reinforcement learning.

Dynamics modeling is key

A key distinction between MPR and other recent approaches leveraging representation learning for reinforcement learning, such as CURL (Srinivas et al., 2020) and DRIML (Mazoure et al., 2020), is our use of an explicit multi-step dynamics model. We test two ablated versions of MPR, one with no dynamics modeling and one that models only a single step of dynamics. Each of these variants has degraded performance compared to five-step MPR, with extended dynamics modeling consistently improving performance (see Table 2).

Comparison with contrastive losses

Although many recent works in representation learning have employed contrastive learning, we find that MPR consistently outperforms both temporal and non-temporal variants of contrastive losses (see Table 5), including CURL (Srinivas et al., 2020).

Method Augmentation MPR Loss Transition Model Steps (K) Median HNS
MPR (no model) 0
MPR (no aug.) 5
DER (controlled) N/A
Table 2: Median human-normalized scores for variants of MPR and previous methods.

6 Future Work

Recent work in both visual (Chen et al., 2020b) and language representation learning (Brown et al., 2020) has suggested that self-supervised models trained on large datasets perform exceedingly well on downstream problems with limited data, often outperforming methods trained using only task-specific data. Future works could similarly exploit large corpora of unlabelled data, perhaps from multiple MDPs or raw videos, to further improve the performance of RL methods in low-data regimes. As the MPR objective is unsupervised, it could be directly applied in such settings.

Another interesting direction is to use the transition model learned by MPR for planning. MuZero (Schrittwieser et al., 2019) has demonstrated that planning with a model supervised via reward and value prediction can work extremely well given sufficient (massive) amounts of data. It remains unclear whether such models can work well in low-data regimes, and whether augmenting such models with self-supervised objectives such as MPR can improve their data efficiency.

7 Conclusion

In this paper we introduced Momentum Predictive Representations (MPR), a self-supervised representation learning algorithm designed to improve the data efficiency of deep reinforcement learning agents. MPR learns representations that are both temporally predictive and consistent across different views of environment observations, directly predicting representations of future states produced by a momentum target encoder. MPR achieves state-of-the-art performance on the 100k steps Atari benchmark, demonstrating significant improvements over prior work. Our experiments show that MPR is highly robust, and is able to outperform the previous state of the art when either data augmentation or temporal prediction is disabled. We identify important directions for future work, and hope continued research at the intersection of self-supervised learning and reinforcement learning leads to algorithms which rival the efficiency and robustness of humans.


We are grateful for the collaborative research environment provided by Mila and Microsoft Research. We would also like to acknowledge Hitachi for providing funding support for this project. We thank Nitarshan Rajkumar and Evan Racah for providing feeback on an earlier draft; Denis Yarats and Aravind Srinivas for answering questions about DrQ and CURL; Michal Valko and Sherjil Ozair for discussions about BYOL; and Phong Nguyen for helpful discussions. Finally, we thank Compute Canada and Microsoft Research for providing computational resources used in this project.


  • A. Anand, E. Racah, S. Ozair, Y. Bengio, M. Côté, and R. D. Hjelm (2019) Unsupervised state representation learning in atari. In NeurIPS, Cited by: §4.2.
  • P. Bachman, O. al Sharif, and D. Precup (2014) Learning with pseudo-ensembles. Advances in Neural Information Processing Systems (NIPS). Cited by: §2.4.
  • P. Bachman, R. D. Hjelm, and W. Buchwalter (2019) Learning representations by maximizing mutual information across views. In NeurIPS, Cited by: §1.
  • A. P. Badia, B. Piot, S. Kapturowski, P. Sprechmann, A. Vitvitskyi, D. Guo, and C. Blundell (2020) Agent57: outperforming the atari human benchmark. arXiv preprint arXiv:2003.13350. Cited by: §1.
  • M. G. Bellemare, W. Dabney, and R. Munos (2017) A distributional perspective on reinforcement learning. ICML. Cited by: §2.1.
  • M. G. Bellemare, Y. Naddaf, J. Veness, and M. Bowling (2013) The arcade learning environment: an evaluation platform for general agents.

    Journal of Artificial Intelligence Research

    Cited by: §1, §2.1.
  • T. B. Brown, B. Mann, N. Ryder, M. Subbiah, J. Kaplan, P. Dhariwal, A. Neelakantan, P. Shyam, G. Sastry, A. Askell, et al. (2020) Language models are few-shot learners. arXiv preprint arXiv:2005.14165. Cited by: §6.
  • T. Chen, S. Kornblith, M. Norouzi, and G. Hinton (2020a) A simple framework for contrastive learning of visual representations. ICML. Cited by: Appendix B, §1, 3rd item, §2.2.
  • T. Chen, S. Kornblith, K. Swersky, M. Norouzi, and G. Hinton (2020b) Big self-supervised models are strong semi-supervised learners. arXiv preprint arXiv:2006.10029. Cited by: §1, §6.
  • R. T. d. Combes, P. Bachman, and H. van Seijen (2018) Learning invariances for policy generalization. arXiv preprint arXiv:1809.02591. Cited by: §4.1.
  • C. Doersch, A. Gupta, and A. A. Efros (2015) Unsupervised visual representation learning by context prediction. In ICCV, Cited by: §1.
  • G. Dulac-Arnold, D. Mankowitz, and T. Hester (2019) Challenges of real-world reinforcement learning. arXiv preprint arXiv:1904.12901. Cited by: §1.
  • M. Fortunato, M. G. Azar, B. Piot, J. Menick, M. Hessel, I. Osband, A. Graves, V. Mnih, R. Munos, D. Hassabis, O. Pietquin, C. Blundell, and S. Legg (2018) Noisy networks for exploration. In ICLR, External Links: Link Cited by: §2.5.
  • V. François-Lavet, P. Henderson, R. Islam, M. G. Bellemare, and J. Pineau (2018) An introduction to deep reinforcement learning. arXiv preprint arXiv:1811.12560. Cited by: §1.
  • C. Gelada, S. Kumar, J. Buckman, O. Nachum, and M. G. Bellemare (2019) Deepmdp: learning continuous latent space models for representation learning. ICML. Cited by: §4.2.
  • J. Grill, F. Strub, F. Altché, C. Tallec, P. H. Richemond, E. Buchatskaya, C. Doersch, B. A. Pires, Z. D. Guo, M. G. Azar, et al. (2020) Bootstrap your own latent: a new approach to self-supervised learning. arXiv preprint arXiv:2006.07733. Cited by: §1, 1st item, 3rd item, §5, footnote 1.
  • D. Guo, B. A. Pires, B. Piot, J. Grill, F. Altché, R. Munos, and M. G. Azar (2020) Bootstrap latent-predictive representations for multitask reinforcement learning. arXiv preprint arXiv:2004.14646. Cited by: §4.2.
  • Z. D. Guo, M. G. Azar, B. Piot, B. A. Pires, and R. Munos (2018) Neural predictive belief representations. ICML. Cited by: 3rd item, §4.2.
  • D. Hafner, T. Lillicrap, J. Ba, and M. Norouzi (2020) Dream to control: learning behaviors by latent imagination. ICLR. Cited by: §4.1.
  • D. Hafner, T. Lillicrap, I. Fischer, R. Villegas, D. Ha, H. Lee, and J. Davidson (2019) Learning latent dynamics for planning from pixels. In ICML, Cited by: §1, §4.1.
  • K. He, H. Fan, Y. Wu, S. Xie, and R. Girshick (2019) Momentum contrast for unsupervised visual representation learning. arXiv preprint arXiv:1911.05722. Cited by: §1, §2.2, §2.2.
  • O. J. Hénaff, A. Srinivas, J. De Fauw, A. Razavi, C. Doersch, S. Eslami, and A. v. d. Oord (2019) Data-efficient image recognition with contrastive predictive coding. arXiv preprint arXiv:1905.09272. Cited by: §1.
  • P. Henderson, R. Islam, P. Bachman, J. Pineau, D. Precup, and D. Meger (2018) Deep reinforcement learning that matters. In Thirty-Second AAAI Conference on Artificial Intelligence, Cited by: §3.1.
  • M. Hessel, J. Modayil, H. van Hasselt, T. Schaul, G. Ostrovski, W. Dabney, D. Horgan, B. Piot, M. G. Azar, and D. Silver (2018) Rainbow: combining improvements in deep reinforcement learning. In AAAI, Cited by: §2.1, §2, §4.1.
  • R. D. Hjelm, A. Fedorov, S. Lavoie-Marchildon, K. Grewal, P. Bachman, A. Trischler, and Y. Bengio (2019) Learning deep representations by mutual information estimation and maximization. ICLR. Cited by: §1.
  • S. Ioffe and C. Szegedy (2015) Batch normalization: accelerating deep network training by reducing internal covariate shift. In ICML, Cited by: §2.3.
  • Ł. Kaiser, M. Babaeizadeh, P. Miłos, B. Osiński, R. H. Campbell, K. Czechowski, D. Erhan, C. Finn, P. Kozakowski, S. Levine, et al. (2019) Model based reinforcement learning for atari. In ICLR, Cited by: Figure 1, §1, §1, Table 1, §3, §4.1.
  • K. P. Kielak (2020) Do recent advancements in model-based deep reinforcement learning really improve data efficiency?. External Links: Link Cited by: §4.1.
  • T. Kipf, E. van der Pol, and M. Welling (2019) Contrastive learning of structured world models. arXiv preprint arXiv:1911.12247. Cited by: §4.2.
  • I. Kostrikov, D. Yarats, and R. Fergus (2020) Image augmentation is all you need: regularizing deep reinforcement learning from pixels. arXiv preprint arXiv:2004.13649. Cited by: item *, §2.2, §2.4, §2.5, §3, §4.1.
  • S. Laine and T. Aila (2016) Temporal ensembling for semi-supervised learning. arXiv preprint arXiv:1610.02242. Cited by: §2.4.
  • M. Laskin, K. Lee, A. Stooke, L. Pinto, P. Abbeel, and A. Srinivas (2020) Reinforcement learning with augmented data. arXiv preprint arXiv:2004.14990. Cited by: §2.2, §4.1, §4.2.
  • A. X. Lee, A. Nagabandi, P. Abbeel, and S. Levine (2019) Stochastic latent actor-critic: deep reinforcement learning with a latent variable model. arXiv preprint arXiv:1907.00953. Cited by: §4.1.
  • T. Lesort, N. Díaz-Rodríguez, J. Goudou, and D. Filliat (2018) State representation learning for control: an overview. Neural Networks 108. Cited by: §4.2.
  • B. Mazoure, R. T. d. Combes, T. Doan, P. Bachman, and R. D. Hjelm (2020) Deep reinforcement and infomax learning. arXiv preprint arXiv:2006.07217. Cited by: §4.2, §5.
  • V. Mnih, K. Kavukcuoglu, D. Silver, A. A. Rusu, J. Veness, M. G. Bellemare, A. Graves, M. Riedmiller, A. K. Fidjeland, G. Ostrovski, et al. (2015) Human-level control through deep reinforcement learning. Nature 518 (7540). Cited by: §1, §2.1, §2.5, §3.1.
  • A. v. d. Oord, Y. Li, and O. Vinyals (2018) Representation learning with contrastive predictive coding. arXiv preprint arXiv:1807.03748. Cited by: 2nd item, Appendix B, §1, §4.2.
  • OpenAI, C. Berner, G. Brockman, B. Chan, V. Cheung, P. Dębiak, C. Dennison, D. Farhi, Q. Fischer, S. Hashme, C. Hesse, R. Józefowicz, S. Gray, C. Olsson, J. Pachocki, M. Petrov, H. P. de Oliveira Pinto, J. Raiman, T. Salimans, J. Schlatter, J. Schneider, S. Sidor, I. Sutskever, J. Tang, F. Wolski, and S. Zhang (2019) Dota 2 with large scale deep reinforcement learning. arXiv preprint arXiv:1912.06680. External Links: Link Cited by: §1, §1.
  • A. Paszke, S. Gross, F. Massa, A. Lerer, J. Bradbury, G. Chanan, T. Killeen, Z. Lin, N. Gimelshein, L. Antiga, et al. (2019)

    Pytorch: an imperative style, high-performance deep learning library

    In NeurIPS, Cited by: §2.5.
  • E. Riba, D. Mishkin, D. Ponsa, E. Rublee, and G. Bradski (2020)

    Kornia: an open source differentiable computer vision library for pytorch

    In The IEEE Winter Conference on Applications of Computer Vision, Cited by: §2.4.
  • J. Schrittwieser, I. Antonoglou, T. Hubert, K. Simonyan, L. Sifre, S. Schmitt, A. Guez, E. Lockhart, D. Hassabis, T. Graepel, et al. (2019) Mastering atari, go, chess and shogi by planning with a learned model. arXiv preprint arXiv:1911.08265. Cited by: §1, §2.3, §2.4, §6.
  • A. Srinivas, M. Laskin, and P. Abbeel (2020) Curl: contrastive unsupervised representations for reinforcement learning. arXiv preprint arXiv:2004.04136. Cited by: 1st item, §4.2, §5, §5.
  • A. Stooke and P. Abbeel (2019) Rlpyt: a research code base for deep reinforcement learning in pytorch. arXiv preprint arXiv:1909.01500. Cited by: §2.5.
  • R. Sutton (2019) The bitter lesson. Incomplete Ideas (blog). Note: Cited by: §1.
  • A. Tarvainen and H. Valpola (2017) Mean teachers are better role models: weight-averaged consistency targets improve semi-supervised deep learning results. In NeurIPS, Cited by: §1, 1st item, §2.2, §2.4, §5, §5.
  • Y. Tassa, Y. Doron, A. Muldal, T. Erez, Y. Li, D. d. L. Casas, D. Budden, A. Abdolmaleki, J. Merel, A. Lefrancq, et al. (2018) Deepmind control suite. arXiv preprint arXiv:1801.00690. Cited by: §1.
  • Y. Tian, D. Krishnan, and P. Isola (2019) Contrastive multiview coding. arXiv preprint arXiv:1906.05849. Cited by: §1.
  • H. Van Hasselt, A. Guez, and D. Silver (2016) Deep reinforcement learning with double q-learning. In Thirtieth AAAI conference on artificial intelligence, Cited by: §1, §2.1, footnote 2.
  • H. P. van Hasselt, M. Hessel, and J. Aslanides (2019)

    When to use parametric models in reinforcement learning?

    In NeurIPS, Cited by: Appendix B, §1, §2.1, §2.5, §3, §4.1.
  • O. Vinyals, I. Babuschkin, W. M. Czarnecki, M. Mathieu, A. Dudzik, J. Chung, D. H. Choi, R. Powell, T. Ewalds, P. Georgiev, et al. (2019) Grandmaster level in starcraft ii using multi-agent reinforcement learning. Nature. Cited by: §1.
  • Z. Wang, T. Schaul, M. Hessel, H. Hasselt, M. Lanctot, and N. Freitas (2016) Dueling network architectures for deep reinforcement learning. In ICML, Cited by: §2.1, §2.5, §3.1, §3.
  • Q. Xie, Z. Dai, E. Hovy, M. Luong, and Q. V. Le (2019) Unsupervised data augmentation for consistency training. arXiv preprint arXiv:1904.12848. Cited by: §1, §1.

Appendix A Atari Details

Parameter Setting (for both variations)
Gray-scaling True
Observation down-sampling 84x84
Frames stacked 4
Action repetitions 4
Reward clipping [-1, 1]
Terminal on loss of life True
Max frames per episode 108K
Update Distributional Q
Dueling True
Support of Q-distribution 51
Discount factor 0.99
Minibatch size 32
Optimizer Adam
Optimizer: learning rate 0.0001
Optimizer: 0.9
Optimizer: 0.999
Optimizer: 0.00015
Max gradient norm 10
Priority exponent 0.5
Priority correction 0.4 1
Exploration Noisy nets
Noisy nets parameter 0.1
Training steps 100K
Evaluation trajectories 100
Min replay size for sampling 2000
Replay period every 1 step
Updates per step 2
Multi-step return length 10
Q network: channels 32, 64, 64
Q network: filter size , ,

Q network: stride

4, 2, 1
Q network: hidden units 256
Non-linearity ReLU
(MPR loss coefficient 2
(Momentum coefficient) 0.99
(Prediction Depth) 5
Parameter With Augmentation Without Augmentation
Data Augmentation Random shifts ( pixels) & None
Target network: update period 1* 1000
Dropout 0 0.5
  • Similar to Kostrikov et al. (2020), we found target networks to not be necessary when using data augmentation.

Table 3: Hyperparameters for MPR with and without augmentation

a.1 Controlled baselines

To ensure that the minor hyper-parameter changes we make to the DER baseline is not solely responsible for our improved performance, we perform controlled experiments using the same hyper-parameters and same random seeds for baselines. We find that our DQN without augmentation is slightly stronger than data-efficient Rainbow, while with augmentation enabled our results are comparable to those of DrQ. None of these methods, however are close to the performance of MPR.

Variant Human-Normalized Score DQN-Normalized Score
median mean median mean
Base DQN 0.191 0.275 0.166 0.307
DER 0.161 0.285 0.142 0.239
Base w/ augmentation 0.268 0.448 0.220 0.273
DrQ 0.268 0.357 0.131 0.171
Table 4: Scores on the 26 Atari games under consideration for our base DQN without MPR.

Appendix B Comparison with a contrastive loss

To compare MPR with alternative methods drawn from contrastive learning, we examine several variants of a contrastive losses based on InfoNCE (Oord et al., 2018):

  • A contrastive loss based solely on different views of the same state, similar to CURL (Srinivas et al., 2020).

  • A temporal contrastive loss with both augmentation and where targets are drawn one step in the future, equivalent to single-step CPC (Oord et al., 2018).

  • A temporal contrastive loss with an explicit dynamics model, similar to CPC|Action (Guo et al., 2018).

In each case use a momentum target encoder with the same update constant as in MPR, 0.01. To create as fair a comparison as possible, we use the same augmentation (random shifts and intensity) and the same DQN hyperparameters as in MPR. As in MPR, we calculate contrastive losses using the output of the first layer of the Q-head MLP, with a bilinear classifier

(as in Oord et al., 2018). Following Chen et al. (2020a) we normalize representations in the contrastive loss, with a temperature of 0.1. We present results in Table 5.

Although all of these variants outperform the previous contrastive result on this task, CURL, none of them substantially improve performance over the DQN they use as a baseline; the two temporal InfoNCE variants each boost performance on 15/26 games, while the non-temporal version actually harms performance on 14/26 games. We consider these results broadly consistent with those of CURL, which observes a relatively small performance boost over their baseline, Data-Efficient Rainbow (van Hasselt et al., 2019).

Method Augmentation InfoNCE Temporal Transition Model Median HNS
InfoNCE w/ dynamics model
1-step temporal InfoNCE
Non-temporal InfoNCE
Base DQN w/ augmentation
DER (controlled)
Table 5: Scores for four contrastive approaches comparable to MPR across the 26 Atari games under consideration. Results averaged over five random seeds per game