Marginal Replay vs Conditional Replay for Continual Learning

by   Timothée Lesort, et al.

We present a new replay-based method of continual classification learning that we term "conditional replay" which generates samples and labels together by sampling from a distribution conditioned on the class. We compare conditional replay to another replay-based continual learning paradigm (which we term "marginal replay") that generates samples independently of their class and assigns labels in a separate step. The main improvement in conditional replay is that labels for generated samples need not be inferred, which reduces the margin for error in complex continual classification learning tasks. We demonstrate the effectiveness of this approach using novel and standard benchmarks constructed from MNIST and FashionMNIST data, and compare to the regularization-based EWC method.



There are no comments yet.


page 8

page 9

page 10

page 11


Distilled Replay: Overcoming Forgetting through Synthetic Samples

Replay strategies are Continual Learning techniques which mitigate catas...

Online Continual Learning with Maximally Interfered Retrieval

Continual learning, the setting where a learning agent is faced with a n...

The Effectiveness of Memory Replay in Large Scale Continual Learning

We study continual learning in the large scale setting where tasks in th...

Knowledge Capture and Replay for Continual Learning

Deep neural networks have shown promise in several domains, and the lear...

Generative replay with feedback connections as a general strategy for continual learning

Standard artificial neural networks suffer from the well-known issue of ...

Continual Learning on Noisy Data Streams via Self-Purified Replay

Continually learning in the real world must overcome many challenges, am...

Efficient Continual Adaptation for Generative Adversarial Networks

We present a continual learning approach for generative adversarial netw...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

This contribution is in the context of incremental and lifelong learning, subject that is gaining increasing recent attention (parisi2018continual; gepperth2016incremental). Briefly put, the problem consists of repeatedly re-training a DNN model with new tasks (e.g., visual classes) over long time periods, while avoiding the abrupt degradation of previously learned abilities that is known under the term "catastrophic interference" or "catastrophic forgetting" (gepperthICANN; french; gepperth2016incremental). Is has long been known that catastrophic forgetting is a problem for connectionist models (french)

of which modern DNNs are a specialized instance, but only recently there have been efforts to propose workable solutions to this problem for deep learning models

(lee2017overcoming; Kirkpatrick2016; selfless; DBLP:journals/corr/abs-1805-10784). One aspect of the problem seems to be that gradient-based DNN training is greedy, i.e., it tries to optimize all weights in the network to solve the current task only. Previous tasks, which are not represented in the current training data, will naturally be disregarded in this process. While approaches such as (Kirkpatrick2016; lee2017overcoming) aim at "protecting" weights that were important for previous tasks, one can approach the problem from the other end and simply include samples from previous tasks in the training process each time a new task is introduced. This is the conditional replay approach we are proposing in this article, which is similar in spirit to (shin2017continual) but presents important conceptual improvements. The major reason for performing conditional replay, and not simply using data stored from previous tasks, is that the storage of several GB of data may not be feasible for, e.g., embodied agents, or embedded devices performing object recognition, while the "essence" of previous tasks, in the form of DNN weights, usually requires far less space. A downside of this and similar approaches is that the time complexity of adapting to a new task is not constant but depends on the number of preceding tasks that should be "remembered". Or, conversely, if continual learning should be performed at constant time complexity, only a fixed amount of samples can be generated, and thus there will be forgetting, although it won’t be catastrophic.

1.1 Contribution

The original contributions of this article can be summarized as follows: first of all, we propose conditional replay as a method for continual classification learning, and compare conditional and marginal replay models on a common set of benchmarks. We furthermore propose an improvement of marginal replay as proposed in (shin2017continual) by using GANs. To measure the merit of these proposals, we use two experimental settings that have not been previously considered for benchmarking generative replay: rotations and permutations. Finally, we show the principled advantage that generative replay techniques have with respect to regularization methods like EWC in a "one class per task" setting, which is after all a very common setting in practice and in which discriminatively trained models strongly tend to assign the same class label to every sample regardless of content.

1.2 Related work

The field of incremental learning is growing and has been recently reviewed in, e.g., (parisi2018continual; gepperth2016incremental)

. In the context of neural networks, principal recent approaches include ensemble methods

(ren2017life; fernando2017pathnet), regularization approaches (Kirkpatrick2016; lee2017overcoming; selfless; DBLP:journals/corr/abs-1805-10784; Srivastava2013; Hinton2012), dual-memory systems (kemker2017fearnet; rebuffi2017icarl; gepperth2015bio) and generative replay methods. In the context of "pure" DNN methods, regularization approaches are predominant: whereas it was proposed in (Goodfellow2013) that the popular Dropout regularization can alleviate catastrophic forgetting, the EWC method (Kirkpatrick2016)

proposes to add a term to the DNN energy function that protects weights that are deemed to be important for the previous sub-task(s). Whether a weight is important or not is determined by approximating and analyzing the Fisher information matrix of the DNN. A somewhat related approach is pursued with the incremental moment matching (IMM, see

(lee2017overcoming)) technique, where weights are transferred between DNNs trained on successive sub-tasks by regularization techniques, and the Fisher information matrix is used to "merge" weights for current and past sub-tasks. Other regularization-oriented approaches are proposed in (selfless; Srivastava2013) which focus on enforcing sparsity of neural activities by lateral interactions within a layer, or in (DBLP:journals/corr/abs-1805-10784). Concerning recent advances in generative replay improving upon (shin2017continual): (wu2018memory; 2018arXiv181209111L) propose a conditional replay mechanism similar to the one investigated here, but their goal is the sequential learning of data generation and not classification tasks.

2 Methods

A basic notion in this article is that of a continual learning task (CLT), denoting a classification problem that is composed of two or more sub-tasks which are presented sequentially to the model in question. Here, the CLTs are constructed from two standard visual classification benchmarks: MNIST and Fashion MNIST, either by dividing available classes into several sub-tasks, or by performing per-sample image processing operations that are identical within, and different between, sub-tasks. All continual learning models are then trained and evaluated in an identical fashion on all CLTs, and performances are compared by a simple visual inspection of classification accuracy plots.

2.1 Benchmarks

MNIST  (LeCun1998)

is a common benchmark for computer vision systems and classification problems. It consist of gray scale images of handwritten digits (0-9).

Fashion MNIST  (Xiao2017) consists of clothes images and is structured like the standard MNIST dataset. We choose this dataset because it claims to be a “more challenging classification task than the simple MNIST digits data (Xiao2017)” while having the same data dimensions, number of classes and roughly the same number of samples.

2.2 Continual learning tasks (CLTs)

Rotations New sub-tasks are generated by choosing a random rotation angle and then performing a 2D in-plane rotation on all samples of the original benchmark. As both benchmarks we use contain samples of 28x28 pixels, no information loss is introduced by this procedure. We limit rotation angles to because larger rotations could mix MNIST classes like 0 and 9.

Permutations New sub-tasks are generated by defining a random pixel permutation scheme, and then applying it to each data sample of the original benchmark.

Disjoint classes For each benchmark, this CLT has as many sub-tasks as there are classes in the benchmark (10 in this article). Each sub-task contains all the samples of a single class. As the classes are balanced for both benchmarks we use, this does not unduly favor certain classes.

2.3 Models

Fully-connected network

As a reference implementation, we use a fully-connected network ( 2 hidden layers with 200 neurons each) with ReLU activation function. No batch normalization or dropout is performed. All other training parameters are described in Appendix


EWC We re-implemented the algorithm described in (Kirkpatrick2016), choosing two hidden layers with 200 neurons each.

Marginal replay  In the context of classification, the marginal replay (2018arXiv181209111L; shin2017continual; wu2018memory) method works as follows : For each sub-task , there is a dataset

, a classifier

, a generator and a memory of past samples composed of a generator and a classifier . The latter two allow the generation of artificial samples from previous sub-tasks. Then, by training and on and , the model can learn the new sub-task without forgetting old ones. At the end of the sub-task, and are frozen and replace and . In our experiments, we always train in a way that makes samples balanced between current sub-task and past sub-tasks . We choose to evaluate two different models: WGAN_GP as used in (shin2017continual) and the original GAN model (NIPS2014_5423) since it is a competitive baseline (lesort2018training).

Conditional replay  The conditional replay method is derived from marginal replay: instead of saving a classifier and a generator, the algorithm only saves a generator that can generate conditionally (for a certain class). Hence, for each sub-task , there is a dataset , a classifier and two generators and . The goal of is to generate data from all the previous sub-tasks during training on the new sub-task. Since data is generated conditionally, samples automatically have a label and thus do not require a frozen classifier. and learn from generated data and . At the end of a sub-task , is able to classify data from the current and previous sub-tasks, and is able to sample from them also. We choose to use two different popular conditional models : CGAN described in (mirza2014conditional) and CVAE (NIPS2015_5775).

3 Results and discussion

As can be observed from Fig. 1, the methods we propose (marginal replay with GAN, conditional replay with CGAN and conditional replay with CVAE) outperform all others, on all CLTs, by quite some margin. In particular, in our experiments, the clear advantage of those methods over marginal replay with WGAN-GP is the higher stability of the generative models. This is not only observable in Fig. 1, but also when measuring performance on the first sub-task only during the course of continual learning (see Fig.8) as well as computing the Fréchet Inception Distance (FID, see Fig. 9). Fig. 7 additionally confirms what Fig. 1 suggests implicitly: that the generator learns incrementally over time, balancing its stability and plasticity in a smooth way. For completeness, we verified that the manner of balancing the number of generated and new samples has a huge impact, so care needs to be taken here (see Fig. 2). Conditional replay is less sensitive to the balance of data than marginal replay since it can make a clear distinction between data from different classes. Balance task necessitate to generate a lot of samples from past tasks which consequently increases the learning time gradually. Conditional replay is then a good solution to learn with less generated data and then reduce training time.

Particular attention should be given to the performance of EWC: while generally acceptable for rotation and permutation CLTs, it completely fails for the disjoint CLT. This is due to the fact that there is only one class in each sub-task, making EWC try to map all samples to the currently presented class label regardless of input, since no replay is available to include samples from previous sub-tasks.

(a) accuracy for MNIST disjoint CLT
(b) accuracy for Fashion MNIST disjoint CLT
(c) accuracy for MNIST permutation CLT
(d) accuracy for Fashion MNIST permutation CLT
(e) accuracy for MNIST rotation CLT
(f) accuracy for Fashion MNIST rotation CLT
Figure 1: Test set accuracies during training on different CLTs, shown for all sub-tasks (indicated by dotted lines).

4 Conclusion

We have proposed two novel ways of performing continual learning with replay-based models and empirically demonstrated (on novel benchmarks) their merit w.r.t. the state of the art. Clearly, focus of future research will lie on understanding the current shortcomings of conditional replay methods and removing them, thus providing a replay-based method of greater simplicity, elegance and learning capacity. Furthermore, we will investigate the advantage of conditional replay with less generated data.


Appendix A Task Balance Influence

(a) Unbalanced MNIST Disjoint
(b) Unbalanced Fashion Disjoint
(c) Balanced MNIST Disjoint
(d) Balanced Fashion Disjoint
Figure 2: We compare accuracy on first task when the ratio between size of old task and new task is 1 (balanced) or 1/5 (unbalanced, factor was chosen empirically).

Appendix B Hyper-parameters

Method Epochs LR Classifier LR Generator beta1 beta2 Batch Size
Marginal Replay 25 0.01 2e-4 5e-1 0.999 64
Conditional Replay 25 0.01 2e-4 5e-1 0.999 64
Ewc 25 0.01 - 5e-1 0.999 64
Finetuning 25 0.01 - 5e-1 0.999 64
Expert 50 0.01 - 5e-1 0.999 64
Table 1: Hyperparameters for MNIST and Fashion MNIST all models ( all CL settings have the same training hyper parameters with Adam)

Appendix C Data

c.1 CLT: Disjoint

(a) Sub-task 0
(b) sub-task 1
(c) sub-task 2
(d) sub-task 3
(e) sub-task 4
(f) sub-task 5
(g) sub-task 6
(h) sub-task 7
(i) sub-task 8
(j) sub-task 9
Figure 3: MNIST training data for the disjoint CLT

c.2 CLT: rotations

(a) sub-task 0
(b) sub-task 1
(c) sub-task 2
(d) sub-task 3
(e) sub-task 4
Figure 4: MNIST training data for rotation sub-tasks.
Figure 5: MNIST training data for permutation-type CLTs.

c.3 CLT: permutations

(a) sub-task 0
(b) sub-task 1
(c) sub-task 2
(d) sub-task 3
(e) sub-task 4
Figure 6: Visualization of training data for the MNIST permutation CLT. The number of rows and columns corresponds to the number of sub-tasks, here: 5, In each raw, from left to right, we display the same selected set of samples from each sub-task in a rectangular area. In each line, the inverse transformation from the corresponding sub-task to the original data is applied, thus making data from a different sub-task interpretable in each line. This figure should be viewed together with Fig.7.

Appendix D Generated data

(a) sub-task 0
(b) sub-task 1
(c) sub-task 2
(d) sub-task 3
(e) sub-task 4
Figure 7: Visualization of data generated during training of marginal replay + GAN on the MNIST permutation CLT. The number of rows and columns corresponds to the number of sub-tasks, here: 5, In each line, from left to right, we display a selected set of samples generated after training on each sub-task in a rectangular area. In each raw, the inverse transformation from the corresponding sub-task to the original data is applied, thus making generated data from a different sub-task interpretable in each line. We observe stable retention behavior as the number of sub-tasks increases, while data from new sub-task is learned successfully as well.

Appendix E Accuracy on first sub-task

(a) MNIST: disjoint CLT
(b) Fashion MNIST: disjoint CLT
(c) MNIST: permutation CLT
(d) Fashion MNIST: permutation CLT
(e) MNIST: rotation CLT
(f) Fashion MNIST: disjoint CLT
Figure 8:

Comparison of the accuracy of each approach on the first sub-task. This is another, very intuitive measure of how much is forgotten during continual learning. Means and standard deviations computed over 8 seeds.

Appendix F Fid

(a) MNIST: disjoint CLT
(b) Fashion MNIST: disjoint CLT
(c) MNIST: permutation CLT
(d) Fashion MNIST: permutation CLT
(e) MNIST: rotation CLT
(f) Fashion MNIST: disjoint CLT
Figure 9: Comparison of the Fréchet Inception Distance (FID) between generated samples and test set samples for all replay-based methods (lower is better). Abrupt growth of the FID is symptomatic of a divergence in the generator and the generation of samples completely alien to the dataset.