Catastrophic Interference in Reinforcement Learning: A Solution Based on Context Division and Knowledge Distillation

by   Tiantian Zhang, et al.
Tsinghua University

The powerful learning ability of deep neural networks enables reinforcement learning (RL) agents to learn competent control policies directly from high-dimensional and continuous environments. In theory, to achieve stable performance, neural networks assume i.i.d. inputs, which unfortunately does no hold in the general RL paradigm where the training data is temporally correlated and non-stationary. This issue may lead to the phenomenon of "catastrophic interference" and the collapse in performance as later training is likely to overwrite and interfer with previously learned policies. In this paper, we introduce the concept of "context" into single-task RL and develop a novel scheme, termed as Context Division and Knowledge Distillation (CDaKD) driven RL, to divide all states experienced during training into a series of contexts. Its motivation is to mitigate the challenge of aforementioned catastrophic interference in deep RL, thereby improving the stability and plasticity of RL models. At the heart of CDaKD is a value function, parameterized by a neural network feature extractor shared across all contexts, and a set of output heads, each specializing on an individual context. In CDaKD, we exploit online clustering to achieve context division, and interference is further alleviated by a knowledge distillation regularization term on the output layers for learned contexts. In addition, to effectively obtain the context division in high-dimensional state spaces (e.g., image inputs), we perform clustering in the lower-dimensional representation space of a randomly initialized convolutional encoder, which is fixed throughout training. Our results show that, with various replay memory capacities, CDaKD can consistently improve the performance of existing RL algorithms on classic OpenAI Gym tasks and the more complex high-dimensional Atari tasks, incurring only moderate computational overhead.



There are no comments yet.


page 1

page 6

page 7

page 10

page 12

page 13

page 15

page 17


Periodic Intra-Ensemble Knowledge Distillation for Reinforcement Learning

Off-policy ensemble reinforcement learning (RL) methods have demonstrate...

Improving Computational Efficiency in Visual Reinforcement Learning via Stored Embeddings

Recent advances in off-policy deep reinforcement learning (RL) have led ...

Overcoming Catastrophic Interference in Online Reinforcement Learning with Dynamic Self-Organizing Maps

Using neural networks in the reinforcement learning (RL) framework has a...

On Catastrophic Interference in Atari 2600 Games

Model-free deep reinforcement learning algorithms are troubled with poor...

Analyzing the Hidden Activations of Deep Policy Networks: Why Representation Matters

We analyze the hidden activations of neural network policies of deep rei...

Privileged Information Dropout in Reinforcement Learning

Using privileged information during training can improve the sample effi...

Robust Domain Randomised Reinforcement Learning through Peer-to-Peer Distillation

In reinforcement learning, domain randomisation is an increasingly popul...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

I Introduction

In recent years, the successful application of deep neural networks (DNNs) in reinforcement learning (RL) [sutton2018reinforcement] has provided a new perspective to boost its performance on high-dimensional continuous problems. With the powerful function approximation and representation learning capabilities of DNNs, deep RL is regarded as a milestone towards constructing autonomous systems with a higher level of understanding of the physical world. Currently, deep RL has demonstrated great potential on complex tasks, from learning to play video games directly from pixels [mnih2013playing, mnih2015human] to making immediate decisions on robot behavior from camera inputs [faust2018prm, chiang2019learning, francis2020long]. However, these successes are limited and prone to catastrophic interference due to the inherent issue of DNNs in face of the non-stationary data distributions, and they rely heavily on a combination of various subtle strategies, such as experience replay [mnih2013playing] and fixed target networks [mnih2015human], or distributed training architecture [mnih2016asynchronous, bellemare2017distributional, espeholt2018impala].

Catastrophic interference is the primary challenge for many neural-network-based machine learning systems when learning over a non-stationary stream of data

[mccloskey1989catastrophic]. It is normally investigated in multi-task continual learning (CL), mainly including supervised continual learning (SCL) for classification tasks [kirkpatrick2017overcoming, fernando2017pathnet, lopez2017gradient, rebuffi2017icarl, mallya2018packnet, delange2021continual] and continual reinforcement learning (CRL) [kirkpatrick2017overcoming, fernando2017pathnet, riemer2019learning, kessler2020unclear, khetarpal2020towards] for decision tasks. In the multi-task CL, the agent continually faces new tasks and the neural network may quickly fit to the data distribution of the current task, while potentially overwriting the information related to learned tasks, leading to catastrophic forgetting of the solutions of old tasks. The underlying reason behind this phenomenon is the global generalization and overlapping representation of neural networks [ghiassian2020improving, bengio2020interference]. Neural networks training normally assumes that the inputs are identically and independently distributed (i.i.d.

) from a fixed data distribution and the output targets are sampled from a fixed conditional distribution. Only when this assumption is satisfied, can positive generalization be ensured among different batches of stochastic gradient descent. However, when the data distribution is drifted during training, the information learned from old tasks may be negatively interfered or even overwritten by the newly updated weights, resulting in catastrophic interference.

Deep RL is essentially a CL problem due to its learning mode of exploring while learning [khetarpal2020towards], and it is particularly vulnerable to catastrophic interference, even within many single-task settings (such as Atari 2600 games, or even simpler classic OpenAI Gym environments) [schaul2019ray, fedus2020catastrophic, lo2019overcoming, liu2019utility]

. The non-stationarity of data distributions in the single-task RL is mainly attributed to the following properties of RL. Firstly, the inputs of RL are sequential observations received from the environment, which are temporally correlated. Secondly, in the progress of learning, the agent’s decision making policy changes gradually, which makes the observations non-stationary. Thirdly, RL methods rely heavily on bootstrapping, where the RL agent uses its own estimated value function as the target, making the target outputs also non-stationary. In addition, as noted in

[fedus2020catastrophic], replay buffers with prioritized experience replay [schaul2016prioritized] that preferentially sample experiences with higher temporal-difference (TD) errors will also exasperate the non-stationarity of training data. Once the distribution of training data encounters notable drift, catastrophic interference and a chain reaction are likely to occur, resulting in a sudden deterioration of the training performance, as shown in Fig. 1.

Currently, there are two major strategies for dealing with catastrophic interference in the single-task RL training: experience replay [mnih2013playing, mnih2015human] and local optimization [liu2019utility, lo2019overcoming, ghiassian2020improving]. The former usually exhibits extreme sensitivity to key parameters (e.g., replay buffer size) and often requires maintaining a large experience storage memory. Furthermore, when faced with imbalanced state data, interference may still occur even if the memory is sufficiently large. The latter advocates local network updating for the data with a specific distribution instead of global generalization to reduce the representation overlap among different data distributions. The major issues are that some methods are limited in the capability of model transfer among differently distributed data [liu2019utility, lo2019overcoming], or can not scale to high-dimensional complex tasks [ghiassian2020improving], or require pretraining and may not be suitable for the online and incremental setting [liu2019utility].

In this paper, we focus on the catastrophic interference problem caused by state distribution drift in the single-task RL. We propose a novel scheme with low buffer-size sensitivity called Context Division and Knowledge Distillation

(CDaKD) that estimates the value function online and incrementally for each state distribution by minimizing the weighted sum of the original loss function of RL algorithms and the regularization term regarding the interference among different groups of states. The schematic architecture is shown in Fig.


In order to mitigate the interference among different state distributions during model training, we introduce the concept of “context” into the single-task RL, and propose a novel context division strategy based on online clustering. We show that it is essential to decouple the correlations among different state distributions with this strategy to divide the state space into a series of independent contexts (each context is a set of states distributed close to each other, conceptually similar to “task” in the multi-task CRL). To achieve efficient and adaptive partition, we employ

Sequential K-Means Clustering

[dias2008skm] to process the states encountered during training in real time. Then, we parameterize the value function by a neural network with multiple output heads commonly used in multi-task learning [zenke2017continual, golkar2019continual, kessler2020unclear]

in which each output head specializes on a specific context, and the feature extractor is shared across all contexts. In addition, we apply knowledge distillation as a regularization term in the objective function for value function estimation, which can preserve the learned policies while the RL agent is trained on states guided by the current policy, to further avoid the interference caused by the shared low-level representation. Furthermore, to ease the curse of dimensionality in high-dimensional state spaces, we employ a random encoder as its low-dimensional representation space can effectively capture the information about the similarity among states without any representation learning

[seo2021state]. Clustering is then performed in the low-dimensional representation space of the randomly initialized convolutional encoder.

Finally, to validate the efficacy of CDaKD, we conduct extensive experiments on several OpenAI Gym standard benchmark environments containing 4 basic MDP control environments [brockman2016openai] and 6 high-dimensional complex Arcade Learning Environments (ALE) [bellemare2013arcade]. Compared to existing experience-replay-based and local-optimization-based methods, CDaKD features state-of-the-art performance on classic control tasks (CartPole-v0, Pendulum-v0, CartPole-v1, Acrobot-v1) and Atari tasks (Pong, Breakout, Carnival, Freeway, Tennis, FishingDerby). The main contributions of this paper are summarized as follows:

  • A novel context division strategy is proposed for the single-task RL. It is essential as the widely studied multi-task CRL methods cannot be used directly to reduce interference in the single-task RL due to the lack of predefined task boundaries. This strategy can detect contexts adaptively online, so that each context can be regarded as a task in multi-task settings. In this way, we bridge the gap between the multi-task CRL and the single-task RL in terms of the catastrophic interference problem.

  • A novel RL training scheme called CDaKD based on multi-head neural networks is proposed following the context division strategy. By incorporating the knowledge distillation loss into the objective function, our method can better alleviate the interference suffered in the single-task RL than existing methods in an online and incremental manner.

  • A high-dimensional state representation method based on the random convolutional encoder is introduced, which further boosts the performance of CDaKD on high-dimensional complex RL tasks.

  • The CDaKD framework is highly flexible and can be incorporated into various RL models. Experiments on classic control tasks and high-dimensional complex tasks demonstrate the overall superiority of our method over baselines in terms of plasticity and stability, confirming the effectiveness of the context division strategy.

The rest of this paper is organized as follows. Section II reviews the relevant strategies for alleviating catastrophic interference as well as context detection and identification. Section III introduces the nature of RL in terms of continuous learning and gives a definition of catastrophic interference in the single-task RL. The details of CDaKD are shown in Section IV, and experimental results and analyses are presented in Section V. Finally, this paper is concluded in Section VI with some discussions and directions for future work.

Ii Related Work

Catastrophic interference within the single-task RL is a special case of CRL, which involves not only the strategies to mitigate interference but also the context detection and identification techniques.

Ii-a Multi-task Continual Reinforcement Learning

Multi-task CRL has been an active research area with the development of a variety of RL architectures [lesort2020continual]. In summary, existing methods mainly consist of three categories: experience replay methods, regularization-based methods, and parameter isolation methods.

The core idea of experience replay is to store samples of previous tasks in raw format (e.g., Selective Experience Replay (SER) [isele2018selective], Meta Experience Replay (MER) [riemer2019learning], Continual Learning with Experience And Replay (CLEAR) [rolnick2019experience]) or generate pseudo-samples from a generative model (e.g., Reinforcement Pseudo Rehearsal (RePR) [atkinson2021pseudo]) to maintain the knowledge about the past in the model. These previous task samples are replayed while learning a new task to alleviate interference, in the form of either being reused as model inputs for rehearsal [isele2018selective, atkinson2021pseudo] or constraining the optimization of the new task loss [rolnick2019experience, riemer2019learning]. As a result, experience replay has become a very successful approach to tackling interference in CRL. However, experience replay in raw format may result in significant storage requirements for more complex CRL settings. Although the generative model can be exempted from a replay buffer, it is still difficult to capture the overall distribution of previous tasks.

Regularization-based methods avoid storing raw inputs and thus alleviate the memory requirements, by introducing an extra regularization term into the loss function to consolidate previous knowledge while learning on new tasks. The regularization term includes penalty computing and knowledge distillation. The former focuses on reducing the chance of weights being modified. For example, Elastic Weight Consolidation (EWC) [kirkpatrick2017overcoming] and UNcertainty guided Continual LEARning (UNCLEAR) [kessler2020unclear] use Fisher matrix to measure the importance of weights and protect important weights on new tasks. The latter is a form of knowledge transfer [hinton2015distilling], which expects that the model trained on a new task can still perform well on the old ones. It is often used for policy transfer from one model to another (e.g., Policy Distillation [rusu2016policy], Genetic Policy Optimization (GPO) [gangwani2018policy], Distillation for Continual Reinforcement learning (DisCoRL) [traore2019discorl]). This family of solutions is easy to implement and tends to perform well on a small number of tasks, but still faces challenges as the number of tasks increases.

Parameter isolation methods dedicate different model parameters to each task, to prevent any possible interference among tasks. Without the constraints on the size of neural networks, one can grow new branches for new tasks, while freezing previous task parameters (e.g., Progressive Natural Networks (PNN) [rusu2016progressive]). Alternatively, the architecture remains static, with fixed parts being allocated to each task. For instance, PathNet [fernando2017pathnet]

uses a genetic algorithm to find a path from input to output for each task in the neural network and isolates the used network parts in parameter level from the new task training. These methods typically require networks with enormous capacity, especially when the number of tasks is large, and there is often unnecessary redundancy in the network structure, bringing a great challenge to model storage and efficiency.

Ii-B Single-task Reinforcement Learning

Compared with the multi-task CRL, catastrophic interference in the single-task RL remains an emerging research area, which has been relatively under-explored. There are two primary aspects of previous studies: one is finding supporting evidence to confirm that catastrophic interference is indeed prevalent within a specific RL task, and the other is proposing effective strategies for dealing with it.

Researchers in DeepMind studied the learning dynamics of the single-task RL and developed a hypothesis that the characteristic coupling between learning and data generation is the main cause of interference and performance plateaus in deep RL systems [schaul2019ray]. Recent studies further confirmed the above hypothesis and its universality in the single-task RL through large-scale empirical studies (called Memento experiments) in Atari 2600 games [fedus2020catastrophic]. However, none of these studies has suggested any practical solution for tackling the interference.

In order to mitigate interference, many deep RL algorithms such as DQN [mnih2015human] and its variants (e.g., Double DQN [van2016deep], Rainbow [hessel2018rainbow]) employ experience replay and fixed target networks to produce approximately i.i.d. training data, which may quickly become intractable in terms of memory requirement as task complexity increases. Furthermore, even with sufficient memory, it is still possible to suffer from catastrophic interference due to the imbalanced distribution of experiences.

In recent studies [liu2019utility, lo2019overcoming, ghiassian2020improving], researchers proposed some methods based on local representation and optimization of neural networks, which showed that interference can be reduced by promoting the local updating of weights while avoiding global generalization. Sparse Representation Neural Network (SRNN) [liu2019utility]

induces sparse representations in neural networks by introducing a distributional regularizer, which requires a large batch of data generated by a fixed policy that covers the space for pretraining. Dynamic Self-Organizing Map (DSOM)

[lo2019overcoming] with neural networks introduces a DSOM module to induce such locality updates. These methods can reduce interference to some extent, but they may inevitably suffer from the lack of positive transfer in the representation layer, which is not desirable in complex tasks. Recently, discretizing (D-NN) and tile coding (TC-NN) were used to remap the input observations to a high-dimensional space to sparsify input features, reducing the activation overlap [ghiassian2020improving]. However, tile coding increases the dimension of inputs to a neural network, which can lead to scalability issues for spaces with high dimensionality.

Ii-C Context Detection and Identification

Context detection and identification is a fundamental step for learning task relatedness in CL. Most multi-task CL methods mentioned above rely on well-defined task boundaries, and are usually trained on a sequence of tasks with known labels or boundaries. Existing context detection approaches commonly leverage statistics or Bayesian inference to detect task boundaries.

On the one hand, some methods tend to be reactive to a changing distribution by finding change points in the pattern of state-reward tuples (e.g., Context QL [padakandla2020reinforcement]), or tracking the difference between the short-term and long-term moving average rewards (e.g., CRL-Unsup [lomonaco2020continual]), or splitting a game into contexts using the undiscounted accumulated game score as a task contextualization [jain2020algorithmic]. These methods can be agile in responding to scenarios with abrupt changes among contexts or tasks, but are insensitive to smooth transitions from one context to another.

On the other hand, some more ambitious approaches try to learn a belief of the unobserved context state directly from the history of environment interactions, such as Forget-me-not Process (FMN) [milan2016forget] for piecewise-repeating data generating sources, and Continual Unsupervised Representation Learning (CURL) [rao2019continual] for task inference without any knowledge about task identity. However, they both need to be pretrained with the complete data when applied to CL problems, and CURL itself also requires additional techniques to deal with catastrophic interference.

Furthermore, Ghosh et al. [ghosh2018divide] proposed to partition the initial state space into a finite set of contexts by performing a K-Means clustering procedure, which can decompose more complex tasks, but cannot completely decouple the correlations among different state distributions from the perspective of interference prevention.

Iii Preliminaries and Problem Statement

To better characterize the problem studied in this paper, some key definitions and glossaries of CRL problems are introduced in this section.

Iii-a Definitions and Glossaries

Some important definitions of RL relevant to this paper are presented as follows.

Definition 1 (RL Paradigm [sutton2018reinforcement]).

A RL problem is regarded as a Markov Decision Process (MDP), which is defined as a tuple

, where is the set of states; is the set of actions;

is the environment transition probability function;

is the reward function, and is the discount factor.

According to Definition 1, at each time step , the agent moves from to with probability after taking action , and receives reward . Based on this definition, the optimization objective of value-based RL models is defined as follow:

Definition 2 (RL Optimization Objective [khetarpal2020towards]).

The optimization objective of the value-based RL is to learn a policy with internal parameter that maximizes the expected long-term discounted returns for each in time, also known as the value function:


Here, the expectation is over the process that generates a history using and decides actions from until the end of the agent’s lifetime.

The optimization objective in Definition 2 does not just concern itself with the current state, but also the full expected future distribution of states. As such, it is possible to overcome the catastrophic interference for RL over non-stationary data distributions. However, much of the recent work in RL has been in the so called episodic environments, which optimizes the episodic RL objective:

Definition 3 (Episodic RL Optimization Objective[khetarpal2020towards]).

Given some future horizon , find a policy , optimizing the expected discounted returns:


Here, to ensure the feasibility and ease of implementation of optimization, the objective is only optimized over a future horizon until the current episode terminates.

It is clear that the episodic objective in Definition 3 is biased towards the current episode distribution while ignoring the possibly far more important future episode distributions over the agent’s lifetime. Plugging in such an objective directly into the non-stationary RL settings leads to biased optimization, which is likely to cause catastrophic interference effects.

For large scale domains, the value function is often approximated with a member of the parametric function class, such as a neural network with parameter , expressed as , which is fit online using experience samples of the form . This experience is typically collected into a buffer from which batches are later drawn at random to form a stochastic estimate of the loss:


where is the agent’s loss function, and is the distribution that defines its sampling strategy. In general, the parameter used to compute the target is a prior copy of that used for action selection (as the settings of DQN [mnih2015human]).

In addition, it is necessary to clarify some important glossaries in relation to CL.

1) Non-stationary [hadsell2020embracing]

: a process whose state or probability distribution changes with time.

2) Interference [bengio2020interference]: a type of influence between two gradient-based processes with objectives , , sharing parameter . Interference is often characterized in the first order by the inner product of their gradients:


and can be seen as being constructive (, transfer) or destructive (, interference), when applying a gradient update using on the value of .

3) Catastrophic Interference [hadsell2020embracing]: a phenomenon observed in neural networks where learning a new task significantly degrades the performance on previous tasks.

Iii-B Problem Statement

The interference within the single-task RL can be approximately measured by the difference in TD errors before and after model update under the current policy, referred to as Approximate Expected Interference (AEI) [liu2020towards]:


where is the distribution of under the current policy and is the TD error.

To illustrate the interaction between interference and the agent’s performance during the single-task RL training, we run an experiment on CartPole using the DQN implemented in OpenAI Baselines111OpenAI Baselines is a set of high-quality implementations of RL algorithms implemented by OpenAI:, and set the replay buffer size to 100, a small capacity to trigger interference to highlight its effect. We trained the agent for 300K environment steps and approximated with a buffer containing recent transitions of capacity 10K to evaluate the AEI value according to Eq. (5) after each update. Fig. 2 shows two segments of the interference and performance curves during training from which we can see that the performance started to oscillate when AEI started to increase (e.g., , , and in Fig. 2(a), and in Fig. 2(b)). In general, the performance of the agent tends to drop significantly in the presence of increasing interference. This result provides direct evidence that interference is correlated closely with the stability and plasticity of the single-task RL model.

From the analysis above, we state the problem investigated in this paper as: proposing a novel and effective training scheme for the single-task RL, to alleviate catastrophic interference and reduce performance oscillation during training, improving stability and plasticity simultaneously.

(a) Phase I
(b) Phase II
Fig. 2: The interference (blue) and training performance (yellow) curve segments of a DQN agent on CartPole (). The interference is measured as the expectation in Eq. (5) and the performance is evaluated by the sum of discounted reward per episode. (a) Phase I with ; (b) Phase II with .

Iv The Proposed Method

Fig. 3: An overview of the CDaKD scheme. This framework consists of three components: 1) Context division, including state assignment and centroid update. Adaptive context division is achieved using Sequential K-Means Clustering online; 2) Knowledge distillation. The knowledge distillation loss () is incorporated into the objective function () to avoid interference among contexts due to the shared feature extractor; 3) Joint optimization with a multi-head neural network, which aims to estimate the value for each with the joint optimization loss (). In summary, our method can improve the performance by decoupling the correlations among differently distributed states and intentionally preserving the learned policies.

In this section, we give a detailed description of our CDaKD scheme whose architecture is shown in Fig. 3. CDaKD consists of three main components, which are jointly optimized to mitigate catastrophic interference in the single-task RL: context division, knowledge distillation, and the collaborative training of the multi-headed neural network. On the basis of CDaKD, we further propose CDaKD-RE with a random encoder for the efficient contextualization of high-dimensional state spaces.

As mentioned before, catastrophic interference is an undesirable byproduct of global updates to the neural network weights on data whose distribution changes over time. A rational solution to this issue is to estimate an individual value function for each distribution, instead of using a single value function for all distributions. When an agent updates its value estimation of a state, the update should only affect the states within the same distribution. With this intuition in mind, we adopt a multi-head neural network with shared representation layers to parameterize the distribution specific value functions.

The CDaKD scheme proposed in this paper can be incorporated into any existing value-based RL methods to train a piecewise -function for the single-task RL. The neural network is parameterized by a shared feature extractor and a set of linear output heads, corresponding to each context. As shown in Fig. 3, the set of weights of the -function is denoted by , where is a set of shared parameters while and are both context specific parameters: is for the context that corresponds to the current input state , and is for others. In this section, we take the combination of CDaKD and DQN as an illustrative example.

Iv-a Context Division

(a) CartPole-v0
(b) Pendulum-v0
Fig. 4: Measuring the interference among contexts by clustering all experienced states when the agent is trained on CartPole-v0 and Pendulum-v0 for 400K environment steps (), respectively. We record the relative changes in Huber loss for all contexts when the agent is trained on a particular context. It is clear that training on a particular context generally reduces the loss on itself and increases the losses on all other contexts.

In MDPs, states (or “observations”) represent the most comprehensive information regarding the environment. To better understand the states of different distributions, we define a variable for a set of states that are close to each other in the state space, referred to as “context”. Formally,


where is a finite set of contexts and is the number of contexts. For an arbitrary MDP, we partition its state space into contexts, and all states within each context follow approximately the same distribution, to decouple the correlations among states against distribution drift. More precisely, for a partition of in Eq. (6), we associate a context with each set , so that for , , where can be thought of as a function of state .

The inherent learning-while-exploring feature of RL agents leads to the fact that the agent generally does not experience all possible states of the environment while searching for the optimal policy. Thus, it is unnecessary to process the entire state space. Based on this fact, in CDaKD, we only perform context division on states experienced during training. In this paper, we employ Sequential K-Means Clustering [dias2008skm] to achieve context detection adaptively (See Appendix A for more details).

In Fig. 3, centroids are initialized at random in the entire state space. In each subsequent time step , we execute State Assignment and Centroid Update steps for each incoming state received from the environment, and store its corresponding transition into the replay buffer . Accordingly, in the training phase, we randomly sample a batch of transitions from and train the shared feature extractor and the specific output head corresponding to the input state simultaneously, while conducting fine-tuning of other output heads to avoid interference on learned policies. Since we store the context label of each state in the replay buffer, there are no additional state assignments required at every update step222In CDaKD, we only need to perform state assignment once for each state..

Note that it is also possible to conduct context division based on the initial state distribution [ghosh2018divide]. By contrast, we show that the partition of all states experienced during training can produce more accurate and effective context division results, as the trajectories starting from the initial states within different contexts have a high likelihood of overlapping in subsequent time steps (See Appendix B for more detailed experiments and analysis).

Interference Among Contexts: We investigate the interference among contexts obtained by our context division method in details. Specifically, we measure the Huber loss of TD errors in different contexts of the game as the agent learns in other contexts, and then record the relative changes in loss before and after the agent’s learning, as shown in Fig. 4. The results show that, long-term training on any context may lead to negative generalization on all other contexts, even in such simple RL tasks (i.e., CartPole-v0 and Pendulum-v0).

Computational Complexity: Assuming a -dimensional environment of contexts, the time and space complexities of our proposed context division module to process environment steps are and , respectively.

Iv-B Knowledge Distillation

The shared low-level representation can cause the learning in new contexts to interfere with previous learning results, leading to catastrophic interference. A relevant technique to address this issue is knowledge distillation [hinton2015distilling], which works well for encouraging the outputs of one network to approximate the outputs of another. The concept of distillation was originally used to transfer knowledge from a complex ensemble of networks to a relatively simpler network to reduce model complexity and facilitate deployment. In CDaKD, we use it as a regularization term in value function estimation to preserve the previously learned information.

When training the model on a specific context, we need to consider two aspects of the loss function: the general loss of the current training context (denoted by ), and the distillation loss of other contexts (denoted by ). The former encourages the model to adapt to the current context to ensure plasticity, while the latter encourages the model to keep the memory of other contexts, preventing interference.

To incorporate CDaKD into the DQN framework, we rewrite the original loss function of DQN with the context variable as:



is the estimated target value of and is the distribution of samples, i.e., , and is the agent’s loss function.

For each of the other contexts that the environment contains, we expect the output value for each pair of to be close to the recorded output from the original network. In knowledge distillation, we regard the learned -function before the current update step as the teacher network, expressed as , and the current network to be trained as the student network, expressed as , where except the current context . Thus, the distillation loss is defined as:



is the distillation loss function of the output head corresponding to context .

Iv-C Joint Optimization Procedure

To optimize a -function that can guide the agent to make proper decisions on each context without being adversely affected by catastrophic interference, we combine Eqs. (7) and (8) to form a joint optimization framework. Namely, we solve the catastrophic interference problem by the following optimization objective:


where is a coefficient to control the trade-off between the plasticity and stability of the neural network.

The complete procedure is described in Algorithm 1. The proposed method learns the value function and context division policy in parallel. For network training, to reduce the correlations with the target and ensure the stability of model training, the target network parameter is only updated by the -network parameter every steps and is held fixed between individual updates, as in DQN [mnih2015human]. Similarly, we also adopt fixed target context centroids () to avoid the instability of RL training caused by constantly updated context centroids (). To simplify the model implementation, we set the updating frequency of the target context centroids to be consistent with the target network.

Iv-D Random Encoders for High-dimensional State Space

Fig. 5: Illustration of CDaKD-RE. The context division is performed in the low-dimensional representation space of a random encoder. A separate RL encoder is used to work with the MLP layers to estimate the value function.

For high-dimensional state spaces, we propose to use random encoders for efficient context division, which can map high-dimensional inputs into low-dimensional representation spaces, easing the curse of dimensionality. Although the original RL model already contains an encoder module, it is constantly updated and directly performing clustering in its representation space may introduce extra instability into context division. Therefore, on the basis of CDaKD, we exploit a dedicated random encoder module for dimension reduction. Fig. 5 gives an illustration of this updated framework called CDaKD-RE in which the structure of the random encoder is consistent with the underlying RL encoder, but its parameter is randomly initialized and fixed throughout training. We provide the full procedure of CDaKD-RE in Appendix C.

The main motivation of using random encoders arises from the observation that distances in the representation space of random encoder are adequate for finding similar states without any representation learning [seo2021state]. That is, the representation space of a random encoder can effectively capture the information about the similarity among states without any representation learning (See Appendix C).

Input: Initial replay buffer with capacity ;
          Initial -function with random weights ;
          Initial target -function with weights ;
          Initial context centroids ;
          Initial target context centroids .
Parameter: Total training steps , the number of contexts ,
                target update period , learning rate .
Output: Updated and .

1:  Initial state ;
2:  for  do
3:     Interact with environment to obtain .
4:     States assignment: , .
5:     Store transition in .
6:     Context centroids update: .
7:     Joint optimization:    Sample mini-batch ;    Calculate , according to Eqs. (7) and (8);    Update parameter:                .
8:     if  then
9:        , .
10:     end if
11:  end for
Algorithm 1 CDaKD: Context Division and Knowledge Distillation-Driven Reinforcement Learning

V Experiments and Evaluations

In this section, we conduct comprehensive experiments on several standard benchmarks from OpenAI Gym333OpenAI Gym is a publicly available released implementation repository of RL environments: containing 4 classic control tasks and 6 high-dimensional complex Atari games to demonstrate the effectiveness of our method444Experiment code: please refer to supplementary material.. Several state-of-the-art methods such as experience replay based methods (e.g., DQN [mnih2015human], Rainbow [hessel2018rainbow]), local optimization based methods (e.g., SRNN [liu2019utility]), and other context division techniques (e.g., Game Score (GS) [jain2020algorithmic], Initial States (IS) [ghosh2018divide]) are employed as our baselines.

V-a Datasets

Classic Control [brockman2016openai] contains 4 classic control tasks: CartPole-v0, Pendulum-v0, CartPole-v1, Acrobot-v1, where the dimensions of state spaces are in the range of 3 to 6. The maximum time steps are 200 for CartPole-v0 and Pendulum-v0, and 500 for CartPole-v1 and Acrobot-v1. Meanwhile, the reward thresholds used to determine the success of tasks are 195.0, 475.0 and -100.0 for CartPole-v0, CartPole-v1 and Acrobot-v1, respectively, while the threshold for Pendulum-v0 is not available. We choose these domains as they are well-understood and relatively simple, suitable for highlighting the mechanism and verifying the effectiveness of our method in a straightforward manner.

Atari Games [bellemare2013arcade] contain 6 image-level complex tasks: Pong, Breakout, Carnival, Freeway, Tennis, FishingDerby, where the observation is the screenshot represented by an RGB image of size . We choose these domains to further demonstrate the scalability of our method on high-dimensional complex tasks.

V-B Implementation

Network Structure

. For the 4 classic control tasks, we employ a fully-connected layer as the feature extractor and a fully-connected layer as the multi-head action scorer, following the network configuration for this type of tasks in OpenAI Baseline. For the 6 Atari games, we employ the similar convolution neural network as

[hessel2018rainbow], [castro18dopamine]

for feature extracting (

i.e., RL Encoder in Fig. 5) and two fully-connected layers as the multi-head action scorer. In addition, for random encoders, we use convolutional neural networks with the same structure as the underlying RL methods, but we do not update its randomly initialized parameters during the training process [seo2021state]. Details of the networks can be found in Appendix D.

Parameter Setting. In CDaKD, there are two key parameters: and . To simplify parameter setting, we set in accordance with the exploration proportion of the agent in all experiments: , due to the inverse relationship between them in training. During the early training, is close to 1, and the model is normally inaccurate with little interference, and a small (close to 0) can promote plasticity construction of the neural network. Then, gradually approaches 0 during the subsequent training, and the model has accumulated more and more useful information, while interference is also likely to occur. Consequently, smoothly increasing is needed to ensure stability and avoid catastrophic interference. Meanwhile, we set to 3 for all classic control tasks, and 4 for all Atari games. It is suggested that larger values should be used for complex environments to achieve reasonable state decoupling effect. In CDaKD-RE, there is an extra parameter: the output dimension of the random encoder. We set to 50 as in [seo2021state], which has been shown to be both efficient and effective. More details of the parameter setting can be found in Appendix D. For classic control tasks, we evaluate the training performance using the average episode returns every 10K time steps for CartPole-v0, Pendulum-v0, and 20K time steps for CartPole-v1, Acrobot-v1. For Atari games, the time step range for performance evaluation is 200K. All experiment results reported are the average episode returns over 5 independent runs.

V-C Baselines

We compare our method with five state-of-the-art methods, including DQN [mnih2015human] and Rainbow [hessel2018rainbow] based on experience replay, SRNN [liu2019utility] based on local optimization for alleviating catastrophic interference, and two context division techniques (i.e., Game Score (GS) [jain2020algorithmic], Initial States (IS) [ghosh2018divide]) for efficient contextualization. We briefly introduce these baselines in the following.

  • DQN [mnih2015human] is a representative algorithm of Deep RL, which reduces catastrophic interference using experience replay and fixed target networks. We use the DQN agent implemented in OpenAI Baselines.

  • Rainbow [hessel2018rainbow] is the upgraded version of DQN containing six extensions, including a prioritized replay buffer [schaul2016prioritized], n-step returns [sutton2018reinforcement] and distributional learning [bellemare2017distributional] for stable RL training. The Rainbow agent is implemented in Google’s Dopamine framework555Dopamine is a research framework developed by Google for fast prototyping of reinforcement learning algorithms:

  • SRNN [liu2019utility] employs a distributional regularizer to induce sparse representations in neural networks to avert catastrophic interference in the single-task RL.

  • CDaKD-GS is our extension to DQN using the CDaKD scheme where the context division is based on the undiscounted accumulated game score in [jain2020algorithmic] instead of all experienced states.

  • CDaKD-IS is another extension to DQN using the CDaKD scheme where the context division is based on the initial states distribution as in [ghosh2018divide] instead of all experienced states.

V-D Evaluation Metrics

Following the convention in previous studies [mnih2016asynchronous, bellemare2017distributional, espeholt2018impala, hessel2018rainbow, fedus2020catastrophic], we employ the average training episode returns to evaluate our method during training:


where is the number of episodes experienced within each evaluation period; is the total time steps in episode ; is the reward received at time step in episode .

V-E Results

Method DQN Rainbow DQN+CDaKD-RE Rainbow+CDaKD-RE
TABLE I: Statistics of the highest cumulative scores achieved during training on Atari games
(based on the performance of five runs in Fig. LABEL:fig:Atari_games).
Method DQN Rainbow DQN+CDaKD-RE Rainbow+CDaKD-RE
TABLE II: Statistics of the maximum deterioration ratios suffered during training on Atari games
(based on the average performance of five runs in Fig. LABEL:fig:Atari_games).

The results on 4 control tasks are presented in Fig. LABEL:fig:Classic_Control, showing the learning curves during training for each task with three levels of replay buffer capacity. Note that CDaKD-ES is our proposed scheme where the context division is based on all experienced states666In this paper, all appearances of CDaKD refer to CDaKD-ES unless otherwise stated.. In general, CDaKD-ES is clearly superior to all baselines in terms of plasticity and stability, especially when the replay buffer capacity is small (e.g., ) or even without experience replay (i.e., ). In most tasks, CDaKD-ES achieves near optimal performance as well as good stability even without any experience replay. For Pendulum-v0 and Acrobot-v1, a large replay buffer (e.g., ) can help DQN and SRNN escape from catastrophic interference. However, this is not the case for two CartPole tasks where the agents exhibit fast initial learning but then encounter collapse in performance.

The learning curves on 6 Atari games are shown in Fig. LABEL:fig:Atari_games. Moreover, the highest cumulative returns (indicator of plasticity) achieved during the training of each task are summarized in Table I, and the maximum deterioration ratios (indicator of stability) compared to the previous maximal episode returns are given in Table II. Overall, for high-dimensional image inputs, the training performance of the original RL algorithms can be noticeably improved with our CDaKD-RE scheme. In Fig. LABEL:fig:Atari_games, our method significantly outperforms DQN on 7 out of 12 tasks, being comparable with DQN on the rest 5 tasks. Similarly, our method outperforms Rainbow on 8 out of 12 tasks, being comparable with Rainbow on the rest 4 tasks. Furthermore, as shown in Table I and Table II, CDaKD-RE achieves higher maximum cumulative scores and less performance degradation in most tasks compared to its counterparts. Among the 24 training settings, only 2 maximum cumulative scores achieved by CDaKD-RE are slightly lower than those of baselines. In terms of the average performance degradation ratio, DQN and Rainbow incorporated with CDaKD-RE surpass the original RL methods by and (), respectively. Note that, even with a large memory (), CDaKD-RE still shows certain advantages over the baselines.

Moreover, we observe that: 1) DQN and SRNN agents exhibit high sensitivity to the replay buffer capacity. They generally perform well with a large buffer (except on CartPole-v1), but their performance deteriorates significantly when the buffer capacity is reduced. DQN performs worst among all baselines when there is no experience replay, as both DQN and SRNN may face severe data drift when the buffer is small, and approximately i.i.d. training data are required to avoid possible interference in training. 2) CDaKD-GS and CDaKD-IS leverage the cumulative game score and initial state distribution to partition contexts, respectively. However, neither the game score nor the initial state distribution is a perfect determining factor for context boundaries, making it difficult to achieve the necessary decoupling of differently distributed states. Furthermore, they require prior knowledge about the scores for each level of the game or the initial state distribution of the environment.

In summary, the proposed techniques containing context division based on the clustering of all experienced states, and knowledge distillation in multi-head neural networks can effectively eliminate catastrophic interference caused by data drift in the single-task RL. In addition, our method leverages a fixed randomly initialized encoder to characterize the similarity among states in the low-dimensional representation space, which can be used to partition contexts effectively for high-dimensional environments.

V-F Analysis

1) Ablation Study: Since our method can be regarded as an extension to existing RL methods (e.g., DQN [mnih2015human]), with three novel components (i.e., adaptive context division by online clustering, knowledge distillation, and the multi-head neural network), the ablation experiments are designed as follows:

  • No clustering means using a random partition of the raw state space instead of adaptive context division by online clustering.

  • No distillation means removing the distillation loss function from the joint objective function in Eq. (9) (i.e., ).

  • No multi-head means removing the context division module and optimize the neural network with a single-head output (i.e., ). Here, the distillation term is represented as the distillation of the network before each update of the output head.

The results of ablation experiments are shown in Fig. LABEL:fig:ablation_study, using classic control tasks for the convenience of validation. From Fig. LABEL:fig:ablation_study, the following facts can be observed: 1) Across all settings, the overall performance of DQN is the worst, showing the effectiveness of the three components introduced for coping with catastrophic interference in the single-task RL, although the contribution of each component varies substantially per task; 2) Removing online clustering from the context division module is likely to damage the performance in most cases; 3) Without the multi-head component, our model is equivalent to a DQN with an extra distillation loss, which can result in better performance in general than DQN; 4) Removing knowledge distillation makes the performance deteriorate on almost all tasks, indicating that knowledge distillation is a key element in our method.

2) Parameter Analysis: There are two critical parameters in CDaKD: and . By its nature, is related to the training process. Since we need to preserve the learned good policies during training, it is intuitive to gradually increase until its value reaches 1. The reason is that, in early-stage training, the model has not learned any sufficiently useful information, so the distillation constraint can be ignored. With the progress of training, the model starts to acquire more and more valuable information and needs to pay serious attention to interference to protect the learned good policies while learning further. In our experiments, we recommend to set to be inversely proportional to the exploration proportion , and the results in Figs. LABEL:fig:Classic_Control and LABEL:fig:Atari_games have demonstrated the simplicity and effectiveness of this setting.

To investigate the effect of , which is related to the complexity of learning tasks, we conduct experiments with different values () and the results are shown in Fig. LABEL:fig:parameter_study. In our experiments, is a reasonably good choice for CartPole-v0 and CartPole-v1, while is good for CDaKD on Acrobot-v1. It is worth noting that, on Pendulum-v0, our method achieves similar performance with set to 2, 3, 4, and 5, respectively, but without any satisfactory performance. A possible explanation is that the agents failed to learn any useful information due to the limited state space explored in the early training, leading to the failure of further learning.

In summary, we can make the following statements: 1) The performance of DQN combined with CDaKD is significantly better than the original DQN regardless of the specific value, confirming the effectiveness of our CDaKD scheme; 2) For , better performance of CDaKD can be expected. However, large values are not always desirable as larger values will result in more fine-grained context divisions and more complex neural networks with a large amount of output heads, making the model unlikely to converge satisfactorily within a limited number of training steps. Thus, we recommend to set the value of by taking into consideration the state-space structure of specific tasks.

Fig. 10: Training curves tracking the agent’s average loss and average predicted action-value for 400K environment steps in Pendulum-v0 ( and , see Fig. LABEL:fig:Classic_Control_b for corresponding curves). (a) Each point is the average loss achieved per training iteration; (b) Average maximum predicted action-value of agents on a held-out set of states; (c) Average maximum predicted action-value of each output head in CDaKD.
(a) Training time
(b) Computational complexity
Fig. 11: Comparison of computational efficiency: (a) Training time of each agent to achieve its performance for 400K environment steps in Pendulum-v0 (, see Fig. LABEL:fig:Classic_Control_b for corresponding learning curves); (b) Number of FLOPs used by each agent at 10M environment steps in Breakout. Here, we only take into account forward and backward passes through neural network layers (See Fig. LABEL:fig:Atari_games for corresponding learning curves).

3) Convergence Analysis: To analyze convergence, we track the agent’s average loss and average predicted action-value during training progress. According to Fig. 10, we can conclude that: 1) Our method has better convergence and stability in face of interference compared with original RL algorithms (See Fig. 10(a) and Fig. 10(b)); 2) For a held-out set of states, the average maximum predicted action-value of each output head reflects the difference as expected (See Fig. 10(c)), and the final output of the CDaKD agent is synthesised based on all output heads.

4) Computational Efficiency: Our methods are computationally efficient in that: 1) In each time step, the extra context division module only needs to compute the distances between the current state and context centroids, which is computationally negligible w.r.t. the SGD complexity of RL itself; 2) Only extra output heads are added to the neural network, in which the increased computation is acceptable w.r.t. the representation complexity; 3) There are no gradient updates through the random encoder; 4) There is no unnecessary distance computation for finding the corresponding context at every update step as the context label for each state is stored in the replay buffer. Fig. 11 shows the training time of each agent on Pendulum and the floating point operations (FLOPs) executed by agents on Breakout, respectively.

Vi Conclusion and Future Work

In this paper, we propose a competent scheme CDaKD to tackle the inherent challenge of catastrophic interference in the single-task RL. The core idea is to partition all states experienced during training into a set of contexts using online clustering techniques and simultaneously estimate the context-specific value function with a multi-head neural network as well as a knowledge distillation loss to mitigate the interference across contexts. Furthermore, we introduced a random convolutional encoder to enhance the context division for high-dimensional complex tasks. Our method can effectively decouple the correlations among differently distributed states and can be easily incorporated into various value-based RL models. Experiments on several benchmark tasks show that our method can significantly outperform state-of-the-art RL methods and dramatically reduce the memory requirement of existing RL methods.

In the future, we aim to incorporate our method into policy-based RL models to reduce the interference during training by applying weight or functional regularization on policies. Furthermore, we will investigate a more challenging setting called continual RL in non-stationary environments [lomonaco2020continual]. This setting is a more realistic representation of the real-world scenarios and includes abrupt changes or smooth transitions on dynamics, or even the dynamics itself is shuffled.


The work presented in this paper was supported by the National Natural Science Foundation of China (U1713214).


Appendix A Sequential K-Means Clustering

The process of Sequential K-Means Clustering for the current state is shown in Algorithm 2. Each context centroid is the average of all of the states closest to . In order to get a better initialization of , we can perform offline K-Means clustering on all states experienced before training starts and set the results of centroids as the initial . Then, Sequential K-Means Clustering is performed in subsequent time steps.

Input: Current state ;
          Initial context centroids .
Output: Updated centroids .

1:  Count the number of samples in each contexts:
2:  if  is closest to centroid  then
3:     Increment : ;
4:     Update : ;
5:  end if
6:  return .
Algorithm 2 SKM: Sequential K-Means Clustering
(a) ISC
(b) ESC
Fig. 12: Two-dimensional t-SNE results of context division according to different kinds of clustering objects on CartPole-v0: (a) ISC vs (b) ESC. Here, the colors represent different contexts of which the points represent the states within the corresponding context.

Appendix B Initial States vs All Experienced States

We compare the effects of two kinds of context division techniques:

  • ISC (Initial States Clustering) means performing context division using K-Means on the samples sampled from the initial states distribution, referring to [ghosh2018divide];

  • ESC (Experienced States Clustering) is our proposed performing context division using Sequential K-Means Clustering on all states experienced during training process.

We run the DQN agent incorporated with CDaKD with the above two clustering techniques on CartPole-v0, respectively, and visualize the two-dimensional t-SNE results of context division in Fig. 12. It can be clearly observed that the contexts divided by ESC are relatively independent and there is almost no overlapping area among contexts, achieving effective decoupling among states with different distributions. By contrast, there are obvious overlapping areas among the contexts divided by ISC, since the trajectories starting from the initial states within different contexts have a high likelihood of overlapping in subsequent time steps, which is not desirable for reducing interference on neural network training among differently distributed states.

Appendix C Random Encoder

We find the K-nearest neighbors of some specific states by measuring the distances in the low-dimensional representation space produced by a randomly initialized encoder (Random Encoder) on Breakout. The results are shown in Fig. 13 from which we can observe that the raw images corresponding to the K-nearest neighbors of the source image in the representation space of the random encoder demonstrate significant similarities. We provide the full procedure of CDaKD with the random encoder in Algorithm 3.

Input: Initial replay buffer with capacity ;
          Initial -function with random weights ;
          Initial target -function with weights ;
          Initial random encoder with weights ;
          Initial context centroids ;
          Initial target context centroids .
Parameter: Total training steps , the number of contexts ,
                the output dimension of random encoder , target
                update period , learning rate .
Output: Updated and .

1:  Initial state ;
2:  for  do
3:     Interact with environment to obtain .
4:     State encoding: get the fixed representation for and , , .
5:     States assignment: , .
6:     Store transition in .
7:     Context centroids update: .
8:     Joint optimization:    Sample mini-batch ;    Calculate , according to Eqs. (7) and (8);    Update parameter: .
9:     if  then
10:        ,
11:     end if
12:  end for
Algorithm 3 CDaKD-RE: Context Division with Random Encoder and Knowledge Distillation-Driven Reinforcement Learning
Fig. 13: Two-dimensional t-SNE visualization of K-nearest neighbors of states found by measuring distances in the representation space of a Random Encoder on Breakout. We observe that the representation space of a randomly initialized encoder effectively captures information about similarity between states.
Tasks Layer Input Filter size Stride Num filters Activation Output
Classic Control Tasks FC1 Dimension of state space - - Tanh
FC2 - - Number of actions Linear Number of actions
Atari Games Conv1 ReLU
Conv2 ReLU
Conv3 ReLU
FC4 - - ReLU
FC5 - - Linear
TABLE III: The neural networks architecture of underlying RL models used in experiments.
Hyperparameter Classic Control Tasks Atari Games
Training time step
steps for CartPole-v0 and Pendulum-v0
steps for CartPole-v1 and Acrobot-v1
decay schedule
steps for CartPole-v0 and Pendulum-v0
steps for CartPole-v1 and Acrobot-v1
Min. history to start learning steps steps
Target network update frequency steps steps
Batch size
Learning rate
for DQN and DQN+CDaKD-RE
for Rainbow and Rainbow+CDaKD-RE
TABLE IV: The common hyperparameters of each mathod used in experiments.

Appendix D Implementation Details

To ensure the fairness of comparison, our results compare the agents based on the underlying RL model with the same hyperparameters and neural network architecture. We provide a full list of neural networks architecture of the underlying RL models in Table III and summarize our choices for common key hyperparameters in Table IV.

Appendix E Calculation of model complexity

Computational Complexity of Context Division. At each time step, SKM only needs to calculate the distances between the current state and context centroids. Given a -dimensional state space and -step environment interactions, the time complexity of context division is . At the same time, since only additional context centroids need to be stored for clustering, the space complexity is .

Calculation of Floating Point Operations. We obtain the number of operations per forward pass for all layers in the encoder (denoted by ) and the number of operations per forward pass for all MLP layers in each output head (denoted by ), as in Therefore, the number of FLOPs of CDaKD-RE is:

where is the batch size; is the number of environment steps; is the number of training updates. The first two terms are for the forward and backward passes required in training updates, respectively. The latter two terms are for the forward passes required to compute the policy action and obtain the low-dimensional representation from the random encoder, respectively. In our experiments: , , , , MFLOPs, MFLOPs for Rainbow and MFLOPs for DQN.