Overcoming Catastrophic Forgetting by Generative Regularization

12/03/2019 ∙ by Patrick H. Chen, et al. ∙ 0

In this paper, we propose a new method to overcome catastrophic forgetting by adding generative regularization to Bayesian inference framework. We could construct generative regularization term for all given models by leveraging Energy-based models and Langevin-Dynamic sampling. By combining discriminative and generative loss together, we show that this intuitively provides a better posterior formulation in Bayesian inference. Experimental results show that the proposed method outperforms state of-the-art methods on a variety of tasks, avoiding catastrophic forgetting in continual learning. In particular, the proposed method outperforms previous methos over 10% in Fashion-MNIST dataset.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Many real-world machine learning applications require systems to learn a sequence of tasks in an incremental way. For example, a recommendation system should establish a general recommendation for all users, and customize the recommendation if certain users behave differently later. In particular, it’s not uncommon that previous customer data could not be accessed due to increasingly more strict data regulations. Critically, the sequence of tasks may not be explicitly labelled, tasks may change over time, any individual task may not recur for long time intervals and entirely new tasks can emerge

(Sutton et al., 2014; Kirkpatrick et al., 2017)

. Thus, an intelligent agent must adapt to perform well on the entire set of tasks in an incremental way that avoids revisiting all previous data at each stage. While in practice, previous research found that deep neural networks tends to lose the information of previous tasks (e.g. task A) when a new task (e.g. task B) is introduced. This phenomenon is called catastrophic forgetting

(McCloskey and Cohen, 1989; Ratcliff, 1990).

To achieve continual learning, a common strategy is to fix certain parameters in the model. When a new incoming task arrives, based on certain criteria. each method could decide whether to reuse certain fixed parameters (Fernando et al., 2017), whether to expand or duplicate some parts of the model (Rusu et al., 2016; Yoon et al., 2018) or search for the best new architecture of model (Li et al., 2019). These methods work well in practice on several large-scale datasets, but it comes at the cost that the model size expands correspondingly to number of new tasks. More importantly, it only demonstrates the capability to achieve continual learning, but do not discuss why and how catastrophic forgetting happens. Instead of adapting the model structure, another line of research fixes the model structure, and try to discuss the catastrophic forgetting phenomena by thinking the incremental training as a moving path in parameter space. By constraining the search directions and updates, these methods could partially approximate the path to the optimal parameter (Kirkpatrick et al., 2017; Nguyen et al., 2018; Zenke et al., 2017; Smola et al., 2003).

Among these methods, Variation Continual Learning (VCL) (Nguyen et al., 2018)

views the problem from canonical Bayesian view, and proposes to use variational methods to approximate posterior. In practice, VCL achieves good performance on various benchmarks. However, one deficit is that VCL formulates posterior distribution by assuming every parameter to be independent, which does not hold in general. For instance, in neural networks, convolution layers in the network could be learning representative features and fully connected layers could be learning classification hyper-planes. Clearly, if the parameters of convolutional layers changes, the feature representation will change, and the fully connected layer should adjusts its parameters correspondingly to the new representation in order to achieve low training error. Therefore, to obtain an accurate posterior estimation, we need to consider the dependency between parameters. Indeed, the estimation of this dependency is difficult as it still requires accurate raw data distribution of all previous tasks. Therefore, it’s unlikely to formulate the dependency explicitly. However, this motivates us to think about the the importance of data generative capability in training process. Despite we cannot generate the data to model the dependency between parameters directly, we might implicitly embed this information by leveraging data generative process as a regularization to constrain the model updates.

Based on this observation, in this work, we propose to use generative modeling as a regularization in Bayesian vairation learning to overcome the catastrophic forgetting. For any given discriminative network structure, we could construct a generative process by formulating a corresponding energy-based model. In each step of variational estimation, we use both disriminative loss and generative loss as the training objective. Our contribution is summarised as follows. Firstly, we analyze the Bayesian approach in continual learning setup and point out the deficiency of parameter independent assumption. Secondly, we propose to use energy-based model with Langevin Dynamic Sampling as an implicit dependency regularization in training discriminative task. Empirically, we show that the proposed variational learning with generative regularization works well in all benchmark datasets.

2 Related Work

Continual learning by Regularization There are a rich body of methods directly solving catastrophic forgetting problem. EWC (Kirkpatrick et al., 2017) whose objective is to minimize the change of weights that are important to previous tasks through the estimation of diagonal empirical fisher information matrix. SI (Zenke et al., 2017)

proposes to alleviate catastrophic forgetting by allowing individual synapse to estimate their importance for solving learned tasks, then penalizing the change on the important weights. IMM

(Lee et al., 2017) trains individual models on each task and then carries out a second stage of training to combine them. VCL (Nguyen et al., 2018) takes a bayesian point of view to model the sequential learning procedure. This paper falls in this line of research and we will be mostly comparing to state-of-the-art methods.

Continual learning by Model Adaption Another class of methods are related to the regime of continual learning continual, and these methods allow the model to expand. Moreover, this class of method will keep the parameters used in the seen tasks unchanged to achieve the continual learning, which is different from catastrophic forgetting problem as it does not answer why forgetting happens. PathNet (Fernando et al., 2017) selects paths between predefined modules, and tuning is allowed only when an unused module is selected. Dynamically expandable networks (DEN) (Yoon et al., 2018) selects whether to expand or duplicate layers based on certain criteria for an incoming new task. Similar method such as Progressive Network (Rusu et al., 2016) also tries to leverage previous similar strategies adopted in progressive networks. ollowing this line of research, (Li et al., 2019) proposed to solve the continual learning by explicitly taking into account continual structure optimization via differentiable neural architecture search.

Generative Models Previous works also try to alleviate catastrophic forgetting by introducing memory systems which store previous data and replay the stored old examples with the new data (Robins, 1995; Rebuffi et al., 2017; Li et al., 2018; Lopez-Paz and Ranzato, 2017). Recently, as Generative Adversarial Network becomes more popular, (Shin et al., 2017) ) proposes to learn a generative model to capture the data distribution of previous tasks, and use the generated data to alleviate forgetting of previous tasks. However, this work does not embed both discriminative and generative loss in one model. Our work is also related to Energy-based models. We refer readers to (LeCun et al., 2006) for a more comprehensive review. The primary difficulty in training EBMs comes from estimation of the partition function. Our work follows the derivation in (Dai et al., 2019).

3 Methods

An arbitrary classification model M, with parameters denoted as , consists of parameters shared across all tasks and parameters dedicated to the specific tasks . The sequential tasks are defined as , where each defines a classification task. At each timestamp, only one dataset could be obtained and all previous datasets cannot be completely accessed. The machine learning task is to achieve good classification accuracy on all tasks after sequntial learning of all T tasks.

In the literature, such a setup is called multi-head models as the bottom part of the model is shared across all tasks and the top of model use task-specific parameters. The top layer could also be a shared structure (i.e. ) such that all tasks are using the same set of parameters. This setup is called single-head setup.

3.1 Revisit Bayesian Inference

Following the canonical Bayesian setting, we assume some previous knowledge of model parameters . According to Bayes’ rule, posterior distribution after observing datasets could be written as:

Therefore, we could see that if posterior approximation works well, Bayesian approach will be handling online learning setup naturally. However, we need to point out that one important deficit when applying Bayesian approach in overcoming catastrophic forgetting. In multi-head model setup at step , as shown in Figure  1, the posterior function would be further decomposed into:


where shared model parameters are assumed to be independent of individual head network . However, assuming independence between and is not true in general. A correct posterior function should be


where is the intermediate feature produced by applying shared model. To estimate , from the condition we notice that is required. Thus, to precisely estimate this relation is not possible in continual learning setup since inevitably we will need all previous datasets to construct such distribution. On the other hand, it shows the importance of data information in establishing the dependency between variables. Thus, instead of explicitly formulate this function, we propose to add a generative term in the training which could help to regularize the estimation of desired posterior distribution in Bayesian inference. To do so, we propose to leverage Engergy-based Models (EBM) which could be constructed on top of any classification machine learning models.

Figure 1: Illustration of the multi-head models.

3.2 Energy Based Model

For any given discriminative model (e.g. deep neural networks for classification tasks) parameterized by

, we could define an energy model as following probability distribution:


where . In this work, is a neural network parameterized by . We could train the energy model by MLE objective function:


However, directly solving the MLE of general exponential family is intractable. To alleviate the computation, Contrastive Divergence (CD) is proposed in

(Hinton, 2002). CD estimates our energy model’s gradient as:


where denotes the underlying energy-based model. The calculation of the second term could be calculated as firstly sample (batch of) data by using Langevin dynamic sampling shown in Algorithm 1, and then calculate the to stochastically get the estimated value of .

3.3 Bayesian Inference as Learning with Generative Regularization

With the formulation of generative loss, we could train a model to have both discriminative loss and generative loss. The core training objective in variational method is to approximate posteriors by using KL-divergence:


where is the functional space of posterior distribution. For simplicity, we follow the literature to assume to represent mean-field distribution parameterized by . By ELBO derivation, we can show that training using (6) is equivalent to maximizing following objective

where we could generate a model parameter by sampling it from . Recall the , thus we can rewrite the objective as


where the can be understood as the ordinary discriminative loss, while both and

can be understood as generative regularizations that match the empirical joint distribution and marginal distribution simultaneously. Contrastive Divergenc provides the estimation of gradient of

and here we give a derivation of unbiased gradient estimator of in the following theorem.

Theorem 1

Given a discriminative model , the unbiased gradient estimator of corresponding Energy-based model is given by

Proof We postpone the derivation to be appendix.

Therefore, we could obtain the derivative of eq (7) by using eq  (5) and Theorem  1:

where is data generated from our model and is training instance sampled from true data distribution with sampled from a mixture of conditional and training sets. Again, to generate the samples from the current model, we exploit the hybrid Monte-Carlo (Neal and others, ), specifically the Langevin dynamics sampler, as listed in Algorithm 1. The first term corresponds to the common discriminative loss used in training deep neural networks. The second term is the regularization introduced in Bayesian setup. These two terms correspond to the gradient of forward neural network computation, and thus it could be obtained by back-propagation of underlying model . The rest two terms correspond to the generative regularization. The generative regularization enables the model to implicitly keep the information of training data such that the forgetting problem could be alleviated. The overall proposed method is summarized in Algorithm  2

  Input: Buffer
  Output: Sampled data
  for  do
      + , ,
  end for
  Return .
Algorithm 1 Gibbs-Langevin Dynamic Sampling
  Input: Dataset of task t , Posterior distribution of previous tasks

, Number of training epochs

and learning rate
  Output: Posterior distribution of learned model
  for  do
     Generate sample by Algorithm  1
     Calculate gradient via (3.3).
      = -
  end for
Algorithm 2 Algorithm of Proposed Method at task t.

4 Experiments

4.1 Datasets

We evaluated the proposed method on following three tasks.

Permuted-MNIST Permuted-MNIST is a very popular benchmark dataset in continual learning literature. The dataset received at each time step consists of labeled MNIST images whose pixels have undergone a fixed random permutation.

Split-MNIST This experiment was used by (Zenke et al., 2017). Five binary classification tasks from the MNIST dataset arrive in sequence: 0/1, 2/3, 4/5, 6/7, and 8/9.

Permuted Split Fashion-ME Fashion-CE Fashion-MH Fashion-CH
Original 37 90 86.8 74.6 70.8 68.0
All-data 99.3 99.1 99.3 99.3 94.0 94.0
EWC 87.5 97.4 95.9 82.2 81.7 71.0
VCL 91.7 96.7 89.6 76.9 80.8 73.8
Ours 92.7 98.5 98.2 97.3 89.1 86.4
Table 1: Comparisons on continual learning tasks. Results shown in the table are average classification accuracy (in ) of each task.
Figure 2: Detailed Continual Classification Results of Permuted-MNIST.
Figure 3: Detailed Continual Classification Results of Split-MNIST.

Split-Fashion-MNIST Fashion-MNIST (Xiao et al., 2017), similar to MNIST dataset, consisting of a training set of 60000 examples and a test set of 10000 examples. Each example is a 28 x 28 grayscale image, associated with a label from 10 classes. This dataset represents more realistic features of real-world images and thus becomes an increasingly popular benchmark. For this task, we create two different splits. The first split partitions the classes by class label sequence as in Split-MNIST. For the second partition, we deliberately select similar objects to be in the same pair, resulting an arriving sequence: 0/6, 7/9, 2/4, 3/8, 1/5, which translates to Top/Shirt, Sneaker/Ankle Boot, Pullover/Coat, Dress/Bag and Top/Sandal. We can see that the first 3 splits are comparing similar objects. This split prevents the model learn small number of powerful discriminative features.

(a) Split-Fashion-MNIST with Easy Setup
(b) Split-Fashion-MNIST with Hard Setup
Figure 4: Detailed Continual Classification Results of Split-Fashion-MNIST on MLP.
(a) Split-Fashion-MNIST with Easy Setup
(b) Split-Fashion-MNIST with Hard Setup
Figure 5: Detailed Continual Classification Results of Split-Fashion-MNIST on CNN.

4.2 Baseline Methods and Implementation Details

We compare our method to the following baseline methods:

  • Original: simply trains each task in an incremental setup without any regularization. It serves as the bottom line of all proposed methods.

  • All-data: trains the tasks jointly assuming all datasets are available. At each step, a random dataset is sampled and then a batch of data is sampled from the dataset. It serves as the upper bound of all continual learning methods as no forgetting will happen under this scheme.

  • EWC (Kirkpatrick et al., 2017): builds the importance estimation on top of diagonal Laplace propagation by calculating the diagonal of empirical Fisher information.

  • VCL (Nguyen et al., 2018): conducts variational inference from Bayesian point of view of continual learning.

For EWC and VCL, we follow the released open source implementation


. For each dataset/task, we compare these methods under the same network architecture. For Permuted-MNIST and Split-MNIST, we use a Multi-layer Perceptron model (MLP) with 2 hidden layers and each layer has dimension 256. ReLU is used as the activation function. For Permuted-MNIST, we use single-head model and for Split-MNIST we use multi-head model. We denotes

Permuted and Split for these 2 experimental setups. For Fashion-MNIST dataset, we evaluate the results on two models. In addition to the MLP introduced above, we also validate on Convolutaional Neural Networks (CNN) with 4 layers of convolutional layer (32,1), (64,32), (64,64), (64,64) followed by one layer of fully connected layer. In combination of 2 sequence split, we denote Fashion-ME, Fashion-MH for MLP model with easy and hard split respectively. Fashion-CE and Fashion-CH for CNN models. For Fashion-MNIST experiments, all models are used in the multi-head way.

All the models are trained with an ADAM optimizer with initial learning rate 1e-3. For this work, is set to 1 for all experiments.

4.3 Results

The evaluation metric used is the average classification accuracy over all tasks. Results are summarized in Table 

1. Detailed results of classification after observing each task are shown in Figure  3 for Split-MNIST and Figure  2 for Permuted-MNIST. And for Fashion-MNIST-MLP and Fashion-MNIST-CNN, each difficulty setup is shown in Figure  4 and Figure  5 respectively. Firstly, we notice that the training over almost all datasets remains high accuracy for the ”All-data” setup. Therefore, we know that each of the underlying tasks is not difficult. The only exception is the Fashion-MNIST difficult split. We could observe that the proposed split itself is indeed a more difficult classification problem.

We can see that our proposed method outperforms baselines in all tasks. In particular, we can see that the improvement is significant in Fashion-MNIST dataset which contains more real-world objects. We can also observe that when similar objects are paired together, it makes forgetting more prominent.

5 Conclusion

In this paper, we propose to use generative loss as a regularization in training Bayesian inference task. By applying energy-based model and hybrid monte carlo sampling process, we could evaluate the generative capability of underlying model. Experimental results show that when the generative loss is combined with Bayesian inference framework, it could alleviate catastrophic forgetting significantly without modifying underlying model architecture. On Fashion-MNIST datset, the proposed method outperforms state-of-the-art method over overall classification accuracy.


  • B. Dai, Z. Liu, H. Dai, N. He, A. Gretton, L. Song, and D. Schuurmans (2019) Exponential family estimation via adversarial dynamics embedding. arXiv preprint arXiv:1904.12083. Cited by: §2.
  • C. Fernando, D. Banarse, C. Blundell, Y. Zwols, D. Ha, A. A. Rusu, A. Pritzel, and D. Wierstra (2017) PathNet: evolution channels gradient descent in super neural networks. CoRR abs/1701.08734. External Links: Link, 1701.08734 Cited by: §1, §2.
  • G. E. Hinton (2002) Training products of experts by minimizing contrastive divergence. Neural computation 14 (8), pp. 1771–1800. Cited by: §3.2.
  • J. Kirkpatrick, R. Pascanu, N. Rabinowitz, J. Veness, G. Desjardins, A. A. Rusu, K. Milan, J. Quan, T. Ramalho, A. Grabska-Barwinska, et al. (2017) Overcoming catastrophic forgetting in neural networks. Proceedings of the national academy of sciences 114 (13), pp. 3521–3526. Cited by: §1, §1, §2, 3rd item.
  • Y. LeCun, S. Chopra, R. Hadsell, M. Ranzato, and F. Huang (2006) A tutorial on energy-based learning. Predicting structured data 1 (0). Cited by: §2.
  • S. Lee, J. Kim, J. Ha, and B. Zhang (2017)

    Overcoming catastrophic forgetting by incremental moment matching

    CoRR abs/1703.08475. External Links: Link, 1703.08475 Cited by: §2.
  • X. Li, Y. Zhou, T. Wu, R. Socher, and C. Xiong (2019) Learn to grow: A continual structure learning framework for overcoming catastrophic forgetting. CoRR abs/1904.00310. External Links: Link, 1904.00310 Cited by: §1, §2.
  • Y. Li, Z. Li, L. Ding, Y. Pan, C. Huang, Y. Hu, W. Chen, and X. Gao (2018) Supportnet: solving catastrophic forgetting in class incremental learning with support data. arXiv preprint arXiv:1806.02942. Cited by: §2.
  • D. Lopez-Paz and M. Ranzato (2017) Gradient episodic memory for continual learning. In Advances in Neural Information Processing Systems, pp. 6467–6476. Cited by: §2.
  • M. McCloskey and N. J. Cohen (1989) Catastrophic interference in connectionist networks: the sequential learning problem. In Psychology of learning and motivation, Vol. 24, pp. 109–165. Cited by: §1.
  • [11] R. M. Neal et al. MCMC using hamiltonian dynamics. Cited by: §3.3.
  • C. V. Nguyen, Y. Li, T. D. Bui, and R. E. Turner (2018) Variational continual learning. In International Conference on Learning Representations, External Links: Link Cited by: §1, §1, §2, 4th item.
  • R. Ratcliff (1990) Connectionist models of recognition memory: constraints imposed by learning and forgetting functions.. Psychological review 97 (2), pp. 285. Cited by: §1.
  • S. Rebuffi, A. Kolesnikov, G. Sperl, and C. H. Lampert (2017)

    Icarl: incremental classifier and representation learning


    Proceedings of the IEEE conference on Computer Vision and Pattern Recognition

    pp. 2001–2010. Cited by: §2.
  • A. Robins (1995) Catastrophic forgetting, rehearsal and pseudorehearsal. Connection Science 7 (2), pp. 123–146. Cited by: §2.
  • A. A. Rusu, N. C. Rabinowitz, G. Desjardins, H. Soyer, J. Kirkpatrick, K. Kavukcuoglu, R. Pascanu, and R. Hadsell (2016) Progressive neural networks. CoRR abs/1606.04671. External Links: Link, 1606.04671 Cited by: §1, §2.
  • H. Shin, J. K. Lee, J. Kim, and J. Kim (2017) Continual learning with deep generative replay. In Advances in Neural Information Processing Systems, pp. 2990–2999. Cited by: §2.
  • A. J. Smola, V. Vishwanathan, and E. Eskin (2003) Laplace propagation.. In NIPS, pp. 441–448. Cited by: §1.
  • R. S. Sutton, S. D. Whitehead, et al. (2014) Online learning with random representations. In Proceedings of the Tenth International Conference on Machine Learning, pp. 314–321. Cited by: §1.
  • H. Xiao, K. Rasul, and R. Vollgraf (2017) External Links: cs.LG/1708.07747 Cited by: §4.1.
  • J. Yoon, E. Yang, J. Lee, and S. J. Hwang (2018) Lifelong learning with dynamically expandable networks. In International Conference on Learning Representations, External Links: Link Cited by: §1, §2.
  • F. Zenke, B. Poole, and S. Ganguli (2017) Continual learning through synaptic intelligence. In Proceedings of the 34th International Conference on Machine Learning - Volume 70, ICML’17, pp. 3987–3995. External Links: Link Cited by: §1, §2, §4.1.


Appendix A Proof of Theorem 1

Theorem 1

Given a discriminative model , the unbiased gradient estimator of corresponding Energy-based model is given by

Proof Notice that we could derive ELBO of as:


and we know the maximal would be obtained when , which implies that optimal is . Thus, we will have

(12) (13) (14) (15)

Thus, we could obtain by taking derivative of eq (15):

(16) (17) (18) (19) (20) (21) (22) (23) (24)

Notice that the outer expectation could be taken off since after inner expectation, there won’t be any randomness on .