Mask-based Latent Reconstruction for Reinforcement Learning

by   Tao Yu, et al.

For deep reinforcement learning (RL) from pixels, learning effective state representations is crucial for achieving high performance. However, in practice, limited experience and high-dimensional input prevent effective representation learning. To address this, motivated by the success of masked modeling in other research fields, we introduce mask-based reconstruction to promote state representation learning in RL. Specifically, we propose a simple yet effective self-supervised method, Mask-based Latent Reconstruction (MLR), to predict the complete state representations in the latent space from the observations with spatially and temporally masked pixels. MLR enables the better use of context information when learning state representations to make them more informative, which facilitates RL agent training. Extensive experiments show that our MLR significantly improves the sample efficiency in RL and outperforms the state-of-the-art sample-efficient RL methods on multiple continuous benchmark environments.



page 1

page 2

page 3

page 4


Reinforcement Learning with Prototypical Representations

Learning effective representations in image-based environments is crucia...

Learning Temporally-Consistent Representations for Data-Efficient Reinforcement Learning

Deep reinforcement learning (RL) agents that exist in high-dimensional s...

Reinforcement Learning with Neural Radiance Fields

It is a long-standing problem to find effective representations for trai...

Bootstrap Latent-Predictive Representations for Multitask Reinforcement Learning

Learning a good representation is an essential component for deep reinfo...

Deep Reinforcement Learning with Graph-based State Representations

Deep RL approaches build much of their success on the ability of the dee...

Seeking Visual Discomfort: Curiosity-driven Representations for Reinforcement Learning

Vision-based reinforcement learning (RL) is a promising approach to solv...

Accelerating Representation Learning with View-Consistent Dynamics in Data-Efficient Reinforcement Learning

Learning informative representations from image-based observations is of...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Learning effective state representations is crucial for reinforcement learning (RL) from vision signals (where a sequence of images is the input to an RL network), such as DeepMind Control Suite (Tassa et al., 2018). Inspired by the success of masked pre-training in the fields of Neural Language Processing (NLP) (Devlin et al., 2018; Radford et al., 2018, 2019; Brown et al., 2020)

and Computer Vision (CV) 

(Bao et al., 2021; He et al., 2021; Xie et al., 2021), we make the first endeavor to explore the idea of mask-based reconstruction in RL.

Masked pre-training aims to exploit the reconstruction of masked word embeddings or pixels to promote feature learning in NLP or CV. This is in fact not straightforwardly applicable for RL due to the following reasons. 1) RL agents learn from interactions with environments, where the experienced states vary as the policy network is updated. Intuitively, collecting additional rollouts for pre-training is often costly in the real-world applications. Besides, it is challenging to learn effective state representations without the awareness on the learned policy. 2) Vision signals are commonly of high information densities, which may contain distractions and redundancies for the policy learning. Thus, for RL, performing reconstruction in the original (pixel) space is not as necessary as it is in the CV or NLP domain.

Based on the analysis above, we study the masked modeling tailored to vision-based RL. We present Mask-based Latent Reconstruction (MLR), a simple yet effective self-supervised method, to better learn state representations in RL. Contrary to treating masked modeling as a pre-training task in the fields of CV and NLP, our proposed MLR is an auxiliary objective optimized together with the policy learning objectives. In this way, the coordination between representation learning and policy learning are considered within a joint training framework. Apart from this, another key difference compared to vision/language research is that we reconstruct masked pixels in the latent space instead of the input space, where we take the state representations (i.e., features) inferred from original unmasked frames as the reconstruction targets. This effectively reduces unnecessary reconstruction relative to the pixel-level one, and further facilitates the coordination between representation learning and policy learning because the state representations are directly optimized.

Consecutive frames are highly correlated. In MLR, we exploit this property to enable the learned state representations to be more informative, predictive and consistent over both spatial and temporal dimensions. Specifically, we randomly mask a portion of space-time cubes in the input observation (i.e., video clip) sequence and reconstruct the missing contents in the latent space. In this way, similar to the spatial reconstruction for images in (He et al., 2021; Xie et al., 2021), MLR enhances the awareness of the agents to the global context information of the entire input observations and promotes the state representations to be predictive in both spatial and temporal dimensions. The predictive global information is encouraged to be encoded into each frame-level state representation, which achieves better representation learning and further facilitates the policy learning.

Not only is an effective masked modeling method proposed, we also conduct a systematical empirical study for the practices of masking and reconstruction that are as applicable to RL as possible. First, we study the influence of masking strategies by comparing spatial masking, temporal masking and space-time masking. Second, we investigate the differences between masking and reconstructing in the pixel space and in the latent space. Finally, we study how to effectively add reconstruction supervisions in the latent space.

Our contributions are summarized below:

  • [noitemsep,nolistsep,leftmargin=*]

  • We introduce the idea of enhancing representation learning by mask-based reconstruction to RL for improving the sample efficiency. We integrate the mask-based reconstruction into RL training with an auxiliary objective, obviating the need for collecting additional rollouts for pre-training and helping the coordination between representation learning and policy learning in RL.

  • We propose Mask-based Latent Reconstruction (MLR), a self-supervised masked modeling method to improve the state representations for RL. Tailored to RL, we propose to randomly mask space-time cubes in the pixel space and reconstruct the missing content from the unmasked pixels in the latent space. This is demonstrated to be effective for improving the sample efficiency on multiple continuous benchmark environments.

  • A systematical empirical study is conducted to investigate the good practices of masking and reconstructing operations in MLR for RL. This demonstrates the effectiveness of our proposed designs in MLR.

2 Related Work

2.1 Representation Learning for RL

Reinforcement learning from vision signals is of high practical values in real-world applications such as robotics, video game AI, etc. However, such high-dimensional observations may contain distractions or redundant information, imposing considerable challenges for RL agents to learn effective representations (Shelhamer et al., 2017)

. Many prior works address this challenge by taking advantages of self-supervised learning to promote the representation learning of the states in RL. A popular approach is to jointly learn policy learning objectives and auxiliary objectives such as pixel reconstruction

(Shelhamer et al., 2017; Yarats et al., 2019), reward prediction (Jaderberg et al., 2016; Shelhamer et al., 2017), bisimulation (Zhang et al., 2021), dynamics prediction (Shelhamer et al., 2017; Guo et al., 2020; Lee et al., 2020a, b; Schwarzer et al., 2021a; Yu et al., 2021) and contrastive learning of instance discrimination (Laskin et al., 2020a) or (spatial -) temporal discrimination (Oord et al., 2018; Anand et al., 2019; Stooke et al., 2020; Zhu et al., 2020; Mazoure et al., 2020). Another feasible way for acquiring good representations is to pre-train the state encoder to learn effective representations for the original observations before policy learning. It requires additional offline sample collection or early access to the environments (Hansen et al., 2019; Stooke et al., 2020; Liu and Abbeel, 2021b, a; Schwarzer et al., 2021b) , which is not fully consistent with the principle of sample efficiency in practice. This work aims to design a more effective auxiliary task to improve the learned state representations towards sample-efficient RL.

2.2 Sample-Efficient Reinforcement Learning

Collecting rollouts from the interaction with the environment is commonly costly especially in the real world, leaving the sample efficiency of RL algorithms widely concerned. To improve the sample efficiency of vision-based RL (i.e., RL from pixel observations), recent works design auxiliary tasks to explicitly improve the learned representations (Yarats et al., 2019; Laskin et al., 2020a; Lee et al., 2020a, b; Zhu et al., 2020; Liu et al., 2021; Schwarzer et al., 2021a; Ye et al., 2021; Yu et al., 2021), or adopt data augmentation techniques, such as random cropping or shifting, to improve the diversity of data used for training based on collected samples (Yarats et al., 2021; Laskin et al., 2020b). Besides, there are some model-based methods that learn (world) models in the pixel (Kaiser et al., 2020) or latent space (Hafner et al., 2019, 2020a, 2020b; Ye et al., 2021), and perform planning, imagination or policy learning based on the learned models. We focus on the auxiliary task line in this work.

2.3 Masked Language/Image Modeling

Masked Language Modeling (MLM) (Devlin et al., 2018) and its autoregressive variants (Radford et al., 2018, 2019; Brown et al., 2020)

achieve significant success in the NLP field and produce impacts in other domains. MLM masks a portion of word tokens from the input sentence and trains the model to predict the masked tokens, which has been demonstrated generally effective in learning language representations for various downstream tasks. For computer vision (CV) tasks, similar to MLM, masked image modeling (MIM) learning representations for images/videos by pre-training the neural network to reconstruct masked pixels from visible ones. As an early exploration, the Context Encoder

(Pathak et al., 2016)

apply this idea to Convolutional Neural Network (CNN) model to train a CNN model with a masked region inpainting task. With the recent popularity of the transfomer-based architectures, a series of works

(Chen et al., 2020; Bao et al., 2021; He et al., 2021; Xie et al., 2021; Wei et al., 2021) dust off the idea of MIM and show impressive performance on learning representations for vision tasks. Inspired by MLM and MIM, we explore the masked modeling for RL to exploit the high correlation in vision data to improve the awareness of agents to global-scope dynamics in learning state repesentations. Most importantly, we propose to predict the masked contents in the latent space, instead of the pixel space like aforementioned MIM works, which better coordinates the representation learning and the policy learning in RL.

Figure 1: The framework of our method MLR. We perform a random spatial-temporal masking (i.e., cube masking) on the sequence of consecutive observations in the pixel space. The mased observations are encoded to be the latent states with an online encoder. We further introduce a predictive latent decoder to decode/predict the latent states conditioned on the corresponding action sequence and a temporal positional embedding. Our method trains the networks to reconstruct the missing contents in an appropriate latent

space using a cosine similarity based distance metric applied between the predicted features of the reconstructed states and the target features inferred from original observations by momentum networks.

3 Approach

3.1 Background

Vision-based RL aims to learn policies from interactions with the observations composed of pixels. The learning process corresponds to a Markov Decision Process (MDP)

(Bellman, 1957; Kaelbling et al., 1998). Generally, at the timestep , given the observation from the environment, the RL agent responses to this observation by taking an action . When the environment receives , it will transfer to the next observation

with a probability and return a reward

. Following the common practice in (Mnih et al., 2013), we encode several consecutive observations as a state (Bellman, 1957)

, which is a feature vector in the latent space. And the reward function and the transition dynamics can be written as

and , respectively. The objective of RL is to learn a policy that maximizes the cumulative discounted return , where is the discount factor.

Our proposed method is theoretically applicable for different RL algorithms. Following the common practices in prior sample-efficient RL studies (Yarats et al., 2019; Lee et al., 2020a; Laskin et al., 2020a; Yarats et al., 2021; Laskin et al., 2020b; Yu et al., 2021), we show its effectiveness based on a strong model-free RL agent Soft Actor Critic (SAC) in this paper. The detailed introduction of SAC can be found in Appendix A.

3.2 Mask-based Latent Reconstruction

Mask-based Latent Reconstruction (MLR) is an auxiliary objective to promote the learned representations from pixel-based observations in RL, towards sample-efficient RL. The core idea of MLR is to facilitate the state representation learning by reconstructing spatially and temporally masked pixels in the latent space. This mechanism enables the better use of context information when learning state representations, further enhancing the understanding of RL agents for vision signals. We illustrate the overall framework of MLR in Figure 1 and elaborate it below.

Framework. In MLR, as shown in Figure 1, we mask a portion of pixels in the input observation sequence along its spatial and temporal dimensions. We encode the masked sequence and the original sequence from observations to states with an encoder and a momentum encoder, respectively. Taking the states encoded from the original sequence as the target, we perform predictive reconstruction from the states corresponding to the masked sequence. We add reconstruction supervisions between prediction results and targets in the decoded latent space. The processes of masking, encoding, decoding and reconstruction are introduced in detail below.

(i) Masking. Given an observation sequence of timesteps , with the shape of , the frames in the observations are stacked to be a cuboid. As illustrated in Figure 2, we divide the cuboid into regular non-overlapping cubes with the shape of

. We then randomly mask a portion of cubes following a uniform distribution and obtain a masked observation sequence

. Following (Yu et al., 2021), we perform stochastic image augmentation (e.g., random crop and intensity) on each masked observation in . The objective of MLR is to predict the state representations of unmasked observation sequence from the masked one in the latent space.

Figure 2: Illustration of our cube masking. We divide the input observation sequence to non-overlapping cubes (). In this example, we have , and where the observation sequence has timesteps and a spatial size of .

(ii) Encoding. We adopt two encoders to learn state representations from masked observations and original observations respectively. A regular CNN-based encoder is used to encode each masked observation into its corresponding state . After the encoding, we obtain a sequence of the masked latent states for masked observations. The parameters of this encoder are updated based on gradient back-propagation in an end-to-end way. We thus call it “online” encoder. The state representations inferred from original observations are taken as the targets of subsequently described reconstruction. To make them more robust, inspired by (Laskin et al., 2020a; Schwarzer et al., 2021a; Yu et al., 2021), we exploit another encoder for the encoding of original observations. This encoder, called “Momentum Encoder” as in Figure 1, has the same architecture as the online encoder, and its parameters are updated by an exponential moving average (EMA) of the online encoder weights with the momentum coefficient , as formulated below:


(iii) Decoding. Similar to the encoder designs in (He et al., 2021; Xie et al., 2021), the online encoder in our proposed MLR reconstructs the representations of masked contents based on visible ones in an implicit way. The predictive global information has been implicitly included into the outputs of the online encoder. To add the reconstruction loss between the state representations inferred from masked and original observations, we need compare them in a unified latent space. To this end, we leverage a Transformer-based latent decoder to refine the outputs of the online encoder via a global message passing where the actions and temporal information are exploited as the contexts. Through this process, the implicitly predicted information is “passed” to its corresponding states for adding reconstruction losses.

As shown in Figure 3, the input tokens of the latent decoder consist of both the masked state sequence (i.e., state tokens) and the corresponding action sequence (i.e., action tokens). Each action token is embedded as a feature vector with the same dimension as the stated token using an embedding layer. We integrate the relative positional embeddings as in (Vaswani et al., 2017) to encode relative temporal positional information into both state and action tokens. Notably, the state and action token at the same timestep share the same positional embedding . Thus, the inputs of latent decoder can be mathematically represented as:


The input token sequence is passed through a Transformer encoder (Vaswani et al., 2017) consisting of attention layers. Each layer is composed of a Multi-Headed Self-Attention (MSA) layer (Vaswani et al., 2017), a layer normalisation (LN) (Ba et al., 2016)

, and multilayer perceptron (MLP) blocks. The process can be described as follows:


The output tokens of the latent decoder, represented as , are the predictive reconstruction results for the latent representations inferred from original observations. We elaborate the reconstruction loss between the prediction results and corresponding targets in the following.

Figure 3: Illustration of predictive latent decoder.

(iv) Reconstruction loss. Motivated by the success of BYOL (Grill et al., 2020) in self-supervised learning, we use an asymmetric architecture for calculating the distance between the predicted/reconstructed latent states and the target states, similar to (Schwarzer et al., 2021a; Yu et al., 2021). For the outputs of the latent decoder, we use a projection head and a prediction head to get the final prediction result corresponding to . For the encoded results of original observations, we use a momentum updated projection head whose weights are updated with an EMA of the weights of the online projection head. These two projection heads have the same architectures. The outputs of the momentum projection head , i.e., , are the final reconstruction targets. Here, we apply a stop-gradient operation as illustrated in Figure 1 to avoid model collapse, following (Grill et al., 2020).

The objective of MLR is to enforce the final prediction result to be as close as possible to its corresponding target . To achieve this, we design the reconstruction loss in our proposed MLR by calculating the cosine similarity between and , which can be formulated below:


The loss is used to update the parameters of the online networks including encoder , predictive latent decoder , projection head and prediction head . Through our proposed self-supervised auxiliary objective MLR, the learned state representations by the encoder will be more informative, thus can further facilitate the policy learning.


The proposed MLR is an auxiliary task, which is optimized together with the policy learning. Thus, the overall loss function

for RL agent training is:


where and are the loss functions of the base RL agent (e.g., SAC (Haarnoja et al., 2018)) and the proposed mask-based latent reconstruction, respectively.

is a hyperparameter for balancing the two terms. Notably, the agent of vision-based RL commonly consists of two parts,

i.e., the (state) representation network (i.e., encoder) and the policy learning network. The encoder of MLR is taken as the representation network to encode observations into the state representations for RL training, where the latent decoder is not adopted because unmasked observations are taken as the inputs. More details can be found in Appendix B.

4 Experiment

4.1 Setup

We evaluate our proposed MLR on the most popular vision-based continuous benchmark environments from DeepMind Control Suite (DMControl) (Tassa et al., 2018). Following the representative previous works (Laskin et al., 2020a; Yarats et al., 2021; Laskin et al., 2020b; Yu et al., 2021), we choose six commonly used environments from DMControl, i.e., Finger, spin (Fing.spin); Cartpole, swingup (Cart.swin); Reacher, easy (Reac.easy); Cheetah, run (; Walker, walk (Walk.walk) and Ball in cup, catch (Ball.catch) for evaluation. We measure the performance of RL agents with the mean and median scores over 10 episodes at 100k and 500k environment steps as the test results, (referred to as DMControl-100k and DMControl-500k benchmarks), respectively. Concretely, DMControl-100k is used for measuring sample efficiency performance while DMControl-500k is used for measuring asymptotic performance. The score of each environment ranges from 0 to 1000 (Tassa et al., 2018). Unless otherwise specified, Soft Actor Critic (SAC) (Haarnoja et al., 2018) is taken as the base agent for effectiveness evaluation as it is in many previous work (Yarats et al., 2019; Lee et al., 2020a; Laskin et al., 2020a; Yarats et al., 2021; Laskin et al., 2020b; Yu et al., 2021).

In our experiments, we denote the base agent trained only by RL loss (as in Equation 6) as Baseline, while denoting the model of applying our proposed MLR to the base agent as MLR for the brevity. As shown in Equation 6, we set a weight to balance and so that the gradients of these two loss items lie in a similar range and empirically find works well. In MLR, by default, we set the length of a sampled trajectory to 16, mask ratio to 50% and the size of the masked cube () to (except for in Cart.swin and Reac.easy due to their large motion range). More implementation details can be found in Appendix B.

100k Step Scores PlaNet Dreamer SAC+AE SLAC CURL DrQ PlayVirtual Baseline MLR
Finger, spin 136 216 341 70 740 64 693 141 767 56 901 104 915 49 853 112 907 58
Cartpole, swingup 297 39 326 27 311 11 - 582 146 759 92 816 36 784 63 806 48
Reacher, easy 20 50 314 155 274 14 - 538 233 601 213 785 142 593 118 866 103
Cheetah, run 138 88 235 137 267 24 319 56 299 48 344 67 474 50 399 80 482 38
Walker, walk 224 48 277 12 394 22 361 73 403 24 612 164 460 173 424 281 643 114
Ball in cup, catch 0 0 246 174 391 82 512 110 769 43 913 53 926 31 648 287 933 16
Mean 135.8 289.8 396.2 471.3 559.7 688.3 729.3 616.8 772.8
Median 137.0 295.5 351.0 436.5 560.0 685.5 800.5 620.5 836.0
500k Step Scores
Finger, spin 561 284 796 183 884 128 673 92 926 45 938 103 963 40 944 97 973 31
Cartpole, swingup 475 71 762 27 735 63 - 841 45 868 10 865 11 871 4 872 5
Reacher, easy 210 390 793 164 627 58 - 929 44 942 71 942 66 943 52 957 41
Cheetah, run 305 131 570 253 550 34 640 19 518 28 660 96 719 51 602 67 674 37
Walker, walk 351 58 897 49 847 48 842 51 902 43 921 4575 928 30 818 263 939 10
Ball in cup, catch 460 380 879 87 794 58 852 71 959 27 963 9 967 5 960 10 964 14
Mean 393.7 782.8 739.5 751.8 845.8 882.0 897.3 856.3 896.5
Median 405.5 794.5 764.5 757.5 914.0 929.5 935.0 907.0 948.0
Table 1: Comparison with the state-of-the-art methods and Baseline

on DMControl-100k and DMControl-500k benchmarks. Scores (mean and standard deviation) on each environment are averaged over 10 random seeds. Our method augments

Baseline with the proposed MLR objective (denoted as MLR).

4.2 Comparison with State-of-the-Arts

In this section, we compare our proposed MLR  with the state-of-the-art (SOTA) sample-efficient RL methods proposed for continuous control, including PlaNet (Hafner et al., 2019), Dreamer (Hafner et al., 2020a), SAC+AE (Yarats et al., 2019), SLAC (Lee et al., 2020a), CURL (Laskin et al., 2020a), DrQ (Yarats et al., 2021) and PlayVirtual (Yu et al., 2021).

The comparison results are shown in Table 1, and all results are averaged over 10 repetitive experiments with different random seeds. From the results shown in Table 1, we can observe that: (i) The proposed MLR significantly improves Baseline in both sample efficiency (i.e., DMControl-100k) and asymptotic performance (i.e., DMControl-500k), and achieves consistent gains relative to Baseline across all enviroments. It is worthy to mention that, for the DMControl-100k, our proposed method outperforms Baseline by 25.3% and 34.7% in mean and median scores, respectively. This demonstrates the superiority of MLR in improving the sample efficiency of RL algorithms. (ii) The RL agent equipped with our proposed MLR outperforms most state-of-the-art methods on the DMControl-100k and DMControl-500k. Specifically, our method surpasses the best previous method (i.e., PlayVirtual) by 43.5 and 35.5 in mean and median scores respectively on DMControl-100k. Besides, our method delivers the best median score and reaches a comparable mean score with the strongest SOTA method on DMControl-500k. Note that PlayVirtual (Yu et al., 2021) generates virtual trajectories for RL training rather than designing auxiliary task to improve the representations like what we focus on in this paper. We are in fact complementary theoretically.

Figure 4: Test performance during the training period (500k environment steps). Lines denote the mean scores over 10 random seeds, and the shadows are the corresponding standard deviations. In most environments on DMControl, our results (blue lines) are consistently better than Baseline (orange lines).

4.3 Ablation Study

Effectiveness evaluation. Besides the comparison with SOTA methods, we demonstrate the effectiveness of our proposed MLR by studying its improvements compared to our Baseline. The numerical results are presented in Table 1, while the curves of test performance during the training process are given in Figure 4. Both numerical results and test performance curves can demonstrate that our method obviously outperforms Baseline across different environments thanks to more informative representations learned by MLR.

Masking strategy. We compare three design choices of the masking operation: (i) Spatial masking (denoted as MLR-S): we randomly mask patches for each frame independently. (ii) Temporal masking (denoted as MLR-T): we divide the input observation sequence into multiple segments along the temporal dimension and mask out a portion of segments randomly. (Here, the segment length is set to be equal to the temporal length of cube, i.e., .) And (iii) Spatial-temporal masking (also referred to as “cube masking”): as aforementioned and illustrated in Figure 2, we rasterize the observation sequence into non-overlapping cubes and randomly mask a portion of them. Except for the differences described above, other configurations for masking remain the same as our proposed spatial-temporal (i.e., cube) masking.

The experiment results of this ablation study on DMControl-100k are presented in Table 2. From such results, we have the following observations: (i) All three masking strategies (i.e., MLR-S, MLR-T and MLR) achieve mean score improvements compared to Baseline by 18.5%, 12.2% and 25.0%, respectively, and achieve median score improvements by 23.4%, 25.0% and 35.9%, respectively. This demonstrates the effectiveness of the core idea of introducing mask-based reconstruction to improve the representation learning of RL. (ii) Spatial-temporal masking is the most effective strategy over these three design choices. This strategy matches better with the nature of video data due to its spatial-temporal continuity in masking. It encourages the state representations to be more predictive and consistent along the spatial and temporal dimensions, thus conducive to facilitating the policy learning in RL.

Env. Baseline MLR-S MLR-T MLR
Fing.spin 822 146 919 55 787 139 907 69
Cart.swin 782 74 665 118 829 33 791 50
Reac.easy 557 137 848 82 745 84 875 92 438 33 449 46 443 43 495 13
Walk.walk 414 310 556 189 393 202 597 102
Ball.catch 669 310 927 6 934 29 939 9
Mean 613.7 727.3 688.5 767.3
Median 613.0 756.5 766.0 833.0
Table 2: Ablation study of masking strategy on DMControl-100k benchmark. We compare three masking strategies: spatial masking (MLR-S), temporal masking (MLR-T) and spatial-temporal masking (MLR). The results are averaged over 5 random seeds.

Reconstruction target. In masked language/image modeling, reconstruction/prediction is commonly performed in the original signal space, such as word embeddings or RGB pixels. To study the influence of reconstruction targets for the task of RL, we compare two different reconstruction spaces: (i) Pixel space reconstruction (denoted as MLR-Pixel): we predict the masked contents directly by reconstructing original pixls, like the practices in CV and NLP domains; (ii) Latent space reconstruction (i.e., MLR): we reconstruct the state representations (i.e., features) of original observations from masked observations, as we proposed in MLR. Table 3 shows the comparison results. The reconstruction in the latent space is superior to that in the pixel space in improving the sample efficiency in RL. As discussed in preceding sections, vision data might contain distractions and redundancies for the policy learning in RL, leaving the pixel-level reconstruction unnecessary. Besides, latent space reconstruction is more conducive to the coordination between the representation learning and the policy learning in RL, because the state representations are directly optimized.

Env. Baseline MLR-Pixel MLR
Fing.spin 822 146 782 95 907 69
Cart.swin 782 74 803 91 791 50
Reac.easy 557 137 787 136 875 92 438 33 346 84 495 13
Walk.walk 414 310 490 216 597 102
Ball.catch 669 310 675 292 939 9
Mean 613.7 647.2 767.3
Median 613.0 728.5 833.0
Table 3: Ablation study of reconstruction target on DMControl-100k benchmark. We compare two reconstruction targets, original pixels (denoted as MLR-Pixel) and momentum projections in the latent space (i.e., MLR). We run each models for 5 random seeds.

Mask ratio. In recent works of masked image modeling (MIM) (He et al., 2021; Xie et al., 2021), the mask ratio is found crucial for the final performance. We study the influences of different masking ratios for sample efficiency in Figure 5, and find that the ratio of 50% is an appropriate choice for our proposed MLR. An over-small value of this ratio could not eliminate redundancy, and make the objective easy to be reached by extrapolation from neighboring contents which is free of capturing and understanding semantics from vision signals. An over-large value leaves too few contexts for achieving the reconstruction goal. As discussed in (He et al., 2021; Xie et al., 2021), the choice of this ratio varies for different modalities and depends on the information density.

Figure 5: Ablation study of mask ratio. We run each model for 3 random seeds and report the average results.
DMControl-100k Baseline MLR w.o. ActTok MLR-F MLR-MoDec MLR
Finger, spin 841 150 843 34 845 137 905 67 904 70
Cartpole, swingup 785 72 822 8 807 30 75 808 40
Reacher, easy 631 55 840 58 828 171 808 57 880 82
Cheetah, run 431 31 403 82 469 49 472 178 487 9
Walker, walk 436 277 470 217 668 46 587 212 616 87
Ball in cup, catch 609 320 818 141 804 150 829 122 932 3
Mean 622.2 699.3 736.8 730.0 771.2
Median 620.0 820.0 805.5 793.5 844.0
Table 4: Ablation studies on action token, masking features and momentum decoder. MLR w.o. ActTok denotes removing the action tokens in the input tokens of the predictive latent decoder. MLR-F indicates performing masking on convolutuional feature maps. And MLR-MoDec indicates adding a momentum predictive latent decoder in the target networks. We run each models on DMControl-100k benchmark for 3 random seeds.

Decoder depth. We analyze the influence of using Transformer-based latent decoders of different depths. As the experimental results shown in Table 5, generally, deeper latent decoders lead to worse sample efficiency with lower mean and median scores. Notably, compared to the encoder (4.04M parameters), our decoder is lightweight (40.8K parameters). Similar to the designs in (He et al., 2021; Xie et al., 2021), it is appropriate to use a lightweight decoder in MLR, because we actually expect the predicting masked information to be mainly completed by the encoder instead of the decoder. Note that the state representations inferred by the encoder are the ones adopted for the policy learning in RL.

Layers Param. Mean Score Median Score
1 20.4K 726.8 719.0
2 40.8K 767.3 833.0
4 81.6K 766.2 789.5
8 163.2K 728.3 763.5
Table 5: Ablation study of predictive latent decoder depth. We run each model on DMControl-100k benchmark for 5 random seeds. We report number of parameters, mean score and median score.

Action token. We study the contributions of action tokens used for the latent decoder as illustrated in Figure 1 by discarding it from our proposed framework. The results are given in Table 4. Intuitively, prediction only from vision signals is of more or less ambiguity. Exploiting action tokens benefits reducing such ambiguity so that the gradients of less uncertainty can be obtained for updating the encoder.

Masking features. We compare “masking pixels” and “masking features” in Table 4. Masking features (denoted by MLR-F) does not perform equally well compared to masking pixels as proposed in MLR, but it still achieves significant improvements to Baseline.

Why not use a latent decoder for targets? We have also attempted to add a momentum updated latent decoder for the target states as illustrated in Figure 1. The results in Table 4 show that adding the momentum decoder leads to performance drops. This is because there is no information prediction for the state representation learning from the original observation sequence without masking. Thus, we do not need to refine the implicitly predicted information like that in the outputs of the online encoder.

Cube size and sequence length. These two factors can be viewed as hyperparameters. Their corresponding experimental analysis and results are in the Appendix C.

5 Conclusion

In this work, we make the first effort to introduce the popular mask-based reconstruction to RL for facilitating the policy learning by improving the learned state representations. We propose MLR, a simple yet effective self-supervised auxiliary objective to reconstruct the masked contents in the latent space. In this way, the learned state representations are encouraged to include richer and more informative features. Extensive experiments show that MLR achieves the state-of-the-art performance on DeepMind Control benchmarks and demonstrate its effectiveness. We conduct detailed ablation study for the proposed designs in MLR and analyze their differences from that in NLP and CV domains. We hope our proposed method can inspire further research for vision-based RL from the perspective of improving the representation learning. Moreover, the concept of the masked latent reconstruction is also worthy of being explored and extended in the fields of computer vision and neural language processing. We are looking forward to seeing more mutual promotion between different research fields.


  • A. Anand, E. Racah, S. Ozair, Y. Bengio, M. Côté, and R. D. Hjelm (2019) Unsupervised state representation learning in atari. In Advances in Neural Information Processing Systems, Cited by: §2.1.
  • J. L. Ba, J. R. Kiros, and G. E. Hinton (2016) Layer normalization. arXiv preprint arXiv:1607.06450. Cited by: §B.1, §3.2.
  • H. Bao, L. Dong, and F. Wei (2021) BEiT: bert pre-training of image transformers. arXiv preprint arXiv:2106.08254. Cited by: §1, §2.3.
  • R. Bellman (1957) A markovian decision process. Journal of mathematics and mechanics 6 (5), pp. 679–684. Cited by: §3.1.
  • T. B. Brown, B. Mann, N. Ryder, M. Subbiah, J. Kaplan, P. Dhariwal, A. Neelakantan, P. Shyam, G. Sastry, A. Askell, et al. (2020) Language models are few-shot learners. arXiv preprint arXiv:2005.14165. Cited by: §1, §2.3.
  • M. Chen, A. Radford, R. Child, J. Wu, H. Jun, D. Luan, and I. Sutskever (2020) Generative pretraining from pixels. In International Conference on Machine Learning, pp. 1691–1703. Cited by: §2.3.
  • J. Devlin, M. Chang, K. Lee, and K. Toutanova (2018) Bert: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805. Cited by: §1, §2.3.
  • J. Grill, F. Strub, F. Altché, C. Tallec, P. Richemond, E. Buchatskaya, C. Doersch, B. Avila Pires, Z. Guo, M. Gheshlaghi Azar, B. Piot, k. kavukcuoglu, R. Munos, and M. Valko (2020) Bootstrap your own latent - a new approach to self-supervised learning. In Advances in Neural Information Processing Systems, Cited by: §3.2.
  • Z. D. Guo, B. A. Pires, B. Piot, J. Grill, F. Altché, R. Munos, and M. G. Azar (2020) Bootstrap latent-predictive representations for multitask reinforcement learning. In International Conference on Machine Learning, pp. 3875–3886. Cited by: §2.1.
  • T. Haarnoja, A. Zhou, K. Hartikainen, G. Tucker, S. Ha, J. Tan, V. Kumar, H. Zhu, A. Gupta, P. Abbeel, et al. (2018) Soft actor-critic algorithms and applications. arXiv preprint arXiv:1812.05905. Cited by: Appendix A, §B.1, §3.2, §4.1.
  • D. Hafner, T. Lillicrap, J. Ba, and M. Norouzi (2020a) Dream to control: learning behaviors by latent imagination. In International Conference on Learning Representations, Cited by: §2.2, §4.2.
  • D. Hafner, T. Lillicrap, I. Fischer, R. Villegas, D. Ha, H. Lee, and J. Davidson (2019) Learning latent dynamics for planning from pixels. In International Conference on Machine Learning, pp. 2555–2565. Cited by: §2.2, §4.2.
  • D. Hafner, T. Lillicrap, M. Norouzi, and J. Ba (2020b) Mastering atari with discrete world models. arXiv preprint arXiv:2010.02193. Cited by: §2.2.
  • S. Hansen, W. Dabney, A. Barreto, T. Van de Wiele, D. Warde-Farley, and V. Mnih (2019) Fast task inference with variational intrinsic successor features. arXiv preprint arXiv:1906.05030. Cited by: §2.1.
  • K. He, X. Chen, S. Xie, Y. Li, P. Dollár, and R. Girshick (2021)

    Masked autoencoders are scalable vision learners

    arXiv preprint arXiv:2111.06377. Cited by: §1, §1, §2.3, §3.2, §4.3, §4.3.
  • M. Jaderberg, V. Mnih, W. M. Czarnecki, T. Schaul, J. Z. Leibo, D. Silver, and K. Kavukcuoglu (2016) Reinforcement learning with unsupervised auxiliary tasks. arXiv preprint arXiv:1611.05397. Cited by: §2.1.
  • L. P. Kaelbling, M. L. Littman, and A. R. Cassandra (1998) Planning and acting in partially observable stochastic domains. Artificial intelligence 101 (1-2), pp. 99–134. Cited by: §3.1.
  • Ł. Kaiser, M. Babaeizadeh, P. Miłos, B. Osiński, R. H. Campbell, K. Czechowski, D. Erhan, C. Finn, P. Kozakowski, S. Levine, A. Mohiuddin, R. Sepassi, G. Tucker, and H. Michalewski (2020) Model based reinforcement learning for atari. In International Conference on Learning Representations, Cited by: §2.2.
  • D. P. Kingma and J. Ba (2014) Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980. Cited by: §B.2.
  • M. Laskin, A. Srinivas, and P. Abbeel (2020a) Curl: contrastive unsupervised representations for reinforcement learning. In International Conference on Machine Learning, pp. 5639–5650. Cited by: §B.1, §2.1, §2.2, §3.1, §3.2, §4.1, §4.2.
  • M. Laskin, K. Lee, A. Stooke, L. Pinto, P. Abbeel, and A. Srinivas (2020b) Reinforcement learning with augmented data. In Advances in Neural Information Processing Systems, Cited by: §B.2, §2.2, §3.1, §4.1.
  • A. X. Lee, A. Nagabandi, P. Abbeel, and S. Levine (2020a) Stochastic latent actor-critic: deep reinforcement learning with a latent variable model. In Advances in Neural Information Processing Systems, Cited by: §2.1, §2.2, §3.1, §4.1, §4.2.
  • K. Lee, I. Fischer, A. Liu, Y. Guo, H. Lee, J. Canny, and S. Guadarrama (2020b) Predictive information accelerates learning in RL. arXiv preprint arXiv:2007.12401. Cited by: §2.1, §2.2.
  • G. Liu, C. Zhang, L. Zhao, T. Qin, J. Zhu, L. Jian, N. Yu, and T. Liu (2021) Return-based contrastive representation learning for reinforcement learning. In International Conference on Learning Representations, Cited by: §2.2.
  • H. Liu and P. Abbeel (2021a) Aps: active pretraining with successor features. In International Conference on Machine Learning, pp. 6736–6747. Cited by: §2.1.
  • H. Liu and P. Abbeel (2021b) Behavior from the void: unsupervised active pre-training. arXiv preprint arXiv:2103.04551. Cited by: §2.1.
  • B. Mazoure, R. Tachet des Combes, T. L. DOAN, P. Bachman, and R. D. Hjelm (2020) Deep reinforcement and infomax learning. In Advances in Neural Information Processing Systems, Cited by: §2.1.
  • V. Mnih, K. Kavukcuoglu, D. Silver, A. Graves, I. Antonoglou, D. Wierstra, and M. Riedmiller (2013) Playing atari with deep reinforcement learning. arXiv preprint arXiv:1312.5602. Cited by: §3.1.
  • A. v. d. Oord, Y. Li, and O. Vinyals (2018) Representation learning with contrastive predictive coding. arXiv preprint arXiv:1807.03748. Cited by: §2.1.
  • D. Pathak, P. Krahenbuhl, J. Donahue, T. Darrell, and A. A. Efros (2016) Context encoders: feature learning by inpainting. In

    Proceedings of the IEEE conference on computer vision and pattern recognition

    pp. 2536–2544. Cited by: §2.3.
  • A. Radford, K. Narasimhan, T. Salimans, and I. Sutskever (2018) Improving language understanding by generative pre-training. Cited by: §1, §2.3.
  • A. Radford, J. Wu, R. Child, D. Luan, D. Amodei, I. Sutskever, et al. (2019) Language models are unsupervised multitask learners. OpenAI blog 1 (8), pp. 9. Cited by: §1, §2.3.
  • M. Schwarzer, A. Anand, R. Goel, R. D. Hjelm, A. Courville, and P. Bachman (2021a) Data-efficient reinforcement learning with self-predictive representations. In International Conference on Learning Representations, Cited by: §B.2, §2.1, §2.2, §3.2, §3.2.
  • M. Schwarzer, N. Rajkumar, M. Noukhovitch, A. Anand, L. Charlin, D. Hjelm, P. Bachman, and A. Courville (2021b) Pretraining representations for data-efficient reinforcement learning. arXiv preprint arXiv:2106.04799. Cited by: §2.1.
  • E. Shelhamer, P. Mahmoudieh, M. Argus, and T. Darrell (2017) Loss is its own reward: self-supervision for reinforcement learning. ArXiv abs/1612.07307. Cited by: §2.1.
  • A. Stooke, K. Lee, P. Abbeel, and M. Laskin (2020) Decoupling representation learning from reinforcement learning. arXiv preprint arXiv:2009.08319. Cited by: §2.1.
  • Y. Tassa, Y. Doron, A. Muldal, T. Erez, Y. Li, D. d. L. Casas, D. Budden, A. Abdolmaleki, J. Merel, A. Lefrancq, et al. (2018) Deepmind control suite. arXiv preprint arXiv:1801.00690. Cited by: §B.3, §1, §4.1.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin (2017) Attention is all you need. In Advances in neural information processing systems, pp. 5998–6008. Cited by: §B.1, §B.2, §3.2.
  • C. Wei, H. Fan, S. Xie, C. Wu, A. Yuille, and C. Feichtenhofer (2021) Masked feature prediction for self-supervised visual pre-training. arXiv preprint arXiv:2112.09133. Cited by: §2.3.
  • Z. Xie, Z. Zhang, Y. Cao, Y. Lin, J. Bao, Z. Yao, Q. Dai, and H. Hu (2021) SimMIM: a simple framework for masked image modeling. arXiv preprint arXiv:2111.09886. Cited by: §1, §1, §2.3, §3.2, §4.3, §4.3.
  • D. Yarats, I. Kostrikov, and R. Fergus (2021) Image augmentation is all you need: regularizing deep reinforcement learning from pixels. In International Conference on Learning Representations, Cited by: §B.2, §2.2, §3.1, §4.1, §4.2.
  • D. Yarats, A. Zhang, I. Kostrikov, B. Amos, J. Pineau, and R. Fergus (2019) Improving sample efficiency in model-free reinforcement learning from images. arXiv preprint arXiv:1910.01741. Cited by: §2.1, §2.2, §3.1, §4.1, §4.2.
  • W. Ye, S. Liu, T. Kurutach, P. Abbeel, and Y. Gao (2021) Mastering atari games with limited data. Advances in Neural Information Processing Systems 34. Cited by: §2.2.
  • T. Yu, C. Lan, W. Zeng, M. Feng, Z. Zhang, and Z. Chen (2021) PlayVirtual: augmenting cycle-consistent virtual trajectories for reinforcement learning. In Advances in Neural Information Processing Systems, A. Beygelzimer, Y. Dauphin, P. Liang, and J. W. Vaughan (Eds.), External Links: Link Cited by: §B.2, §2.1, §2.2, §3.1, §3.2, §3.2, §3.2, §4.1, §4.2, §4.2.
  • A. Zhang, R. T. McAllister, R. Calandra, Y. Gal, and S. Levine (2021) Learning invariant representations for reinforcement learning without reconstruction. In International Conference on Learning Representations, Cited by: §2.1.
  • J. Zhu, Y. Xia, L. Wu, J. Deng, W. Zhou, T. Qin, and H. Li (2020) Masked contrastive representation learning for reinforcement learning. arXiv preprint arXiv:2010.07470. Cited by: §2.1, §2.2.
  • B. D. Ziebart, A. L. Maas, J. A. Bagnell, A. K. Dey, et al. (2008) Maximum entropy inverse reinforcement learning.. In Aaai, Vol. 8, pp. 1433–1438. Cited by: Appendix A.

Appendix A Extended Background of Soft Actor-Critic

Soft Actor-Critic (SAC) (Haarnoja et al., 2018) is an off-policy actor-critic style algorithm. SAC is based on a maximum entropy RL framework where the standard maximum reward RL objective is augmented with an entropy maximization term (Ziebart et al., 2008). SAC has a soft Q-function and a policy .

The soft Q-function is learned by minimizing the soft Bellman error:


where is a tuple with current state , action , successor and reward , is the replay buffer and is the target value function. has the following expectation:


where is the target Q-function whose parameters are updated by an exponentially moving average of the parameters of the Q-function , and the temperature is used to balance the return maximization and the entropy maximization.

The policy is represented by using reparameterization trick and optimized by minimizing the following objective:



is the input noise vector sampled from Gaussian distribution

, and denotes actions sampled stochastically from the policy , i.e., .

Appendix B Implementation Detail

b.1 Network Architecture

Our model has two parts: the basic networks and the auxiliary networks. The basic networks consist of a representation network (i.e., encoder) parameterized by and the policy learning networks of SAC (Haarnoja et al., 2018) parameterized by .

We follow CURL (Laskin et al., 2020a)

to build the architecture of the basic networks. The encoder is composed of 4 convolutional layers (with a rectified linear units (ReLU) activation after each), a fully connected (FC) layer and an layer normalization (LN)

(Ba et al., 2016)

layer. And the policy learning networks are built by multilayer perceptrons (MLP).

Network Encoder Predictive Latent Decoder Projection Head Prediction Head
Param. 4.04M 40.8K 10.2K 10.2K
Table 6: Number of parameters of main networks in MLR.

Our auxiliary networks have online networks and momentum (or target) networks. The online networks consist of an encoder , a predictive latent decoder (PLD) , a projection head and a prediction head , parameterized by , , and , respectively. Notably, the encoders in the basic networks and the auxiliary networks are shared. As shown in Figure 1, there are a momentum encoder and a momentum projection head for computing the self-supervised targets. The momentum networks have the same architectures as the corresponding online networks. Our PLD is a transformer encoder (Vaswani et al., 2017) and has two standard attention layers (with a single attention head). We use an FC layer as the action embedding head to transform the original action to an embedding which has the same dimension of the state representation (i.e., state token). We use sine and cosine functions to build the positional embedding following (Vaswani et al., 2017):


where is the position, is the dimension and is the embedding size (equal to the state representation size). Both the projection head and the prediction head have two FC layers with a hidden size of 100 and a ReLU activation is followed by the first FC layer. We further present the parameters of the main networks in our method in Table 6. The encoder dominates the number of parameters in the auxiliary networks, which implies that the encoder plays a major role in the masked prediction.

b.2 Training Detail

Optimization. The training algorithm of our method is presented in Algorithm 1. We use Adam optimizer (Kingma and Ba, 2014) to optimize all parameters in our model, with (except for for SAC temperature ). Following (Vaswani et al., 2017), we warmup the learning rate of our MLR objective by


where the and denote the current learning rate and the initial learning rate, respectively, and and denote the current step and the warmup step, respectively. In this work, we run each model by using a single NVIDIA Tesla V100 GPU.

  Require: An online encoder , a momentum encoder , a predictive latent decoder , an online projection head , a momentum projection head , a prediction head and policy learning networks , parameterized by , , , , , and , respectively; a stochastic cube masking function ; a stochastic image augmentation function .
  Determine auxiliary loss weight , sequence length , mask ratio , cube size and EMA coefficient .
  Initialize a replay buffer .
  Initialize with and .
  Initialize all network parameters.
  while train do
     Interact with the environment based on the policy
     Collect the transition:
     Sample a trajectory of timesteps from :
     Randomly mask the observation sequence:
     Perform encoding:
     Perform decoding:
     Perform project and prediction:
     Calculate targets:
     Calculate MLR loss:
     Calculate RL loss based on a given base RL algorithm (e.g., SAC)
  end while
Algorithm 1 Training Algorithm for MLR

Date augmentation. Modest data augmentation such as crop or shift is shown to be effective for improving RL agent performance in vision-based RL (Yarats et al., 2021; Laskin et al., 2020b; Schwarzer et al., 2021a; Yu et al., 2021). Following PlayVirtual (Yu et al., 2021), we perform random crop on the observations in training the RL objective and random crop and intensity in training the auxiliary objective, i.e., .

b.3 Hyperparameters.

We list the hyperparameters used for DMControl benchmarks (Tassa et al., 2018) in Table 8.

Appendix C Extended Ablation Study

Sequence length. Table shows the results of the observation sequence length at {8, 16, 24}. A large (e.g., 24) does not bring further performance improvement as the network can reconstruct the missing content in a trivial way like copying and pasting the missing content from other states, while a small like 8 may not be sufficient for learning rich context information. 16 is a good trade-off in our experiment.

Env. Baseline K=8 K=16 K=24
Fing.spin 822 146 816 129 907 69 875 63
Cart.swin 782 74 857 3 791 50 781 58
Reac.easy 557 137 779 116 875 92 736 247 438 33 469 51 495 13 454 41
Walk.walk 414 310 473 264 597 102 533 98
Ball.catch 669 310 910 58 939 9 944 22
Mean 613.7 717.3 767.3 720.5
Median 613.0 797.5 833.0 758.5
Table 7: Ablation study of sequence length . We run each model on DMControl-100k benchmark with 5 random seeds.

Cube size. Our space-time cube can be flexibly designed. We ablate the influence of the spatial size ( and , by default) and temporal depth , as shown in Figure 5(a) and 5(b). In general, a proper cube size leads to good results. The spatial size has a large influence on the final performance. A moderate spatial size is good for MLR. The performance generally has a upward tendency when increasing the cube depth . However, a cube mask with too large (e.g., 16) possibly masks some necessary contents for the reconstruction and hinders the training.

(a) Cube spatial size &
(b) Cube depth
Figure 6: Ablation studies of (a) cube spatial size & and (b) cube depth . We report the mean and median scores on DMControl-100k benchmark. The result of each model is averaged over 3 random seeds.
Hyperparameter Value
Frame stack 3
Observation rendering (100, 100)
Observation downsampling (84, 84)
Augmentation for policy learning Random crop
Augmentation for auxiliary task Random crop and intensity
Replay buffer size 100000
Initial exploration steps 1000
Action repeat 2 Finger, spin and Walker, walk;
8 Cartpole, swingup;
4 otherwise
Evaluation episodes 10
Optimizer Adam
     (0.9, 0.999)
     (temperature in SAC) (0.5, 0.999)
Learning rate 0.0002 Cheetah, run
0.001 otherwise
Learning rate 0.0001 Cheetah, run
0.0005 otherwise
Learning rate warmup 6000 steps
Learning rate 0.0001
Batch size for policy learning 512
Batch size for auxiliary task 128
Q-function EMA 0.99
Critic target update freq 2
Discount factor 0.99
Initial temperature 0.1
Target network update period 1
Target network EMA 0.9 Walker, walk
0.95 otherwise
State representation dimension 50
Important Hyperparameters in MLR
Weight of MLR loss 1
Mask ratio 50%
Sequence length 16
Cube spatial size 10 10
Cube depth 4 Cartpole, swingup and Reacher, easy
8 otherwise
Decoder depth (number of attention layers) 2
Table 8: Hyperparameters used for DMControl environments.