Reinforcement Learning with Augmented Data

04/30/2020 ∙ by Michael (Misha) Laskin, et al. ∙ 14

Learning from visual observations is a fundamental yet challenging problem in reinforcement learning (RL). Although algorithmic advancements combined with convolutional neural networks have proved to be a recipe for success, current methods are still lacking on two fronts: (a) sample efficiency of learning and (b) generalization to new environments. To this end, we present RAD: Reinforcement Learning with Augmented Data, a simple plug-and-play module that can enhance any RL algorithm. We show that data augmentations such as random crop, color jitter, patch cutout, and random convolutions can enable simple RL algorithms to match and even outperform complex state-of-the-art methods across common benchmarks in terms of data-efficiency, generalization, and wall-clock speed. We find that data diversity alone can make agents focus on meaningful information from high-dimensional observations without any changes to the reinforcement learning method. On the DeepMind Control Suite, we show that RAD is state-of-the-art in terms of data-efficiency and performance across 15 environments. We further demonstrate that RAD can significantly improve the test-time generalization on several OpenAI ProcGen benchmarks. Finally, our customized data augmentation modules enable faster wall-clock speed compared to competing RL techniques. Our RAD module and training code are available at https://www.github.com/MishaLaskin/rad.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 5

page 7

page 9

page 10

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Learning from visual observations is a fundamental problem in reinforcement learning (RL). Current success stories build on two key ideas: (a) using expressive convolutional neural networks (CNNs) LeCun et al. (1998) that provide strong spatial inductive bias; (b) better credit assignment Mnih et al. (2015); Lillicrap et al. (2016); Schulman et al. (2017) techniques that are crucial for sequential decision making. This combination of CNNs with modern RL algorithms has led to impressive success with human-level performance in Atari Mnih et al. (2015), super-human Go players Silver et al. (2017), continuous control from pixels Lillicrap et al. (2016); Schulman et al. (2017) and learning policies for real-world robot grasping Kalashnikov et al. (2018).

Figure 1: Reinforcement Learning with Augmented Data (RAD) applies data augmentations to image-based observations for reinforcement learning. RAD can be combined with any reinforcement learning algorithm, on-policy or off-policy, and can be used for both discrete and continuous control tasks without any additional losses. RAD ensures that the trained policy and (or) value function neural networks are consistent across augmented views of the image-based observations. RAD makes no change to the underlying reinforcement learning method and no additional assumptions about the underlying domain other than the knowledge that the agent operates from pixel-based inputs. The simple implementation and efficiency of RAD both in terms of wall-clock and data-efficiency allow it to be an easy plug-and-play module for any reinforcement learning set up.

While these achievements are truly impressive, RL is notoriously plagued with poor data-efficiency and generalization capabilities Duan et al. (2016); Henderson et al. (2018). Real-world successes of reinforcement learning often require months of data-collection and (or) training Kalashnikov et al. (2018); Akkaya et al. (2019). On the other hand, biological agents have the remarkable ability to learn quickly Lake et al. (2017); Kaiser et al. (2020), while being able to generalize to a wide variety of unseen tasks Ghahramani et al. (1996). These challenges associated with RL are further exacerbated when we operate on pixels due to high-dimensional and partially-observable inputs. Bridging the gap of data-efficiency and generalization is hence pivotal to the real-world applicability of RL.

Supervised learning, in the context of computer vision, has addressed the problems of data-efficiency and generalization by injecting useful priors. One such often ignored prior is Data Augmentation. It was critical to the early successes of CNNs LeCun et al. (1998); Krizhevsky et al. (2012) and has more recently enabled better supervised Cubuk et al. (2019a, b), semi-supervised Xie et al. (2019, 2020); Berthelot et al. (2019) and self-supervised Hénaff et al. (2019); Chen et al. (2020); He et al. (2020) learning. By using multiple augmented views of the same data-point as input, CNNs are forced to learn consistencies in their internal representations. This results in a visual representation that improves generalization Xie et al. (2020); Hénaff et al. (2019); Chen et al. (2020); He et al. (2020), data-efficiency Xie et al. (2019); Hénaff et al. (2019); Chen et al. (2020)

and transfer learning 

Hénaff et al. (2019); He et al. (2020).

Inspired by the impact of data augmentation in computer vision, we present RAD: Reinforcement Learning with Augmented Data, a technique to incorporate data-augmentations on input observations for reinforcement learning pipelines. Through RAD, we ensure that the agent is learning on multiple views (or augmentations) of the same input (see Fig. 1). This allows the agent to improve on two key capabilities: (a) data-efficiency: learning to quickly master the task at hand with drastically fewer experience rollouts; (b) generalization: improving transfer to unseen tasks or levels simply by training on more diversely augmented samples. Through RAD, we present the first extensive study of the use of data augmentation techniques for reinforcement learning with no changes to the underlying reinforcement learning algorithm and no additional assumptions about the domain other than the knowledge that the agent operates from image-based observations.

We highlight the main contributions of RAD below:

  • We show that simple end-to-end RL algorithms coupled with augmented data either match or beat every state-of-the-art baseline in terms of performance and data-efficiency across 15 DeepMind Control environments Tassa et al. (2018). A similar result has also been demonstrated in concurrent and independent work by Kostrikov et al. Kostrikov et al. (2020).

  • RAD significantly improves test-time generalization on several environments in the OpenAI ProcGen benchmark suite Cobbe et al. (2019a) widely used for generalization in reinforcement learning.

  • RAD is faster and a more compute-efficient method by significant margins compared to state-of-the-art model-based algorithms such as SLAC Lee et al. (2019), PlaNet Hafner et al. (2019) and Dreamer Hafner et al. (2020) for data and wall-clock efficiency.

  • Our custom implementations of random data augmentations enable us to apply augmentation in the RL setting, where observations consist of stacked frames inputs, without breaking the temporal information present in the stack. Our vectorized and GPU-accelerated augmentations are competitive and on average faster than state-of-the-art framework APIs such as PyTorch, showing a 2x increase in speed (see Table

    5). This plug-and-play module is publicly released here: https://www.github.com/MishaLaskin/rad.

2 Related Work

2.1 Data Augmentation in Supervised Learning

Since our focus is on image-based observations, we cover the related work in computer vision. Data augmentation in deep learning systems for computer vision can be found as early as LeNet-5

LeCun et al. (1998), an early implementation of CNNs on MNIST digit classification. In AlexNet Krizhevsky et al. (2012)

wherein the authors applied CNNs to image classification on ImageNet, data augmentations were used to increase the size of the original dataset by a factor of

by randomly flipping and cropping patches from the original image. These data augmentations inject the priors of invariance to translation and reflection, playing a significant role in improving the performance of supervised computer vision systems. Recently, new augmentation techniques such as AutoAugment Cubuk et al. (2019a) and RandAugment Cubuk et al. (2019b) have been proposed to further improve the performance of these systems.

2.2 Data Augmentation for Data-Efficiency in Semi Self-Supervised Learning

Aside from improving supervised learning, data augmentation has also been widely utilized for unsupervised and semi-supervised learning. MixMatch 

Berthelot et al. (2019), FixMatch Sohn et al. (2020), UDA Xie et al. (2019) use unsupervised data augmentation in order to maximize label agreement without access to the actual labels. Several contrastive representation learning approaches Hénaff et al. (2019); He et al. (2020); Chen et al. (2020) have recently dramatically improved the label-efficiency of downstream vision tasks like ImageNet classification. Contrastive approaches utilize data augmentations and perform patch-wise Hénaff et al. (2019) or instance discrimination (MoCo, SimCLR) He et al. (2020); Chen et al. (2020). In the instance discrimination setting, the contrastive objective aims to maximize agreement between augmentations of the same image and minimize it between all other images in the dataset Chen et al. (2020); He et al. (2020). The choice of augmentations has a significant effect on the quality of the learned representations as demonstrated in SimCLR Chen et al. (2020).

2.3 Prior work in Reinforcement Learning related to Data Augmentation

2.3.1 Data Augmentation with Domain Knowledge

While not directly known for data augmentation in reinforcement learning, the following ideas can be viewed as techniques to diversify the data used to train an RL agent:

Domain Randomization Sadeghi and Levine (2016); Tobin et al. (2017) is a simple data augmentation technique primarily used for transferring policies from simulation to the real world where one takes advantage of the simulator’s access to information about the rendering and physics and thus can train transferable policies from diverse simulated experiences.

Hindsight Experience Replay Andrychowicz et al. (2017) applies the idea of re-labeling trajectories with terminal states as fictitious goals, improving the ability of goal-conditioned RL to learn quickly with sparse rewards. This, however, makes assumptions about the goal space matching with the state space and has had limited success with pixel-based observations.

2.3.2 Synthetic Rollouts using a Learned World Model

While usually not viewed as a data augmentation technique, the idea of generating fake or synthetic rollouts to improve the data-efficiency of RL agents has been proposed in the Dyna framework Sutton (1991). In recent times, these ideas have been used to improve the performance of systems that have explicitly learned world models of the environment and generated synthetic rollouts using them Ha and Schmidhuber (2018); Kaiser et al. (2020); Hafner et al. (2020).

2.3.3 Data Augmentation for Data-Efficient Reinforcement Learning

Data augmentation is a key component for learning contrastive representations in the RL setting as shown in the CURL framework Srinivas et al. (2020), which learns representations that improve the data-efficiency of pixel-based RL by enforcing consistencies between an image and its augmented version through instance contrastive losses. Prior to our work, CURL was the state-of-the-art model for data-efficient RL from pixel inputs. While the focus in CURL was to make use of data augmentations jointly through contrastive and reinforcement learning losses, RAD attempts to directly use data augmentations for reinforcement learning without any auxiliary loss. We refer the reader to a discussion on tradeoffs between CURL and RAD in Section 6. Concurrent and independent to our work, DrQ Kostrikov et al. (2020) uses data augmentations and weighted Q-functions in conjunction with the off-policy RL algorithm SAC Haarnoja et al. (2018) to achieve state-of-the-art data-efficiency results on the DeepMind Control Suite. On the other hand, RAD can be plugged into any reinforcement learning method (on-policy methods like PPO Schulman et al. (2017) and off-policy methods like SAC Haarnoja et al. (2018)) without making any changes to the underlying algorithm. We further demonstrate the benefits of data augmentation to generalization on the OpenAI ProcGen benchmarks in addition to data-efficiency on the DeepMind Control Suite.

2.4 Data Augmentation for Generalization in Reinforcement Learning

Cobbe et al. Cobbe et al. (2019b) and Lee et al. Lee et al. (2020) showed that simple data augmentation techniques such as cutout Cobbe et al. (2019b) and random convolution Lee et al. (2020) can be useful to improve generalization of agents on the OpenAI CoinRun and ProcGen benchmarks. In this paper, we extensively investigate more data augmentation techniques such as random crop and color jitter on a more diverse array of tasks. With our efficient implementation of these augmentations, we demonstrate their utility with the on-policy RL algorithm PPO Schulman et al. (2017) for the first time.

3 Background

RL agents act within a Markov Decision Process, defined as the tuple

, with the following components: states , actions either discrete () or continuous (

), and state transition probability function,

, which defines the task mechanics and rewards. Without prior knowledge of , the RL agent’s goal is to use experience to maximize expected rewards, , under discount factor . Crucially, in RL from pixels, the agent receives image-based observations, , which are a high-dimensional, indirect representation of the state.

Soft Actor-Critic. SAC Haarnoja et al. (2018) is a state-of-the-art off-policy algorithm for continuous control problems. SAC learns a policy, , and a critic, , and aims to maximize a weighted objective of the reward and the policy entropy, . The critic parameters are learned by minimizing the squared Bellman error using transitions, , replayed from an experience buffer, :

(1)

The target value of the next state can be estimated by sampling an action using the current policy:

(2)

where represents a more slowly updated copy of the critic. The policy is learned by minimizing the divergence from the exponential of the soft-Q function at the same states:

(3)

via the reparameterization trick for the newly sampled action. is learned against a target entropy.

Proximal Policy Optimization. PPO Schulman et al. (2017) is a state-of-the-art on-policy algorithm for learning a continuous or discrete control policy, . PPO forms policy gradients using action-advantages, , and minimizes a clipped-ratio loss over minibatches of recent experience (collected under ):

(4)

Our PPO agents learn a state-value estimator, , which is regressed against a target of discounted returns and used with Generalized Advantage Estimation Schulman et al. (2017):

(5)

RAD. When applying RAD to SAC, our data augmentations are applied to the observation passed to in (1) and in (3). Every use of the observation in (4)-(5) for training the PPO agent is subject to the same augmentation. However, unlike SAC, we apply the data augmentation on samples collected from rollouts because PPO is an on-policy RL method.

4 Reinforcement Learning with Augmented Data

We investigate the utility of data augmentations in model-free RL for both off-policy and on-policy settings by processing image observations with stochastic augmentations before passing them to the agent for training. For the base RL agent, we use SAC Haarnoja et al. (2018) and PPO Schulman et al. (2017) as the off-policy and on-policy RL methods respectively. During training, we sample observations from either a replay buffer or a recent trajectory and augment the images within the minibatch. In the RL setting, it is common to stack consecutive frames as observations to infer temporal information such as object velocities. Crucially, augmentations are applied randomly across the batch but consistently across the frame stack Srinivas et al. (2020) as shown in Figure 2.111 For on-policy RL methods such as PPO, we apply the different augmentations across the batch but consistently across time. This enables the augmentation to retain temporal information present across the frame stack.

Figure 2: We investigate eight different types of data augmentations - crop, grayscale, cutout, cutout-color, flip, rotate, random convolution, and color-jitter. During training, a minibatch is sampled from the replay buffer or a recent trajectory randomly augmented. While augmentation across the minibatch is stochastic, it is consistent across the stacked frames.

Across our experiments, we investigate and ablate the following data augmentations, which are visualized in Figure 2:

  • Crop: Extracts a random patch from the original frame. For example, in DMControl we render pixel frames and crop randomly to pixels.

  • Grayscale: Converts RGB images to grayscale with some random probability .

  • Cutout: Randomly inserts a small black occlusion into the frame, which may perceived as cutting out a small patch from the originally rendered frame.

  • Cutout-color: Another variant of cutout where instead of rendering black, the occlusion color is randomly generated.

  • Flip: Flips an image at random across the vertical axis.

  • Rotate: Randomly samples an angle from the following set and rotates the image accordingly.

  • Random convolution: First introduced in Lee et al. (2020), augments the image color by passing the input observation through a random convolutional layer.

  • Color jitter: Converts RGB image to HSV and adds noise to the HSV channels, which results in explicit color jittering.

A major issue with applying data augmentations to RL is that RL algorithms are typically run in a minibatch setting with a dynamic dataset. In contrast, data augmentation for image-based supervised learning assumes that the dataset is static, which enables efficient augmentation by treating the dataset as a large, repeating and cached buffer to sample from. Since the data augmentation APIs from popular frameworks like Tensorflow 

Abadi et al. (2016) or PyTorch Paszke et al. (2019) are optimized for static image datasets, they are not suitable for augmenting stacked frames in a minibatch setting with a dynamic dataset. Particularly, it is infeasible to augment randomly across the batch but consistently across the frame stack without incurring significant additional wall-clock time. A key technical contribution of this work is the custom implementation of data augmentation ops that utilize vectorization Walt et al. (2011) and GPU-acceleration Chetlur et al. (2014) to add minimal overhead when processing stacked frames in the minibatch setting. We show in Table 5 that our implementations are on average 2x faster compared to using the PyTorch API. Representative code snippets are presented in appendix E, and our code is publicly available at https://www.github.com/MishaLaskin/rad.

5 Experimental Results

5.1 Setup

DMControl: Our goal is to investigate whether data augmentation can be broadly used to improve pixel-based RL. For this reason, we focus on studying the data-efficiency and generalization abilities of our proposed methods. To benchmark data-efficiency, we utilize the DeepMind Control Suite (DMControl) Tassa et al. (2018), which has recently become a common benchmark for comparing efficient RL agents, both model-based and model-free. DMControl presents a variety of complex tasks including bipedal balance, locomotion, contact forces, and goal-reaching with both sparse and dense reward signals.

For DMControl experiments, we evaluate the data-efficiency by measuring the performance of our method at 100k (i.e., low sample regime) and 500k (i.e., asymptotically optimal regime) simulator or environment steps during training by following the setup in CURL Srinivas et al. (2020). These benchmarks are referred to as DMControl100k and DMControl500k. For comparison, we consider six powerful recent pixel-based methods: CURL Srinivas et al. (2020) learns contrastive representations, SLAC Lee et al. (2019) learns a forward model and uses it to shape encoder representations, while SAC+AE Yarats et al. (2019) minimizes a reconstruction loss as an auxiliary task. All three methods use SAC Haarnoja et al. (2018) as their base algorithm. Dreamer Hafner et al. (2020) and PlaNet Hafner et al. (2019) learn world models and use them to generate synthetic rollouts similar to Dyna Sutton (1991). Pixel SAC is a vanilla Soft Actor-Critic operating on pixel inputs, and state SAC is an oracle baseline that operates on the proprioceptive state of the simulated agent, which includes joint positions and velocities. We also provide learning curves for longer runs and examine how RAD compares to state SAC and CURL across a more diverse set of environments in Figure 3.222environment steps

refers to the number of times the underlying simulator is stepped through. This measure is independent of policy heuristics such as action repeat. For example, if action repeat is set to 4, then 100k

environment steps corresponds to 25k policy steps.

ProcGen: Although DMControl is suitable for benchmarking data-efficiency and performance, it evaluates the performance on the same environment in which the agent was trained and is thus not applicable for studying generalization. For this reason, we focus on the OpenAI ProcGen benchmarks Cobbe et al. (2019a) to investigate the generalization capabilities of RAD. ProcGen presents a suite of game-like environments where the train and test environments differ in visual appearance and structure. For this reason, it is a commonly used benchmark for studying the generalization abilities of RL agents Cobbe et al. (2019b). Specifically, we evaluate the zero-shot performance of the trained agents on the full distribution of unseen levels. Specifically, following the setup in Cobbe et al. Cobbe et al. (2019a), we use the CNN architecture found in IMPALA Espeholt et al. (2018) as the policy network and train the agents using the Proximal Policy Optimization (PPO) Schulman et al. (2017) algorithm for 20M timesteps. For all experiments, we use the easy environment difficulty

and the hyperparameters suggested in

Cobbe et al. (2019a), which have been shown to be empirically effective.333We used a reference implementation publicly available at https://github.com/openai/train-procgen.

5.2 Improving Data-efficiency on Deepmind Control

500K step scores RAD CURL PlaNet Dreamer SAC+AE SLACv1 Pixel SAC State SAC
Finger, spin
962
16
926
45
561
284
796
183
884
128
673
92
192
166
923
211
Cartpole, swing
847
26
845
45
475
71
762
27
735
63
-
419
40
848
15
Reacher, easy
952
25
929
44
210
44
793
164
627
58
-
145
30
923
24
Cheetah, run
611
51
518
28
305
131
732
103
550
34
640
19
197
15
795
30
Walker, walk
918
16
902
43
351
58
897
49
847
48
842
51
42
12
948
54
Cup, catch
960
12
959
27
460
380
879
87
794
58
852
71
312
63
974
33
100K step scores
Finger, spin
831
30
767
56
136
216
341
70
740
64
693
141
224
101
811
46
Cartpole, swing
813
32
582
146
297
39
326
27
311
11
-
200
72
835
22
Reacher, easy
510
76
538
233
20
50
314
155
274
14
-
136
15
746
25
Cheetah, run
387
82
299
48
138
88
238
76
267
24
319
56
130
12
616
18
Walker, walk
429
58
403
24
224
48
277
12
394
22
361
73
127
24
891
82
Cup, catch
495
168
769
43
0
0
246
174
391
82
512
110
97
27
746
91
Table 1: We report scores for RAD and baseline methods on DMControl100k and DMControl500k. In both settings, RAD achieves state-of-the-art performance on the majority (5 out of 6) environments. We selected these 6 environments for benchmarking due to availability of baseline performance data from CURL Srinivas et al. (2020), PlaNet Hafner et al. (2019), Dreamer Hafner et al. (2020), SAC+AE Yarats et al. (2019), and SLAC Lee et al. (2019). We also show performance data on 15 environments in total in Figure 3. Results are reported as averages across 5 seeds for 6 main environments.

Figure 3: We benchmark the performance of RAD relative to the best performing pixel-based baseline (CURL) as well as SAC operating on state input on 15 environments in total. RAD matches state SAC performance on the majority (11 out of 15 environments) and performs comparably or better than CURL on all of the environments tested. Results are average values across 3 seeds.

Figure 4: We ablate six common data augmentations on the walker, walk environment by measuring performance on DMControl500k of each permutation of any two data augmentations being performed in sequence. For example, the crop row and grayscale column correspond to the score achieved after applying random crop and then random grayscale to the input images. The vanilla Pixel SAC baseline achieves a score of 42 on DMControl500k, suggesting that nearly all data augmentations improve performance. However, random crop is the most effective data augmentation by a large margin. Random crop alone improves the pixel SAC baseline performance by 22x on the walker, walk environment.

Data-efficiency: Mean scores shown in Table 1 and learning curves in Figure 3 show that data augmentation significantly improves the data-efficiency and performance across the six extensively benchmarked environments compared to existing methods. We summarize the main findings below:

  • RAD is the state-of-the-art algorithm on the majority (5 out of 6) environments on both DMControl100k and DMControl500k benchmarks.

  • RAD improves the performance of pixel SAC by 4x on both DMControl100k and DMControl500k solely through data augmentation without learning a forward model or any other auxiliary task.

  • RAD matches the performance of state-based SAC on the majority of (11 out of 15) DMControl environments tested as shown in Figure 3.

  • Random crop, stand-alone, has the highest impact on final performance relative to all other augmentations as shown in in Figure 4.

Which data augmentations contribute the most? To understand which data augmentations are the most helpful for DMControl, we run RAD with all possible permutations of two data augmentations applied in sequence (e.g. crop followed by grayscale) on the Walker Walk environment and benchmark the scores at 500k environment steps. We choose this environment because the original SAC policy fails entirely and achieves a near-zero score in this regime, which makes it easy to interpret the source of performance improvement. Results shown in Figure 4 suggest that while most data augmentations improve the performance of the base policy, random crop is the most effective augmentation by a large margin. RAD trained with random crop alone achieves the highest score out of all of the possible pair-wise permutations. For this reason, we use random crop as the main augmentation for results shown in Table 1.

Why is random crop so effective? Surprisingly, one data augmentation is powerful enough to transform an agent that hardly learns, to an agent that achieves state-of-the-art performance. To understand how random crop affects learned representations within the convolutional neural network (CNN), we visualize a spatial attention map of the encoder. We denote the activations of the encoder as , where is the number of channels and is the spatial dimension. Similar to Zagoruyko and Komodakis (2017), we first compute a spatial attention map by mean-pooling the absolute values of the activations across the channel dimension, i.e., followed by a spatial softmax. This enables us to visualize which regions in the image are attended to. Results in Figure 5 visualize the spatial attention maps across a variety of data augmentations.

(a) Walker
(b) Cheetah
Figure 5: Spatial attention map of an encoder that shows where the agent focuses on in order to make a decision in (a) Walker Walk and (b) Cheetah Run environments. Random crop enables the agent to focus on the robot body and ignore irrelevant scene details compared to other augmentations as well as the base agent that learns without any augmentation. In addition to the agent, the base cheetah encoder focuses on the stars in the background, which are irrelevant to the task and likely to harm the performance. Random crop enables the encoder to capture the agent’s state much more clearly compared to other augmentations. The quality of the attention map with random crop suggests that RAD improves the contingency-awareness of the agent (recognizing aspects of the environment that are under the agent’s control) thereby improving the data-efficiency.
# of training
levels
Pixel
PPO
RAD
(gray)
RAD
(flip)
RAD
(rotate)
RAD
(random conv)
RAD
(color-jitter)
RAD
(cutout)
RAD
(cutout-color)
RAD
(crop)
BigFish
100
1.9
0.1
1.5
0.3
2.3
0.4
1.9
0.0
1.0
0.1
1.0
0.1
2.9
0.2
2.0
0.2
5.4
0.5
200
4.3
0.5
2.1
0.3
3.5
0.4
1.5
0.6
1.2
0.1
1.5
0.2
3.3
0.2
3.5
0.3
6.7
0.8
StarPilot
100
18.0
0.7
10.6
1.4
13.1
0.2
9.7
1.6
7.4
0.7
15.0
1.1
17.2
2.0
22.4
2.1
20.3
0.7
200
20.3
0.7
20.6
1.0
20.7
3.9
15.7
0.7
11.0
1.5
20.6
1.1
24.5
0.1
24.5
1.6
24.3
0.1
Jumper
100
5.2
0.5
5.2
0.1
5.2
0.7
5.7
0.6
5.5
0.3
6.1
0.2
5.6
0.1
5.8
0.6
5.1
0.2
200
6.0
0.2
5.6
0.1
5.4
0.3
5.5
0.1
5.2
0.1
5.9
0.1
5.4
0.1
5.6
0.4
5.2
0.7
Table 2:

We present the generalization results of RAD with different data augmentation methods on the three OpenAI ProcGen environments: BigFish, StarPilot and Jumper. We report the test performances after 20M timesteps. The results show the mean and standard deviation averaged over three runs. We see that RAD is able to outperform the baseline PPO trained on two times the number of training levels benefitting from data-augmentations such as random crop, cutout and color jitter.

5.3 Improving Generalization on OpenAI ProcGen

Generalization: We evaluate the generalization ability on three environments from OpenAI Procgen: BigFish, StarPilot, and Jumper (see Figure 6(a) and Appendix B for more detailed environment descriptions) by varying the number of training environments and ablating for different data augmentation methods. We summarize our findings below:

  • As shown in Table 2, various data augmentation methods such as random crop and cutout significantly improve the generalization performance on the BigFish and StarPilot environments (Refer to Appendix A for learning curves).

  • In particular, RAD with random crop achieves 55.8% relative gain over pixel-based PPO on the BigFish environment.

  • RAD trained with 100 training levels outperforms the pixel-based PPO trained with 200 training levels on both BigFish and StarPilot environments. This shows that data augmentation can be more effective in learning generalizable representations compared to simply increasing the number of training environments.

  • In the case of Jumper (a navigation task), the gain from data augmentation is not as significant because the task involves structural generalization to different map layouts and is likely to require recurrent policies Cobbe et al. (2019a).

  • To verify the effects of data augmentations on such environments, we consider a modified version of CoinRun Cobbe et al. (2019b) which corresponds to a simpler version of Jumper. By following the set up in Lee et al. (2020), we train agents on a fixed set of 500 levels with half of the available themes (style of backgrounds, floors, agents, and moving obstacles) and then measure the test performance on 1000 different levels consisting of unseen themes to evaluate the generalization ability across the visual changes. As shown in Figure 6(b), data augmentation methods, such as random convolution, color-jitter, and cutout-color improve the generalization ability of the agent to a greater extent than random crop suggesting the need for further study on more data augmentations in these environments.

(a) ProcGen
(b) Test performance on modified CoinRun
Figure 6: (a) Examples of seen and unseen environments on ProcGen. (b) The test performance under the modified CoinRun. The solid/dashed lines and shaded regions represent the mean and standard deviation, respectively.

6 Discussion

6.1 CURL vs RAD

Both CURL and RAD improve the sample-efficiency of RL agents by enforcing consistencies in the input observations presented to the agent. CURL does this with an explicit instance contrastive loss between an image and its augmented version using the MoCo He et al. (2020) mechanism. On the other hand, RAD does not employ any auxiliary loss and directly trains the RL objective on multiple augmented views of the observations, thereby ensuring consistencies on the augmented views implicitly. The performance of RAD matches that of CURL and surpasses CURL on some of the environments in the DeepMind Control Suite (refer to Figure 3). This suggests the potential conclusion that data augmentation is sufficient for sample-efficient reinforcement learning from pixels. We argue that the conclusion requires a bit more nuance in the following subsection.

6.2 Is Data Augmentation Sufficient for RL from Pixels?

The improved performance of RAD over CURL can be attributed to the following line of thought: While both methods try to improve the sample-efficiency through augmentation consistencies (CURL explicitly, RAD implicitly); RAD outperforms CURL because it only optimizes for what we care about, which is the task reward. CURL, on the other hand, jointly optimizes the reinforcement and contrastive learning objectives. If the metric used to evaluate and compare these methods is the score attained on the task at hand, a method that purely focuses on reward optimization is expected to be better as long as it implicitly ensures similarity consistencies on the augmented views (in this case, just by training the RL objective on different augmentations directly).

However, we believe that a representation learning method like CURL is arguably a more general framework for the usage of data augmentations in reinforcement learning. CURL can be applied even without any task (or environment) reward available. The contrastive learning objective in CURL that ensures consistencies between augmented views is disentangled from the reward optimization (RL) objective and is therefore capable of learning-rich semantic representations from high dimensional observations gathered from random rollouts. Real-world applications of RL might involve performing plenty of interactions (or rollouts) with sparse reward signals, and tasks presented to the agent as image-based goals. In such scenarios, CURL and other representation learning methods are likely to be more important even though current RL benchmarks are primarily about single or multi-task reward optimization.

Given these subtle considerations, we believe that both RAD and representation learning methods like CURL will be useful tools for an RL practitioner in future research encompassing data-efficient and generalizable RL.

7 Conclusion

In this work, we proposed RAD, a simple plug-and-play module to enhance any reinforcement learning (RL) method using data augmentations. For the first time, we show that data augmentations alone can significantly improve the data-efficiency and generalization of RL methods operating from pixels, without any changes to the underlying RL algorithm

, on the DeepMind Control Suite and the OpenAI ProcGen benchmarks respectively. Our implementation is extremely simple and efficient and has been open-sourced. We hope that the performance gains, implementation ease, and wall clock efficiency of RAD make it a useful module for future research in data-efficient and generalizable RL methods; and a useful tool for facilitating real-world applications of RL.

References

  • M. Abadi, A. Agarwal, P. Barham, E. Brevdo, Z. Chen, C. Citro, G. S. Corrado, A. Davis, J. Dean, M. Devin, et al. (2016)

    Tensorflow: large-scale machine learning on heterogeneous distributed systems

    .
    arXiv preprint arXiv:1603.04467. Cited by: §4.
  • I. Akkaya, M. Andrychowicz, M. Chociej, M. Litwin, B. McGrew, A. Petron, A. Paino, M. Plappert, G. Powell, R. Ribas, et al. (2019) Solving rubik’s cube with a robot hand. arXiv preprint arXiv:1910.07113. Cited by: §1.
  • M. Andrychowicz, F. Wolski, A. Ray, J. Schneider, R. Fong, P. Welinder, B. McGrew, J. Tobin, O. Pieter Abbeel, and W. Zaremba (2017) Hindsight experience replay. In NeurIPS, Cited by: §2.3.1.
  • D. Berthelot, N. Carlini, I. Goodfellow, N. Papernot, A. Oliver, and C. Raffel (2019) MixMatch: a holistic approach to semi-supervised learning. In NeurIPS, Cited by: §1, §2.2.
  • T. Chen, S. Kornblith, M. Norouzi, and G. Hinton (2020) A simple framework for contrastive learning of visual representations. arXiv:2002.05709. Cited by: §1, §2.2.
  • S. Chetlur, C. Woolley, P. Vandermersch, J. Cohen, J. Tran, B. Catanzaro, and E. Shelhamer (2014) Cudnn: efficient primitives for deep learning. arXiv preprint arXiv:1410.0759. Cited by: §4.
  • K. Cobbe, C. Hesse, J. Hilton, and J. Schulman (2019a) Leveraging procedural generation to benchmark reinforcement learning. arXiv preprint arXiv:1912.01588. Cited by: Appendix C, 2nd item, 4th item, §5.1.
  • K. Cobbe, O. Klimov, C. Hesse, T. Kim, and J. Schulman (2019b) Quantifying generalization in reinforcement learning. In ICML, Cited by: §2.4, 5th item, §5.1.
  • E. D. Cubuk, B. Zoph, D. Mane, V. Vasudevan, and Q. V. Le (2019a) Autoaugment: learning augmentation strategies from data. In CVPR, Cited by: §1, §2.1.
  • E. D. Cubuk, B. Zoph, J. Shlens, and Q. V. Le (2019b) RandAugment: practical automated data augmentation with a reduced search space. arXiv:1909.13719. Cited by: §1, §2.1.
  • Y. Duan, J. Schulman, X. Chen, P. L. Bartlett, I. Sutskever, and P. Abbeel (2016) RL: fast reinforcement learning via slow reinforcement learning. arXiv:1611.02779. Cited by: §1.
  • L. Espeholt, H. Soyer, R. Munos, K. Simonyan, V. Mnih, T. Ward, Y. Doron, V. Firoiu, T. Harley, I. Dunning, et al. (2018) Impala: scalable distributed deep-rl with importance weighted actor-learner architectures. In ICML, Cited by: Appendix C, §5.1.
  • Z. Ghahramani, D. M. Wolpert, and M. I. Jordan (1996) Generalization to local remappings of the visuomotor coordinate transformation. Journal of Neuroscience 16 (21), pp. 7085–7096. Cited by: §1.
  • D. Ha and J. Schmidhuber (2018) World models. In NeurIPS, Cited by: §2.3.2.
  • T. Haarnoja, A. Zhou, P. Abbeel, and S. Levine (2018) Soft actor-critic: off-policy maximum entropy deep reinforcement learning with a stochastic actor. In ICML, Cited by: §2.3.3, §3, §4, §5.1.
  • D. Hafner, T. Lillicrap, J. Ba, and M. Norouzi (2020) Dream to control: learning behaviors by latent imagination. In ICLR, Cited by: 3rd item, §2.3.2, §5.1, Table 1.
  • D. Hafner, T. Lillicrap, I. Fischer, R. Villegas, D. Ha, H. Lee, and J. Davidson (2019) Learning latent dynamics for planning from pixels. In ICML, Cited by: 3rd item, §5.1, Table 1.
  • K. He, H. Fan, Y. Wu, S. Xie, and R. Girshick (2020) Momentum contrast for unsupervised visual representation learning. In CVPR, Cited by: §1, §2.2, §6.1.
  • O. J. Hénaff, A. Srinivas, A. Razavi, C. Doersch, S. Eslami, and A. van den Oord (2019) Data-efficient image recognition with contrastive predictive coding. arXiv preprint arXiv:1905.09272. Cited by: §1, §2.2.
  • P. Henderson, R. Islam, P. Bachman, J. Pineau, D. Precup, and D. Meger (2018) Deep reinforcement learning that matters. In

    Thirty-Second AAAI Conference on Artificial Intelligence

    ,
    Cited by: §1.
  • L. Kaiser, M. Babaeizadeh, P. Milos, B. Osinski, R. H. Campbell, K. Czechowski, D. Erhan, C. Finn, P. Kozakowski, S. Levine, et al. (2020) Model-based reinforcement learning for atari. In ICLR, Cited by: §1, §2.3.2.
  • D. Kalashnikov, A. Irpan, P. Pastor, J. Ibarz, A. Herzog, E. Jang, D. Quillen, E. Holly, M. Kalakrishnan, V. Vanhoucke, et al. (2018) Qt-opt: scalable deep reinforcement learning for vision-based robotic manipulation. arXiv preprint arXiv:1806.10293. Cited by: §1, §1.
  • I. Kostrikov, D. Yarats, and R. Fergus (2020) Image augmentation is all you need: regularizing deep reinforcement learning from pixels. arXiv preprint arXiv:2004.13649. Cited by: 1st item, §2.3.3.
  • A. Krizhevsky, I. Sutskever, and G. E. Hinton (2012) ImageNet classification with deep convolutional neural networks. In NeurIPS, Cited by: §1, §2.1.
  • B. M. Lake, T. D. Ullman, J. B. Tenenbaum, and S. J. Gershman (2017) Building machines that learn and think like people. Behavioral and brain sciences 40. Cited by: §1.
  • Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner (1998) Gradient-based learning applied to document recognition. Proceedings of the IEEE 86 (11), pp. 2278–2324. Cited by: §1, §1, §2.1.
  • A. X. Lee, A. Nagabandi, P. Abbeel, and S. Levine (2019) Stochastic latent actor-critic: deep reinforcement learning with a latent variable model. arXiv preprint arXiv:1907.00953. Cited by: 3rd item, §5.1, Table 1.
  • K. Lee, K. Lee, J. Shin, and H. Lee (2020) Network randomization: a simple technique for generalization in deep reinforcement learning. In ICLR, Cited by: Appendix B, §2.4, 7th item, 5th item.
  • T. P. Lillicrap, J. J. Hunt, A. Pritzel, N. Heess, T. Erez, Y. Tassa, D. Silver, and D. Wierstra (2016) Continuous control with deep reinforcement learning. In ICLR, Cited by: §1.
  • V. Mnih, K. Kavukcuoglu, D. Silver, A. A. Rusu, J. Veness, M. G. Bellemare, A. Graves, M. Riedmiller, A. K. Fidjeland, G. Ostrovski, et al. (2015) Human-level control through deep reinforcement learning. Nature 518 (7540), pp. 529–533. Cited by: §1.
  • A. Paszke, S. Gross, F. Massa, A. Lerer, J. Bradbury, G. Chanan, T. Killeen, Z. Lin, N. Gimelshein, L. Antiga, et al. (2019) PyTorch: an imperative style, high-performance deep learning library. In NeurIPS, Cited by: §4.
  • F. Sadeghi and S. Levine (2016) Cad2rl: real single-image flight without a single real image. arXiv preprint arXiv:1611.04201. Cited by: §2.3.1.
  • J. Schulman, F. Wolski, P. Dhariwal, A. Radford, and O. Klimov (2017) Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347. Cited by: §1, §2.3.3, §2.4, §3, §4, §5.1.
  • D. Silver, J. Schrittwieser, K. Simonyan, I. Antonoglou, A. Huang, A. Guez, T. Hubert, L. Baker, M. Lai, A. Bolton, et al. (2017) Mastering the game of go without human knowledge. Nature 550 (7676), pp. 354–359. Cited by: §1.
  • K. Sohn, D. Berthelot, C. Li, Z. Zhang, N. Carlini, E. D. Cubuk, A. Kurakin, H. Zhang, and C. Raffel (2020) FixMatch: simplifying semi-supervised learning with consistency and confidence. arXiv:2001.07685. Cited by: §2.2.
  • A. Srinivas, M. Laskin, and P. Abbeel (2020) CURL: contrastive unsupervised representations for reinforcement learning. arXiv:2004.04136. Cited by: Appendix D, §2.3.3, §4, §5.1, Table 1.
  • R. Sutton (1991) Dyna, an integrated architecture for learning, planning, and reacting. In ACM SIGART Bulletin, External Links: Link Cited by: §2.3.2, §5.1.
  • Y. Tassa, Y. Doron, A. Muldal, T. Erez, Y. Li, D. d. L. Casas, D. Budden, A. Abdolmaleki, J. Merel, A. Lefrancq, et al. (2018) Deepmind control suite. arXiv preprint arXiv:1801.00690. Cited by: 1st item, §5.1.
  • J. Tobin, R. Fong, A. Ray, J. Schneider, W. Zaremba, and P. Abbeel (2017) Domain randomization for transferring deep neural networks from simulation to the real world. In IROS, Cited by: §2.3.1.
  • S. v. d. Walt, S. C. Colbert, and G. Varoquaux (2011) The numpy array: a structure for efficient numerical computation. Computing in Science & Engineering 13 (2), pp. 22–30. Cited by: §4.
  • Q. Xie, Z. Dai, E. Hovy, M. Luong, and Q. V. Le (2019) Unsupervised data augmentation for consistency training. arXiv:1904.12848. Cited by: §1, §2.2.
  • Q. Xie, E. Hovy, M. Luong, and Q. V. Le (2020) Self-training with noisy student improves imagenet classification. In CVPR, Cited by: §1.
  • D. Yarats, A. Zhang, I. Kostrikov, B. Amos, J. Pineau, and R. Fergus (2019) Improving sample efficiency in model-free reinforcement learning from images. arXiv preprint arXiv:1910.01741. Cited by: Appendix D, §5.1, Table 1.
  • S. Zagoruyko and N. Komodakis (2017) Paying more attention to attention: improving the performance of convolutional neural networks via attention transfer. In ICLR, Cited by: §5.2.

Appendix A Additional Results in ProcGen

(a) Train (100)
(b) Test (100)
(c) Train (200)
(d) Test (200)
Figure 7: Learning curves of PPO and RAD agents trained with (a/b) 100 and (c/d) 200 training levels on StarPilot. The solid line and shaded regions represent the mean and standard deviation, respectively, across three runs.
(a) Train (100)
(b) Test (100)
(c) Train (200)
(d) Test (200)
Figure 8: Learning curves of PPO and RAD agents trained with (a/b) 100 and (c/d) 200 training levels on Bigfish. The solid line and shaded regions represent the mean and standard deviation, respectively, across three runs.
(a) Train (100)
(b) Test (100)
(c) Train (200)
(d) Test (200)
Figure 9: Learning curves of PPO and RAD agents trained with (a/b) 100 and (c/d) 200 training levels on Jumper. The solid line and shaded regions represent the mean and standard deviation, respectively, across three runs.

Figure 10: Learning curves of PPO and RAD in the modified Coinrun. The solid line and shaded regions represent the mean and standard deviation, respectively, across three runs.

Appendix B Environment Descriptions in ProcGen

BigFish. In this environment, the agent starts as a small fish and the goal is to eat fish smaller than itself. The agent can receive a small reward for eating fishes and a large reward is given when it becomes bigger that all other fish. The spawn timing, position of all fish, and style of background are changing throughout the level.

StarPilot. A simple side scrolling shooter game, where the agent receive the reward by avoiding enemy. The spawn timing of all enemies and obstacles, along with their corresponding types, are changing throughout the level.

Jumper. An open world environment, where the goal is to find the carrot which is randomly located in the map. Style of background, location of enemy and map structure are changing throughout the level.

Modified CoinRun. In this task, an agent is located at the leftmost side of the map and the goal is to collect the coin located at the rightmost side of the map within 1,000 timesteps. The agent observes its surrounding environment in the third-person point of view, where the agent is always located at the center of the observation. Similar to [28], half of the available themes are utilized (i.e., style of backgrounds, floors, agents, and moving obstacles) for training.

Appendix C Implementation Details for ProcGen

For ProcGen experiments, we follow the hyperparameters proposed in [7], which are empirically shown to be effective. Specifically, we use the CNN architecture found in IMPALA [12] as the policy network, and train the agents using the Proximal Policy Optimization (PPO) with following hyperparameters:


Hyperparameter Value
Observation rendering
Discount
GAE parameter 0.95
# of timesteprs per rollout 256
# of minibatches per rollout 8
Entropy bonus 0.1
PPO clip range 0.2
Reward Normalization Yes
# of Workers 1
# of environments per worker 64
Total timesteps 20M
LSTM No
Frame Stack No
Optimizer Adam
Learning rate ()
Table 3: Hyperparameters used for ProcGen experiments.

Appendix D Implementation Details for DMControl

For DMControl experiments, we utilize the same encoder architecture as in [36] which is similar to the architecture in [43]. We show a full list of hyperparameters for DMControl experiments in Table 4.


Hyperparameter Value
Random crop True
Observation rendering
Observation downsampling
Replay buffer size
Initial steps
Stacked frames
Action repeat finger, spin; walker, walk
cartpole, swingup
otherwise
Hidden units (MLP)
Evaluation episodes
Optimizer Adam
Learning rate cheetah, run
otherwise
Learning rate ()

Batch Size
(cheetah), 128 (rest)
function EMA
Critic target update freq
Convolutional layers
Number of filters
Non-linearity ReLU
Encoder EMA
Latent dimension
Discount
Initial temperature
Table 4: Hyperparameters used for DMControl experiments. Most hyperparameters values are unchanged across environments with the exception for action repeat, learning rate, and batch size.

Appendix E Code for Select Augmentations

def random_crop(imgs, size=84):
    n, c, h, w = imgs.shape
    w1 = torch.randint(0, w - size + 1, (n,))
    h1 = torch.randint(0, h - size + 1, (n,))
    cropped = torch.empty((n, c, size, size),
        dtype=imgs.dtype, device=imgs.device)
    for i, (img, w11, h11) in enumerate(zip(imgs, w1, h1)):
        cropped[i][:] = img[:, h11:h11 + size, w11:w11 + size]
    return cropped
def random_cutout(imgs, min_cut=4, max_cut=24):
    n, c, h, w = imgs.shape
    w_cut = torch.randint(min_cut, max_cut + 1, (n,))  # random size cut
    h_cut = torch.randint(min_cut, max_cut + 1, (n,))  # rectangular shape
    fills = torch.randint(0, 255, (n, c, 1, 1))  # assume uint8.
    for img, wc, hc, fill in zip(imgs, w_cut, h_cut, fills):
        w1 = torch.randint(w - wc + 1, ())  # uniform over interior
        h1 = torch.randint(h - hc + 1, ())
        img[:, h1:h1 + hc, w1:w1 + wc] = fill
    return imgs
def random_flip(imgs, p=0.5):
    n, _, _, _ = imgs.shape
    flip_mask = torch.rand(n, device=imgs.device) < p
    imgs[flip_mask] = imgs[flip_mask].flip([3])  # left-right
    return imgs

Appendix F Time-efficiency of Data Augmentation

The primary gain of our data augmentation modules is enabling efficient augmentation of stacked frame inputs in the minibatch setting. Since the augmentations must be applied randomly across the batch but consistently across the frame stack, traditional frameworks like Tensorflow and PyTorch that focus on augmenting single-frame static datasets, are unsuitable for this task. We further

Ours PyTorch
Crop 31.8 33.5
Grayscale 15.6 51.2
Cutout 36.6 -
Cutout color 45.2 -
Flip 4.9 37.0
Rotate 46.5 62.4
Random Conv. 45.8 -

Table 5: We compare the data augmentation speed between the RAD augmentation modules and performing the same augmentations in PyTorch. We calculate the number of additional minutes required to perform 100k training steps. On average, the RAD augmentations are nearly 2x faster than augmentations accessed through the native PyTorch API. Additionally, since the PyTorch API is meant for processing single-frame images, it is not designed to apply augmentations consistently across the frame stack but randomly across the batch.