DeepAI
Log In Sign Up

A Theoretical Study on Solving Continual Learning

Continual learning (CL) learns a sequence of tasks incrementally. There are two popular CL settings, class incremental learning (CIL) and task incremental learning (TIL). A major challenge of CL is catastrophic forgetting (CF). While a number of techniques are already available to effectively overcome CF for TIL, CIL remains to be highly challenging. So far, little theoretical study has been done to provide a principled guidance on how to solve the CIL problem. This paper performs such a study. It first shows that probabilistically, the CIL problem can be decomposed into two sub-problems: Within-task Prediction (WP) and Task-id Prediction (TP). It further proves that TP is correlated with out-of-distribution (OOD) detection, which connects CIL and OOD detection. The key conclusion of this study is that regardless of whether WP and TP or OOD detection are defined explicitly or implicitly by a CIL algorithm, good WP and good TP or OOD detection are necessary and sufficient for good CIL performances. Additionally, TIL is simply WP. Based on the theoretical result, new CIL methods are also designed, which outperform strong baselines in both CIL and TIL settings by a large margin.

READ FULL TEXT VIEW PDF

page 1

page 2

page 3

page 4

03/17/2022

Continual Learning Based on OOD Detection and Task Masking

Existing continual learning techniques focus on either task incremental ...
03/02/2022

Continual Feature Selection: Spurious Features in Continual Learning

Continual Learning (CL) is the research field addressing learning settin...
11/23/2022

CODA-Prompt: COntinual Decomposed Attention-based Prompting for Rehearsal-Free Continual Learning

Computer vision models suffer from a phenomenon known as catastrophic fo...
07/11/2022

Susceptibility of Continual Learning Against Adversarial Attacks

The recent advances in continual (incremental or lifelong) learning have...
01/25/2021

Online Continual Learning in Image Classification: An Empirical Survey

Online continual learning for image classification studies the problem o...
10/21/2021

HCV: Hierarchy-Consistency Verification for Incremental Implicitly-Refined Classification

Human beings learn and accumulate hierarchical knowledge over their life...
05/02/2019

Continuous Learning for Large-scale Personalized Domain Classification

Domain classification is the task of mapping spoken language utterances ...

1 Introduction

Continual learning aims to incrementally learn a sequence of tasks (chen2018lifelong). Each task consists of a set of classes to be learned together. A major challenge of CL is catastrophic forgetting (CF). Although a large number of CL techniques have been proposed, they are mainly empirical. Limited theoretical research has done on how to solve CL. This paper performs such a theoretical study about the necessary and sufficient conditions for effective CL. There are two main CL settings that have been extensively studied: class incremental learning (CIL) and task incremental learning (TIL) (van2019three)

. In CIL, the learning process builds a single classifier for all tasks/classes learned so far. In testing, a test instance from any class may be presented for the model to classify. No prior task information (e.g., task-id) of the test instance is provided. Formally, CIL is defined as follows.

Class incremental learning (CIL). CIL learns a sequence of tasks, . Each task has a training dataset , where is the number of data samples in task , and is an input sample and (the set of all classes of task ) is its class label. All ’s are disjoint () and . The goal of CIL is to construct a single predictive function or classifier that can identify the class label of each given test instance .

In the TIL setup, each task is a separate classification problem. For example, one task could be to classify different breeds of dogs and another task could be to classify different types of animals (the tasks may not be disjoint). One model is built for each task in a shared network. In testing, the task-id of each test instance is provided and the system uses only the specific model for the task (dog or animal classification) to classify the test instance. Formally, TIL is defined as follows.

Task incremental learning (TIL). TIL learns a sequence of tasks, . Each task has a training dataset , where is the number of data samples in task , and is an input sample and is its class label. The goal of TIL is to construct a predictor to identify the class label for (the given test instance from task ).

Several techniques are available to effectively overcome CF for TIL (with almost no CF) Serra2018overcoming; supsup2020. However, CIL remains to be highly challenging due to the additional problem of Inter-task Class Separation (ICS) (establishing decision boundaries between the classes of the new task and the classes of the previous tasks) because the previous task data are not accessible. Before discussing the proposed work, we recall the closed-world

assumption made by traditional machine learning, i.e.,

the classes seen in testing must have been seen in training chen2018lifelong; liu2021self. However, in many applications, there are unknowns in testing, which is called the open world setting chen2018lifelong; liu2021self. In open world learning, the training (or known) classes are called in-distribution (IND) classes. A classifier built for the open world can (1) classify test instances of training/IND classes to their respective classes, which is called IND prediction, and (2) detect test instances that do not belong to any of the IND/known classes but some unknown or out-of-distribution (OOD) classes, which is called OOD detection. Many OOD detection algorithms can perform both IND prediction and OOD detection tack2020csi; liang2018enhancing; esmaeilpour2022zero; wang2022omg. The commonality of OOD detection and CL is that they both need to consider future unknowns.

This paper conducts a theoretical study of CIL, which is applicable to any CIL classification model. Instead of focusing on the traditional PAC generalization bound pentina2014pac; karakida2022learning, we focus on how to solve the CIL problem. We first decompose the CIL problem into two sub-problems in a probabilistic framework: Within-task Prediction (WP) and Task-id Prediction (TP). WP means that the prediction for a test instance is only done within the classes of the task to which the test instance belongs, which is basically the TIL problem. TP predicts the task-id. TP is needed because in CIL, task-id is not provided in testing. This paper then proves based on the popular cross-entropy loss that (1) the CIL performance is bounded by WP and TP performances, and (2) TP and task OOD detection performance bound each other (which connects CL and OOD detection). The key result is that regardless of whether WP and TP or OOD detection are defined explicitly or implicitly by a CIL algorithm, good WP and good TP or OOD detection are necessary and sufficient conditions for good CIL performances. This result is applicable to both batch/offline and online CIL and to CIL problems with blurry task boundaries. The intuition is also quite simple because if a CIL model is perfect at detecting OOD samples for each task (which solves the ICS problem), then CIL is reduced to WP.

This theoretical result provides a principled guidance for solving the CIL problem, i.e., to help design better CIL algorithms that can achieve strong WP and TP performances. Since WP is basically IND prediction for each task and most OOD techniques perform both IND prediction and OOD detection, to achieve good CIL accuracy, a strong OOD performance for each task is necessary.

Based on the theoretical guidance, several new CIL methods are designed, including techniques based on the integration of a TIL method and an OOD detection method for CIL, which outperform strong baselines in both the CIL and TIL settings by a large margin. This combination is particularly attractive because TIL has achieved no forgetting, and we only need a strong OOD technique that can perform both IND prediction and OOD detection to learn each task to achieve strong CIL results.

2 Related Work

Although numerous CL techniques have been proposed, little study has been done to provide a theoretical guidance on how to solve the problem. Existing approaches mainly belong to several categories. Using regularization kirkpatrick2017overcoming; Li2016LwF to minimize changes to model parameters learned from previous tasks is a popular approach Jung2016less; Camoriano2017incremental; zenke2017continual; ritter2018online; schwarz2018progress; xu2018reinforced; castro2018end; hu2019overcoming; Dhar2019CVPR; lee2019overcoming; ahn2019neurIPS; Liu2020; Zhu_2021_CVPR_pass. Memorizing some old examples and using them to jointly train the new task is another popular approach (called replayRusu2016; Lopez2017gradient; Rebuffi2017; Chaudhry2019ICLR; hou2019learning; wu2019large; rolnick2019neurIPS; NEURIPS2020_b704ea2c_derpp; rajasegaran2020adaptive; Liu2020AANets; Cha_2021_ICCV_co2l; yan2021dynamically; wang2022memory; guo2022online. Some systems learn to generate pseudo training data of old tasks and use them to jointly train the new task, called pseudo-replay Gepperth2016bio; Kamra2017deep; Shin2017continual; wu2018memory; Seff2017continual; Kemker2018fearnet; hu2019overcoming; Rostami2019ijcai; ostapenko2019learning. Orthogonal projection learns each task in an orthogonal space to other tasks zeng2019continuous; guo2022adaptive; chaudhry2020continual. Our theoretical study is applicable to any continually trained classification models.

Parameter isolation is yet another popular approach, which makes different subsets (which may overlap) of the network parameters dedicated to different tasks using masks Serra2018overcoming; ke2020continual; Mallya2017packnet; supsup2020; NEURIPS2019_3b220b43compact. This approach is particularly suited for TIL. Several methods have almost completely overcome forgetting. HAT Serra2018overcoming and CAT ke2020continual protect previous tasks by masking the important parameters to those tasks. PackNet Mallya2017packnet, CPG NEURIPS2019_3b220b43compact and SupSup supsup2020 find an isolated sub-network for each task. HyperNet von2019continual initializes task-specific parameters conditioned on task-id. ADP Yoon2020Scalable decomposes parameters into shared and adaptive parts to construct an order robust TIL system. CCLL singh2020calibrating uses task-adaptive calibration in convolution layers. Our methods designed based on the proposed theory make use of two parameter isolation-based TIL methods and two OOD detection methods. A strong OOD detection method CSI in tack2020csi helps produce very strong CIL results. CSI is based on data augmentation he2020momentum and contrastive learning chen2020simple. Excellent surveys of OOD detection include bulusu2020anomalous; geng2020recent.

Some methods have used a TIL method for CIL with an additional task-id prediction technique. iTAML rajasegaran2020itaml requires each test batch to be from a single task. This is not practical as test samples usually come one by one. CCG abati2020conditional builds a separate network to predict the task-id. Expert Gate Aljundi2016expert

constructs a separate autoencoder for each task. HyperNet 

von2019continual and PR-Ent henning2021posterior use entropy to predict the task id. Since none of these papers is a theoretical study, they did not know that strong OOD detection is the key. Our methods based on OOD detection perform dramatically better.

Several theoretical studies have been made on lifelong/continual learning. However, they focus on traditional generalization bound. pentina2014pac proposes a PAC-Bayesian framework to provide a learning bound on expected error in future tasks by the average loss on the observed tasks. The work in lee2021continual studies the generalization error by task similarity and karakida2022learning studies the dependence of generalization error on sample size or number of tasks including forward and backward transfer. bennani2020generalisation shows that orthogonal gradient descent gives a tighter generalization bound than SGD. Our work is very different as we focus on how to solve the CIL problem, which is orthogonal to the existing theoretical analysis.

3 CIL by Within-Task Prediction and Task-ID Prediction

This section presents our theory. It first shows that the CIL performance improves if the within-task prediction (WP) performance and/or the task-id prediction (TP) performance improve, and then shows that TP and OOD detection bound each other, which indicates that CIL performance is controlled by WP and OOD detection. This connects CL and OOD detection. Finally, we study the necessary conditions for a good CIL model, which includes a good WP, and a good TP (or OOD detection).

3.1 CIL Problem Decomposition

This sub-section first presents the assumptions made by CIL based on its definition and then proposes a decomposition of the CIL problem into two sub-problems. A CL system learns a sequence of tasks , where is the domain of task and are classes of task as , where indicates the th class in task . Let to be the domain of th class of task , where . For accuracy, we will use instead of in probabilistic analysis. Based on the definition of class incremental learning (CIL) (Sec. 1), the following assumptions are implied,

Assumption 1.

The domains of classes of the same task are disjoint, i.e., .

Assumption 2.

The domains of tasks are disjoint, i.e., .

For any ground event , the goal of a CIL problem is to learn

. This can be decomposed into two probabilities,

within-task IND prediction (WP) probability and task-id prediction (TP) probability. WP probability is and TP probability is . We can rewrite the CIL problem using WP and TP based on the two assumptions,

(1)
(2)

where means a particular task and a particular class in the task.

Some remarks are in order about Eq. 2 and our subsequent analysis to set the stage.

Remark 1.

Eq. 2 shows that if we can improve either the WP or TP performance, or both, we can improve the CIL performance.

Remark 2.

It is important to note that our theory is not concerned with the learning algorithm or the training process, but we will propose some concrete learning algorithms based on the theoretical result in the experiment section.

Remark 3.

Note that the CIL definition and the subsequent analysis are applicable to tasks with any number of classes (including only one class per task) and to online CIL where the training data for each task or class comes gradually in a data stream and may also cross task boundaries (blurry tasks bang2021rainbow) because our analysis is based on an already-built CIL model after training. Regarding blurry task boundaries, suppose dataset 1 has classes {dog, cat, tiger} and dataset 2 has classes {dog, computer, car}. We can define task 1 as {dog, cat, tiger} and task 2 as {computer, car}. The shared class dog in dataset 2 is just additional training data of dog appeared after task 1.

Remark 4.

Furthermore, CIL = WP * TP in Eq. 2 means that when we have WP and TP (defined either explicitly or implicitly by implementation), we can find a corresponding CIL model defined by WP * TP. Similarly, when we have a CIL model, we can find the corresponding underlying WP and TP defined by their probabilistic definitions.

In the following sub-sections, we develop this further concretely to derive the sufficient and necessary conditions for solving the CIL problem in the context of cross-entropy loss as it is used in almost all supervised CL systems.

3.2 CIL Improves as WP and/or TP Improve

As stated in Remark 2 above, the study here is based on a trained CIL model

and not concerned with the algorithm used in training the model. We use cross-entropy as the performance measure of a trained model as it is the most popular loss function used in supervised CL. For experimental evaluation, we use

accuracy

following CL papers. Denote the cross-entropy of two probability distributions

and as

(3)

For any , let to be the CIL ground truth label of , where if otherwise , . Let be the WP ground truth label of , where if otherwise , . Let be the TP ground truth label of , where if otherwise , . Denote

(4)
(5)
(6)

where , and are the cross-entropy values of WP, CIL and TP, respectively. We now present our first theorem. The theorem connects CIL to WP and TP, and suggests that by having a good WP or TP, the CIL performance improves as the upper bound for the CIL loss decreases.

Theorem 1.

If and , we have

The detailed proof is given in Appendix A.1. This theorem holds regardless of whether WP and TP are trained together or separately. When they are trained separately, if WP is fixed and we let , , which means if TP is better, CIL is better. Similarly, if TP is fixed, we have . When they are trained concurrently, there exists a functional relationship between and depending on implementation. But no matter what it is, when decreases, CIL gets better.

Theorem 1 holds for any that satisfies or . To measure the overall performance under expectation, we present the following corollary.

Corollary 1.

Let

represents the uniform distribution on

. i) If , then . Similarly, ii) , then .

The proof is given in Appendix A.2. The corollary is a direct extension of Theorem 1 in expectation. The implication is that given TP performance, CIL is positively related to WP. The better the WP is, the better the CIL is as the upper bound of the CIL loss decreases. Similarly, given WP performance, a better TP performance results in a better CIL performance. Due to the positive relation, we can improve CIL by improving either WP or TP using their respective methods developed in each area.

3.3 Task Prediction (TP) to OOD Detection

Building on Eq. 2, we have studied the relationship of CIL, WP and TP in Theorem 1. We now connect TP and OOD detection. They are shown to be dominated by each other to a constant factor.

We again use cross-entropy to measure the performance of TP and OOD detection of a trained network as in Sec. 3.2 To build the connection between and OOD detection of each task, we first define the notations of OOD detection. We use to represent the probability distribution predicted by the th task’s OOD detector. Notice that the task prediction (TP) probability distribution is a categorical distribution over tasks, while the OOD detection probability distribution

is a Bernoulli distribution. For any

, define

(7)

In CIL, the OOD detection probability for a task can be defined using the output values corresponding to the classes of the task. Some examples of the function is a sigmoid of maximum logit value or a maximum softmax probability after re-scaling to 0 to 1. It is also possible to define the OOD detector directly as a function of tasks instead of a function of the output values of all classes of tasks, i.e. Mahalanobis distance. The following theorem shows that TP and OOD detection bound each other.

Theorem 2.

i) If , let , then . ii) If , let , then , where is an indicator function.

See Appendix A.3 for the proof. As we use cross-entropy, the lower the bound, the better the performance is. The first statement (i) says that the OOD detection performance improves if the TP performance gets better (i.e., lower ). Similarly, the second statement (ii) says that the TP performance improves if the OOD detection performance on each task improves (i.e., lower ). Besides, since converges to as ’s converge to in order of , we further know that and are equivalent in quantity up to a constant factor.

In Theorem 1, we studied how CIL is related to WP and TP. In Theorem 2, we showed that TP and OOD bound each other. Now we explicitly give the upper bound of CIL in relation to WP and OOD detection of each task. The detailed proof can be found in Appendix A.4.

Theorem 3.

If and , we have

where is an indicator function.

3.4 Necessary Conditions for Improving CIL

In Theorem 1, we showed that good performances of WP and TP are sufficient to guarantee a good performance of CIL. In Theorem 3, we showed that good performances of WP and OOD are sufficient to guarantee a good performance of CIL. For completeness, we study the necessary conditions of a well-performed CIL in this sub-section.

Theorem 4.

If , then there exist i) a WP, s.t. , ii) a TP, s.t. , and iii) an OOD detector for each task, s.t. .

The detailed proof is given in Appendix A.5. This theorem tells that if a good CIL model is trained, then a good WP, a good TP and a good OOD detector for each task are always implied. More importantly, by transforming Theorem 4 into its contraposition, we have the following statements: If for any WP, , then . If for any TP, , then . If for any OOD detector, , then . Regardless of whether WP and TP (or OOD detection) are defined explicitly or implicitly by a CIL algorithm, the existence of a good WP and the existence of a good TP or OOD detection are necessary conditions for a good CIL performance.

Remark 5.

It is important to note again that our study in this section is based on a CIL model that has already been built. In other words, our study tells the CIL designers what should be achieved in the final model. Clearly, one would also like to know how to design a strong CIL model based on the theoretical results, which also considers catastrophic forgetting (CF). One effective method is to make use of a strong existing TIL algorithm, which can already achieve no or little forgetting (CF), and combine it with a strong OOD detection algorithm (as mentioned earlier, most OOD detection methods can also perform WP). Thus, any improved method from the OOD detection community can be applied to CIL to produce improved CIL systems (see Sections 4.3 and 4.4).

Recall in Section 2, we reviewed prior works that have tried to use a TIL method for CIL with a task-id prediction method von2019continual; Aljundi2016expert; rajasegaran2020itaml; abati2020conditional; henning2021posterior. However, since they did not know that the key to the success of this approach is a strong OOD detection algorithm, they are quite weak (see Section 4).

4 New CIL Techniques and Experiments

Based on Theorem 3, we have designed several new CIL methods, each of which integrates an existing CL algorithm and an OOD detection algorithm. The OOD detection algorithm that we use can perform both within-task IND prediction (WP) and OOD detection. Our experiments have two goals: (1) to show that a good OOD detection method can help improve the accuracy of an existing CIL algorithm, and (2) to fully compare two of these methods (see some others in Sec. 4.5) with strong baselines to show that they outperform the existing strong baselines considerably.

4.1 Datasets, CL Baselines and OOD Detection Methods

Datasets and CIL Tasks. Four popular benchmark image classification datasets are used, from which six CIL problems are created following recent papers Liu2020; NEURIPS2020_b704ea2c_derpp; Zhu_2021_CVPR_pass. (1) MNIST consists of handwritten images of 10 digits with 60,000/10,000 training/testing samples. We create a CIL problem (M-5T) of 5 tasks with 2 consecutive classes/digits as a task. (2) CIFAR-10 consists of 32x32 color images of 10 classes with 50,000/10,000 training/testing samples. We create a CIL problem (C10-5T) of 5 tasks with 2 consecutive classes as a task. (3) CIFAR-100 consists of 60,000 32x32 color images with 50,000/10,000 training/testing samples. We create two CIL problems by splitting 100 classes into 10 tasks (C100-10T) and 20 tasks (C100-20T), where each task has 10 and 5 classes, respectively. (4)

Tiny-ImageNet

has 120,000 64x64 color images of 200 classes with 500/50 images per class for training/testing. We create two CIL problems by splitting 200 classes into 5 tasks (T-5T) and 10 tasks (T-10T), where each task has 40 and 20 classes, respectively.

Baseline CL Methods. We include different families of CL methods: regularization, replay, orthogonal projection, and parameter isolation. MUC Liu2020 and PASS Zhu_2021_CVPR_pass are regularization-based methods. For replay methods, we use LwF Li2016LwF, iCaRL Rebuffi2017, Mnemonics Liu_2020_CVPR, BiC wu2019large, DER++ NEURIPS2020_b704ea2c_derpp, and CoCha_2021_ICCV_co2l. For orthogonal projection, we use OWM zeng2019continuous. Finally, for parameter isolation, we use CCG abati2020conditional, HyperNet von2019continual, HAT Serra2018overcoming, SupSup supsup2020 (Sup), and PR henning2021posterior.222iTAML rajasegaran2020itaml is not included as it requires a batch of test data from the same task to predict the task-id. When each batch has only one test sample, which is our setting, it is very weak. For example, its CIL accuracy is only 33.5% on C100-10T. Expert Gate (EG) Aljundi2016expert is also very weak. Its CIL accuracy is only 43.2 on M-5T. They are much weaker than many baselines. DER yan2021dynamically is not included as it expands the network after each task, which is somewhat unfair to other systems as all others do not expand the network. DER can generate a large number of parameters after the last task, e.g., 117.6 millions (M) for C100-20T while our proposed methods require 44.6M (HAT+CSI) and 11.6M (Sup+CSI) (refer to Appendix H). The average accuracy of DER over the 6 CL experiments is 61.4 while our methods achieve 67.9 (HAT+CSI+c) and 64.9 (Sup+CSI+c) (refer to Tab. 3). We use the official codes for the baselines except for , CCG, and PR. For these three systems, we copy the results from their papers as the code for CCG is not released and we are unable to run and PR on our machines.

OOD Detection Methods. Two OOD detection methods are used. We combine them with the above existing CL algorithms. Both these methods can also perform within-task IND prediction (WP).

(1). ODIN: Researchers have proposed several methods to improve the OOD detection performance of a trained network by post-processing liang2018enhancing; liu2020energy; lee2018simple_md. ODIN liang2018enhancing is a representative method. It adds perturbation to input and applies a temperature scaling to the softmax output of a trained network.

(2). CSI: It is a recently proposed OOD detection technique tack2020csi that is highly effective. It is based on data and class label augmentation and supervised contrastive learning khosla2020supervised. Its rotation data augmentations create distributional shifted samples to act as negative data for the original samples for contrastive learning. The details of CSI is given in Appendix D.

4.2 Training Details and Evaluation Metrics

Training Details. For the backbone structure, we follow supsup2020; Zhu_2021_CVPR_pass; NEURIPS2020_b704ea2c_derpp. AlexNet-like architecture NIPS2012_c399862d_alexnet is used for MNIST and ResNet-18 he2016deep

is used for CIFAR-10. For CIFAR-100 and Tiny-ImageNet, ResNet-18 is also used as CIFAR-10, but the number of channels are doubled to fit more classes. All the methods use the same backbone architecture except for OWM and HyperNet, for which we use their original architectures. OWM uses AlexNet. It is not obvious how to apply the technique to the ResNet structure. HyperNet uses a fully-connected network and ResNet-32 for MNIST and other datasets, respectively. We are unable to change the structure due to model initialization arguments unexplained in the original paper. For the replay methods, we use memory buffer 200 for MNIST and CIFAR-10 and 2000 for CIFAR-100 and Tiny-ImageNet as in

Rebuffi2017; NEURIPS2020_b704ea2c_derpp. We use the hyper-parameters suggested by the authors. If we could not reproduce any result, we use 10% of the training data as a validation set to grid-search for good hyper-parameters. For our proposed methods, we report the hyper-parameters in Appendix G. All the results are averages over 5 runs with random seeds.

(1). Average classification accuracy over all classes after learning the last task. The final class prediction depends prediction methods (see below). We also report forgetting rate in Appendix J.

(2). Average AUC (Area Under the ROC Curve) over all task models for the evaluation of OOD detection. AUC is the main measure used in OOD detection papers. Using this measure, we show that a better OOD detection method will result in a better CIL performance. Let be the AUC score of task . It is computed by using only the model (or classes) of task to score the test data of task as the in-distribution (IND) data and the test data from other tasks as the out-of-distribution (OOD) data. The average AUC score is: , where is the number of tasks.

It is not straightforward to change existing CL algorithms to include a new OOD detection method that needs training, e.g., CSI, except for TIL (task incremental learning) methods, e.g., HAT and Sup. For HAT and Sup, we can simply switch their methods for learning each task with CSI (see Sec.4.4).

Prediction Methods. The theoretical result in Sec. 3 states that we use Eq. 2 to perform the final prediction. The first probability (WP) in Eq. 2 is easy to get as we can simply use the softmax values of the classes in each task. However, the second probability (TP) in Eq. 2 is tricky as each task is learned without the data of other tasks. There can be many options. We take the following approaches for prediction (which are a special case of Eq. 2, see below):

(1). For those approaches that use a single classification head to include all classes learned so far, we predict as follows (which is also the approach taken by the existing papers.)

(8)

where is the logit output of the network.

(2). For multi-head methods (e.g., HAT, HyperNet, and Sup), which use one head for each task, we use the concatenated output as

(9)

where indicate concatenation and is the output of task .333The Sup paper proposed an one-shot task-id prediction assuming that the test instances come in a batch and all belong to the same task like iTAML. We assume a single test instance per batch. Its task-id prediction results in accuracy of 50.2 on C10-5T, which is much lower than 62.6 by using Eq. 9. The task-id prediction of HyperNet also works poorly. The accuracy by its id prediction is 49.34 on C10-5T while it is 53.4 using Eq. 9. PR uses entropy to find task-id. Among many variations of PR, we use the variations that perform the best for each dataset with exemplar-free and single sample per batch at testing (i.e., no PR-BW).

These methods (in fact, they are the same method used in two different settings) is a special case of Eq. 2 if we define as , where is the sigmoid. Hence, the theoretical results in Sec. 3 are still applicable. We present a detailed explanation about this prediction method and some other options in Appendix C. These two approaches work quite well.

4.3 Better OOD Detection Produces Better CIL Performance

The key theoretical result in Sec. 3 is that better OOD detection will produce better CIL performance. Recall our considered methods ODIN and CSI can perform both WP and OOD detection.

Applying ODIN. We first train the baseline models using their original algorithms, and then apply temperature scaling and input noise of ODIN at testing for each task (no training data needed). More precisely, the output of class in task changes by temperature scaling factor of task as

(10)

and the input changes by the noise factor as

(11)

where is the class with the maximum output value in task . This is a positive adversarial example inspired by goodfellow2015explaining. The values and are hyper-parameters and we use the same values for all tasks except for PASS, for which we had to use a validation set to tune (see Appendix B).

Method OOD AUC CIL
OWM Original 71.31 28.91
ODIN 70.06 28.88
MUC Original 72.69 30.42
ODIN 72.53 29.79
PASS Original 69.89 33.00
ODIN 69.60 31.00
LwF Original 88.30 45.26
ODIN 87.11 51.82
BiC Original 87.89 52.92
ODIN 86.73 48.65
DER++ Original 85.99 53.71
ODIN 88.21 55.29
HAT Original 77.72 41.06
ODIN 77.80 41.21
HyperNet Original 71.82 30.23
ODIN 72.32 30.83
Sup Original 79.16 44.58
ODIN 80.58 46.74
Table 1: Performance comparison based on C100-10T between the original output and the output post-processed with OOD detection technique ODIN. Note that ODIN is not applicable to iCaRL and Mnemonics as they are not based on softmax but some distance functions. The results for other datasets are reported in Appendix B.

Tab. 1 gives the results for C100-10T. The CIL results clearly show that the CIL performance increases if the AUC increases with ODIN. For instance, the CIL of DER++ and Sup improves from 53.71 to 55.29 and 44.58 to 46.74, respectively, as the AUC increases from 85.99 to 88.21 and 79.16 to 80.58. It shows that when this method is incorporated into each task model in existing trained CIL network, the CIL performance of the original method improves. We note that ODIN does not always improve the average AUC. For those experienced a decrease in AUC, the CIL performance also decreases except LwF. The inconsistency of LwF is due to its severe classification bias towards later tasks as discussed in BiC wu2019large. The temperature scaling in ODIN has a similar effect as the bias correction in BiC, and the CIL of LwF becomes close to that of BiC after the correction. Regardless of whether ODIN improves AUC or not, the positive correlation between AUC and CIL (except LwF) verifies the efficacy of Theorem 3, indicating better OOD detection results in better CIL performances.

Applying CSI. We now apply the OOD detection method CSI. Due to its sophisticated data augmentation, supervised constrative learning and results ensemble, it is hard to apply CSI to other baselines without fundamentally change them except for HAT and Sup (SupSup) as these methods are parameter isolation-based TIL methods. We can simply replace their model for training each task with CSI wholesale (the full detail is given in Appendix D). As mentioned earlier, both HAT and SupSup as TIL methods have almost no forgetting.

Tab. 2 reports the results of using CSI and ODIN. ODIN is a weaker OOD method than CSI. Both HAT and Sup improve greatly as the systems are equipped with a better OOD detection method CSI. These experiment results empirically demonstrate the efficacy of Theorem 3, i.e., the CIL performance can be improved if a better OOD detection method is used.

CL OOD C10-5T C100-10T C100-20T T-5T T-10T
AUC CIL AUC CIL AUC CIL AUC CIL AUC CIL
HAT ODIN 82.5 62.6 77.8 41.2 75.4 25.8 72.3 38.6 71.8 30.0
CSI 91.2 87.8 84.5 63.3 86.5 54.6 76.5 45.7 78.5 47.1
Sup ODIN 82.4 62.6 80.6 46.7 81.6 36.4 74.0 41.1 74.6 36.5
CSI 91.6 86.0 86.8 65.1 88.3 60.2 77.1 48.9 79.4 45.7
Table 2: Average CIL and AUC of HAT and Sup with OOD detection methods ODIN and CSI. ODIN is a traditional OOD detection method while CSI is a recent OOD detection method known to be better than ODIN. As CL methods produce better OOD detection performance by CSI, their CIL performances are better than the ODIN counterparts.
Method M-5T C10-5T C100-10T C100-20T T-5T T-10T
OWM 95.8

0.13

51.8

0.05

28.9

0.60

24.1

0.26

10.0

0.55

8.6

0.42

MUC 74.9

0.46

52.9

1.03

30.4

1.18

14.2

0.30

33.6

0.19

17.4

0.17

PASS 76.6

1.67

47.3

0.98

33.0

0.58

25.0

0.69

28.4

0.51

19.1

0.46

LwF 85.5

3.11

54.7

1.18

45.3

0.75

44.3

0.46

32.2

0.50

24.3

0.26

iCaRL 96.0

0.43

63.4

1.11

51.4

0.99

47.8

0.48

37.0

0.41

28.3

0.18

Mnemonics 96.3

0.36

64.1

1.47

51.0

0.34

47.6

0.74

37.1

0.46

28.5

0.72

BiC 94.1

0.65

61.4

1.74

52.9

0.64

48.9

0.54

41.7

0.74

33.8

0.40

DER++ 95.3

0.69

66.0

1.20

53.7

1.30

46.6

1.44

35.8

0.77

30.5

0.47

CoL 65.6
CCG 97.3 70.1
HAT 81.9

3.74

62.7

1.45

41.1

0.93

25.6

0.51

38.5

1.85

29.8

0.65

HyperNet 56.6

4.85

53.4

2.19

30.2

1.54

18.7

1.10

7.9

0.69

5.3

0.50

Sup 70.1

1.51

62.4

1.45

44.6

0.44

34.7

0.30

41.8

1.50

36.5

0.36

PR-Ent 74.1 61.9 45.2
HAT+CSI 94.4

0.26

87.8

0.71

63.3

1.00

54.6

0.92

45.7

0.26

47.1

0.18

Sup+CSI 80.7

2.71

86.0

0.41

65.1

0.39

60.2

0.51

48.9

0.25

45.7

0.76

HAT+CSI+c 96.9

0.30

88.0

0.48

65.2

0.71

58.0

0.45

51.7

0.37

47.6

0.32

Sup+CSI+c 81.0

2.30

87.3

0.37

65.2

0.37

60.5

0.64

49.2

0.28

46.2

0.53

Table 3: Average accuracy after all tasks are learned. Exemplar-free methods are italicized. indicates that in their original papers, PASS and Mnemonics are pre-trained with the first half of the classes. Their results with pre-train are 50.1 and 53.5 on C100-10T, respectively, which are still much lower than the proposed HAT+CSI and Sup+CSI without pre-training. We do not use pre-training in our experiment for fairness. indicates that iCaRL and Mnemonics report average incremental accuracy in their original papers. We report average accuracy over all classes after all tasks are learned.

4.4 Full Comparison of HAT+CSI and Sup+CSI with Baselines

We now make a full comparison of the two strong systems (HAT+CSI and Sup+CSI) designed based on the theoretical results. These combinations are particularly attractive because both HAT and Sup are TIL systems and have little or no CF. Then a strong OOD method (that can also perform WP (within-task/IND prediction) will result in a strong CIL method. Since HAT and Sup are exemplar-free CL methods, HAT+CSI and Sup+CSI also do not need to save any previous task data. Tab. 3 shows that HAT and Sup equipped with CSI outperform the baselines by large margins. DER++, the best replay method, achieves 66.0 and 53.7 on C10-5T and C100-10T, respectively, while HAT+CSI achieves 87.8 and 63.3 and Sup+CSI achieves 86.0 and 65.1. The large performance gap remains consistent in more challenging problems, T-5T and T-10T. We note that Sup works very poorly on M-5T, but Sup+CSI improved it drastically, although still very weak compared to HAT+CSI.

Due to the definition of OOD in the prediction method and the fact that each task is trained separately in HAT and Sup, the outputs from different tasks can be in different scales, which will result in incorrect predictions. To deal with the problem, we can calibrate the output as and use . The optimal and for each task can be found by optimization with a memory buffer to save a very small number of training examples from previous tasks like that in the replay-based methods. We refer the calibrated methods as HAT+CSI+c and Sup+CSI+c. They are trained by using the memory buffer of the same size as the replay methods (see Sec. 4.2). Tab. 3 shows that the calibration improves from their memory free versions, i.e., without calibration. We provide the details about how to train the calibration parameters and in Appendix E.

As shown in Theorem 1, the CIL performance also depends on the TIL (WP) performance. We compare the TIL accuracies of the baselines and our methods in Tab. 4. Our systems again outperform the baselines by large margins on more challenging datasets (e.g., CIFAR100 and Tiny-ImageNet).

Method M-5T C10-5T C100-10T C100-20T T-5T T-10T
DER++ 99.7

0.08

92.0

0.54

84.0

9.43

86.6

9.44

57.4

1.31

60.0

0.74

HAT 99.9

0.02

96.7

0.18

84.0

0.23

85.0

0.98

61.2

0.72

63.8

0.41

Sup 99.6

0.01

96.6

0.21

87.9

0.27

91.6

0.15

64.3

0.24

68.4

0.22

HAT+CSI 99.9

0.00

98.7

0.06

92.0

0.37

94.3

0.06

68.4

0.16

72.4

0.21

Sup+CSI 99.0

0.08

98.7

0.07

93.0

0.13

95.3

0.20

65.9

0.25

74.1

0.28

Table 4: TIL (WP) results of 3 best performing baselines and our methods. The full results are given in Appendix F. The calibrated versions (+c) of our methods are omitted as calibration does not affect TIL performances.

4.5 Implications for Existing CL Methods, Open-World Learning and Future Research

Implication for regularization and replay methods. Regularization-based (exemplar-free) methods try to protect important parameters of old tasks to mitigate CF. However, since the training of each task does not consider OOD detection, TP will be weak, which causes difficulty for inter-task class separation (ICS) and thus low CL accuracy. Replay-based methods are better as the replay data from old tasks can be naturally regarded as OOD data for the current task, then a better OOD model is built, which improves TP. However, since the replay data is small. the OOD model is sub-optimal, especially for earlier tasks as their training cannot see any future task data. Thus for both approaches, it will be beneficial to consider CF and OOD together in learning each task (e.g., kim2022multi).

Implication for open-world learning. Since our theory says that CL needs OOD detection, and OOD detection is also the first step in open-world learning (OWL), CL and OWL naturally work together to achieve self-motivated open-world continual learning liu2021self

for autonomous learning or AI autonomy. That is, the AI agent can continually discover new tasks (OOD detection) and incrementally learn the tasks (CL) all on its own with no involvement of human engineers. Further afield, this work is also related to curiosity-driven self-supervised learning 

pathak2017curiosity

in reinforcement learning and 3D navigation.

Limitation and future work. The proposed theory provides a principled guidance on what needs to be done in order to achieve good CIL results, but it gives no guidance on how to do it. Although two example techniques are presented and evaluated, they are empirical. There are many options to define WP and TP (or OOD). An idea in (guo2022online) may be helpful in this regard. (guo2022online) argues that a continual learner should learn holistic feature representations of the input data, meaning to learn as many features as possible from the input data. The rationale is that if the system can learn all possible features from each task, then a future task does not have to learn those shared/intersecting features by modifying the parameters, which will result in less CF and also better ICS. A full representation of the IND data also improves OOD detection because the OOD score of a data point is basically the distance between the data point and the IND distribution. Only capturing a subset of features (e.g., by cross entropy) will result in poor OOD detection hu2020hrn because those missing features may be necessary to separate IND and some OOD data. In our future work, we will study how to optimize WP and TP/OOD and find the necessary conditions for them to do well.

5 Conclusion

This paper proposed a theoretical study on how to solve the highly challenging continual learning (CL) problem. class incremental learning (CIL) (the other popular CL setting is task incremental learning (TIL)). The theoretical result provides a principled guidance for designing better CIL algorithms. The paper first decomposed the CIL prediction into within-task prediction (WP) and task-id prediction (TP). WP is basically TIL. The paper further theoretically demonstrated that TP is correlated with out-of-distribution (OOD) detection. It then proved that a good performance of the two is both necessary and sufficient for good CIL performances. Based on the theoretical result, several new CIL methods were designed. They outperform strong baselines in CIL and also in TIL by a large margin. Finally, we also discussed the implications for existing CL techniques and open-world learning.

Acknowledgments

The work of Gyuhak Kim, Zixuan Ke and Bing Liu was supported in part by a research contract from KDDI, two NSF grants (IIS-1910424 and IIS-1838770), and a DARPA contract HR001120C0023.

References

Appendix A Proof of Theorems and Corollaries

a.1 Proof of Theorem 1

Proof.

Since

and

we have

a.2 Proof of Corollary 1.

Proof.

By proof of Theorem 1, we have

Taking expectations on both sides, we have i)

and ii)

a.3 Proof of Theorem 2.

Proof.

i) Assume .

For , we have

For , we have

ii) Assume .

For , by , we have

which means

For , by , we have

which means

Therefore, we have

Hence,

a.4 Proof of Theorem 3.

Proof.

Using Theorem 1 and 2,

a.5 Proof of Theorem 4.

Proof.

i) Assume .

Define .

According to proof of Theorem 1,

Hence, we have

ii) Assume .

Define .

According to proof of Theorem 1,

Hence, we have

iii) Assume .

Define .

According to proof of Theorem 4 ii), we have

According to proof of Theorem 2 i), we have

Therefore,

Appendix B Additional Results and Explanation Regarding Table 1 in the Main Paper

In Sec. 4.3, we showed that a better OOD detection improves CIL performance. For the post-processing method ODIN, we only reported the results on C100-10T due to space limitations. Tab. 5 shows the results on the other datasets.

M-5T C10-5T C100-20T T-5T T-10T
Method OOD AUC CIL AUC CIL AUC CIL AUC CIL AUC CIL
OWM Original 99.13 95.81 81.33 51.79 71.90 24.15 58.49 10.00 59.48 8.57
ODIN 98.86 95.16 71.72 40.65 68.52 23.05 58.46 10.77 59.38 9.52
MUC Original 92.27 74.90 79.49 52.85 66.20 14.19 68.42 33.57 62.63 17.39
ODIN 92.67 75.71 79.54 53.22 65.72 14.11 68.32 33.45 62.17 17.27
PASS Original 98.74 76.58 66.51 47.34 70.26 24.99 65.18 28.40 63.27 19.07
ODIN 90.40 74.33 63.08 35.20 69.81 21.83 65.93 29.03 62.73 17.78
LwF Original 99.19 85.46 89.39 54.67 89.84 44.33 78.20 32.17 79.43 24.28
ODIN 98.52 90.39 88.94 63.04 88.68 47.56 76.83 36.20 77.02 28.29
BiC Original 99.40 94.11 90.89 61.41 89.46 48.92 80.17 41.75 80.37 33.77
ODIN 98.57 95.14 91.86 64.29 87.89 47.40 74.54 37.40 76.27 29.06
DER++ Original 99.78 95.29 90.16 66.04 85.44 46.59 71.80 35.80 72.41 30.49
ODIN 99.09 94.96 87.08 63.07 87.72 49.26 73.92 37.87 72.91 32.52
HAT Original 94.46 81.86 82.47 62.67 75.35 25.64 72.28 38.46 71.82 29.78
ODIN 94.56 82.06 82.45 62.60 75.36 25.84 72.31 38.61 71.83 30.01
HyperNet Original 85.83 56.55 78.54 53.40 72.04 18.67 54.58 7.91 55.37 5.32
ODIN 86.89 64.31 79.39 56.72 73.89 23.8 54.60 8.64 55.53 6.91
Sup Original 90.70 70.06 79.16 62.37 81.14 34.70 74.13 41.82 74.59 36.46
ODIN 90.68 69.70 82.38 62.63 81.48 36.35 73.96 41.10 74.61 36.46
Table 5: Performance comparison between the original output and output post-processed with OOD detection technique ODIN. Note that ODIN is not applicable to iCaRL and Mnemonics as they are not based on softmax but some distance functions. The result for C100-10T are reported in the main paper.

A continual learning method with a better AUC shows a better CIL performance than other methods with lower AUC. For instance, original HAT achieves AUC of 82.47 while HyperNet achieves 78.54 on C10-5T. The CIL for HAT is 62.67 while it is 53.40 for HyperNet. However, there are some exceptions that this comparison does not hold. An example is LwF. Its AUC and CIL are 89.39 and 54.67 on C10-5T. Although its AUC is better than HAT, the CIL is lower. This is due to the fact that CIL improves with WP and TP according to Theorem 1. The contraposition of Theorem 4 also says if the cross-entropy of TIL is large, that of CIL is also large. Indeed, the average within-task prediction (WP) accuracy for LwF on C10-5T is 95.2 while the same for HAT is 96.7. Improving WP is also important in achieving good CIL performances.

For PASS, we had to tune using a validation set. This is because the softmax in Eq. 10 improves AUC by making the IND (in-distribution) and OOD scores more separable within a task, but deteriorates the final scores across tasks. To be specific, the test instances are predicted as one of the classes in the first task after softmax because the relative values between classes in task 1 is larger than the other tasks in PASS. Therefore, larger and smaller , for , are chosen to compensate the relative values.

Appendix C Definitions of TP

As noted in the main paper, the class prediction in Eq. 2 varies by definition of WP and TP. The precise definition of WP and TP depends on implementation. Due to this subjectivity, we follow the prediction method as the existing methods in continual learning, which is the over the output. In this section, we show that the over output is a special case of Eq. 2. We also provide CIL results using different definitions of TP.

We first establish another theorem. This is an extension of Theorem 2 and connects the standard prediction method to our analysis.

Theorem 5 (Extension of Theorem 2).

i) If , let , , then .

ii) If , let , , then , where is an indicator function.

In Theorem 5 (proof appears later), we can observe that decreases with the increase of , while increases. Hence, when TP is given, let , we can find the optimal to define OOD by solving . Similarly, given OOD, let , we can find the optimal to define TP by finding the global minima of . The optimal can be found using a memory buffer to save a small number of previous data like that in a replay-based continual learning method.

In Theorem 5 (ii), let