Log In Sign Up

Learning Bayesian Sparse Networks with Full Experience Replay for Continual Learning

by   Dong Gong, et al.

Continual Learning (CL) methods aim to enable machine learning models to learn new tasks without catastrophic forgetting of those that have been previously mastered. Existing CL approaches often keep a buffer of previously-seen samples, perform knowledge distillation, or use regularization techniques towards this goal. Despite their performance, they still suffer from interference across tasks which leads to catastrophic forgetting. To ameliorate this problem, we propose to only activate and select sparse neurons for learning current and past tasks at any stage. More parameters space and model capacity can thus be reserved for the future tasks. This minimizes the interference between parameters for different tasks. To do so, we propose a Sparse neural Network for Continual Learning (SNCL), which employs variational Bayesian sparsity priors on the activations of the neurons in all layers. Full Experience Replay (FER) provides effective supervision in learning the sparse activations of the neurons in different layers. A loss-aware reservoir-sampling strategy is developed to maintain the memory buffer. The proposed method is agnostic as to the network structures and the task boundaries. Experiments on different datasets show that our approach achieves state-of-the-art performance for mitigating forgetting.


Understanding Catastrophic Forgetting and Remembering in Continual Learning with Optimal Relevance Mapping

Catastrophic forgetting in neural networks is a significant problem for ...

Consistency is the key to further mitigating catastrophic forgetting in continual learning

Deep neural networks struggle to continually learn multiple sequential t...

Routing Networks with Co-training for Continual Learning

The core challenge with continual learning is catastrophic forgetting, t...

Continual Learning via Neural Pruning

We introduce Continual Learning via Neural Pruning (CLNP), a new method ...

Integral Continual Learning Along the Tangent Vector Field of Tasks

We propose a continual learning method which incorporates information fr...

Revisiting Experience Replay: Continual Learning by Adaptively Tuning Task-wise Relationship

Continual learning requires models to learn new tasks while maintaining ...

1 Introduction

Humans continually acquire new skills and knowledge while maintaining the ability to perform tasks they learned in childhood. Continual Learning (CL) [ring1998child, rebuffi2017icarl] aims to mimic this process, enabling learning of new tasks sequentially without storing all of the associated data. However, even with the help of CL, deep neural networks often suffer from serious performance degradation on previously learned tasks when learning new ones. This phenomenon is known as catastrophic forgetting [mccloskey1989catastrophic].

Figure 1: Without continual learning, training on a new task interferes with the knowledge learned in mastering previous tasks, thus introducing the risk of catastrophic forgetting. Simple experience replay approaches can help preserve the learned knowledge. However, the regularizations work on the network parameters as a whole, invariably leading to forgetting. The proposed SNCL learns sparse networks at all stages, enabling more intrinsic and common representations with less model capacity. It reserves more capacity/parameter space for future tasks, resulting in less interference and forgetting.

The above issues motivate more effective CL algorithms [chaudhry2020using, chaudhry2019continual, xu2018reinforced, oren2021defense] to mitigate catastrophic forgetting when the number of tasks and associated knowledge increases over time. For this purpose, existing approaches modify the the network learning strategies in different ways, such as the regularization methods [kirkpatrick2017overcoming, zenke2017continual, li2017learning], memory-based experience replay [lopez2017gradient, chaudhry2018efficient, chaudhry2019continual, buzzega2020dark], and dynamic modular approaches [rusu2016progressive, yoon2017lifelong, abati2020conditional]. Specifically, Experience Replay (ER) methods [lopez2017gradient, chaudhry2019tiny, buzzega2020dark] mitigate catastrophic forgetting by replaying a selection of previously seen data maintained in a small episodic memory along with the new task data. Knowledge Distillation (KD) [hinton2015distilling, gou2021knowledge] based methods [li2017learning, rebuffi2017icarl, buzzega2020dark] alleviate forgetting and encourage knowledge transfer by applying the old models as teachers. All these approaches learn all sequential tasks with the same whole parameter space. The learner potentially uses all neurons in the over-complete network to decrease the loss on current data. These characteristics inhibit the effectiveness in reducing the interference and invariably lead to forgetting of the past tasks (See Fig. 1).

In this work, we propose to learn Sparse neural Networks for Continual Learning (SNCL). We enforce the sparsity of the network neurons in the learning stages to reserve more parameters for future tasks. This prevents learning a new task from interfering with the parameters critical to achieving previously learned tasks (See Fig. 1). Specifically, we introduce to use variational Bayesian sparsity (VBS) priors on activations to induce a sparse network in a principled manner. The sparse prior helps the model to learn a given task in the most compact set of neurons. This allows far greater control over the process of learning, and forgetting. It particularly enables preserving of resources for future learning, and alleviates potential interference when new tasks are learned. A small replay buffer allows the commonality between new and previous tasks to be exploited, thus increasing the network’s total learning capacity and reducing the risk of forgetting.

The variational Bayesian framework flexibly allows learning the sparse networks by only relying on the backpropagation of the supervision signal (

e.g., the classification loss). As a result, instead of performing experience replay with only data labels, we propose a more effective experience replay strategy by enforcing the stability of the intermediate representations of old samples. We label this approach Full Experience Replay

(FER). To achieve this we store not only the past data in the memory but also the corresponding logits and intermediate layer features. FER thus helps activate the appropriate neurons in each layer more effectively, which helps avoid representation drift. Note that, in contrast to many similar knowledge distillation-based methods 

[rebuffi2017icarl, li2017learning], the proposed approach is not limited to the case where the task boundaries are known. We also introduce an effective Loss-aware Reservoir Sampling (LRS) strategy to maintain the memory, which selects samples on the basis of their ability to improve performance. In summary, our contributions are as follows:

  • [itemsep=1pt,topsep=0pt,parsep=0pt]

  • We propose to learn sparse networks for continual learning with variational Bayesian sparsity priors on the neurons (SNCL). The proposed method alleviates the catastrophic forgetting and interference by enforcing sparsity of the networks to reserve parameters for future tasks. The memory based experience reply al- lows the commonality between new and previous tasks to be exploited, thus increasing the network’s total learning capacity and reducing the risk of forgetting.

  • We propose Full Experience Replay (FER) to store and replay with old samples’ intermediate layer features generated when observed in data steam, unlike previous methods with only labels. It provides more effective supervision on learning the sparse activation of the neurons at different layers. A loss-aware reservoir sampling strategy is proposed to maintain the memory.

  • The proposed approaches can generally improve the performance of the ER based methods, achieving state-of-the-art results under the same setting. The proposed method is agnostic and general to the network structures and task boundaries.

2 Related Work

2.1 Continual Learning Settings

Early work in CL focussed on Task Incremental Learning (Task-IL) [delange2021continual, kirkpatrick2017overcoming, zenke2017continual, shin2017continual] whereby models are trained on a series of distinct tasks with clear boundaries and no overlap. Recent work has focused on a more realistic setting whereby tasks may change gradually, named Class Incremental Learning (Class-IL) [rebuffi2017icarl, zeno2018task, rajasegaran2020itaml, hou2019learning]. In Class-IL, models are not provided with task boundaries during testing and must the infer task identity from the data. Although huge advance has been obtained, the majority of Task-IL and Class-IL approaches are unsuited for real-world applications, where task boundaries are unclear and overlap. Recently, General Continual Learning (GCL) [delange2021continual, buzzega2020dark] has been introduced whereby: (i) the training and testing phases do not rely on boundaries between tasks; (ii) the testing phase does not use task identifiers; (iii) Memory size is limited. These challenging guidelines guarantee that GCL can be applied to practical scenarios.

2.2 Continual Learning Approaches

Regularization Based Methods.

Regularization based methods use regularization terms in the loss function to encourage stability on previous tasks and prevent forgetting. Kirkpatrick

et al. proposed Elastic Weight Consolidation (EWC) [kirkpatrick2017overcoming] which evaluates the importance of parameters using the Fisher information matrix. Li et al. [li2017learning] regularized the network using knowledge distillation [hinton2015distilling]. To alleviate forgetting, Tao et al. [tao2020topology] propose maintaining the topology of the network’s feature space. Yu et al. [yu2020semantic], in contrast, approximate the drift in previous task performance based on the drift that is experienced by current task data.

Parameter Isolation Based Methods. Parameter isolation based methods are only employed in Task-IL. To alleviate the forgetting, related approaches allocate each task with different subnetwork’s parameters. Rusu et al. [rusu2016progressive] proposed progressive networks to leverage prior knowledge via lateral connections to previously learned features. Xu et al. [xu2018reinforced]

searched for the best neural network architecture for each coming task by leveraging reinforcement learning strategies. Saha

et al. [saha2020structured] enabled a network to learn continually and efficiently by partitioning the learned space into a core space and a residual space. Golkar et al. [golkar2019continual] isolate the parameters for different tasks via neural pruning and regain model capacity with graceful forgetting for performance. In [aljundi2018selfless], neural inhibition is used to impose sparsity in neural networks for CL.

Replay Based Methods. To reduce the forgetting, replay-based methods store a subset of previously-seen data, or important gradient spaces from the previous tasks in a memory buffer, or produce synthetic data using generative models for pseudo-rehearsal. Chaudhry et al. [chaudhry2019continual] proposed Experience Replay (ER) that jointly trains the model with the data from the new tasks and memory. Buzzega et al. [buzzega2020dark] improved ER, named Dark Experience Replay (DER), by mixing rehearsal with knowledge distillation and regularization. Gradient Episodic Memory (GEM) [lopez2017gradient] and Averaged GEM (A-GEM) [chaudhry2018efficient]

used samples from the memory buffer to estimate gradient constraint so that loss on the previous task does not increase.

Figure 2: The framework of the proposed SNCL. It uses Variational Bayesian Sparse (VBS) and Full Experience Replay (FER) to learn new tasks and alleviate forgetting. is the -th task loss. alleviates the catastrophic forgetting and interference by enforcing sparsity of the networks. Full experience replay loss provides more effective supervision on learning the sparse activation of the neurons at different layers. The memory is updated by the proposed loss-aware reservoir sampling strategy.

3 Learning Sparse Networks for Continual Learning

A CL problem can be formulated as learning from an ordered sequence of datasets, each from one of tasks , where represents the data examples of -th task. During each task , the samples in are drawn from an i.i.d. distribution associated with the task. The goal is to carry out sequential training while maintaining the performance of the previous tasks. The task boundaries can be available or unavailable (e.g., the GCL setting) during both training and testing stages under different settings. CL aims to learn a function , with parameters , for all the tasks by being optimized on one task at a time in a sequential manner.

3.1 Overview

We focus on the CL classification problem in this paper. Given any sample , we let represent the output logits of the . For classification, predictions for is made using . Formally, the objective of CL can be formulated as learning an , i.e., the parameter , that can minimize the classification loss on all the tasks :


where represents cross entropy loss. However, it is challenging since the datasets ’s are observed in a sequential manner. When we train the model on any task , all the previous data samples are unavailable. The knowledge/representations learned on past tasks are easily flushed by the new task data with distribution shifting, leading to catastrophic forgetting. To encourage the parameters to adapt to new tasks and maintain the knowledge from past tasks, as many CL approaches [chaudhry2019tiny, chaudhry2019continual, lopez2017gradient], we employ a replay buffer as an episodic memory to preserve a small subset of old experiences, i.e., samples in old tasks. During training, we use the newly arriving data of the current task and the examples sampled from the reply buffer to optimize the parameters. In our work, memory is updated in a modified Reservoir Sampling method [chaudhry2019tiny], i.e. LRS, being agnostic to the task boundaries. The cross entropy based classification loss on memory can be formulated as


In a sense, the replay memory can help the model be aware of the past tasks and provide some hints about the old knowledge. By applying only the loss functions in Eq. (1) and (2), we can arrive at the basic experience replay (ER) based CL methods [chaudhry2019continual, lopez2017gradient].

Although simply replaying the old data can help the model to preserve the performance on old tasks, the methods [buzzega2020dark, lopez2017gradient] still learn all the sequential data with the same whole parameter space. To decrease the loss on current data, the learner potentially uses all neurons and capacities in the over-complete network, when there is no restriction. Although the memory provides some gradients of the past tasks, interference on the parameters and forgetting can still easily happen, as shown in Fig. 1. This motivates us to study more effective capacity/parameter management mechanism CL. We thus propose to enforce sparsity on the network as shown in Fig. 2. In this way, the learner uses the most sustainable way to learn the current (and past) data, and thus reserves more capacity for the future tasks, alleviating possible interference in the future. As shown in Sec. 3.2, we learn the sparse network by applying variational Bayesian sparsity priors on the neurons. Relying on the supervision from the current data and past data in the memory, the proposed method automatically activates only a sparse set of neurons to maintain current and learned knowledge jointly. It also leads to more intrinsic and common knowledge for the seen tasks so far. To further enhance the supervisions for learning the sparse neurons across different layers, we propose the Full Experience Replay (FER) by directly enforcing the stability of the intermediate representations of the old samples (See Sec. 3.3).

3.2 Bayesian Sparse Networks

To learn sparse networks for CL, we introduce variational Bayesian sparse (VBS) prior on the activations, i.e., neuron responses on , in the neural network . For a network with layers, we let denote the -th neuron at the -th layer, where and denote the index of the layers and and the neuron at the -th layers, respectively. For a convolution layer , represents the -th channel of the produced feature map (with channels at -th layer). For a fully connected linear layer, represents the -th column of the produced feature matrix.

To apply the VBS priors on the neurons, we firstly introduce a scalar indicator for the sparsity of the neuron at -th layer -th channel. We define to denote the non-sparse version of produced in the network before applying the sparsity-inducing operation. For simplicity, without losing the generality, we induce the sparsity of the neurons by applying the scalar product on the activation representations [liu2019variational] as follows:


where denotes the pointwise product between the scalar and the activations . By enforcing sparsity on ’s, ’s are sparse activations as shown in Fig. 2. plays a key role to activate the nonzero . To enforce sparsity on

, we impose a hierarchical prior with a zero-mean Gaussian distribution with the variance

sampled from a uniform hyper-prior (i.e., ) [liu2019variational]:


Together with the following variational posterior with mean and variance :


we arrive at the following regularization term to encourage sparsity on (i.e., ):


As the posterior mean of , can be regarded as the actual sparse weights on the activations. The two parameters, and , can be adaptively optimized via backpropagation from the classification, FER, and the sparse regularizer Eq. (6). directly represents the sparsity level of each neuron. Specifically, when , the activation and the corresponding neuron will be set as zero and inactivated. The term in Eq. (6) is a concave, non-decreasing function on the domain , with respect to . Since the proposed layer is implemented in a Bayesian framework, the sparsity of network is obtained to faithfully and adaptively fit the data in a principled way.

3.3 Full Experience Replay

In ER based methods, a memory buffer is used to store the past samples. It stores the input and label of the samples and is maintained via sampling on the past data.

We learn the adaptive sparsity on the neurons only relying on the backpropagation of the loss. It is flexible but also renders the inefficient supervisions on the neurons at different layers, since the losses are added on the final output of the network. Instead of the ER only with original labels, we propose a Full Experience Replay (FER) strategy which retains the data, labels and features to the replay buffer. We formulate the replay buffer as with , where is the old model parameter, and denote the logits and intermediate layer features of the past samples, and denotes the buffer size. denotes the parameter when each sample is trained with the model. So for different samples should be different (with a slight notation abuse here). In this way the memory can maintain the knowledge from diverse status [buzzega2020dark].

Relying on the memory with logits and features, we define the loss function of FER with the following components. preserves the logits of the past samples under the response of current model and old model :


To provide more effectively supervision on learning the sparse activation of the neurons at different layers, is employed on all intermediate features:


Together with the ER loss on labels in Eq. (2), we arrive at the FER total loss as: The total loss of FER can be written as:


Although the formulates of FER loss are similar to the KD techniques, e.g., LWF [li2017learning] and iCaRL [rebuffi2017icarl], they are intrinsically different. The KD based LWF does not use a replay buffer and cannot use previous samples. Different to iCaRL employing the outputs of the network at the end of each task, we store both features and logits from the model when the sample is observed in the training process, helping to preserve diverse model status and be free to task boundries. Different to DER [buzzega2020dark] only relying on the logits to play the past knowledge, our FER obtains the logits and features for a full replay of the whole experience, which is tailored for promoting the estimation of the sparsity across all layers. Moreover, FER is also powerful for directly reliving the representation drifting of the past samples.

3.4 The Overall Loss

In summary, the cross entropy loss focuses on learning the current data; provides an adaptive sparsity regularization on the network neurons; helps to replay the full knowledge of the past samples. By integrating them together, the overall loss of SNCL can be written as:


3.5 Loss-aware Reservoir Sampling

Like all CL methods with ER, we maintain a small memory buffer for past data. For updating the memory, we rely on reservoir-type sampling (RS) [vitter1985random, buzzega2020dark, buzzega2021rethinking] to avoid using the task boundary information. Under the setting of FER, we use the memory buffer to store the full knowledge of the sampled set of data seen until -th learning step. When the memory is full, RS randomly overwrites the memory slot with the randomly selected new samples [chaudhry2019tiny]. We modify the standard RS as Loss-aware Reservoir Sample (LRS) to maintain memory with more balanced distributions on different classes and training loss values (i.e., learning difficulties).

Given a new data batch, LRS randomly samples a set of data to store in the memory. Algorithm 1 summarizes the loss-aware memory updating process. If the memory is full, LRS sorts the data points w.r.t. the loss values, specifically for each class seen so far (step 7), and then samples the data with diverse loss values (step 8). As a result, the updated memory can cover balanced classes and learning difficulties. We abuse the notation in step 6 for simplicity.

Input: Memory from previous stage , the sample set to store , memory size
Output: Updated memory buffer
1 if  < then
2      ;
4       Obtain the number of classes in as and initilize ;
5       for  do
6             ;
7             Sort w.r.t. the training loss values;
8             Obtain containing samples with diverse loss values by sampling in the sorted with interval ;
9             Update ;
Algorithm 1 Loss-aware Memory Updating
Buffer Method




Class-IL() Task-IL() Class-IL() Task-IL() Domain-IL() Domain-IL()
- JOINT 92.20±0.15 98.31±0.12 59.99±0.19 82.04±0.10 94.33±0.17 95.76±0.04
SGD 19.62±0.05 61.02±3.33 7.92±0.26 18.31±0.68 40.70±2.33 67.66±8.53
- oEWC [schwarz2018progress] 19.49±0.12 68.29±3.92 7.58±0.10 19.20±0.31 75.79±2.25 77.35±5.77
SI [zenke2017continual] 19.48±0.17 68.05±5.91 6.58±0.31 36.32±0.13 65.86±1.57 71.91±.83
LwF [li2017learning] 19.61±0.05 63.29±2.35 8.46±0.22 15.85±0.58 - -
PNN [rusu2016progressive] - 95.13±0.72 - 67.84±0.29 - -
200  ER [riemer2018learning] 44.79±1.86 91.19±0.94 8.49±0.16 38.17±2.00 72.37±0.87 85.01±1.90
GEM [lopez2017gradient] 25.54±0.76 90.44±0.94 - - 66.93±1.25 80.80±1.15
A-GEM [chaudhry2018efficient] 20.04±0.34 83.88±1.49 8.07±0.08 22.77±0.03 66.42±4.00 81.91±0.76
iCaRL [rebuffi2017icarl] 49.02±3.20 88.99±2.13 7.53±0.79 28.19±1.47 - -
HAL [chaudhry2020using] 32.36±2.70 82.51±3.20 - - 74.15±1.65 84.02±0.98
DER [buzzega2020dark] 61.93±1.79 91.40±0.92 11.57±0.78 40.22±0.67 81.74±1.07 90.04±2.61
SNCL (Ours) 66.16±1.48 92.91±0.81 12.85±0.69 43.01±1.67 86.23±0.20 91.54±2.58
500 ER[riemer2018learning] 57.74±0.27 93.61±0.27 9.99±0.29 48.64±0.46 80.60±0.86 88.91±1.44
GEM [lopez2017gradient] 26.20±1.26 92.16±0.69 - - 76.88±0.52 81.15±1.98
A-GEM [chaudhry2018efficient] 22.67±0.57 89.48±1.45 8.06±0.04 25.33±0.49 67.56±1.28 80.31±6.29
iCaRL [rebuffi2017icarl] 47.55±3.95 88.22±2.62 9.38±1.53 31.55±3.27 - -
HAL [chaudhry2020using] 41.79±4.46 84.54±2.36 - - 80.13±0.49 85.00±0.96
DER [buzzega2020dark] 70.51±1.67 93.40±0.39 17.75±1.14 51.78±0.88 87.29±0.46 92.24±1.12
SNCL (Ours) 76.35±1.21 94.02±0.43 20.27±0.76 52.85±0.67 88.53±0.41 93.05±1.02
5120 ER [riemer2018learning] 82.47±0.52 96.98±0.17 27.40±0.31 67.29±0.23 89.90±0.13 93.45±0.56
GEM [lopez2017gradient] 25.26±3.46 95.55±0.02 - - 87.42±0.95 88.57±0.40
A-GEM [chaudhry2018efficient] 21.99±2.29 90.10±2.09 7.96±0.13 26.22±0.65 73.32±1.12 80.18±5.52
iCaRL [rebuffi2017icarl] 55.07±1.55 92.23±0.84 14.08±1.92 40.83±3.11 - -
HAL [chaudhry2020using] 59.12±4.41 88.51±3.32 - - 89.20±0.14 91.17±0.31
DER [buzzega2020dark] 83.81±0.33 95.43±0.33 36.73±0.64 69.50±0.26 91.66±0.11 94.14±0.31
SNCL (Ours) 90.41±0.46 97.11±0.19 39.83±0.52 70.52±0.37 92.93±0.11 94.83±0.34
Table 1: Classification results for standard CL benchmarks. Average accuracy across 10 runs.

4 Experiments

Datasets. We test the proposed method on several public datasets generally used in continual learning: sequential CIFAR-10 [lopez2017gradient]

, sequential Tiny ImageNet

[chaudhry2019continual], Permuted MNIST [kirkpatrick2017overcoming], Rotated MNIST [lopez2017gradient] and MNIST-360 [buzzega2020dark].

CIFAR-10 [krizhevsky2009learning] includes 10 classes, each class has 5000 training examples and 1000 test examples. Tiny ImageNet [le2015tiny] is a subset of ImageNet, which contains 200 classes, and each class has 500 training examples. Sequential CIFAR-10 [lopez2017gradient] consists 5 sequential tasks, each of which has 2 classes. Sequential Tiny ImageNet is built by splitting 100 classes of Tiny ImageNet into 20 tasks where each task has 5 classes. Permuted MNIST [kirkpatrick2017overcoming], Rotated MNIST [lopez2017gradient] both have 20 subsequent tasks, Permuted MNIST employs a random permutation to the pixels, Rotated MNIST rotates the inputs by a random angle in the interval . Buzzega et al. randomly choose two MNIST numbers from 0 to 8 (e.g., (0, 1), (1, 2), etc.) and rotates those numbers by an increasing angle. Then, MNIST-360 [buzzega2020dark] designs a stream of data presenting batches of two consecutive digits at a time.

Experimental Protocol. We evaluate our method under different settings. We use Sequential CIFAR-10 and Sequential Tiny ImageNet to validate the proposed method on Task-IL setting. Following [delange2021continual, buzzega2020dark], Sequential CIFAR-10 and Sequential Tiny ImageNet are employed as dataset for verifying Class-IL. For Domain-IL, we compare with other methods on Permuted MNIST and Rotated MNIST datasets. We use MNIST-360 dataset to test the performance on GCL setting.

Architectures. For CIFAR and Tiny ImageNet, we leverage a standard ResNet18 [he2016deep] without pretraining as the baseline architecture, similar to [rebuffi2017icarl]. Following the exact settings in [lopez2017gradient, riemer2018learning, buzzega2020dark]

, we deploy a fully connected network with two hidden layers for variants of the MNIST dataset. All networks employ ReLU in the hidden layers and softmax with cross-entropy loss in the final layer.

Augmentation. Following [buzzega2020dark], we adopt random crops and horizontal flips to examples from the current task and replay buffer. Note that data augmentation provides an implicit constraint to network.

Configuration and Hyperparameters.

All models are trained using Stochastic Gradient Descent (SGD) optimizer. The number of epoch for MNIST-based setting is 1. On sequential CIFAR-10 and sequential Tiny ImageNet, we train the model for 50 and 100 epochs in each phase, respectively. Each batch consists of half the new task’s samples and half samples from the replay buffer. The replay buffer is updated with the proposed LRS at the end of each batch.

JOINT SGD Buffer ER [riemer2018learning] MER [riemer2018learning] A-GEM-R [chaudhry2018efficient] GSS [aljundi2019gradient] DER [buzzega2020dark] SNCL (Ours)
82.98±3.24 19.09±0.69 200 49.27±2.25 48.58±1.07 28.34±2.24 43.92±2.43 55.22±1.67 62.37±1.52
500 65.04±1.53 62.21±1.36 28.13±2.62 54.45±3.14 69.11±1.66 74.92±1.73
1000 75.18±1.50 70.91±0.76 29.21±2.62 63.84±2.09 75.97±2.08 78.86±1.25
Table 2: Accuracy on the test set for MNIST-360.
Variants S-CIFAR-10 S-Tiny-ImageNet P-MNIST R-MNIST
Class-IL Task-IL Class-IL Task-IL Domain-IL Domain-IL
Baseline 61.93 91.40 11.57 40.22 81.74 90.04
+VBS 64.87 91.65 11.97 40.93 84.52 90.38
+FER 65.85 92.13 12.24 42.23 85.53 90.68
+LRS 63.07 91.49 11.66 40.64 83.63 90.12
+VBS+FER 66.02 92.48 12.61 42.81 86.02 91.32
+VBS+LRS 64.93 91.73 12.05 41.33 84.94 90.40
+FER+LRS 65.91 92.19 12.54 42.63 85.61 90.76
SNCL 66.16 92.91 12.85 43.01 86.23 91.54
Table 3: Ablation studies of different components in SNCL.

4.1 Experimental Results

We compare the proposed SNCL with multiple baseline methods in the CL setup. Regularization-based methods: oEWC [schwarz2018progress] and SI [zenke2017continual]; Knowledge distillation-based methods: iCaRL [rebuffi2017icarl] and LwF [li2017learning]; Dynamic modular approach: PNN [rusu2016progressive]; Rehearsal-based methods: ER [riemer2018learning], GEM [lopez2017gradient], A-GEM [chaudhry2018efficient], GSS [aljundi2019gradient], HAL[chaudhry2020using], DER [buzzega2020dark]. We include two non-continual learning baselines: SGD (lower bound) and JOINT (upper bound). All experiments are averaged over 10 runs using different initialization.

Three CL settings with task boundaries. Table 1 summarizes the average accuracy of all tasks after training. The proposed method significantly outperforms the other methods on the different settings. Regularization-based methods, oEWC and SI, obtain terrible results, which shows that regularization towards old sets of parameters is not an effective method to alleviate forgetting. The reason is that the weight importance is calculated in previous tasks, the importance will be changed with the following task.

Compared with replay-based methods, DER and the proposed SNCL significantly outperform the other methods, especially in the Domain-IL setting. For domain-IL settings, a shift occurs within the input domain, but not within the classes: hence, the relations among them also likely persist. Other methods struggle for learning the similarity relations through the data-stream, and decrease the performance. The proposed method can obtain shareable and generalized knowledge from the data-stream and achieve the best performance. In Class-IL setting, we observe the proposed method obtains a significant improvement in terms of average accuracy. For the S-Tiny-ImageNet dataset, the proposed SNCL outperforms all the methods significantly with the small buffer size.

GCL setting. In [buzzega2020dark], MNIST-360 dataset is proposed to validate general continual learning setting, where the task boundaries are not available. We compare our method with ER, MER, GSS, and A-GEM-R. A-GEM-R is a variant of A-GEM with a reservoir replay buffer. The results are reported in Table 2, we observe that the proposed SNCL significantly outperforms the other methods on different buffer sizes. With the small buffer size, the performance of the proposed method improves 13% compared with DER. This result demonstrates that the proposed method can effectively prevent catastrophic forgetting.

Baseline 81.74 90.04
FER-on-layer1 82.98 90.17
VBS + FER-on-layer1 83.41 90.61
FER-on-layer2 85.27 90.34
VBS + FER-on-layer2 85.73 90.93
FER-on-layer1&2 85.53 90.68
VBS + FER-on-layer1&2 (SNCL) 86.23 91.54
Table 4: Ablation studies w.r.t. different layers.

4.2 Ablation Study

Bayesian sparse networks. To verify the impact of variational Bayesian sparsity in the SNCL, we conduct several different experiments. As reported in Table 3, DER is compared as the baseline. From Table 3, the results of the variant baseline+VBS significantly outperform baseline. VBS can also further improve the performance of baseline+FER, which shows the effectiveness of the proposed network sparsity with VBS. We further study the effectiveness of VBS on different layers by operating only on each layer of the network. We report the results in in Table 4, from which the variants with VBS achieve higher numerical results. For example, compared with FER-on-layer1 (using FER on layer1 of baseline), its combination with VBS promotes performance. From Table 4, we can find that employing VBS on all layers obtains the best performance. Thus, we deem VBS restricts the sparsity of the network at an early stage and reserves capacity for the latter tasks. More results can be found in supplementary.

Full experience replay. To verify the effectiveness of full experience replay (FER), we organize a series of experiences in Table 3 and 4. As shown in Table 3, compared with baseline, the variant baseline+FER has obvious improvement. Unlike previous methods that only replay samples and logits/labels, FER provides effective replay on features and can learn aligned knowledge at each layer. In addition, we also conduct ablation studies on different layer’s features as shown in Table 4, from which each layer with FER will promote the performance. It demonstrates that the stored intermediate features have contributions to the results. We also observe that the impacts of high-level features (FER-on-layer2) are larger than that of low-level features (FER-on-layer1). In addition, employing FER on all layers will improve the results further, as reported in Table 4. This result attributes to the full experience replay strategy effectively providing more effective supervision on learning knowledge at different layers.

Loss-aware reservoir sampling. To evaluate the impact of LRS, we further conduct ablation studies. As shown in Table 3, the baseline with LRS brings improvement in performance. Compared with baseline+FER, the combination between FER and LRS achieves better results. The reason is that LRS maintains memory with more balanced distributions on different classes and training loss.

Generalization, robustness, and flat minima. Generalization and robustness are desirable properties for CL approaches. As discussed in [keskar2016large, kirkpatrick2017overcoming, buzzega2020dark], these properties link to the flatness of the attained minimum. If the model can converge to a flat minimum w.r.t. any task, the model tends to find the joint minimum of all tasks easily. It coincides with the motivation of our SNCL at this viewpoint of optimization geometry. We first measure the sensitivity of the model to the local perturbations [buzzega2020dark] on learned parameters. As shown in Fig. 3 (a), the proposed model has a high tolerance towards the perturbations, which indicates that the model converges to the flat and robust minima. Following [buzzega2020dark], we compute the empirical Fisher Information Matrix on all training data in S-CIFAR-10. Fig. 3

shows that our method produces the lowest sum eigenvalue of

, implying flatter minima achieved by SNCL.

(a) Perturbation [] (b) Fisher Eigenvalues []
Figure 3: Analyses of the robust and flat minima.
(a) Forgetting of DER (b) Forgetting of SNCL (c) Learned representation of DER (d) Learned representation of SNCL
Figure 4: Visualizing t-SNE embeddings on S-CIFAR-10. (a) and (b) analyze the forgetting issue. Blue points are from a task after learning other tasks in CL; Red points are obtained from a single-task no-forgetting training as reference. (c) and (d) show the representations of all the classes after CL.

4.3 Visualization Analysis

This section conducts visualization analyses of the learned representations to show how the proposed SNCL performs intuitively. Specifically, we visualize the final layer representations via t-SNE [van2008visualizing] of the learned networks on different tasks and compare them with the state-of-the-art ER based method.

In Fig. 4 (a) and (b), we show that the proposed method suffers from less catastrophic forgetting. We visualize the learned presentation of task #3 after learning all 10 tasks in S-CIFAR-10 under the CL setting in blue color. We also directly train the network only on task #3 as the reference of no forgetting. By taking the no-forgetting results as a reference, Fig. 4 (a) shows the representations of DER drift severely after the CL process. On the contrary, in Fig. 4 (b), after learning other tasks, the representations on task #3 (in blue) of SNCL keep similar to the no-forgetting reference, implying SNCL suffers less forgetting.

Fig. 4 (c) and (d) visualize the representations of all the classes (on the testing data) after the CL learning. The representations of SNCL on different classes (in Fig. 4 (d)) can be separated better than DER (in (c)), which shows that there are less interference and forgetting in SNCL. DER forgets more severely due to that the past tasks can be more easily interfered by the new tasks, such as the green and red points in Fig. 4 (c).

5 Conclusion

In this paper, we proposed to learn sparse networks for continual learning with variational Bayesian sparsity priors on the neurons (SNCL). The proposed method alleviates the catastrophic forgetting and interference by enforcing the sparsity of the network to reserve parameters for future tasks. We propose full experience replay (FER) to store and replay with old samples’ intermediate layer features generated when observed in the data stream. A loss-aware reservoir sampling strategy is proposed to maintain the memory. Extensive experimental analysis shows that the proposed SNCL framework outperforms the state-of-the-art method on several datasets under different settings. As a main limitation, SNCL does not model the more fine-grained sparsity in an element-wise way. And the sparsity structure correlation across the neurons is not considered. As with other existing CL methods, SNCL cannot guarantee to avoid forgetting in CL, stimulating further studies for real-world applications in the future.