Developing agents that can perform complex control tasks from high dimensional observations such as pixels has been possible by combining the expressive power of deep neural networks with the long-term credit assignment power of reinforcement learning algorithms. Notable successes include learning to play a diverse set of video games from raw pixels(Mnih et al., 2015), continuous control tasks such as controlling a simulated car from a dashboard camera (Lillicrap et al., 2015) and subsequent algorithmic developments and applications to agents that successfully navigate mazes and solve complex tasks from first-person camera observations (Jaderberg et al., 2016; Espeholt et al., 2018; Jaderberg et al., 2019); and robots that successfully grasp objects in the real world (Kalashnikov et al., 2018).
However, it has been empirically observed that reinforcement learning from high dimensional observations such as raw pixels is sample-inefficient (Lake et al., 2017; Kaiser et al., 2019). Moreover, it is widely accepted that learning policies from physical state based features is significantly more sample-efficient than learning from pixels (Tassa et al., 2018). However, in principle, if the state information is present in the pixel data, then we should be able to learn representations that extract the relevant state information. For this reason, it may well be possible to learn from pixels as fast as from state given the right representation.
From a practical standpoint, although high rendering speeds in simulated environments enable RL agents to solve complex tasks within reasonable wall clock time, learning in the real world means that agents are bound to work within the limitations of physics. Kalashnikov et al. (2018) needed an arm farm of robots that collected large scale robot interaction data over several months to develop their robot grasp value functions and policies. The data-efficiency of the whole pipeline thus has significant room for improvement. Similarly, in simulated worlds which are limited by rendering speeds in the absence of GPU accelerators, data efficiency is extremely crucial to have a fast experimental turnover and iteration. Therefore, improving the sample efficiency of reinforcement learning (RL) methods that operate from high dimensional observations is of paramount importance to RL research both in simulation and the real world and allows for faster progress towards the broader goal of developing intelligent autonomous agents.
A number of approaches have been proposed in the literature to address the sample inefficiency of deep RL algorithms. Broadly, they can be classified into two streams of research, though not mutually exclusive: (i) Auxiliary tasks on the agent’s sensory observations; (ii) World models that predict the future. While the former class of methods use auxiliary self-supervision tasks to accelerate the learning progress of model-free RL methods(Jaderberg et al., 2016; Mirowski et al., 2016), the latter class of methods build explicit predictive models of the world and use those models to plan through or collect fictitious rollouts for model-free methods to learn from (Sutton, 1990; Ha and Schmidhuber, 2018; Kaiser et al., 2019; Schrittwieser et al., 2019).
Our work falls into the first class of models, which use auxiliary tasks to improve sample efficiency. Our hypothesis is simple: If an agent learns a useful semantic representation from high dimensional observations, control algorithms built on top of those representations should be significantly more data-efficient. Self-supervised representation learning has seen dramatic progress in the last couple of years with huge advances in masked language modeling (Devlin et al., 2018) and contrastive learning (Hénaff et al., 2019; He et al., 2019a; Chen et al., 2020)
for language and vision respectively. The representations uncovered by these objectives improve the performance of any supervised learning system especially in scenarios where the amount of labeled data available for the downstream task is really low.
We take inspiration from the contrastive pre-training successes in computer vision. However, there are a couple of key differences: (i) There is no giant unlabeled dataset of millions of images available beforehand - the dataset is collected online from the agent’s interactions and changes dynamically with the agent’s experience; (ii) The agent has to perform unsupervised and reinforcement learning simultaneously as opposed to fine-tuning a pre-trained network for a specific downstream task. These two differences introduce a different challenge: How can we use contrastive learning for improving agents that can learn to control effectively and efficiently from online interactions?
To address this challenge, we propose CURL - Contrastive Unsupervised Representations for Reinforcement Learning. CURL uses a form of contrastive learning that maximizes agreement between augmented versions of the same observation, where each observation is a stack of temporally sequential frames. We show that CURL significantly improves sample-efficiency over prior pixel-based methods by performing contrastive learning simultaneously with an off-policy RL algorithm. CURL coupled with the Soft-Actor-Critic (SAC) (Haarnoja et al., 2018) results in 2.8x mean higher performance over the prior state-of-the-art on DMControl environments benchmarked at 100k interaction steps and matches the performance of state-based SAC on the majority of 16 environments tested, a first for pixel-based methods. In the Atari setting also benchmarked at 100k interaction steps, we show that CURL coupled with a data-efficient version of Rainbow DQN (van Hasselt et al., 2019) results in 1.6x performance gain over prior methods and matches human efficiency on a few games.
While contrastive learning in aid of model-free RL has been studied in the past by van den Oord et al. (2018) using Contrastive Predictive Coding (CPC), the results were mixed with marginal gains in a few DMLab (Espeholt et al., 2018) environments. CURL is the first model to show substantial data-efficiency gains from using a contrastive self-supervised learning objective for model-free RL agents across a multitude of pixel based continuous and discrete control tasks in DMControl and Atari.
We prioritize designing a simple and easily reproducible pipeline. While the promise of auxiliary tasks and learning world models for RL agents has been demonstrated in prior work, there’s an added layer of complexity when introducing components like modeling the future in a latent space (van den Oord et al., 2018; Ha and Schmidhuber, 2018)
. CURL is designed to add minimal overhead in terms of architecture and model learning. The contrastive learning objective in CURL operates with the same latent space and architecture typically used for model-free RL and seamlessly integrates with the training pipeline without the need to introduce multiple additional hyperparameters.
Our paper makes the following key contributions: We present CURL, a simple framework that integrates contrastive learning with model-free RL with minimal changes to the architecture and training pipeline. Using 16 complex control tasks from the DeepMind control (DMControl) suite and 26 Atari games, we empirically show that contrastive learning combined with model-free RL outperforms the prior state-of-the-art by 2.8x on DMControl and 1.6x on Atari compared across leading prior pixel-based methods. CURL is also the first algorithm across both model-based and model-free methods that operates purely from pixels, and nearly matches the performance and sample-efficiency of a SAC algorithm trained from the state based features on the DMControl suite. Finally, our design is simple and does not require any custom architectural choices or hyperparameters which is crucial for reproducible end-to-end training. Through these strong empirical results, we demonstrate that a contrastive objective is the preferred self-supervised auxiliary task for achieving sample-efficiency compared to reconstruction based methods, and enables model-free methods to outperform state-of-the-art model-based methods in terms of data-efficiency.
2 Related Work
Self-Supervised Learning is aimed at learning rich representations of high dimensional unlabeled data to be useful for a wide variety of tasks. The fields of natural language processing and computer vision have seen dramatic advances in self-supervised methods such as BERT(Devlin et al., 2018) and CPC, MoCo, SimCLR (Hénaff et al., 2019; He et al., 2019a; Chen et al., 2020).
Contrastive Learning: Contrastive Learning is a framework to learn representations that obey similarity constraints in a dataset typically organized by similar and dissimilar pairs. This is often best understood as performing a dictionary lookup task wherein the positive and negatives represent a set of keys with respect to a query (or an anchor). A simple instantiation of contrastive learning is Instance Discrimination (Wu et al., 2018)
wherein a query and key are positive pairs if they are data-augmentations of the same instance (example, image) and negative otherwise. A key challenge in contrastive learning is the choice of negatives which can decide the quality of the underlying representations learned. The loss functions used to contrast could be among several choices such as InfoNCE(van den Oord et al., 2018), Triplet (Wang and Gupta, 2015), Siamese (Chopra et al., 2005) and so forth.
Self-Supervised Learning for RL: Auxiliary tasks such as predicting the future conditioned on the past observation(s) and action(s) (Jaderberg et al., 2016; Shelhamer et al., 2016; van den Oord et al., 2018), and predicting the depth image for maze navigation (Mirowski et al., 2016) are a few representative examples of using auxiliary tasks to improve the sample-efficiency of model-free RL algorithms. The future prediction is either done in a pixel space (Jaderberg et al., 2016) or latent space (van den Oord et al., 2018). The sample-efficiency gains from reconstruction-based auxiliary losses have been benchmarked in Jaderberg et al. (2016); Higgins et al. (2017); Yarats et al. (2019). Contrastive learning across has been used to extract reward signals characterized as distance metrics in the latent space by Sermanet et al. (2018) and Warde-Farley et al. (2018).
World Models for sample-efficiency: While joint learning of an auxiliary unsupervised task with model-free RL is one way to improve the sample-efficiency of agents, there has also been another line of research that has tried to learn world models of the environment and use them to sample rollouts and plan. An early instantiation of the generic principle was put forth by Sutton (1990) in Dyna where fictitious samples rolled out from a learned world model are used in addition to the agent’s experience for sample-efficient learning. Planning through a learned world model is another way to improve sample-efficiency. While Jaderberg et al. (2016); van den Oord et al. (2018); Lee et al. (2019) also learn pixel and latent space forward models, the models are learned to shape the latent representations, and there is no explicit Dyna or planning. Planning through learned world models has been successfully demonstrated in Ha and Schmidhuber (2018); Hafner et al. (2018, 2019). Kaiser et al. (2019) introduce SimPLe which implements Dyna with expressive deep neural networks for the world model and show impressive sample-efficiency on Atari games.
Sample-efficient RL for image-based control: CURL encompasses the areas of self-supervision, contrastive learning and using auxiliary tasks for sample-efficient RL. We benchmark for sample-efficiency on the DMControl suite (Tassa et al., 2018) and Atari Games benchmarks (Bellemare et al., 2013). The DMControl suite has been used widely by Yarats et al. (2019), Hafner et al. (2018), Hafner et al. (2019) and Lee et al. (2019) for benchmarking sample-efficiency for image based continuous control methods. As for Atari, Kaiser et al. (2019) propose to use the 100k interaction steps benchmark for sample-efficiency which has been adopted in Kielak (2020); van Hasselt et al. (2019). The Rainbow DQN (Hessel et al., 2017) was originally proposed for maximum sample-efficiency on the Atari benchmark and in recent times has been adapted to a version known as Data-Efficient Rainbow (van Hasselt et al., 2019) with competitive performance to SimPLe without learning world models. We benchmark extensively against both model-based and model-free algorithms in our experiments. For the DMControl experiments, we compare our method to Dreamer, PlaNet, SLAC, SAC+AE whereas for Atari experiments we compare to SimPLe, Rainbow, and OverTrained Rainbow (OTRainbow) and Efficient Rainbow (Eff. Rainbow).
CURL is a general framework for combining contrastive learning with RL. In principle, one could use any RL algorithm in the CURL pipeline, be it on-policy or off-policy. We use the widely adopted Soft Actor Critic (SAC) (Haarnoja et al., 2018) for continuous control benchmarks (DM Control) and Rainbow DQN (Hessel et al., 2017) for discrete control benchmarks (Atari). Below, we review SAC, Rainbow DQN and Contrastive Learning.
3.1 Soft Actor Critic
SAC is an off-policy RL algorithm that optimizes a stochastic policy for maximizing the expected trajectory returns. Like other state-of-the-art end-to-end RL algorithms, SAC is effective when solving tasks from state observations but fails to learn efficient policies from pixels. SAC is an actor-critic method that learns a policy and critics and . The parameters are learned by minimizing the Bellman error:
where is a tuple with observation , action , reward and done signal , is the replay buffer, and is the target, defined as:
In the target equation (2), denotes the exponential moving average (EMA) of the parameters of . Using the EMA has empirically shown to improve training stability in off-policy RL algorithms. The parameter is a positive entropy coefficient that determines the priority of the entropy maximization over value function optimization.
While the critic is given by , the actor samples actions from policy and is trained by maximizing the expected return of its actions as in:
where actions are sampled stochastically from the policy and
is a standard normalized noise vector.
Rainbow DQN (Hessel et al., 2017) is best summarized as multiple improvements on top of the original Nature DQN (Mnih et al., 2015) applied together. Specifically, Deep Q Network (DQN) (Mnih et al., 2015)
combines the off-policy algorithm Q-Learning with a convolutional neural network as the function approximator to map raw pixels to action value functions. Since then, multiple improvements have been proposed such as Double Q Learning(Van Hasselt et al., 2016), Dueling Network Architectures (Wang et al., 2015), Prioritized Experience Replay (Schaul et al., 2015), and Noisy Networks (Fortunato et al., 2017). Additionally, distributional reinforcement learning (Bellemare et al., 2017) proposed the technique of predicting a distribution over possible value function bins through the C51 Algorithm. Rainbow DQN combines all of the above techniques into a single off-policy algorithm for state-of-the-art sample efficiency on Atari benchmarks. Additionally, Rainbow also makes use of multi-step returns (Sutton and others, 1998).
3.3 Contrastive Learning
The loss 4 can be interpreted as the log-loss of a -way softmax classifier whose label is .
4 CURL Implementation
CURL minimally modifies a base RL algorithm by training the contrastive objective as an auxiliary loss during the batch update. In our experiments, we train CURL alongside two model-free RL algorithms — SAC for DMControl experiments and Rainbow DQN (data-efficient version) for Atari experiments. To specify a contrastive learning objective, we need to define (i) the discrimination objective (ii) the transformation for generating query-key observations (iii) the embedding procedure for transforming observations into queries and keys and (iv) the inner product used as a similarity measure between the query-key pairs in the contrastive loss. The exact specification these aspects largely determine the quality of the learned representations.
We first summarize the CURL architecture, and then cover each architectural choice in detail.
4.1 Architectural Overview
CURL uses instance discrimination with similarities to SimCLR (Chen et al., 2020), MoCo (He et al., 2019a) and CPC (Hénaff et al., 2019). Most Deep RL architectures operate with a stack of temporally consecutive frames as input (Hessel et al., 2017). Therefore, instance discrimination is performed across the frame stacks as opposed to single image instances. We use a momentum encoding procedure for targets similar to MoCo (He et al., 2019b) which we found to be better performing for RL. Finally, for the InfoNCE score function, we use a bi-linear inner product similar to CPC (van den Oord et al., 2018) which we found to work better than unit norm vector products used in MoCo and SimCLR. Ablations for both the encoder and the similarity measure choices are shown in Figure 5. The contrastive representation is trained jointly with the RL algorithm, and the latent code receives gradients from both the contrastive objective and the Q-function. An overview of the architecture is shown in in Figure 2.
4.2 Discrimination Objective
A key component of contrastive representation learning is the choice of positives and negative samples relative to an anchor (Bachman et al., 2019; Hénaff et al., 2019; He et al., 2019a; Chen et al., 2020). Contrastive Predictive Coding (CPC) based pipelines (Hénaff et al., 2019; van den Oord et al., 2018) use groups of image patches separated by a carefully chosen spatial offset for anchors and positives while the negatives come from other patches within the image and from other images.
While patches are a powerful way to incorporate spatial and instance discrimination together, they introduce extra hyperparameters and architectural design choices which may be hard to adapt for a new problem. SimCLR (Chen et al., 2020) and MoCo (He et al., 2019a) opt for a simpler design where there is no patch extraction.
Discriminating transformed image instances as opposed to image-patches within the same image optimizes a simpler instance discrimination objective (Wu et al., 2018) with the InfoNCE loss and requires minimal architectural adjustments (He et al., 2019b; Chen et al., 2020). It is preferable to pick a simpler discrimination objective in the RL setting for two reasons. First, considering the brittleness of reinforcement learning algorithms (Henderson et al., 2018), complex discrimination may destabilize the RL objective. Second, since RL algorithms are trained on dynamically generated datasets, a complex discrimination objective may significantly increase the wall-clock training time. CURL therefore uses instance discrimination rather than patch discrimination. One could view contrastive instance discrimination setups like SimCLR and MoCo as maximizing mutual information between an image and its augmented version. The reader is encouraged to refer to van den Oord et al. (2018); Hjelm et al. (2018); Tschannen et al. (2019) for connections between contrastive learning and mutual information.
4.3 Query-Key Pair Generation
Similar to instance discrimination in the image setting (He et al., 2019b; Chen et al., 2020), the anchor and positive observations are two different augmentations of the same image while negatives come from other images. CURL primarily relies on the random crop data augmentation, where a random square patch is cropped from the original rendering.
A significant difference between RL and computer vision settings is that an instance ingested by a model-free RL algorithm that operates from pixels is not just a single image but a stack of frames (Mnih et al., 2015). For example, one typically feeds in a stack of 4 frames in Atari experiments and a stack of 3 frames in DMControl. This way, performing instance discrimination on frame stacks allows CURL to learn both spatial and temporal discriminative features. For details regarding the extent to which CURL captures temporal features, see Appendix E.
4.4 Similarity Measure
Another determining factor in the discrimination objective is the inner product used to measure agreement between query-key pairs. CURL employs the bi-linear inner-product , where is a learned parameter matrix. We found this similarity measure to outperform the normalized dot-product (see Figure 5 in Appendix A) used in recent state-of-the-art contrastive learning methods in computer vision like MoCo and SimCLR.
4.5 Target Encoding with Momentum
The motivation for using contrastive learning in CURL is to train encoders that map from high dimensional pixels to more semantic latents. InfoNCE is an unsupervised loss that learns encoders and mapping the raw anchors (query) and targets (keys) into latents and , on which we apply the similarity dot products. It is common to share the same encoder between the anchor and target mappings, that is, to have (van den Oord et al., 2018; Hénaff et al., 2019).
From the perspective of viewing contrastive learning as building differentiable dictionary lookups over high dimensional entities, increasing the size of the dictionary and enriching the set of negatives is helpful in learning rich representations. He et al. (2019a) propose momentum contrast (MoCo), which uses the exponentially moving average (momentum averaged) version of the query encoder for encoding the keys in . Given parametrized by and parametrized by , MoCo performs the update and encodes any target using [SG : Stop Gradient].
CURL couples frame-stack instance discrimination with momentum encoding for the targets during contrastive learning, and RL is performed on top of the encoder features.
4.6 Differences Between CURL and Prior Contrastive Methods in RL
van den Oord et al. (2018) use Contastive Predictive Coding (CPC) as an auxiliary task wherein an LSTM operates on a latent space of a convolutional encoder; and both the CPC and A2C (Mnih et al., 2015) objectives are jointly optimized. CURL avoids using pipelines that predict the future in a latent space such as van den Oord et al. (2018); Hafner et al. (2019). In CURL, we opt for a simple instance discrimination style contrastive auxiliary task.
4.7 CURL Contrastive Learning Pseudocode (PyTorch-like)
We measure the data-efficiency and performance of our method and baselines at 100k interaction steps on both DMControl and Atari, which we will henceforth refer to as DMControl100k and Atari100k for clarity. Benchmarking at 100k steps makes for a fixed experimental setup that is easy to reproduce and has been common practice when investigating data-efficiency on Atari Kaiser et al. (2019); Kielak (2020). A broader motivation is that while RL algorithms can achieve super-human performance on many Atari games, they are still far from the data-efficiency of a human learner. Training for 100k steps is within the order of magnitude that we would expect for humans to a learn similar tasks. In our experiments, 100k steps corresponds to 300-400k frames (due to using a frame-skip of 3 or 4), which equates to roughly a 2-4 hours of human game play.
We evaluate (i) sample-efficiency by measuring how many interaction steps it takes the best performing baselines to match CURL performance at 100k interaction steps and (ii) performance by measuring the ratio of the episode returns achieved by CURL versus the best performing baseline at 100k steps. To be explicit, when we say data or sample-efficiency we’re referring to (i) and when we say performance we’re referring to (ii).
|500K step scores||CURL||PlaNet||Dreamer||SAC+AE||SLAC||Pixel SAC||State SAC|
|Finger, spin||971 18||693 27||796 183||884 128||892 130||509 148||932 32|
|Cartpole, swingup||853 10||794 14||762 27||735 63||-||382 79||870 11|
|Reacher, easy||945 27||833 101||793 164||627 58||-||201 94||944 30|
|Cheetah, run||694 42||608 20||542 37||550 34||617 14||292 31||826 22|
|Walker, walk||925 21||912 35||909 11||847 48||877 54||226 15||935 31|
|Ball in cup, catch||956 18||725 309||879 87||794 58||900 181||118 92||984 16|
|100K step scores|
|Finger, spin||845 42||632 112||341 70||740 64||693 141||403 67||74154|
|Cheetah, run||49531||29498||238 76||26724||29634||4914||6179|
|Ball in cup, catch||942 27||405 375||246 174||391 82||834 128||131 52||97914|
Our primary goal for CURL is sample-efficient control from pixels that is broadly applicable across a range of environments. We benchmark the performance of CURL for both discrete and continuous control environments. Specifically, we focus on DMControl suite for continuous control tasks and the Atari Games benchmark for discrete control tasks with inputs being raw pixels rendered by the environments.
DeepMind Control: Recently, there have been a number of papers that have benchmarked for sample efficiency on challenging visual continuous control tasks belonging to the DMControl suite (Tassa et al., 2018) where the agent operates purely from pixels. The reason for operating in these environments is multi fold: (i) they present a reasonably challenging and diverse set of tasks; (ii) sample-efficiency of pure model-free RL algorithms operating from pixels on these benchmarks is poor; (iii) multiple recent efforts to improve the sample efficiency of both model-free and model-based methods on these benchmarks thereby giving us sufficient baselines to compare against; (iv) performance on the DM control suite is relevant to robot learning in real world benchmarks.
We run experiments on sixteen environments from DMControl to examine the performance of CURL on pixels relative to SAC with access to the ground truth state, shown in Figure 7. For more extensive benchmarking, we compare CURL to five leading pixel-based methods across the the six environments presented in Yarats et al. (2019): ball-in-cup, finger-spin, reacher-easy, cheetah-run, walker-walk, cartpole-swingup for benchmarking.
Atari: Similar to DMControl sample-efficiency benchmarks, there have been a number of recent papers that have benchmarked for sample-efficiency on the Atari 2600 Games. Kaiser et al. (2019) proposed comparing various algorithms in terms of performance achieved within K timesteps (K frames, frame skip of ) of interaction with the environments (games). The method proposed by Kaiser et al. (2019) called SimPLe is a model-based RL algorithm. SimPLe is compared to a random agent, model-free Rainbow DQN (Hessel et al., 2017) and human performance for the same amount of interaction time. Recently, van Hasselt et al. (2019) and Kielak (2020) proposed data-efficient versions of Rainbow DQN which are competitive with SimPLe on the same benchmark. Given that the same benchmark has been established in multiple recent papers and that there is a human baseline to compare to, we benchmark CURL on all the 26 Atari Games (Table 2).
5.3 Baselines for benchmarking sample efficiency
DMControl baselines: We present a number of baselines for continuous control within the DMControl suite: (i) SAC-AE (Yarats et al., 2019) where the authors attempt to use a -VAE (Higgins et al., 2017), VAE (Kingma and Welling, 2013)
and a regualrized autoencoderVincent et al. (2008); Ghosh et al. (2019) jointly with SAC; (ii) SLAC (Lee et al., 2019) which learns a latent space world model on top of VAE features Ha and Schmidhuber (2018) and builds value functions on top; (iii) PlaNet and (iv) Dreamer (Hafner et al., 2018, 2019) both of which learn a latent space world model and explicitly plan through it; (v) Pixel SAC: Vanilla SAC operating purely from pixels (Haarnoja et al., 2018). These baselines are competitive methods for benchmarking control from pixels. In addition to these, we also present the baseline State-SAC where the assumption is that the agent has access to low level state based features and does not operate from pixels. This baseline acts as an oracle in that it approximates the upper bound of how sample-efficient a pixel-based agent can get in these environments.
Atari baselines: For benchmarking performance on Atari, we compare CURL to (i) SimPLe (Kaiser et al., 2019), the top performing model-based method in terms of data-efficiency on Atari and (ii) Rainbow DQN (Hessel et al., 2017), a top-performing model-free baseline for Atari, (iii) OTRainbow (Kielak, 2020) which is an OverTrained version of Rainbow for data-efficiency, (iv) Efficient Rainbow (van Hasselt et al., 2019) which is a modification of Rainbow hyperparameters for data-efficiency, (v) Random Agent (Kaiser et al., 2019), (vi) Human Performance (Kaiser et al., 2019; van Hasselt et al., 2019). All the baselines and our method are evaluated for performance after 100K interaction steps (400K frames with a frame skip of 4) which corresponds to roughly two hours of gameplay. These benchmarks help us understand how the state-of-the-art pixel based RL algorithms compare in terms of sample efficiency and also to human efficiency. Note: Scores for SimPLe and Human baselines have been reported differently in prior work (Kielak, 2020; van Hasselt et al., 2019). To be rigorous, we take the best reported score for each individual game across the numbers used in prior work.
(i) CURL is the state-of-the-art image-based RL algorithm on every single DMControl environment that we benchmark on for sample-efficiency against existing pixel-based baselines. On DMControl100k, CURL achieves 2.8x higher performance than Dreamer (Hafner et al., 2019), a leading model-based method, and is 9.9x more data-efficient.
(ii) CURL operating purely from pixels nearly matches (and sometimes surpasses) the sample efficiency of SAC operating from state on the majority of 16 DMControl environments tested shown in Figure 7. This is a first for any image-based RL algorithm, be it model-based, model-free, with or without auxiliary tasks.
(iii) CURL solves (converges close to optimal score of 1000) on the majority of 16 DMControl experiments within 500k steps. It is also competitive to state-of-the-art asymptotic performance within just 100k steps, and significantly outperforms other methods in this regime.
Results for Atari100k are shown in Table 2. Below are the key findings:
(i) CURL is the state-of-the-art pixel-based RL algorithm in terms of data-efficiency on the majority of twenty-six Atari100k experiments. On average, CURL outperforms SimPLe by 1.6x and Efficient Rainbow DQN by 2.5x on Atari100k.
(ii) CURL achieves a median human-normalized score (HNS) of 24% while SimPLe and Efficient Rainbow DQN achieve 13.5% and 14.7% respectively. The mean HNS is 37.3%, 39%, and 23.8% for CURL, SimPLe, and Efficient Rainbow DQN respectively.
(iii) CURL nearly matches human efficiency on three games JamesBond (98.4% HNS), Freeway (94.2% HNS), and Road_Runner (86.5% HNS), a first for any pixel-based RL algorithm.
7 Ablation Studies
In Appendix E, we present the results of ablation studies carried out to answer the following questions: (i) Does CURL learn only visual features or does it also capture temporal dynamics of the environment? (ii) How well does the RL policy perform if CURL representations are learned solely with the contrastive objective and no signal from RL? (iii) Why does CURL match state-based RL performance on some DMControl environments but not on others?
In this work, we proposed CURL, a contrastive unsupervised representation learning method for RL, that achieves state-of-the-art data-efficiency on pixel-based RL tasks across a diverse set of benchmark environments. CURL is the first model-free RL pipeline accelerated by contrastive learning with minimal architectural changes to demonstrate state-of-the-art performance on complex tasks so far dominated by approaches that have relied on learning world models and (or) decoder-based objectives. We hope that progress like CURL in combining contrastive learning with RL enables avenues for real-world deployment of RL in areas like robotics where data-efficiency is paramount.
This research is supported in part by DARPA through the Learning with Less Labels (LwLL) Program and by ONR through PECASE N000141612723. We also thank Zak Stone and Google TFRC for cloud credits. Thanks to Danijar Hafner, Alex Lee, and Denis Yarats for sharing performance data for the Dreamer, SLAC, and SAC+AE baselines; and Lerrel Pinto, Adam Stooke and Will Whitney for useful discussions and preliminary feedback.
- Learning representations by maximizing mutual information across views. In Advances in Neural Information Processing Systems, pp. 15509–15519. Cited by: §4.2.
A distributional perspective on reinforcement learning.
Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 449–458. Cited by: §3.2.
The arcade learning environment: an evaluation platform for general agents.
Journal of Artificial Intelligence Research47, pp. 253–279. Cited by: §2.
- Dopamine: A Research Framework for Deep Reinforcement Learning. External Links: Cited by: Appendix B.
- A simple framework for contrastive learning of visual representations. External Links: Cited by: Figure 5, Appendix A, §1, §2, §4.1, §4.2, §4.2, §4.2, §4.3.
Learning a similarity metric discriminatively, with application to face verification.
2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR’05), Vol. 1, pp. 539–546. Cited by: §2.
- RandAugment: practical automated data augmentation with a reduced search space. External Links: Cited by: Appendix A.
- Bert: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805. Cited by: §1, §2.
- Impala: scalable distributed deep-rl with importance weighted actor-learner architectures. arXiv preprint arXiv:1802.01561. Cited by: §1, §1.
- Noisy networks for exploration. arXiv preprint arXiv:1706.10295. Cited by: §3.2.
- From variational to deterministic autoencoders. External Links: Cited by: §5.3.
- World models. arXiv preprint arXiv:1803.10122. Cited by: §1, §1, §2, §5.3.
- Soft actor-critic algorithms and applications. arXiv preprint arXiv:1812.05905. Cited by: Appendix A, §1, §3, §5.3, Table 1.
- Dimensionality reduction by learning an invariant mapping. In 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR’06), Vol. 2, pp. 1735–1742. Cited by: §3.3.
- Dream to control: learning behaviors by latent imagination. arXiv preprint arXiv:1912.01603. Cited by: Appendix A, Figure 6, §2, §2, §4.6, §5.3, Table 1, §6.1.
- Learning latent dynamics for planning from pixels. arXiv preprint arXiv:1811.04551. Cited by: Appendix A, §2, §2, §5.3, Table 1.
- Momentum contrast for unsupervised visual representation learning. arXiv preprint arXiv:1911.05722. Cited by: §1, §2, Figure 2, §3.3, §4.1, §4.2, §4.2, §4.5.
- Momentum contrast for unsupervised visual representation learning. External Links: Cited by: Figure 5, Appendix A, §4.1, §4.2, §4.3.
- Data-efficient image recognition with contrastive predictive coding. arXiv preprint arXiv:1905.09272. Cited by: §1, §2, §3.3, §4.1, §4.2, §4.5.
- Deep reinforcement learning that matters. In Thirty-Second AAAI Conference on Artificial Intelligence, Cited by: §4.2.
- Rainbow: combining improvements in deep reinforcement learning. External Links: Cited by: §2, §3.2, §3, §4.1, §5.2, §5.3, Table 2.
- Darla: improving zero-shot transfer in reinforcement learning. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 1480–1490. Cited by: §2, §5.3.
Learning deep representations by mutual information estimation and maximization. arXiv preprint arXiv:1808.06670. Cited by: §4.2.
- Human-level performance in 3d multiplayer games with population-based reinforcement learning. Science 364 (6443), pp. 859–865. Cited by: §1.
- Reinforcement learning with unsupervised auxiliary tasks. arXiv preprint arXiv:1611.05397. Cited by: Appendix A, §1, §1, §2, §2.
- Model-based reinforcement learning for atari. arXiv preprint arXiv:1903.00374. Cited by: §1, §1, §2, §2, §5.1, §5.2, §5.3, Table 2.
- Qt-opt: scalable deep reinforcement learning for vision-based robotic manipulation. arXiv preprint arXiv:1806.10293. Cited by: §1, §1.
- Do recent advancements in model-based deep reinforcement learning really improve data efficiency?. External Links: Cited by: §2, §5.1, §5.2, §5.3, Table 2.
- Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114. Cited by: §5.3.
- ImageNet classification with deep convolutional neural networks. In Advances in Neural Information Processing Systems 25, F. Pereira, C. J. C. Burges, L. Bottou, and K. Q. Weinberger (Eds.), pp. 1097–1105. External Links: Cited by: Appendix A.
- Building machines that learn and think like people. Behavioral and brain sciences 40. Cited by: §1.
- A tutorial on energy-based learning. Cited by: §3.3.
- Stochastic latent actor-critic: deep reinforcement learning with a latent variable model. arXiv preprint arXiv:1907.00953. Cited by: Appendix A, §2, §2, §5.3, Table 1.
- Continuous control with deep reinforcement learning. arXiv preprint arXiv:1509.02971. Cited by: §1.
- Learning to navigate in complex environments. arXiv preprint arXiv:1611.03673. Cited by: §1, §2.
- Human-level control through deep reinforcement learning. Nature 518 (7540), pp. 529–533. Cited by: Appendix A, §1, §3.2, §4.3, §4.6.
- Prioritized experience replay. arXiv preprint arXiv:1511.05952. Cited by: §3.2.
- Mastering atari, go, chess and shogi by planning with a learned model. arXiv preprint arXiv:1911.08265. Cited by: §1.
Facenet: a unified embedding for face recognition and clustering. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 815–823. Cited by: §3.3.
- Time-contrastive networks: self-supervised learning from video. In 2018 IEEE International Conference on Robotics and Automation (ICRA), pp. 1134–1141. Cited by: §2.
- Loss is its own reward: self-supervision for reinforcement learning. arXiv preprint arXiv:1612.07307. Cited by: §2.
- Introduction to reinforcement learning. Vol. 135. Cited by: §3.2.
- Integrated architectures for learning, planning, and reacting based on approximating dynamic programming. In Machine learning proceedings 1990, pp. 216–224. Cited by: §1, §2.
- Going deeper with convolutions. In Computer Vision and Pattern Recognition (CVPR), External Links: Cited by: Appendix A.
- Deepmind control suite. arXiv preprint arXiv:1801.00690. Cited by: §E.3, §1, §2, §5.2.
- On mutual information maximization for representation learning. arXiv preprint arXiv:1907.13625. Cited by: §4.2.
- Representation learning with contrastive predictive coding. arXiv preprint arXiv:1807.03748. Cited by: Figure 5, Appendix A, §1, §1, §2, §2, §2, §3.3, §4.1, §4.2, §4.2, §4.5, §4.6.
- Deep reinforcement learning with double q-learning. In Thirtieth AAAI conference on artificial intelligence, Cited by: §3.2.
When to use parametric models in reinforcement learning?. In Advances in Neural Information Processing Systems, pp. 14322–14333. Cited by: Appendix B, §1, §2, §5.2, §5.3, Table 2.
Extracting and composing robust features with denoising autoencoders. In Proceedings of the 25th international conference on Machine learning, pp. 1096–1103. Cited by: §5.3.
- Unsupervised learning of visual representations using videos. In Proceedings of the IEEE International Conference on Computer Vision, pp. 2794–2802. Cited by: §2, §3.3.
- Dueling network architectures for deep reinforcement learning. arXiv preprint arXiv:1511.06581. Cited by: §3.2.
- Unsupervised control through non-parametric discriminative rewards. arXiv preprint arXiv:1811.11359. Cited by: §2.
- Unsupervised feature learning via non-parametric instance-level discrimination. arXiv preprint arXiv:1805.01978. Cited by: §2, §3.3, §4.2.
- Improving sample efficiency in model-free reinforcement learning from images. arXiv preprint arXiv:1910.01741. Cited by: Appendix A, Appendix A, Appendix A, §2, §2, §5.2, §5.3, Table 1.
Appendix A Implementation Details
Below, we explain the implementation details for CURL in the DMControl setting. Specifically, we use the SAC algorithm as the RL objective coupled with CURL and build on top of the publicly released implementation from Yarats et al. (2019). We present in detail the hyperparameters for the architecture and optimization. We do not use any extra hyperparameter for balancing the contrastive loss and the reinforcement learning losses. Both the objectives are weighed equally in the gradient updates.
|Replay buffer size|
|Action repeat||finger, spin|
|Hidden units (MLP)|
|Learning rate||cheetah, run|
|Learning rate ()|
|(cheetah), 128 (rest)|
|Critic target update freq|
|Number of filters|
Architecture: We use an encoder architecture that is similar to (Yarats et al., 2019)
, which we sketch in PyTorch-like pseuodocode below. The actor and critic both use the same encoder to embed image observations. A full list of hyperparameters is displayed in Table3.
For contrastive learning, CURL utilizes momentum for the key encoder (He et al., 2019b) and a bi-linear inner product as the similarity measure (van den Oord et al., 2018). Performance curves ablating these two architectural choices are shown in Figure 5.
Pseudo-code for our proposed architecture is provided below:
Batch Updates: After initializing the replay buffer with observations extracted by a random agent, we sample a batch of observations, compute the CURL objectives, and step through the optimizer. Note that since queries and keys are generated by data-augmenting an observation, we can generate arbitrarily many keys to increase the contrastive batch size without sampling any additional observations.
Shared Representations: The objective of performing contrastive learning together with RL is to ensure that the shared encoder learns rich features that facilitate sample efficient control. There is a subtle coincidental connection between MoCo and off-policy RL. Both the frameworks adopt the usage of a momentum averaged (EMA) version of the underlying model. In MoCo, the EMA encoder is used for encoding the keys (targets) while in off-policy RL, the EMA version of the Q-networks are used as targets in the Bellman error (Mnih et al., 2015; Haarnoja et al., 2018). Thanks to this connection, CURL shares the convolutional encoder, momentum coefficient and EMA update between contrastive and reinforcement learning updates for the shared parameters. The MLP part of the critic that operates on top of these convolutional features has a separate momentum coefficient and update decoupled from the image encoder parameters.
Balancing Contrastive and RL Updates: While past work has learned hyperparameters to balance the auxiliary loss coefficient or learning rate relative to the RL objective (Jaderberg et al., 2016; Yarats et al., 2019), CURL does not need any such adjustments. We use both the contrastive and RL objectives together with equal weight and learning rate. This simplifies the training process compared to other methods, such as training a VAE jointly (Hafner et al., 2018, 2019; Lee et al., 2019), that require careful tuning of coefficients for representation learning.
Differences in Data Collection between Computer Vision and RL Settings: There are two key differences between contrastive learning in the computer vision and RL settings because of their different goals. Unsupervised feature learning methods built for downstream vision tasks like image classification assume a setting where there is a large static dataset of unlabeled images. On the other hand, in RL, the dataset changes over time to account for the agent’s new experiences. Secondly, the size of the memory bank of labeled images and dataset of unlabeled ones in vision-based settings are 65K and 1M (or 1B) respectively. The goal in vision-based methods is to learn from millions of unlabeled images. On the other hand, the goal in CURL is to develop sample-efficient RL algorithms. For example, to be able to solve a task within 100K timesteps (approximately 2 hours in real-time), an agent can only ingest 100K image frames.
Therefore, unlike MoCo, CURL does not use a memory bank for contrastive learning. Instead, the negatives are constructed on the fly for every minibatch sampled from the agent’s replay buffer for an RL update similar to SimCLR. The exact implementation is provided as a PyTorch-like code snippet in 4.7.
Random crop data augmentation has been crucial for the performance of deep learning based computer vision systems in object recognition, detection and segmentation(Krizhevsky et al., 2012; Szegedy et al., 2015; Cubuk et al., 2019; Chen et al., 2020). However, similar augmentation methods have not seen much adoption in the field of RL even though several benchmarks use raw pixels as inputs to the model.
CURL adopts the random crop data augmentation as the stochastic data augmentation applied to a frame stack. To make it easier for the model to correlate spatio-temporal patterns in the input, we apply the same random crop (in terms of box coordinates) across all four frames in the stack as opposed to extracting different random crop positions from each frame in the stack. Further, unlike in computer vision systems where the aspect ratio for random crop is allowed to be as low as 0.08, we preserve much of the spatial information as possible and use a constant aspect ratio of 0.84 between the original and cropped. In our experiments, data augmented samples for CURL are formed by cropping frames from an input frame of .
DMControl: We render observations at and randomly crop frames. For evaluation, we render observations at and center crop to pixels. We found that implementing random crop efficiently was extremely important to the success of the algorithm. We provide pseudocode below:
Appendix B Atari100k Implementation Details
The flexibility of CURL allows us to apply it to discrete control setting with minimal modifications. Similar to our rationale for picking SAC as the baseline RL algorithm to couple CURL with (for continuous control), we pick the data-efficient version of Rainbow DQN (Efficient Rainbow) (van Hasselt et al., 2019) for Atari100K which performs competitively with an older version of SimPLE (most recent version has improved numbers). We adopt the same hyperparameters specified in the paper (including a modified convolutional encoder that uses larger kernel size and stride of 5, faster decay for exploration (50000 timesteps), and so forth). We present the details in Table 4. Similar to DMControl, the contrastive objective and the RL objective are weighted equally for learning (no extra hyperparameter for balancing the auxiliary objective). We use the Google Dopamine (Castro et al., 2018) codebase for reference Rainbow implementation. We evaluate with three random seeds and report the mean score for each game. We restrict ourselves to using grayscale renderings of image observations and use random crop of frame stack as data augmentation similar to DMControl.
|Data Augmentation||Random Crop (Train)|
|Rollout preprocessing||Reshape ( to )|
|Replay buffer size|
|Replay period every|
|Number of updates per transition|
|Q network: channels||,|
|Q network: filter size|
|Q network: stride||,|
|Q network: hidden units|
|Momentum (EMA for CURL)|
|Multi step return|
|Minimum replay size for sampling|
|Max frames per episode||K|
|Update||Distributional Double Q|
|Target Network Update Period||every updates|
|Optimizer: learning rate||0.0001|
|Max gradient norm|
Appendix C Benchmarking Data Efficiency
Tables 1 and 2 show the episode returns of DMControl100k, DMControl500k, and Atari100k across CURL and a number of pixel-based baselines. CURL outperforms all baseline pixel-based methods across experiments on both DMControl100k and DMControl500k. On Atari100k experiments, CURL outperforms prior methods on the majority of games tested (14 out of 26 games).
Appendix D Further Investigation of Data-Efficiency in Contrastive RL
To further benchmark CURL’s sample-efficiency, we compare it to state-based SAC on a total of 16 DMControl environments. Shown in Figure 7, CURL matches state-based data-efficiency on most of the environments, but lags behind state-based SAC on more challenging environments.
Appendix E Ablations
e.1 Learning Temporal Dynamics
To gain insight as to whether CURL learns temporal dynamics across the stacked frames, we also train a variant of CURL where the discriminants are individual frames as opposed to stacked ones. This can be done by sampling stacked frames from the replay buffer but only using the first frame to update the contrastive loss:
During the actor-critic update, frames in the batch are encoded individually into latent codes, which are then concatenated before being passed to a dense network.
Encoding each frame indiviudally ensures that the contrastive objective only has access to visual discriminants. Comparing the visual and spatiotemporal variants of CURL in Figure 8 shows that the variant trained on stacked frames outperforms the visual-only version in most environments. The only exceptions are reacher and ball-in-cup environments. Indeed, in those environments the visual signal is strong enough to solve the task optimally, whereas in other environments, such as walker and cheetah, where balance or coordination is required, visual information alone is insufficient.
e.2 Decoupling Representation Learning from Reinforcement Learning
Typically, Deep RL representations depend almost entirely on the reward function specific to a task. However, hand-crafted representations such as the proprioceptive state are independent of the reward function. It is much more desirable to learn reward-agnostic representations, so that the same representation can be re-used across different RL tasks. We test whether CURL can learn such representations by comparing CURL to a variant where the critic gradients are backpropagated through the dense feedforward network but stopped before reaching the encoder.
Scores displayed in Figure 9 show that for many environments, the purely contrastive representations are sufficient to learn an optimal policy. The major exception is the cheetah environment, where the detached representation significantly under-performs. Though promising, we leave further exploration of task-agnostic representations for future work.
e.3 Predicting State from Pixels
Despite improved sample-efficiency on most DMControl tasks, there is still a visible gap between the performance of SAC on state and SAC with CURL in some environments. Since CURL learns representations by performing instance discrimination across stacks of three frames, it’s possible that the reason for degraded sample-efficiency on more challenging tasks is due to partial-observability of the ground truth state.
To test this hypothesis, we perform supervised regression from pixels to the proprioceptive state , where each data point is a stack of three consecutive frames and is the corresponding state extracted from the simulator. We find that the error in predicting the state from pixels correlates with the policy performance of pixel-based methods. Test-time error rates displayed in Figure 10 show that environments that CURL solves as efficiently as state-based SAC have low error-rates in predicting the state from stacks of pixels. The prediction error increases for more challenging environments, such as cheetah-run and walker-walk. Finally, the error is highest for environments where current pixel-based methods, CURL included, make no progress at all (Tassa et al., 2018), such as humanoid and swimmer.
This investigation suggests that degraded policy performance on challenging tasks may result from the lack of requisite information about the underlying state in the pixel data used for learning representations. We leave further investigation for future work.