Continual learning, also known as lifelong learning, is the crucial ability for humans to continually acquire and transfer new knowledge across their lifespans while retaining previously learnt experiences hassabis2017neuroscience
. This ability is also critical for artificial intelligence (AI) systems to interact with the real world and process continuous streams of informationthrun1995lifelong . However, the continual acquisition of incrementally available data from non-stationary data distributions generally leads to catastrophic forgetting in the system mccloskey1989catastrophic ; ratcliff1990connectionist ; french1999catastrophic
. Continual learning remains a long-standing challenge for deep neural network models since these models typically learn representations from stationary batches of training data and tend to fail to retain good performances in previous tasks when data become incrementally available over taskskemker2018measuring ; maltoni2019continuous .
Numerous methods for alleviating catastrophic forgetting have been currently proposed. The most pragmatical way is to jointly train deep neural network models on both old and new tasks, which however demands a large amount of resources to store previous training data and hinders the learning of novel data in real time. Another option is to complement the training data for each new task with “pseudo-data” of the previous tasks shin2017continual ; robins1995catastrophic . Besides the main model for task performance, a separate generative model is trained to generate fake historical data used for pseudo-rehearsal. Deep Generative Replay (DGR) shin2017continual replaces the storage of the previous training data with a Generative Adversarial Network to synthesize training data on all previously learnt tasks. These generative approaches have succeeded over very simple and artificial inputs but they cannot tackle more complicated inputs atkinson2018pseudo . Moreover, to synthesize the historical data reasonably well, the size of the generative model is usually huge that costs much memory wen2018few . An alternative method is to store the weights of the model trained on previous tasks, and impose constraints of weight updates on new tasks he2018overcoming ; kirkpatrick2017overcoming ; zenke2017continual ; lee2017overcoming ; lopez2017gradient . For example, Elastic Weight Consolidation (EWC) kirkpatrick2017overcoming and Learning Without Forgetting (LwF) li2018learning
store all the model parameters on previously learnt tasks, estimate their importance on previous tasks and penalize future changes to the weights on new tasks. However, selecting the “important” parameters for previous tasks complicates the implementation by exhaustive hyper-parameter tuning. In addition, state-of-the-art neural network models often involve millions of parameters and storing all network parameters from previous tasks does not necessarily reduce the memory costwen2018few . In contrast with these methods, storing a small subset of examples from previous tasks and replaying the “exact subset” substantially boost performance kemker2017fearnet ; rebuffi2017icarl ; nguyen2017variational
. To achieve the desired network behavior on previous tasks, incremental Classifier and Representation Learner (iCARL)rebuffi2017icarl and Few-shot Self-Reminder (FSR) wen2018few
follow the idea of logit matching or knowledge distillation in model compressionba2014deep ; bucilua2006model ; hinton2015distilling .
In this paper, we propose the method, Prototype Reminding, for continual learning in classification tasks. Similar as snell2017prototypical , we use a neural network to learn class-representative prototypes in an embedding space and classify embedded test data by finding their nearest class prototype. To tackle the problem of catastrophic forgetting, we impose additional constraints on the network by classifying the embedded test data based on prototypes from previous tasks, which promotes the preservation of initial embedding function. For example (Figure 3), in the first task (Subfigure (a)a), the network learns color prototypes to classify blue and yellow circles and in the second task (Subfigure (b)b), the network learns shape prototypes to classify green circles and triangles. With catastrophically forgetting color features, the network extracts circle features on the first task and fails to classify blue and yellow circles. To alleviate catastrophic forgetting, our method replays the embeded previous samples (blue and yellow circles) and match them with previous color prototypes (blue and yellow) which reminds the network of extracting both color and shape features in both classification tasks.
We evaluate our method under two typical experimental protocols, incremental domain and incremental class, for continual learning across three benchmark datasets, MNIST deng2012mnist , CIFAR10 krizhevsky2009learning and miniImageNet deng2009imagenet . Compared with the state-of-the-arts, our method significantly boosts the performance of continual learning in terms of memory retention capability while being able to adapt to new tasks. Unlike parameter regularization methods or iCARL or FSR, our approach further reduces the memory storage by replacing logits of each data or network parameters with one prototype of each class in the episodic memory. While the last layer in traditional classification networks often structurally depends on the number of classes, our method leverages on metric learning, maintains the same network architecture and does not require adding new neurons or layers for new object classes. Additionally, without sacrificing classification accuracy on initial tasks, our method can generalize to learn new concepts given a few training examples in new tasks due to the advantage of metric learning, commonly used in few-shot settings snell2017prototypical ; hoffer2015deep .
2 Proposed Method
We propose the method, Prototype Reminding, for continual learning. For a sequence of datasets , given in task , the goal for the model is to retain the good classification performance on all datasets after being sequentially trained over tasks. The value of is not pre-determined. The model with learnable parameters is only allowed to carry over a limited amount of information from the previous tasks. This constraint eliminates the naive solution of combining all previous datasets to form one big training set for fine-tuning the model at task . Each dataset consists of labeled examples where each is the
-dimensional feature vector of an example andis the corresponding class label. denotes the set of examples labeled with class .
At task , if we simply train a model by only minimizing the classification loss on dataset , the model will forget how to perform classification on previous datasets which is described as catastrophic forgetting problem mccloskey1989catastrophic ; ratcliff1990connectionist ; french1999catastrophic . Here we show how the model trained in our method retains the good performance on all previous tasks while adaptively learning new tasks. The loss for all the previous datasets is denoted by . Our objective is to learn defined as follows:
where measures the differences in the network behaviors in the embedding space learnt by and on . Given that are learnt from the previous tasks, at task , learning requires minimizing both terms and . In the subsections below and Figure 3, we describe how to optimize these two terms.
To perform classification on dataset , our method learns an embedding space in which points cluster around a single prototype representation for each class and classification is performed by finding the nearest class prototype snell2017prototypical (Figure (a)a). Compared to traditional classification networks with a specific classification layer attached in the end, our method keeps the network architecture unchanged while finding the nearest neighbour in the embedding space, which would lead to more efficient memory usage. For example, in one of the continual learning protocols snell2017prototypical where the models are asked to classify incremental classes (also see Section 3.1), traditional classification networks have to expand their architectures by accommodating more output units in the last classification layer based on the number of incremental classes and consequently, additional network parameters have to be added into the memory. Without loss of generality, here we show how our method performs classification on . First, the model learns an embedding function and computes an -dimensional prototype which is the mean of the embeddings from examples :
The pairwise distance of one embedding and one prototype within the same class should be smaller than the intra-class ones. Our method introduces a distance function . For each example , it estimates a distance distribution based on a softmax over distances to the prototypes of classes in the embedding space:
The objective function
is to minimize the negative log-probabilityof the ground truth class label bottou2010large . In practice, when is large, computing is costly and memory inefficient during training. Thus, at each training iteration, we randomly sample two complement subsets from over all classes: one for computing prototypes and the other for estimating distance distribution. Our primary choice of the distance function is squared Euclidean distance which has been verified to be effective in snell2017prototypical
. In addition, we include temperature hyperparameterin as introduced in network distillation literature hinton2015distilling . A higher value for
produces a softer probability distribution over classes.
2.2 Prototype Reminding
Regardless of the changes of the network parameters from to at task and respectively, the primary goal of is to learn the embedding function which results in the similar metric space as on dataset in task (Figure (b)b). Given a limited amount of memory, a direct approach is to randomly sample a small subset from and replay these examples on task . There have been some attempts chen2012super ; koh2017understanding ; brahma2018subset selecting representative examples for based on different scoring functions. However, the recent work wen2018few has shown that random sampling uniformly across classes has already yielded outstanding performance in continual learning tasks. Hence, we adopt the same random sampling strategy to form .
The iCARL rebuffi2017icarl and FSR wen2018few follow the idea of logits matching for regularizing the mapping function in a network. However, such approaches ignore the topological relations among clusters in the embedding space and rely too much on a small amount of individual data, which may result in overfitting as shown in our experiments (Section 4.2). In contrast with them, our method compares the feature similarities represented by class prototypes in the embedding space which improves generalization, as also been verified in works hoffer2015deep ; snell2017prototypical .
Intuitively, if the number of data samples in is very large, the network could re-produce the metric space at task by replaying , which is our desired goal. However, this does not hold in practice given limited memory capacity. With the simple inductive bias that the metric space at task can be underlined by class-representative prototypes, we introduce another loss that embedded data sample in should still be closest to their corresponding class prototype among all prototypes at task . This ensures the metric space represented by a set of prototypes learnt from by provides good approximation to the one in task .
Formally, we formulate the regularization of network behaviors in the metric space of task by satisfying two criteria: first, learns a metric space to classify by minimizing ; second, to preserve the similar topological structure among clusters on dataset , the embeddings predicted by based on should produce the similar distance distribution based on a softmax over the distance to prototypes computed using on dataset :
2.3 Dynamic Episodic Memory Allocation
Given a limited amount of memory with capacity , our proposed method has to store a small subset with examples randomly sampled from and prototypes computed using embedding function on where . The following constraint has to be satisfied:
When the number of tasks is small, can be large and the episodic memory stores more examples in . Dynamic memory allocation of enabling more example replays in earlier tasks puts more emphasis on reviewing earlier tasks which are easier to forget, and introduces more varieties in data distributions when matching with prototypes. Pseudocode to our proposed algorithm in continual learning for a training episode is provided in Algorithm 1.
3 Experimental Details
We introduce two task protocols for evaluating continual learning algorithms with different memory usage over three benchmark datasets. Source codes will be public available upon acceptance.
3.1 Task Protocols
Permuted MNIST in incremental domain task is a benchmark task protocol in continual learning lee2017overcoming ; lopez2017gradient ; zenke2017continual (Figure (a)a). In each task, a fixed permutation sequence is randomly generated and is applied to input images in MNIST deng2012mnist . Though the input distribution always changes across tasks, models are trained to classify 10 digits in each task and the model structure is always the same. There are 20 tasks in total. During testing, the task identity is not available to models. The models have to classify input images into 1 out of 10 digits.
Split CIFAR10 and split MiniImageNet in incremental class task is a more challenging task protocol where models need to infer the task identity and meanwhile solve each task. The input data is also more complex which includes classification on natural images in CIFAR10 krizhevsky2009learning and miniImageNet deng2009imagenet . The former contains 10 classes and the latter consists of 100 classes. In CIFAR10, the model is first trained with 2 classes and later with 1 more class in each subsequent task. There are 9 tasks in total and 5,000 images per class in the training set. In miniImageNet, models are trained with 10 classes in each task. There are 10 tasks in total and 480 images per class in the training set.
Few-shot Continual Learning Humans can learn novel concepts given a few examples without sacrificing classification accuracy on initial tasks gidaris2018dynamic . However, typical continual learning schemes assume that a large amount of training data over all tasks is always available for fine-tuning networks to adapt to new data distributions, which does not always hold in practice. We revise task protocols to more challenging ones: networks are trained with a few examples per class in sequential tasks except for the first task. For example, on CIFAR10/miniImageNet, we train the models with 5,000/480 example images per class in the first task and 50/100 images per class in subsequent tasks.
We include the following categories of continual learning methods for comparing with our method:
and Memory Aware Synapses (MAS)aljundi2018memory
where regularization terms are added in the loss function; online EWCschwarz2018progress which is an extension of EWC in scalability to a large number of tasks; L2 distance indicating parameter changes between tasks is added in the loss kirkpatrick2017overcoming ; SGD, which is a naive baseline without any regularization terms, is optimized with Stochastic Gradient Descent bottou2010large sequentially over all tasks.
propose to regularize network behaviors by exact pseudo replay. Specifically, in FSR, there are two variants: FSR-KLD for logits matching via Kullback–Leibler Divergence loss and FSR-MSE for logits distillation via L2 distance loss.
3.3 Memory Comparison
For fair comparison, we use the same architecture for all the methods and allocate a comparable amount of memory as EWC kirkpatrick2017overcoming and other parameter regularization methods, for storing example images per class and their prototypes. In EWC, the model often allocates a memory size twice as the number of network parameters for computing Fisher information matrix which can be used for regularizing changes of network parameters kirkpatrick2017overcoming . In more challenging classification tasks, the network size tends to be larger and hence, these methods require much more memory. In Table 16, we show an example of memory allocation on split CIFAR10 in incremental class tasks with full memory and little memory respectively. The feed-forward classification network contains around parameters. Weight regularization methods require memory allocation twice as that, which takes about parameters. The input RGB images are of size . Via Eqn. (5), our method can allocate episodic memory with full capacity and calculate which is equivalent to storing example images per class. In experiments with little training data as described in Section 3.1, we reduce to example images per class.
4 Experimental Results
4.1 Alleviating Forgetting
Figure 11 reports the results of continual learning methods with full memory under the two task protocols. All compared continual learning methods outperform SGD (cyan) which is a baseline without preventing catastrophic forgetting. Our method (red) achieves the highest average classification accuracy among all the compared methods with minimum forgetting.
A good continual learning method should not only show good memory retention but also be able to adapt to new tasks. In Figure (a)a, although our method (red) performs on par with EWC (brown) in retaining the classification accuracy on dataset in the first task along with 20 sequential tasks, the average classification accuracy of our method is far higher than EWC (brown) as shown in Figure (b)b, indicating EWC is able to retain good memory but fails to learn new tasks. After the 13th task, the average classification performance of EWC is even worse than SGD. Similar reasoning can be applied to comparison with SI (green): although our method performs comparably well as SI (green) in terms of average classification accuracy, SI (green) fails to retain the classification accuracy on , which is 6% lower than ours in the 20th task.
Figure (c)c and (d)d show the average task classification accuracy over sequential tasks in incremental class protocol. Incremental class protocol is more challenging than incremental domain protocol, since the models have to infer both the task identity and class labels in the task. Our method (red) has the hightest average classification accuracy in continual learning, 15% and 5% higher than the second best method L2 (yellow) on CIFAR10 and miniImageNet respectively. Note that most weight regularization methods, such as EWC (brown), perform as badly as SGD. One possible reason is that EWC computes Fisher matrix to maintain local information and does not consider the scenarios when data distributions across tasks are too far apart. On the contrary, our method maintains remarkably better performance than EWC, because ours focuses primarily on the behaviors of network outputs, which indirectly relaxes the constraint about the change of network parameters.
4.2 Few-shot Continual Learning
We evaluate continual learning methods with little memory under two task protocols with few training data in the second tasks and onwards except for the first tasks. Figure 14 reports their performance. Our method (red) has the highest average classification accuracy over all sequential tasks among state-of-the-art methods with 27% and 11% vs. 19% and 4% of FSR-KLD (yellow), which is the second best, at the 9th and 10th tasks on CIFAR10 and miniImageNet respectively. Weight regularization methods, such as EWConline (blue) and MAS (brown), perform as badly as SGD (cyan), worse than logits matching methods, such as FSR (green and yellow) or iCARL (purple). Similar observations have been made as Figure 11 with full training data.
Compared with logits matching methods, our method has the highest average task classification accuracy. It reveals that our method performs classification via metric learning in an effective few-shot manner. It is also because our network architecture is not dependent on the number of output classes and the knowledge in previous tasks can be well preserved and transferred to new tasks. It is superior to traditional networks with new parameters added in the last classification layer, which easily leads to overfitting. As a side benefit, given the same number of example inputs in the episodic memory, our method is more efficient in memory usage since it stores one prototype per class instead of the logits for each example input as verified in Table 16.
|Full Training and Full Memory Size in Magnitudes of|
|Little Training and Little Memory Size in Magnitudes of|
4.3 Network Analysis
We also study the effects of the following three factors upon performance improvement. Figure 16 reports the average classification accuracy of these ablated methods. (1) Intuitively, limited memory capacity restricts number of example inputs to re-play and leads to performance drop. On permuted MNIST in incremental domain, with full memory capacity reduced by 2.5 times (from 5,000 example inputs to 2,000), our method shows a moderate decrease of average classification accuracy by 1% in the 20th task. (2) We also compare our method with memory replay optimized by cross-entropy loss at full memory conditions. A performance drop around 1.5% is observed which validates classifying example inputs based on initial prototypes results in better performance in memory retention. (3) Given fixed , our method adopts the strategy of decreasing , the number of example inputs in episodic memory, with the increasing number of tasks. The performance drop of 1.5% using uniform memory allocation demonstrates the usefulness of dynamic memory allocation which enforces more examples to be replayed in earlier tasks, and therefore promotes better memory retention.
In Figure 17, we provide visualizations of class embeddings by projecting these latent representations of classes into 2D space. It can be seen that our method is capable of clustering latent representations belonging to the same class and meanwhile accommodating new class embeddings across sequential tasks. Interestingly, the clusters are topologically organized based on feature similarities among classes and the topological structure from the same classes is preserved across tasks. For example, the cluster of “bird” (black) is close to that of “plane” (orange) in Task 3 and the same two clusters are still close in Task 9. This again validates that classifying example inputs from previous tasks based on initial prototypes promotes preservation of topological structure in the initial metric space.
We address the problem of catastrophic forgetting by proposing prototype reminding. In addition to significantly alleviating catastrophic forgetting on benchmark datasets, our method is superior to others in terms of making the memory usage efficient, and being generalizable to learning novel concepts given only a few training examples in new tasks.
R. Aljundi, F. Babiloni, M. Elhoseiny, M. Rohrbach, and T. Tuytelaars.
Memory aware synapses: Learning what (not) to forget.
Proceedings of the European Conference on Computer Vision (ECCV), pages 139–154, 2018.
-  C. Atkinson, B. McCane, L. Szymanski, and A. Robins. Pseudo-recursal: Solving the catastrophic forgetting problem in deep neural networks. arXiv preprint arXiv:1802.03875, 2018.
-  J. Ba and R. Caruana. Do deep nets really need to be deep? In Advances in neural information processing systems, pages 2654–2662, 2014.
Large-scale machine learning with stochastic gradient descent.In Proceedings of COMPSTAT’2010, pages 177–186. Springer, 2010.
P. P. Brahma and A. Othon.
Subset replay based continual learning for scalable improvement of
2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), pages 1179–11798. IEEE, 2018.
-  C. Bucilua, R. Caruana, and A. Niculescu-Mizil. Model compression. In Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining, pages 535–541. ACM, 2006.
-  Y. Chen, M. Welling, and A. Smola. Super-samples from kernel herding. arXiv preprint arXiv:1203.3472, 2012.
-  J. Deng, W. Dong, R. Socher, L.-J. Li, K. Li, and L. Fei-Fei. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pages 248–255. Ieee, 2009.
The mnist database of handwritten digit images for machine learning research [best of the web].IEEE Signal Processing Magazine, 29(6):141–142, 2012.
-  R. M. French. Catastrophic forgetting in connectionist networks. Trends in cognitive sciences, 3(4):128–135, 1999.
-  S. Gidaris and N. Komodakis. Dynamic few-shot visual learning without forgetting. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 4367–4375, 2018.
-  D. Hassabis, D. Kumaran, C. Summerfield, and M. Botvinick. Neuroscience-inspired artificial intelligence. Neuron, 95(2):245–258, 2017.
X. He and H. Jaeger.
Overcoming catastrophic interference using conceptor-aided backpropagation.2018.
-  G. Hinton, O. Vinyals, and J. Dean. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
-  E. Hoffer and N. Ailon. Deep metric learning using triplet network. In International Workshop on Similarity-Based Pattern Recognition, pages 84–92. Springer, 2015.
-  R. Kemker and C. Kanan. Fearnet: Brain-inspired model for incremental learning. arXiv preprint arXiv:1711.10563, 2017.
-  R. Kemker, M. McClure, A. Abitino, T. L. Hayes, and C. Kanan. Measuring catastrophic forgetting in neural networks. In Thirty-second AAAI conference on artificial intelligence, 2018.
-  J. Kirkpatrick, R. Pascanu, N. Rabinowitz, J. Veness, G. Desjardins, A. A. Rusu, K. Milan, J. Quan, T. Ramalho, A. Grabska-Barwinska, et al. Overcoming catastrophic forgetting in neural networks. Proceedings of the national academy of sciences, 114(13):3521–3526, 2017.
-  P. W. Koh and P. Liang. Understanding black-box predictions via influence functions. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pages 1885–1894. JMLR. org, 2017.
-  A. Krizhevsky and G. Hinton. Learning multiple layers of features from tiny images. Technical report, Citeseer, 2009.
S.-W. Lee, J.-H. Kim, J. Jun, J.-W. Ha, and B.-T. Zhang.
Overcoming catastrophic forgetting by incremental moment matching.In Advances in neural information processing systems, pages 4652–4662, 2017.
-  Z. Li and D. Hoiem. Learning without forgetting. IEEE transactions on pattern analysis and machine intelligence, 40(12):2935–2947, 2018.
-  D. Lopez-Paz et al. Gradient episodic memory for continual learning. In Advances in Neural Information Processing Systems, pages 6467–6476, 2017.
-  D. Maltoni and V. Lomonaco. Continuous learning in single-incremental-task scenarios. Neural Networks, 2019.
-  M. McCloskey and N. J. Cohen. Catastrophic interference in connectionist networks: The sequential learning problem. In Psychology of learning and motivation, volume 24, pages 109–165. Elsevier, 1989.
-  C. V. Nguyen, Y. Li, T. D. Bui, and R. E. Turner. Variational continual learning. arXiv preprint arXiv:1710.10628, 2017.
-  R. Ratcliff. Connectionist models of recognition memory: constraints imposed by learning and forgetting functions. Psychological review, 97(2):285, 1990.
-  S.-A. Rebuffi, A. Kolesnikov, G. Sperl, and C. H. Lampert. icarl: Incremental classifier and representation learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 2001–2010, 2017.
-  A. Robins. Catastrophic forgetting, rehearsal and pseudorehearsal. Connection Science, 7(2):123–146, 1995.
-  J. Schwarz, J. Luketina, W. M. Czarnecki, A. Grabska-Barwinska, Y. W. Teh, R. Pascanu, and R. Hadsell. Progress & compress: A scalable framework for continual learning. arXiv preprint arXiv:1805.06370, 2018.
-  H. Shin, J. K. Lee, J. Kim, and J. Kim. Continual learning with deep generative replay. In Advances in Neural Information Processing Systems, pages 2990–2999, 2017.
-  J. Snell, K. Swersky, and R. Zemel. Prototypical networks for few-shot learning. In Advances in Neural Information Processing Systems, pages 4077–4087, 2017.
-  S. Thrun and T. M. Mitchell. Lifelong robot learning. Robotics and autonomous systems, 15(1-2):25–46, 1995.
-  G. M. van de Ven and A. S. Tolias. Generative replay with feedback connections as a general strategy for continual learning. arXiv preprint arXiv:1809.10635, 2018.
-  L. Van Der Maaten. Accelerating t-sne using tree-based algorithms. The Journal of Machine Learning Research, 15(1):3221–3245, 2014.
-  J. Wen, Y. Cao, and R. Huang. Few-shot self reminder to overcome catastrophic forgetting. arXiv preprint arXiv:1812.00543, 2018.
-  F. Zenke, B. Poole, and S. Ganguli. Continual learning through synaptic intelligence. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pages 3987–3995. JMLR. org, 2017.